IMP logo
IMP Reference Guide  develop.50fdd7fa33,2025/09/05
The Integrative Modeling Platform
predict.py
1 import sys
2 import os
3 import time
4 from IMP import ArgumentParser
5 
6 __doc__ = "Run ML prediction given an input pickle."
7 
8 # Non-voxel columns in the input dataframe
9 other_columns = ["EMDB", "resolution", "pdbname", "chain", "resid",
10  "resname", "ss"]
11 data_feats = ["EMDB", "resolution", "pdbname", "chain", "resid",
12  "resname", "ss"]
13 other_features = ["resolution"]
14 
15 # Learning parameters
16 param_BATCH_SIZE = 64
17 param_MAX_EPOCHS = 100
18 param_VAL_SPLIT = 0.25
19 param_LOG_LRATE = -4 # -5 # -6 --> -1 (DIB : changed from -4 to -5)
20 # 1--> MAX_EPOCHS, but should be much less than MAX_EPOCHS
21 param_PATIENCE = 10 # 5
22 param_DROPOUT = 0.0 # 0.1 # 0-->1
23 
24 
25 def send_chunk(df, counter, chunk_size=20000):
26  df_chunk = df[counter:counter+chunk_size]
27  return df_chunk
28 
29 
30 def get_correct_probs(df, probs):
31  import pandas as pd
32  prob_list = []
33  resnames = ["p_"+res for res in df["resname"].values]
34  i = 0
35  for j, row in probs.iterrows():
36  # print(j, i, resnames[i], row[resnames[i]])
37  prob_list.append(row[resnames[i]])
38  i += 1
39  return pd.Series(prob_list)
40 
41 
42 def predict(pickle_file, chunk_size):
43  from tensorflow import keras
44  import tensorflow as tf
45  import pandas as pd
46  from .data_v2_for_parts import (reshape_df, split_image_and_other_features,
47  ResidueVoxelDataset)
48 
49  # Set dynamic memory growth
50  gpus = tf.config.experimental.list_physical_devices('GPU')
51  for gpu in gpus:
52  tf.config.experimental.set_memory_growth(gpu, True)
53 
54  t0 = time.time()
55  ###############################
56  # Import input data
57  spl_input_file = pickle_file
58  print("Begin", flush=True)
59  indat = ResidueVoxelDataset(
60  spl_input_file, target="resname", train_features=["resolution"])
61  print("Time to open data:"+str(time.time()-t0), flush=True)
62  print("Database Size:", len(indat.data_df))
63 
64  zb = indat.get_binarizer()
65  pred_classes = ["p_"+c for c in zb.classes_] # Make column labels
66 
67  print(indat.binarize_ss(indat.data_df))
68 
69  ##################################
70  # Seup multi-GPU strategy
71  strategy = tf.distribute.MirroredStrategy()
72  print("Strategy:")
73  print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
74 
75  with strategy.scope():
76  # Load checkpoint:
77  from . import get_data_path
78  checkpoint_path = get_data_path(
79  "finalmodel_no_amino_weights/finalmodel.h5")
80  # checkpoint_path = None
81  test_mode = False
82  evaluate_mode = True
83  if checkpoint_path is not None and test_mode:
84  # Load model:
85  print("Loading Model .....")
86  model = keras.models.load_model(checkpoint_path)
87  model.summary()
88  train_data, test_data, val_data = indat.get_train_test_val_sets(
89  0.01, 0.98, 0.01)
90  score = model.evaluate(
91  [test_data["image_features"], test_data["other_features"]],
92  zb.transform(test_data["target"]))
93  print("test results: ", score)
94 
95  elif checkpoint_path is not None and evaluate_mode:
96  # Load model:
97  print("Loading Model .....")
98  model = keras.models.load_model(checkpoint_path)
99  model.summary()
100  data_set_size = indat.data_df.shape[0]
101  df_preds = []
102  for row_num in range(0, data_set_size, chunk_size):
103  print(indat.data_df["H"].head(5))
104  print(indat.data_df["S"].head(5))
105  Ximageall, Ximageother = split_image_and_other_features(
106  indat.data_df[row_num:row_num+chunk_size],
107  other_columns=["resolution", "H", "S"])
108  Xinall = reshape_df(Ximageall, (14, 14, 14))
109  probsall = model.predict([Xinall, Ximageother])
110  df_preds.append(pd.DataFrame(data=probsall,
111  columns=pred_classes))
112  df_pred = pd.concat(df_preds)
113  else:
114  print("No model to load, please train your model first")
115  sys.exit()
116 
117  # df_pred = pd.DataFrame(data=probsall, columns=pred_classes)
118  df_oc = indat.data_df[data_feats]
119  print("---")
120  # print(len(Xinall), len(df_oc), len(indat.data_df))
121  print("OC columns", len(df_oc.columns), len(data_feats))
122 
123  # Use the same prefix as the input .pkl
124  base_name = os.path.splitext(pickle_file)[0] # e.g., "3j5r_ML_side"
125  output_file = f"{base_name}_ML_prob.dat"
126  df_oc.loc[:, ("c_prob",)] = get_correct_probs(indat.data_df, df_pred)
127  # df_oc["vavg"] = np.average(Ximageall, axis=1)
128  # df_oc["vstd"] = np.std(Ximageall, axis=1)
129  df_oc.reset_index(drop=True, inplace=True)
130  df_pred.reset_index(drop=True, inplace=True)
131  df_out = pd.concat([df_oc, df_pred], axis=1)
132  df_out.to_csv(output_file, sep=" ", index=False)
133 
134 
135 def parse_args():
136  parser = ArgumentParser(
137  description="Run ML prediction given an input pickle")
138  parser.add_argument("pickle_file", help="Input Python pickle file")
139  parser.add_argument("chunk_size", type=int,
140  help="Chunk size for image splitting")
141  return parser.parse_args()
142 
143 
144 def main():
145  args = parse_args()
146  predict(args.pickle_file, args.chunk_size)
147 
148 
149 if __name__ == '__main__':
150  main()
def get_data_path
Return the full path to one of this module's data files.