diff --git a/python/gps/algorithm/policy_opt/tf_model_example.py b/python/gps/algorithm/policy_opt/tf_model_example.py index 5259aefcb..479c2a4cf 100644 --- a/python/gps/algorithm/policy_opt/tf_model_example.py +++ b/python/gps/algorithm/policy_opt/tf_model_example.py @@ -119,7 +119,7 @@ def multi_modal_network(dim_input=27, dim_output=7, batch_size=25, network_confi else: x_idx = x_idx + list(range(i, i+dim)) i += dim - + dim_input = i nn_input, action, precision = get_input_layer(dim_input, dim_output) state_input = nn_input[:, 0:x_idx[-1]+1] @@ -193,7 +193,7 @@ def multi_modal_network_fp(dim_input=27, dim_output=7, batch_size=25, network_co else: x_idx = x_idx + list(range(i, i+dim)) i += dim - + dim_input = i nn_input, action, precision = get_input_layer(dim_input, dim_output) state_input = nn_input[:, 0:x_idx[-1]+1]