Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions diffdrr/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@
'diffdrr.pose.RigidTransform.__getitem__': ('api/pose.html#rigidtransform.__getitem__', 'diffdrr/pose.py'),
'diffdrr.pose.RigidTransform.__init__': ('api/pose.html#rigidtransform.__init__', 'diffdrr/pose.py'),
'diffdrr.pose.RigidTransform.__len__': ('api/pose.html#rigidtransform.__len__', 'diffdrr/pose.py'),
'diffdrr.pose.RigidTransform.__matmul__': ('api/pose.html#rigidtransform.__matmul__', 'diffdrr/pose.py'),
'diffdrr.pose.RigidTransform.__new__': ('api/pose.html#rigidtransform.__new__', 'diffdrr/pose.py'),
'diffdrr.pose.RigidTransform.compose': ('api/pose.html#rigidtransform.compose', 'diffdrr/pose.py'),
'diffdrr.pose.RigidTransform.convert': ('api/pose.html#rigidtransform.convert', 'diffdrr/pose.py'),
'diffdrr.pose.RigidTransform.forward': ('api/pose.html#rigidtransform.forward', 'diffdrr/pose.py'),
Expand Down
17 changes: 14 additions & 3 deletions diffdrr/pose.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,15 @@ class RigidTransform(torch.nn.Module):
inversion, and conversions to various representations of SE(3).
"""

def __new__(cls, matrix, eps=1e-6):
if isinstance(matrix, cls):
return matrix
return super().__new__(cls)

def __init__(self, matrix, eps=1e-6):
if isinstance(matrix, type(self)):
return

super().__init__()
if matrix.dim() == 2:
matrix = matrix.unsqueeze(0)
Expand All @@ -29,7 +37,10 @@ def __len__(self):
return len(self.matrix)

def __getitem__(self, idx):
return self.matrix[idx]
return type(self)(self.matrix[idx])

def __matmul__(self, T):
return T.compose(self)

def forward(self, x):
"""Apply (a batch) of rigid transforms to a pointcloud."""
Expand All @@ -53,11 +64,11 @@ def inverse(self):
matrix = make_matrix(Rinv, tinv)
else:
matrix = self.matrix.inverse()
return RigidTransform(matrix)
return type(self)(matrix)

def compose(self, T):
matrix = torch.einsum("bij, bjk -> bik", T.matrix, self.matrix)
return RigidTransform(matrix)
return type(self)(matrix)

def convert(self, parameterization, convention=None, degrees=False):
translation = -self.inverse().translation
Expand Down
17 changes: 14 additions & 3 deletions notebooks/api/06_pose.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,15 @@
" inversion, and conversions to various representations of SE(3).\n",
" \"\"\"\n",
"\n",
" def __new__(cls, matrix, eps=1e-6):\n",
" if isinstance(matrix, cls):\n",
" return matrix\n",
" return super().__new__(cls)\n",
"\n",
" def __init__(self, matrix, eps=1e-6):\n",
" if isinstance(matrix, type(self)):\n",
" return \n",
"\n",
" super().__init__()\n",
" if matrix.dim() == 2:\n",
" matrix = matrix.unsqueeze(0)\n",
Expand All @@ -113,7 +121,10 @@
" return len(self.matrix)\n",
"\n",
" def __getitem__(self, idx):\n",
" return self.matrix[idx]\n",
" return type(self)(self.matrix[idx])\n",
"\n",
" def __matmul__(self, T):\n",
" return T.compose(self)\n",
"\n",
" def forward(self, x):\n",
" \"\"\"Apply (a batch) of rigid transforms to a pointcloud.\"\"\"\n",
Expand All @@ -137,11 +148,11 @@
" matrix = make_matrix(Rinv, tinv)\n",
" else:\n",
" matrix = self.matrix.inverse()\n",
" return RigidTransform(matrix)\n",
" return type(self)(matrix)\n",
"\n",
" def compose(self, T):\n",
" matrix = torch.einsum(\"bij, bjk -> bik\", T.matrix, self.matrix)\n",
" return RigidTransform(matrix)\n",
" return type(self)(matrix)\n",
"\n",
" def convert(self, parameterization, convention=None, degrees=False):\n",
" translation = -self.inverse().translation\n",
Expand Down
Loading