From 5526ac87ebf3983ee72e536344f2c5e36a45b868 Mon Sep 17 00:00:00 2001 From: eigenvivek Date: Sat, 15 Nov 2025 21:41:29 -0500 Subject: [PATCH 1/4] If input is RigidTransform, return input --- diffdrr/_modidx.py | 1 + diffdrr/pose.py | 8 ++++++++ notebooks/api/06_pose.ipynb | 8 ++++++++ 3 files changed, 17 insertions(+) diff --git a/diffdrr/_modidx.py b/diffdrr/_modidx.py index db423cd6..e7d92eda 100644 --- a/diffdrr/_modidx.py +++ b/diffdrr/_modidx.py @@ -87,6 +87,7 @@ 'diffdrr.pose.RigidTransform.__getitem__': ('api/pose.html#rigidtransform.__getitem__', 'diffdrr/pose.py'), 'diffdrr.pose.RigidTransform.__init__': ('api/pose.html#rigidtransform.__init__', 'diffdrr/pose.py'), 'diffdrr.pose.RigidTransform.__len__': ('api/pose.html#rigidtransform.__len__', 'diffdrr/pose.py'), + 'diffdrr.pose.RigidTransform.__new__': ('api/pose.html#rigidtransform.__new__', 'diffdrr/pose.py'), 'diffdrr.pose.RigidTransform.compose': ('api/pose.html#rigidtransform.compose', 'diffdrr/pose.py'), 'diffdrr.pose.RigidTransform.convert': ('api/pose.html#rigidtransform.convert', 'diffdrr/pose.py'), 'diffdrr.pose.RigidTransform.forward': ('api/pose.html#rigidtransform.forward', 'diffdrr/pose.py'), diff --git a/diffdrr/pose.py b/diffdrr/pose.py index f8096417..3e6d39d3 100644 --- a/diffdrr/pose.py +++ b/diffdrr/pose.py @@ -18,7 +18,15 @@ class RigidTransform(torch.nn.Module): inversion, and conversions to various representations of SE(3). """ + def __new__(cls, matrix, eps=1e-6): + if isinstance(matrix, cls): + return matrix + return super().__new__(cls) + def __init__(self, matrix, eps=1e-6): + if isinstance(matrix, type(self)): + return + super().__init__() if matrix.dim() == 2: matrix = matrix.unsqueeze(0) diff --git a/notebooks/api/06_pose.ipynb b/notebooks/api/06_pose.ipynb index b565ea2e..e648681b 100644 --- a/notebooks/api/06_pose.ipynb +++ b/notebooks/api/06_pose.ipynb @@ -102,7 +102,15 @@ " inversion, and conversions to various representations of SE(3).\n", " \"\"\"\n", "\n", + " def __new__(cls, matrix, eps=1e-6):\n", + " if isinstance(matrix, cls):\n", + " return matrix\n", + " return super().__new__(cls)\n", + "\n", " def __init__(self, matrix, eps=1e-6):\n", + " if isinstance(matrix, type(self)):\n", + " return \n", + "\n", " super().__init__()\n", " if matrix.dim() == 2:\n", " matrix = matrix.unsqueeze(0)\n", From 2aacd25e36082c17ec9b7212d6da002b5a3853a2 Mon Sep 17 00:00:00 2001 From: eigenvivek Date: Sat, 15 Nov 2025 21:42:13 -0500 Subject: [PATCH 2/4] Remove hardcoded class name --- diffdrr/pose.py | 4 ++-- notebooks/api/06_pose.ipynb | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/diffdrr/pose.py b/diffdrr/pose.py index 3e6d39d3..84904481 100644 --- a/diffdrr/pose.py +++ b/diffdrr/pose.py @@ -61,11 +61,11 @@ def inverse(self): matrix = make_matrix(Rinv, tinv) else: matrix = self.matrix.inverse() - return RigidTransform(matrix) + return type(self)(matrix) def compose(self, T): matrix = torch.einsum("bij, bjk -> bik", T.matrix, self.matrix) - return RigidTransform(matrix) + return type(self)(matrix) def convert(self, parameterization, convention=None, degrees=False): translation = -self.inverse().translation diff --git a/notebooks/api/06_pose.ipynb b/notebooks/api/06_pose.ipynb index e648681b..af96f535 100644 --- a/notebooks/api/06_pose.ipynb +++ b/notebooks/api/06_pose.ipynb @@ -145,11 +145,11 @@ " matrix = make_matrix(Rinv, tinv)\n", " else:\n", " matrix = self.matrix.inverse()\n", - " return RigidTransform(matrix)\n", + " return type(self)(matrix)\n", "\n", " def compose(self, T):\n", " matrix = torch.einsum(\"bij, bjk -> bik\", T.matrix, self.matrix)\n", - " return RigidTransform(matrix)\n", + " return type(self)(matrix)\n", "\n", " def convert(self, parameterization, convention=None, degrees=False):\n", " translation = -self.inverse().translation\n", From 5701db71878dcc0e621ad13a79f0605c091a1f52 Mon Sep 17 00:00:00 2001 From: eigenvivek Date: Sat, 15 Nov 2025 22:59:34 -0500 Subject: [PATCH 3/4] Make slicing return a RigidTransform --- diffdrr/pose.py | 2 +- notebooks/api/06_pose.ipynb | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/diffdrr/pose.py b/diffdrr/pose.py index 84904481..0c57b8ab 100644 --- a/diffdrr/pose.py +++ b/diffdrr/pose.py @@ -37,7 +37,7 @@ def __len__(self): return len(self.matrix) def __getitem__(self, idx): - return self.matrix[idx] + return type(self)(self.matrix[idx]) def forward(self, x): """Apply (a batch) of rigid transforms to a pointcloud.""" diff --git a/notebooks/api/06_pose.ipynb b/notebooks/api/06_pose.ipynb index af96f535..e0a0e205 100644 --- a/notebooks/api/06_pose.ipynb +++ b/notebooks/api/06_pose.ipynb @@ -121,7 +121,7 @@ " return len(self.matrix)\n", "\n", " def __getitem__(self, idx):\n", - " return self.matrix[idx]\n", + " return type(self)(self.matrix[idx])\n", "\n", " def forward(self, x):\n", " \"\"\"Apply (a batch) of rigid transforms to a pointcloud.\"\"\"\n", From df559ec175c3d08b27410e24433979d7036f2443 Mon Sep 17 00:00:00 2001 From: eigenvivek Date: Sat, 15 Nov 2025 23:01:32 -0500 Subject: [PATCH 4/4] Add a matmul overload --- diffdrr/_modidx.py | 1 + diffdrr/pose.py | 3 +++ notebooks/api/06_pose.ipynb | 3 +++ 3 files changed, 7 insertions(+) diff --git a/diffdrr/_modidx.py b/diffdrr/_modidx.py index e7d92eda..97b7b223 100644 --- a/diffdrr/_modidx.py +++ b/diffdrr/_modidx.py @@ -87,6 +87,7 @@ 'diffdrr.pose.RigidTransform.__getitem__': ('api/pose.html#rigidtransform.__getitem__', 'diffdrr/pose.py'), 'diffdrr.pose.RigidTransform.__init__': ('api/pose.html#rigidtransform.__init__', 'diffdrr/pose.py'), 'diffdrr.pose.RigidTransform.__len__': ('api/pose.html#rigidtransform.__len__', 'diffdrr/pose.py'), + 'diffdrr.pose.RigidTransform.__matmul__': ('api/pose.html#rigidtransform.__matmul__', 'diffdrr/pose.py'), 'diffdrr.pose.RigidTransform.__new__': ('api/pose.html#rigidtransform.__new__', 'diffdrr/pose.py'), 'diffdrr.pose.RigidTransform.compose': ('api/pose.html#rigidtransform.compose', 'diffdrr/pose.py'), 'diffdrr.pose.RigidTransform.convert': ('api/pose.html#rigidtransform.convert', 'diffdrr/pose.py'), diff --git a/diffdrr/pose.py b/diffdrr/pose.py index 0c57b8ab..92c9dcb2 100644 --- a/diffdrr/pose.py +++ b/diffdrr/pose.py @@ -39,6 +39,9 @@ def __len__(self): def __getitem__(self, idx): return type(self)(self.matrix[idx]) + def __matmul__(self, T): + return T.compose(self) + def forward(self, x): """Apply (a batch) of rigid transforms to a pointcloud.""" x_pad = torch.nn.functional.pad(x, (0, 1), value=1.0) diff --git a/notebooks/api/06_pose.ipynb b/notebooks/api/06_pose.ipynb index e0a0e205..8555a35e 100644 --- a/notebooks/api/06_pose.ipynb +++ b/notebooks/api/06_pose.ipynb @@ -123,6 +123,9 @@ " def __getitem__(self, idx):\n", " return type(self)(self.matrix[idx])\n", "\n", + " def __matmul__(self, T):\n", + " return T.compose(self)\n", + "\n", " def forward(self, x):\n", " \"\"\"Apply (a batch) of rigid transforms to a pointcloud.\"\"\"\n", " x_pad = torch.nn.functional.pad(x, (0, 1), value=1.0)\n",