diff --git a/diffdrr/_modidx.py b/diffdrr/_modidx.py index db423cd6..97b7b223 100644 --- a/diffdrr/_modidx.py +++ b/diffdrr/_modidx.py @@ -87,6 +87,8 @@ 'diffdrr.pose.RigidTransform.__getitem__': ('api/pose.html#rigidtransform.__getitem__', 'diffdrr/pose.py'), 'diffdrr.pose.RigidTransform.__init__': ('api/pose.html#rigidtransform.__init__', 'diffdrr/pose.py'), 'diffdrr.pose.RigidTransform.__len__': ('api/pose.html#rigidtransform.__len__', 'diffdrr/pose.py'), + 'diffdrr.pose.RigidTransform.__matmul__': ('api/pose.html#rigidtransform.__matmul__', 'diffdrr/pose.py'), + 'diffdrr.pose.RigidTransform.__new__': ('api/pose.html#rigidtransform.__new__', 'diffdrr/pose.py'), 'diffdrr.pose.RigidTransform.compose': ('api/pose.html#rigidtransform.compose', 'diffdrr/pose.py'), 'diffdrr.pose.RigidTransform.convert': ('api/pose.html#rigidtransform.convert', 'diffdrr/pose.py'), 'diffdrr.pose.RigidTransform.forward': ('api/pose.html#rigidtransform.forward', 'diffdrr/pose.py'), diff --git a/diffdrr/pose.py b/diffdrr/pose.py index f8096417..92c9dcb2 100644 --- a/diffdrr/pose.py +++ b/diffdrr/pose.py @@ -18,7 +18,15 @@ class RigidTransform(torch.nn.Module): inversion, and conversions to various representations of SE(3). """ + def __new__(cls, matrix, eps=1e-6): + if isinstance(matrix, cls): + return matrix + return super().__new__(cls) + def __init__(self, matrix, eps=1e-6): + if isinstance(matrix, type(self)): + return + super().__init__() if matrix.dim() == 2: matrix = matrix.unsqueeze(0) @@ -29,7 +37,10 @@ def __len__(self): return len(self.matrix) def __getitem__(self, idx): - return self.matrix[idx] + return type(self)(self.matrix[idx]) + + def __matmul__(self, T): + return T.compose(self) def forward(self, x): """Apply (a batch) of rigid transforms to a pointcloud.""" @@ -53,11 +64,11 @@ def inverse(self): matrix = make_matrix(Rinv, tinv) else: matrix = self.matrix.inverse() - return RigidTransform(matrix) + return type(self)(matrix) def compose(self, T): matrix = torch.einsum("bij, bjk -> bik", T.matrix, self.matrix) - return RigidTransform(matrix) + return type(self)(matrix) def convert(self, parameterization, convention=None, degrees=False): translation = -self.inverse().translation diff --git a/notebooks/api/06_pose.ipynb b/notebooks/api/06_pose.ipynb index b565ea2e..8555a35e 100644 --- a/notebooks/api/06_pose.ipynb +++ b/notebooks/api/06_pose.ipynb @@ -102,7 +102,15 @@ " inversion, and conversions to various representations of SE(3).\n", " \"\"\"\n", "\n", + " def __new__(cls, matrix, eps=1e-6):\n", + " if isinstance(matrix, cls):\n", + " return matrix\n", + " return super().__new__(cls)\n", + "\n", " def __init__(self, matrix, eps=1e-6):\n", + " if isinstance(matrix, type(self)):\n", + " return \n", + "\n", " super().__init__()\n", " if matrix.dim() == 2:\n", " matrix = matrix.unsqueeze(0)\n", @@ -113,7 +121,10 @@ " return len(self.matrix)\n", "\n", " def __getitem__(self, idx):\n", - " return self.matrix[idx]\n", + " return type(self)(self.matrix[idx])\n", + "\n", + " def __matmul__(self, T):\n", + " return T.compose(self)\n", "\n", " def forward(self, x):\n", " \"\"\"Apply (a batch) of rigid transforms to a pointcloud.\"\"\"\n", @@ -137,11 +148,11 @@ " matrix = make_matrix(Rinv, tinv)\n", " else:\n", " matrix = self.matrix.inverse()\n", - " return RigidTransform(matrix)\n", + " return type(self)(matrix)\n", "\n", " def compose(self, T):\n", " matrix = torch.einsum(\"bij, bjk -> bik\", T.matrix, self.matrix)\n", - " return RigidTransform(matrix)\n", + " return type(self)(matrix)\n", "\n", " def convert(self, parameterization, convention=None, degrees=False):\n", " translation = -self.inverse().translation\n",