4 from IMP
import ArgumentParser
6 __doc__ =
"Run ML prediction given an input pickle."
9 other_columns = [
"EMDB",
"resolution",
"pdbname",
"chain",
"resid",
11 data_feats = [
"EMDB",
"resolution",
"pdbname",
"chain",
"resid",
13 other_features = [
"resolution"]
17 param_MAX_EPOCHS = 100
18 param_VAL_SPLIT = 0.25
25 def send_chunk(df, counter, chunk_size=20000):
26 df_chunk = df[counter:counter+chunk_size]
30 def get_correct_probs(df, probs):
33 resnames = [
"p_"+res
for res
in df[
"resname"].values]
35 for j, row
in probs.iterrows():
37 prob_list.append(row[resnames[i]])
39 return pd.Series(prob_list)
42 def predict(pickle_file, chunk_size):
43 from tensorflow
import keras
44 import tensorflow
as tf
46 from .data_v2_for_parts
import (reshape_df, split_image_and_other_features,
50 gpus = tf.config.experimental.list_physical_devices(
'GPU')
52 tf.config.experimental.set_memory_growth(gpu,
True)
57 spl_input_file = pickle_file
58 print(
"Begin", flush=
True)
59 indat = ResidueVoxelDataset(
60 spl_input_file, target=
"resname", train_features=[
"resolution"])
61 print(
"Time to open data:"+str(time.time()-t0), flush=
True)
62 print(
"Database Size:", len(indat.data_df))
64 zb = indat.get_binarizer()
65 pred_classes = [
"p_"+c
for c
in zb.classes_]
67 print(indat.binarize_ss(indat.data_df))
71 strategy = tf.distribute.MirroredStrategy()
73 print(
'Number of devices: {}'.format(strategy.num_replicas_in_sync))
75 with strategy.scope():
77 from .
import get_data_path
79 "finalmodel_no_amino_weights/finalmodel.h5")
83 if checkpoint_path
is not None and test_mode:
85 print(
"Loading Model .....")
86 model = keras.models.load_model(checkpoint_path)
88 train_data, test_data, val_data = indat.get_train_test_val_sets(
90 score = model.evaluate(
91 [test_data[
"image_features"], test_data[
"other_features"]],
92 zb.transform(test_data[
"target"]))
93 print(
"test results: ", score)
95 elif checkpoint_path
is not None and evaluate_mode:
97 print(
"Loading Model .....")
98 model = keras.models.load_model(checkpoint_path)
100 data_set_size = indat.data_df.shape[0]
102 for row_num
in range(0, data_set_size, chunk_size):
103 print(indat.data_df[
"H"].head(5))
104 print(indat.data_df[
"S"].head(5))
105 Ximageall, Ximageother = split_image_and_other_features(
106 indat.data_df[row_num:row_num+chunk_size],
107 other_columns=[
"resolution",
"H",
"S"])
108 Xinall = reshape_df(Ximageall, (14, 14, 14))
109 probsall = model.predict([Xinall, Ximageother])
110 df_preds.append(pd.DataFrame(data=probsall,
111 columns=pred_classes))
112 df_pred = pd.concat(df_preds)
114 print(
"No model to load, please train your model first")
118 df_oc = indat.data_df[data_feats]
121 print(
"OC columns", len(df_oc.columns), len(data_feats))
124 base_name = os.path.splitext(pickle_file)[0]
125 output_file = f
"{base_name}_ML_prob.dat"
126 df_oc.loc[:, (
"c_prob",)] = get_correct_probs(indat.data_df, df_pred)
129 df_oc.reset_index(drop=
True, inplace=
True)
130 df_pred.reset_index(drop=
True, inplace=
True)
131 df_out = pd.concat([df_oc, df_pred], axis=1)
132 df_out.to_csv(output_file, sep=
" ", index=
False)
136 parser = ArgumentParser(
137 description=
"Run ML prediction given an input pickle")
138 parser.add_argument(
"pickle_file", help=
"Input Python pickle file")
139 parser.add_argument(
"chunk_size", type=int,
140 help=
"Chunk size for image splitting")
141 return parser.parse_args()
146 predict(args.pickle_file, args.chunk_size)
149 if __name__ ==
'__main__':
def get_data_path
Return the full path to one of this module's data files.