From b2c86900753bdc80fe92268c1f352232c1a5740d Mon Sep 17 00:00:00 2001 From: lisab00 Date: Tue, 12 Sep 2023 17:06:20 +0200 Subject: [PATCH 1/3] hduva test file --- .../vae/compos/encoder_xydt_elevator.py | 7 ++ .../compos/vae/utils_request_chain_builder.py | 4 +- .../vae/zoo_vae_builders_classif_topic.py | 45 ++++++++++- tests/test_mk_exp_hduva.py | 74 +++++++++++++++++++ 4 files changed, 125 insertions(+), 5 deletions(-) create mode 100644 tests/test_mk_exp_hduva.py diff --git a/domainlab/compos/vae/compos/encoder_xydt_elevator.py b/domainlab/compos/vae/compos/encoder_xydt_elevator.py index f4c5e3546..b0c83abc6 100644 --- a/domainlab/compos/vae/compos/encoder_xydt_elevator.py +++ b/domainlab/compos/vae/compos/encoder_xydt_elevator.py @@ -85,6 +85,13 @@ def __init__(self, device, topic_dim, zd_dim, super().__init__(net_infer_zd_topic, net_infer_zx, net_infer_zy) +class XYDTEncoderArgUser(XYDTEncoderElevator): + + @store_args + def __init__(self, net_infer_zd_topic, net_infer_zx, net_infer_zy): + super().__init__(net_infer_zd_topic, net_infer_zx, net_infer_zy) + + # To remove class XYDTEncoderConvBnReluPool(XYDTEncoderElevator): """ diff --git a/domainlab/compos/vae/utils_request_chain_builder.py b/domainlab/compos/vae/utils_request_chain_builder.py index c740d5e94..fe5d5b5b6 100644 --- a/domainlab/compos/vae/utils_request_chain_builder.py +++ b/domainlab/compos/vae/utils_request_chain_builder.py @@ -1,6 +1,7 @@ from domainlab.compos.vae.zoo_vae_builders_classif import ( NodeVAEBuilderArg, NodeVAEBuilderUser, NodeVAEBuilderImgAlex, NodeVAEBuilderImgConvBnPool) -from domainlab.compos.vae.zoo_vae_builders_classif_topic import NodeVAEBuilderImgTopic +from domainlab.compos.vae.zoo_vae_builders_classif_topic import ( + NodeVAEBuilderImgTopic, NodeVAEBuilderImgTopicUser) class VAEChainNodeGetter(object): @@ -24,6 +25,7 @@ def __call__(self): """ if self.topic_dim is not None: chain = NodeVAEBuilderImgTopic(None) + chain = NodeVAEBuilderImgTopicUser(chain) else: chain = NodeVAEBuilderImgConvBnPool(None) chain = NodeVAEBuilderImgAlex(chain) diff --git a/domainlab/compos/vae/zoo_vae_builders_classif_topic.py b/domainlab/compos/vae/zoo_vae_builders_classif_topic.py index b8164224e..bdda79073 100644 --- a/domainlab/compos/vae/zoo_vae_builders_classif_topic.py +++ b/domainlab/compos/vae/zoo_vae_builders_classif_topic.py @@ -3,8 +3,8 @@ """ from domainlab.compos.vae.compos.decoder_concat_vec_reshape_conv_gated_conv import \ DecoderConcatLatentFCReshapeConvGatedConv -from domainlab.compos.vae.compos.encoder_xydt_elevator import XYDTEncoderArg -from domainlab.compos.vae.zoo_vae_builders_classif import NodeVAEBuilderArg +from domainlab.compos.vae.compos.encoder_xydt_elevator import (XYDTEncoderArg, XYDTEncoderArgUser) +from domainlab.compos.vae.zoo_vae_builders_classif import NodeVAEBuilderArg, NodeVAEBuilderUser class NodeVAEBuilderImgTopic(NodeVAEBuilderArg): @@ -14,8 +14,7 @@ def is_myjob(self, request): :param request: """ - self.args = request.args - flag = True # @FIXME + flag = hasattr(request, "args") self.config_img(flag, request) return flag @@ -46,3 +45,41 @@ def build_decoder(self, topic_dim): i_c=self.i_c, i_w=self.i_w, i_h=self.i_h) return decoder + + +class NodeVAEBuilderImgTopicUser(NodeVAEBuilderUser): + """NodeVAEBuilderImgTopic if user input does not come from command line""" + + def is_myjob(self, request): + """is_myjob. + + :param request: + """ + flag = not hasattr(request, "args") + self.request = request + self.config_img(flag, request) + return flag + + + def build_encoder(self, device, topic_dim): + """build_encoder. + + :param device: + :param topic_dim: + """ + encoder = XYDTEncoderArgUser(self.request.net_class_d, + self.request.net_x, + self.request.net_class_y) + return encoder + + + def build_decoder(self, topic_dim): + """build_decoder. + + :param topic_dim: + """ + decoder = DecoderConcatLatentFCReshapeConvGatedConv( + z_dim=self.zd_dim+self.zx_dim+self.zy_dim+topic_dim, + i_c=self.i_c, i_w=self.i_w, + i_h=self.i_h) + return decoder diff --git a/tests/test_mk_exp_hduva.py b/tests/test_mk_exp_hduva.py new file mode 100644 index 000000000..8ba5aafdf --- /dev/null +++ b/tests/test_mk_exp_hduva.py @@ -0,0 +1,74 @@ +""" +make an experiment +""" +from domainlab.compos.vae.compos.encoder import LSEncoderConvBnReluPool +from domainlab.mk_exp import mk_exp +from domainlab.dsets.dset_mnist_color_solo_default import DsetMNISTColorSoloDefault +from domainlab.tasks.task_dset import mk_task_dset +from domainlab.models.model_hduva import mk_hduva +from domainlab.tasks.utils_task import ImSize +from domainlab.compos.vae.utils_request_chain_builder import VAEChainNodeGetter +from domainlab.compos.pcr.request import RequestVAEBuilderNN + + +def test_mk_exp_hduva(): + """ + test mk experiment API with "hduva" model and trainers "mldg", "diva" + """ + + mk_exp_hduva(trainer="mldg") + mk_exp_hduva(trainer="diva") + + +def mk_exp_hduva(trainer="mldg"): + """ + execute experiment with "hduva" model and custom trainer + """ + + # specify domain generalization task + task = mk_task_dset(isize=ImSize(3, 28, 28), dim_y=10, taskna="custom_task") + task.add_domain(name="domain1", + dset_tr=DsetMNISTColorSoloDefault(0), + dset_val=DsetMNISTColorSoloDefault(1)) + task.add_domain(name="domain2", + dset_tr=DsetMNISTColorSoloDefault(2), + dset_val=DsetMNISTColorSoloDefault(3)) + task.add_domain(name="domain3", + dset_tr=DsetMNISTColorSoloDefault(4), + dset_val=DsetMNISTColorSoloDefault(5)) + + # specify backbone to use + zy_dim = 10 + zd_dim = 3 + list_str_y = [f"class{i}" for i in range(task.dim_y)] + list_d_tr = ["domain2", "domain3"] + gamma_d = 1e5 + gamma_y = 7e5 + beta_d = 1e3 + beta_x = 1e3 + beta_y = 1e3 + beta_t = 1e3 + device = "cpu" + zx_dim = 0 + topic_dim = 3 + net_class_y = LSEncoderConvBnReluPool(zy_dim, + task.isize.c, task.isize.w, task.isize.h, + conv_stride=1) + net_x = LSEncoderConvBnReluPool(zx_dim, + task.isize.c, task.isize.w, task.isize.h, + conv_stride=1) + net_class_d = LSEncoderConvBnReluPool(zd_dim, + task.isize.c, task.isize.w, task.isize.h, + conv_stride=1) + + request = RequestVAEBuilderNN(task.isize.c, task.isize.h, task.isize.w, + net_class_y, net_x, net_class_d) + chain_node_builder = VAEChainNodeGetter(request, topic_dim)() + + # specify model to use + model = mk_hduva()(chain_node_builder, zy_dim, zd_dim, list_str_y, list_d_tr, gamma_d, gamma_y, + beta_d, beta_x, beta_y, beta_t, device, zx_dim, topic_dim) + + # make trainer for model + exp = mk_exp(task, model, trainer=trainer, test_domain="domain1", batchsize=32) + exp.execute(num_epochs=3) From 81c41e1183726d372258abf71e913c03607b7eff Mon Sep 17 00:00:00 2001 From: lisab00 Date: Wed, 13 Sep 2023 14:57:47 +0200 Subject: [PATCH 2/3] . --- domainlab/compos/vae/compos/encoder_xydt_elevator.py | 4 ++-- tests/test_mk_exp_hduva.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/domainlab/compos/vae/compos/encoder_xydt_elevator.py b/domainlab/compos/vae/compos/encoder_xydt_elevator.py index b0c83abc6..1cc6b0f9b 100644 --- a/domainlab/compos/vae/compos/encoder_xydt_elevator.py +++ b/domainlab/compos/vae/compos/encoder_xydt_elevator.py @@ -88,8 +88,8 @@ def __init__(self, device, topic_dim, zd_dim, class XYDTEncoderArgUser(XYDTEncoderElevator): @store_args - def __init__(self, net_infer_zd_topic, net_infer_zx, net_infer_zy): - super().__init__(net_infer_zd_topic, net_infer_zx, net_infer_zy) + def __init__(self, net_class_d, net_x, net_class_y): + super().__init__(net_class_d, net_x, net_class_y) # To remove diff --git a/tests/test_mk_exp_hduva.py b/tests/test_mk_exp_hduva.py index 8ba5aafdf..aa6fd8f29 100644 --- a/tests/test_mk_exp_hduva.py +++ b/tests/test_mk_exp_hduva.py @@ -57,12 +57,13 @@ def mk_exp_hduva(trainer="mldg"): net_x = LSEncoderConvBnReluPool(zx_dim, task.isize.c, task.isize.w, task.isize.h, conv_stride=1) + # FIXME net_class_d = LSEncoderConvBnReluPool(zd_dim, task.isize.c, task.isize.w, task.isize.h, conv_stride=1) - request = RequestVAEBuilderNN(task.isize.c, task.isize.h, task.isize.w, - net_class_y, net_x, net_class_d) + request = RequestVAEBuilderNN(net_class_d, net_x, net_class_y, + task.isize.c, task.isize.h, task.isize.w) chain_node_builder = VAEChainNodeGetter(request, topic_dim)() # specify model to use From 575b82503503cd6b3e7dabd3a8ebb2a7732fc7f9 Mon Sep 17 00:00:00 2001 From: lisab00 Date: Wed, 13 Sep 2023 16:01:30 +0200 Subject: [PATCH 3/3] . --- domainlab/compos/vae/compos/encoder_xydt_elevator.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/domainlab/compos/vae/compos/encoder_xydt_elevator.py b/domainlab/compos/vae/compos/encoder_xydt_elevator.py index 1cc6b0f9b..2f10d3dd3 100644 --- a/domainlab/compos/vae/compos/encoder_xydt_elevator.py +++ b/domainlab/compos/vae/compos/encoder_xydt_elevator.py @@ -86,7 +86,9 @@ def __init__(self, device, topic_dim, zd_dim, class XYDTEncoderArgUser(XYDTEncoderElevator): - + """ + This class only reimplemented constructor of parent class + """ @store_args def __init__(self, net_class_d, net_x, net_class_y): super().__init__(net_class_d, net_x, net_class_y)