1 """@namespace IMP.spatiotemporal.write_output
2 Functions to write spatiotemporal graph information to files.
7 from graphviz
import Digraph
12 from matplotlib
import colormaps
as cm
14 from matplotlib
import cm
15 from matplotlib
import colors
as clr
25 Function to output the cumulative distribution function (cdf)
27 @param out_cdf: bool, writes cdf if true
28 @param cdf_fn: str, filename of cdf
29 @param graph_prob: list of probabilities for each path, (path_prob
33 cdf = np.cumsum(np.flip(np.sort(graph_prob), axis=0))
34 np.savetxt(cdf_fn, cdf)
39 Function to output the probability distribution function (pdf)
40 @param out_pdf: bool, writes pdf if true
41 @param pdf_fn: str, filename of pdf
42 @param graph_prob: list of probabilities for each path, (path_prob
46 pdf = np.flip(np.sort(graph_prob), axis=0)
47 np.savetxt(pdf_fn, pdf)
52 Function to output the labeled probability distribution function (pdf)
53 @param out_labeled_pdf: bool, writes labeled_pdf if true
54 @param labeled_pdf_fn: str, filename of labeled_pdf
55 @param graph: list of graphNode objects visited for each path,
56 (all_paths from score_graph())
57 @param graph_prob: list of probabilities for each path, (path_prob
62 new = open(labeled_pdf_fn,
'w')
63 new.write(
'#\tPath\t\tpdf\n')
65 for i
in range(0, len(graph_prob)):
67 pdf_index = np.flip(np.argsort(graph_prob), axis=0)[i]
68 path = graph[pdf_index]
72 all_labels += node.get_label() +
'_' + node.get_time() +
'|'
74 new.write(all_labels +
'\t' + str(graph_prob[pdf_index]) +
'\n')
80 Function to output a file with all states for each of the n most likely
83 @param npaths: int, number of paths to output
84 @param npath_fn: str, name of the file for all paths
85 @param graph_scores: list of tuples, where the first object is the path
86 (list of graphNode objects for each state along the trajectory),
87 and the second object is the score of the path, which can be used
88 to calculate the probability. (path_scores from score_graph())
89 @param graph_prob: list of probabilities for each path, (path_prob from
93 for i
in range(-1, -1 * npaths - 1, -1):
96 m = np.argsort(graph_prob)[i]
98 for node
in graph_scores[m][0]:
100 if node.get_time()
not in path:
101 path.append(node.get_label() +
'_' + node.get_time())
104 with open(npath_fn + str(abs(i)) +
".txt",
"w")
as fh:
105 for statename
in path:
106 fh.write(statename +
"\n")
111 fontname=
"Helvetica", fontsize=
"18", penscale=0.6,
112 arrowsize=1.2, height=
"0.6", width=
"0.6"):
113 """Draw a DAG representation in graphviz and return the resulting Digraph.
114 Takes a list of graphNodes and initializes the nodes and edges.
115 Coloring is expected to be a list of RGBA strings specifying how to color
116 each node. Expected to be same length as nodes.
118 @param nodes: list of graphNode objects
119 @param coloring: list of RGBA strings to specify the color of each node.
120 Expected to be the same length as nodes
121 @param draw_label: bool, whether or not to draw graph labels
122 @param fontname: string, name of font for graph labels
123 @param fontsize: string, size of font for graph labels
124 @param penscale: float, size of pen
125 @param arrowsize: float, size of arrows
126 @param height: string, height of nodes
127 @param width: string, width of nodes
128 @return dot: Digraph object to be rendered
133 "graphviz not available, will not be able to draw graph")
136 dot = Digraph(format=
"eps", engine=
"dot")
137 dot.attr(ratio=
"1.5")
140 for ni, node
in enumerate(nodes):
141 if coloring
is not None:
147 dot.node(str(node), label=node.get_label(), style=
"filled",
148 fillcolor=color, fontname=fontname, fontsize=fontsize,
149 height=height, width=width)
151 dot.node(str(node), label=
' ', style=
"filled",
152 fillcolor=color, fontname=fontname, fontsize=fontsize,
153 height=height, width=width)
155 for ni, node
in enumerate(nodes):
156 edges = node.get_edges()
160 arrowsize=str(arrowsize),
162 penwidth=str(penscale))
169 def draw_dag(dag_fn, nodes, paths, path_prob, keys,
171 heatmap=
True, colormap=
"Purples", penscale=0.6, arrowsize=1.2,
172 fontname=
"Helvetica", fontsize=
"18", height=
"0.6", width=
"0.6",
175 Function to render the DAG with heatmap information.
176 @param dag_fn: string, filename path
177 @param nodes: list of graphNode objects for which the graph will be drawn
178 @param paths: list of lists containing all paths visited by the graphNode
180 @param path_prob: list of probabilities for each path, (path_prob from
182 @param keys: states visited in the graph (list of keys to the state_dict)
183 @param heatmap: Boolean to determine whether or not to write the dag with
184 a heatmap based on the probability of each state (default: True)
185 @param colormap: string, colormap used by the dag to represent probability.
186 Chooses from those available in matplotlib
187 (https://matplotlib.org/stable/users/explain/colors/colormaps.html)
188 (default: "Purples").
189 @param penscale: float, size of the pen used to draw arrows on the dag
190 @param arrowsize: float, size of arrows connecting states on the dag
191 @param fontname: string, font used for the labels on the dag
192 @param fontsize: string, font size used for the labels on the dag
193 @param height: string, height of each node on the dag
194 @param width: string, width of each node on the dag
195 @param draw_label: Boolean to determine whether or not to draw state
202 if cm
is None or clr
is None:
204 "matplotlib not available, will not be able to draw graph")
207 default_cmap = cm.get_cmap(colormap)
210 coloring = np.zeros(len(nodes), dtype=float)
211 for path, p
in zip(paths, path_prob):
213 coloring[int(n.get_index())] += 1 * p
217 b = np.array([t == n.get_time()
for n
in nodes])
218 coloring[b] /= coloring[b].sum()
221 cmap_colors = [clr.to_hex(default_cmap(color))
222 for color
in coloring]
225 nodes, coloring=cmap_colors, draw_label=draw_label,
226 fontname=fontname, fontsize=fontsize, penscale=penscale,
227 arrowsize=arrowsize, height=height, width=width)
233 nodes, coloring=
None, draw_label=draw_label, fontname=fontname,
234 fontsize=fontsize, penscale=penscale, arrowsize=arrowsize,
235 height=height, width=width)
def write_labeled_pdf
Function to output the labeled probability distribution function (pdf)
def write_pdf
Function to output the probability distribution function (pdf)
def draw_dag_in_graphviz
Draw a DAG representation in graphviz and return the resulting Digraph.
def draw_dag
Function to render the DAG with heatmap information.
The general base class for IMP exceptions.
def write_final_npaths
Function to output a file with all states for each of the n most likely paths.
def write_cdf
Function to output the cumulative distribution function (cdf)