From cdd397ed9e27910978dd29bc50274662920fe3f7 Mon Sep 17 00:00:00 2001 From: fhoekstra <32362869+fhoekstra@users.noreply.github.com> Date: Fri, 29 Sep 2017 11:56:32 +0200 Subject: [PATCH] Infer dim_input in tf_model_example I found this fixed an error that was otherwise raised by get_input_layer in my attempt at a box2d_arm_img experiment. I am not sure where in the initialization process of the experiment the dim_input is passed along, but found out that it was zero here. Adding this 1 line per multi_modal_network fixes the issue by inferring dim_input from the given obs_include hyperparameter. --- python/gps/algorithm/policy_opt/tf_model_example.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/gps/algorithm/policy_opt/tf_model_example.py b/python/gps/algorithm/policy_opt/tf_model_example.py index 5259aefcb..479c2a4cf 100644 --- a/python/gps/algorithm/policy_opt/tf_model_example.py +++ b/python/gps/algorithm/policy_opt/tf_model_example.py @@ -119,7 +119,7 @@ def multi_modal_network(dim_input=27, dim_output=7, batch_size=25, network_confi else: x_idx = x_idx + list(range(i, i+dim)) i += dim - + dim_input = i nn_input, action, precision = get_input_layer(dim_input, dim_output) state_input = nn_input[:, 0:x_idx[-1]+1] @@ -193,7 +193,7 @@ def multi_modal_network_fp(dim_input=27, dim_output=7, batch_size=25, network_co else: x_idx = x_idx + list(range(i, i+dim)) i += dim - + dim_input = i nn_input, action, precision = get_input_layer(dim_input, dim_output) state_input = nn_input[:, 0:x_idx[-1]+1]