1 def get_model_memory_usage(batch_size, model):
4 from keras
import backend
as K
6 from tensorflow.keras
import backend
as K
9 internal_model_mem_count = 0
10 for layer
in model.layers:
11 layer_type = layer.__class__.__name__
12 if layer_type ==
'Model':
13 internal_model_mem_count += get_model_memory_usage(batch_size,
16 out_shape = layer.output_shape
17 if type(out_shape)
is list:
18 out_shape = out_shape[0]
23 shapes_mem_count += single_layer_mem
25 trainable_count = np.sum([K.count_params(p)
26 for p
in model.trainable_weights])
27 non_trainable_count = np.sum([K.count_params(p)
28 for p
in model.non_trainable_weights])
31 if K.floatx() ==
'float16':
33 if K.floatx() ==
'float64':
36 total_memory = number_size * (batch_size * shapes_mem_count
37 + trainable_count + non_trainable_count)
38 gbytes = (np.round(total_memory / (1024.0 ** 3), 3)
39 + internal_model_mem_count)