diff --git a/kornia/augmentation/container/augment.py b/kornia/augmentation/container/augment.py index b87327f9648..a9cad919248 100644 --- a/kornia/augmentation/container/augment.py +++ b/kornia/augmentation/container/augment.py @@ -498,15 +498,22 @@ def __call__( if len(inputs) == 1 and isinstance(inputs[0], dict): original_keys, in_data_keys, inputs, invalid_data = self._preproc_dict_data(inputs[0]) else: - in_data_keys = kwargs["data_keys"] if "data_keys" in kwargs else self.data_keys + in_data_keys = kwargs.get("data_keys", self.data_keys) data_keys = self.transform_op.preproc_datakeys(in_data_keys) - if len(data_keys) > 1 and data_keys.index(DataKey.INPUT): + if len(data_keys) > 1 and DataKey.INPUT in data_keys: # NOTE: we may update it later for more supports of drawing boxes, etc. idx = data_keys.index(DataKey.INPUT) if output_type == "tensor": self._output_image = _output_image - self._output_image[idx] = self._detach_tensor_to_cpu(_output_image[idx]) + if isinstance(_output_image, dict): + self._output_image[original_keys[idx]] = self._detach_tensor_to_cpu( + _output_image[original_keys[idx]] + ) + else: + self._output_image[idx] = self._detach_tensor_to_cpu(_output_image[idx]) + elif isinstance(_output_image, dict): + self._output_image[original_keys[idx]] = _output_image[original_keys[idx]] else: self._output_image[idx] = _output_image[idx] else: diff --git a/tests/augmentation/test_augmentation_mix.py b/tests/augmentation/test_augmentation_mix.py index 3bb6573362a..50ad9a81696 100644 --- a/tests/augmentation/test_augmentation_mix.py +++ b/tests/augmentation/test_augmentation_mix.py @@ -520,6 +520,28 @@ def test_data_keys(self, wrapper, device, dtype): self.assert_close(image_out, image_out2) self.assert_close(mask_out, mask_out2) + @pytest.mark.parametrize("wrapper", [AugmentationSequential]) + def test_dict_input(self, wrapper, device, dtype): + torch.manual_seed(22) + image = torch.rand(4, 3, 10, 10, device=device, dtype=dtype) + mask = torch.randint(0, 2, (4, 10, 10), device=device, dtype=dtype) + + f = wrapper(RandomTransplantation(p=1), data_keys=None) + torch.manual_seed(22) + dict_input = {"image": image, "mask": mask} + aug_dict_output = f(dict_input) + torch.manual_seed(22) + dict_input2 = {"mask": mask, "image": image} + aug_dict_output2 = f(dict_input2) + + image_out = aug_dict_output["image"] + mask_out = aug_dict_output["mask"] + image_out2 = aug_dict_output2["image"] + mask_out2 = aug_dict_output2["mask"] + + self.assert_close(image_out, image_out2) + self.assert_close(mask_out, mask_out2) + @pytest.mark.parametrize("n_spatial", [2, 3]) def test_sequential(self, n_spatial, device, dtype): torch.manual_seed(22)