IMP logo
IMP Reference Guide  develop.330bebda01,2025/01/21
The Integrative Modeling Platform
write_output.py
1 """@namespace IMP.spatiotemporal.write_output
2  Functions to write spatiotemporal graph information to files.
3 """
4 
5 import numpy as np
6 try:
7  from graphviz import Digraph
8 except ImportError:
9  Digraph = None
10 try:
11  try:
12  from matplotlib import colormaps as cm # matplotlib 3.7+
13  except ImportError:
14  from matplotlib import cm
15  from matplotlib import colors as clr
16 except ImportError:
17  cm = None
18  clr = None
19 
20 
21 # Text / probability output
22 
23 def write_cdf(out_cdf, cdf_fn, graph_prob):
24  """
25  Function to output the cumulative distribution function (cdf)
26 
27  @param out_cdf: bool, writes cdf if true
28  @param cdf_fn: str, filename of cdf
29  @param graph_prob: list of probabilities for each path, (path_prob
30  from score_graph())
31  """
32  if out_cdf:
33  cdf = np.cumsum(np.flip(np.sort(graph_prob), axis=0))
34  np.savetxt(cdf_fn, cdf)
35 
36 
37 def write_pdf(out_pdf, pdf_fn, graph_prob):
38  """
39  Function to output the probability distribution function (pdf)
40  @param out_pdf: bool, writes pdf if true
41  @param pdf_fn: str, filename of pdf
42  @param graph_prob: list of probabilities for each path, (path_prob
43  from score_graph())
44  """
45  if out_pdf:
46  pdf = np.flip(np.sort(graph_prob), axis=0)
47  np.savetxt(pdf_fn, pdf)
48 
49 
50 def write_labeled_pdf(out_labeled_pdf, labeled_pdf_fn, graph, graph_prob):
51  """
52  Function to output the labeled probability distribution function (pdf)
53  @param out_labeled_pdf: bool, writes labeled_pdf if true
54  @param labeled_pdf_fn: str, filename of labeled_pdf
55  @param graph: list of graphNode objects visited for each path,
56  (all_paths from score_graph())
57  @param graph_prob: list of probabilities for each path, (path_prob
58  from score_graph())
59  """
60  if out_labeled_pdf:
61  # open file
62  new = open(labeled_pdf_fn, 'w')
63  new.write('#\tPath\t\tpdf\n')
64  # loop over all paths in the graph
65  for i in range(0, len(graph_prob)):
66  # get index for the ith most likely path
67  pdf_index = np.flip(np.argsort(graph_prob), axis=0)[i]
68  path = graph[pdf_index]
69  # get all labels / time for the ith most likely path
70  all_labels = ''
71  for node in path:
72  all_labels += node.get_label() + '_' + node.get_time() + '|'
73  # write that path to a new file
74  new.write(all_labels + '\t' + str(graph_prob[pdf_index]) + '\n')
75  new.close()
76 
77 
78 def write_final_npaths(npaths, npath_fn, graph_scores, graph_prob):
79  """
80  Function to output a file with all states for each of the n most likely
81  paths
82 
83  @param npaths: int, number of paths to output
84  @param npath_fn: str, name of the file for all paths
85  @param graph_scores: list of tuples, where the first object is the path
86  (list of graphNode objects for each state along the trajectory),
87  and the second object is the score of the path, which can be used
88  to calculate the probability. (path_scores from score_graph())
89  @param graph_prob: list of probabilities for each path, (path_prob from
90  score_graph())
91  """
92  # loop over npaths
93  for i in range(-1, -1 * npaths - 1, -1):
94  path = []
95  # get index for sorted probability
96  m = np.argsort(graph_prob)[i]
97  # go to that index and grab the path
98  for node in graph_scores[m][0]:
99  # append times not yet in the path
100  if node.get_time() not in path:
101  path.append(node.get_label() + '_' + node.get_time())
102 
103  # save to new file
104  with open(npath_fn + str(abs(i)) + ".txt", "w") as fh:
105  for statename in path:
106  fh.write(statename + "\n")
107 
108 
109 # Rendering DAG
110 def draw_dag_in_graphviz(nodes, coloring=None, draw_label=True,
111  fontname="Helvetica", fontsize="18", penscale=0.6,
112  arrowsize=1.2, height="0.6", width="0.6"):
113  """Draw a DAG representation in graphviz and return the resulting Digraph.
114  Takes a list of graphNodes and initializes the nodes and edges.
115  Coloring is expected to be a list of RGBA strings specifying how to color
116  each node. Expected to be same length as nodes.
117 
118  @param nodes: list of graphNode objects
119  @param coloring: list of RGBA strings to specify the color of each node.
120  Expected to be the same length as nodes
121  @param draw_label: bool, whether or not to draw graph labels
122  @param fontname: string, name of font for graph labels
123  @param fontsize: string, size of font for graph labels
124  @param penscale: float, size of pen
125  @param arrowsize: float, size of arrows
126  @param height: string, height of nodes
127  @param width: string, width of nodes
128  @return dot: Digraph object to be rendered
129  """
130 
131  if Digraph is None:
132  raise Exception(
133  "graphviz not available, will not be able to draw graph")
134  else:
135  # create a dot object for the graph
136  dot = Digraph(format="eps", engine="dot")
137  dot.attr(ratio="1.5")
138  dot.attr(rotate="0")
139 
140  for ni, node in enumerate(nodes):
141  if coloring is not None:
142  color = coloring[ni]
143  else:
144  color = "#ffffff"
145 
146  if draw_label:
147  dot.node(str(node), label=node.get_label(), style="filled",
148  fillcolor=color, fontname=fontname, fontsize=fontsize,
149  height=height, width=width)
150  else:
151  dot.node(str(node), label=' ', style="filled",
152  fillcolor=color, fontname=fontname, fontsize=fontsize,
153  height=height, width=width)
154 
155  for ni, node in enumerate(nodes):
156  edges = node.get_edges()
157  for edge in edges:
158  dot.edge(str(node),
159  str(edge),
160  arrowsize=str(arrowsize),
161  color="black",
162  penwidth=str(penscale))
163 
164  return dot
165 
166 
167 # first set of parameters are required and determine the connectivity
168 # of the map
169 def draw_dag(dag_fn, nodes, paths, path_prob, keys,
170  # 2nd set of parameters are for rendering the heatmap
171  heatmap=True, colormap="Purples", penscale=0.6, arrowsize=1.2,
172  fontname="Helvetica", fontsize="18", height="0.6", width="0.6",
173  draw_label=True):
174  """
175  Function to render the DAG with heatmap information.
176  @param dag_fn: string, filename path
177  @param nodes: list of graphNode objects for which the graph will be drawn
178  @param paths: list of lists containing all paths visited by the graphNode
179  objects
180  @param path_prob: list of probabilities for each path, (path_prob from
181  score_graph())
182  @param keys: states visited in the graph (list of keys to the state_dict)
183  @param heatmap: Boolean to determine whether or not to write the dag with
184  a heatmap based on the probability of each state (default: True)
185  @param colormap: string, colormap used by the dag to represent probability.
186  Chooses from those available in matplotlib
187  (https://matplotlib.org/stable/users/explain/colors/colormaps.html)
188  (default: "Purples").
189  @param penscale: float, size of the pen used to draw arrows on the dag
190  @param arrowsize: float, size of arrows connecting states on the dag
191  @param fontname: string, font used for the labels on the dag
192  @param fontsize: string, font size used for the labels on the dag
193  @param height: string, height of each node on the dag
194  @param width: string, width of each node on the dag
195  @param draw_label: Boolean to determine whether or not to draw state
196  labels on the dag
197  """
198 
199  # determines if heatmap will be overlaid on top of DAG
200  if heatmap:
201 
202  if cm is None or clr is None:
203  raise Exception(
204  "matplotlib not available, will not be able to draw graph")
205  else:
206 
207  default_cmap = cm.get_cmap(colormap)
208 
209  # make a list of counts for each node to color
210  coloring = np.zeros(len(nodes), dtype=float)
211  for path, p in zip(paths, path_prob):
212  for n in path:
213  coloring[int(n.get_index())] += 1 * p
214 
215  # normalize probability
216  for t in keys:
217  b = np.array([t == n.get_time() for n in nodes])
218  coloring[b] /= coloring[b].sum()
219 
220  # convert probability to colors
221  cmap_colors = [clr.to_hex(default_cmap(color))
222  for color in coloring]
223 
224  dot = draw_dag_in_graphviz(
225  nodes, coloring=cmap_colors, draw_label=draw_label,
226  fontname=fontname, fontsize=fontsize, penscale=penscale,
227  arrowsize=arrowsize, height=height, width=width)
228  dot.render(dag_fn)
229 
230  # no heatmap
231  else:
232  dot = draw_dag_in_graphviz(
233  nodes, coloring=None, draw_label=draw_label, fontname=fontname,
234  fontsize=fontsize, penscale=penscale, arrowsize=arrowsize,
235  height=height, width=width)
236  dot.render(dag_fn)
def write_labeled_pdf
Function to output the labeled probability distribution function (pdf)
Definition: write_output.py:50
def write_pdf
Function to output the probability distribution function (pdf)
Definition: write_output.py:37
def draw_dag_in_graphviz
Draw a DAG representation in graphviz and return the resulting Digraph.
def draw_dag
Function to render the DAG with heatmap information.
The general base class for IMP exceptions.
Definition: exception.h:48
def write_final_npaths
Function to output a file with all states for each of the n most likely paths.
Definition: write_output.py:78
def write_cdf
Function to output the cumulative distribution function (cdf)
Definition: write_output.py:23