From ce13573956dafd2e1817b4532447d1992c353d3e Mon Sep 17 00:00:00 2001 From: eigenvivek Date: Thu, 13 Nov 2025 17:19:21 -0500 Subject: [PATCH 1/6] Add roma to dependencies --- environment.yml | 2 ++ settings.ini | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/environment.yml b/environment.yml index 89948b86..65fdecac 100644 --- a/environment.yml +++ b/environment.yml @@ -20,3 +20,5 @@ dependencies: - tqdm - pyvista>=0.45.0 - vtk>9.4.0 + - pip: + - roma diff --git a/settings.ini b/settings.ini index 66823fff..6e26d393 100644 --- a/settings.ini +++ b/settings.ini @@ -27,7 +27,7 @@ language = English status = 3 user = eigenvivek requirements = matplotlib seaborn tqdm imageio fastcore 'pyvista[all]' einops torchvision scipy torchio timm numpy kornia -pip_requirements = torch +pip_requirements = torch roma conda_requirements = pytorch dev_requirements = nbdev black flake8 ipykernel ipywidgets jupyterlab jupyterlab_execute_time jupyterlab-code-formatter isort readme_nb = index.ipynb From b4ff46b609377a904d0d6fd48a12a5beb18f932b Mon Sep 17 00:00:00 2001 From: eigenvivek Date: Thu, 13 Nov 2025 17:19:38 -0500 Subject: [PATCH 2/6] Add rotation matrix check in inverse --- diffdrr/pose.py | 14 +++++++++----- notebooks/api/06_pose.ipynb | 14 +++++++++----- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/diffdrr/pose.py b/diffdrr/pose.py index 482005c9..2dddf99e 100644 --- a/diffdrr/pose.py +++ b/diffdrr/pose.py @@ -8,6 +8,7 @@ import torch from einops import rearrange +from roma import is_rotation_matrix class RigidTransform(torch.nn.Module): @@ -43,11 +44,14 @@ def translation(self): return self.matrix[..., :3, 3] def inverse(self): - R = self.matrix[..., :3, :3] - t = self.matrix[..., :3, 3] - Rinv = R.mT - tinv = -torch.einsum("bij, bj -> bi", Rinv, t) - matrix = make_matrix(Rinv, tinv) + if is_rotation_matrix(self.matrix[..., :3, :3]): + R = self.matrix[..., :3, :3] + t = self.matrix[..., :3, 3] + Rinv = R.mT + tinv = -torch.einsum("bij, bj -> bi", Rinv, t) + matrix = make_matrix(Rinv, tinv) + else: + matrix = self.matrix.inverse() return RigidTransform(matrix) def compose(self, T): diff --git a/notebooks/api/06_pose.ipynb b/notebooks/api/06_pose.ipynb index 4db67290..4469a805 100644 --- a/notebooks/api/06_pose.ipynb +++ b/notebooks/api/06_pose.ipynb @@ -92,6 +92,7 @@ "import torch\n", "\n", "from einops import rearrange\n", + "from roma import is_rotation_matrix\n", "\n", "\n", "class RigidTransform(torch.nn.Module):\n", @@ -127,11 +128,14 @@ " return self.matrix[..., :3, 3]\n", "\n", " def inverse(self):\n", - " R = self.matrix[..., :3, :3]\n", - " t = self.matrix[..., :3, 3]\n", - " Rinv = R.mT\n", - " tinv = -torch.einsum(\"bij, bj -> bi\", Rinv, t)\n", - " matrix = make_matrix(Rinv, tinv)\n", + " if is_rotation_matrix(self.matrix[..., :3, :3]):\n", + " R = self.matrix[..., :3, :3]\n", + " t = self.matrix[..., :3, 3]\n", + " Rinv = R.mT\n", + " tinv = -torch.einsum(\"bij, bj -> bi\", Rinv, t)\n", + " matrix = make_matrix(Rinv, tinv)\n", + " else:\n", + " matrix = self.matrix.inverse()\n", " return RigidTransform(matrix)\n", "\n", " def compose(self, T):\n", From cdf1e952e0cf3ad237db53ff502f08cf8c0cb792 Mon Sep 17 00:00:00 2001 From: eigenvivek Date: Thu, 13 Nov 2025 17:21:26 -0500 Subject: [PATCH 3/6] Make Rigid.Transform__getitem__ return a RigidTransform --- diffdrr/pose.py | 2 +- notebooks/api/06_pose.ipynb | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/diffdrr/pose.py b/diffdrr/pose.py index 2dddf99e..b346a751 100644 --- a/diffdrr/pose.py +++ b/diffdrr/pose.py @@ -28,7 +28,7 @@ def __len__(self): return len(self.matrix) def __getitem__(self, idx): - return self.matrix[idx] + return RigidTransform(self.matrix[idx]) def forward(self, x): """Apply (a batch) of rigid transforms to a pointcloud.""" diff --git a/notebooks/api/06_pose.ipynb b/notebooks/api/06_pose.ipynb index 4469a805..ed47ef0b 100644 --- a/notebooks/api/06_pose.ipynb +++ b/notebooks/api/06_pose.ipynb @@ -112,7 +112,7 @@ " return len(self.matrix)\n", "\n", " def __getitem__(self, idx):\n", - " return self.matrix[idx]\n", + " return RigidTransform(self.matrix[idx])\n", "\n", " def forward(self, x):\n", " \"\"\"Apply (a batch) of rigid transforms to a pointcloud.\"\"\"\n", From 37f87ee586577249362d5a6377af5cc015cee088 Mon Sep 17 00:00:00 2001 From: eigenvivek Date: Thu, 13 Nov 2025 17:24:35 -0500 Subject: [PATCH 4/6] Change check from rotation to orthonormal --- diffdrr/pose.py | 7 ++++--- notebooks/api/06_pose.ipynb | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/diffdrr/pose.py b/diffdrr/pose.py index b346a751..1ebf1bdb 100644 --- a/diffdrr/pose.py +++ b/diffdrr/pose.py @@ -8,7 +8,7 @@ import torch from einops import rearrange -from roma import is_rotation_matrix +from roma import is_orthonormal_matrix class RigidTransform(torch.nn.Module): @@ -18,11 +18,12 @@ class RigidTransform(torch.nn.Module): inversion, and conversions to various representations of SE(3). """ - def __init__(self, matrix): + def __init__(self, matrix, eps=1e-6): super().__init__() if matrix.dim() == 2: matrix = matrix.unsqueeze(0) self.register_buffer("matrix", matrix) + self.eps = eps def __len__(self): return len(self.matrix) @@ -44,7 +45,7 @@ def translation(self): return self.matrix[..., :3, 3] def inverse(self): - if is_rotation_matrix(self.matrix[..., :3, :3]): + if is_orthonormal_matrix(self.matrix[..., :3, :3], self.eps): R = self.matrix[..., :3, :3] t = self.matrix[..., :3, 3] Rinv = R.mT diff --git a/notebooks/api/06_pose.ipynb b/notebooks/api/06_pose.ipynb index ed47ef0b..c8d9027b 100644 --- a/notebooks/api/06_pose.ipynb +++ b/notebooks/api/06_pose.ipynb @@ -92,7 +92,7 @@ "import torch\n", "\n", "from einops import rearrange\n", - "from roma import is_rotation_matrix\n", + "from roma import is_orthonormal_matrix\n", "\n", "\n", "class RigidTransform(torch.nn.Module):\n", @@ -102,11 +102,12 @@ " inversion, and conversions to various representations of SE(3).\n", " \"\"\"\n", "\n", - " def __init__(self, matrix):\n", + " def __init__(self, matrix, eps=1e-6):\n", " super().__init__()\n", " if matrix.dim() == 2:\n", " matrix = matrix.unsqueeze(0)\n", " self.register_buffer(\"matrix\", matrix)\n", + " self.eps = eps\n", "\n", " def __len__(self):\n", " return len(self.matrix)\n", @@ -128,7 +129,7 @@ " return self.matrix[..., :3, 3]\n", "\n", " def inverse(self):\n", - " if is_rotation_matrix(self.matrix[..., :3, :3]):\n", + " if is_orthonormal_matrix(self.matrix[..., :3, :3], self.eps):\n", " R = self.matrix[..., :3, :3]\n", " t = self.matrix[..., :3, 3]\n", " Rinv = R.mT\n", From d16dc69b0c253d625c4245b83a0e59707b755b56 Mon Sep 17 00:00:00 2001 From: eigenvivek Date: Fri, 14 Nov 2025 23:52:58 -0500 Subject: [PATCH 5/6] Add roma to requirements --- settings.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/settings.ini b/settings.ini index 6e26d393..89fe1898 100644 --- a/settings.ini +++ b/settings.ini @@ -26,7 +26,7 @@ keywords = nbdev jupyter notebook python language = English status = 3 user = eigenvivek -requirements = matplotlib seaborn tqdm imageio fastcore 'pyvista[all]' einops torchvision scipy torchio timm numpy kornia +requirements = matplotlib seaborn tqdm imageio fastcore 'pyvista[all]' einops torchvision scipy torchio timm numpy kornia roma pip_requirements = torch roma conda_requirements = pytorch dev_requirements = nbdev black flake8 ipykernel ipywidgets jupyterlab jupyterlab_execute_time jupyterlab-code-formatter isort From b2273a819c214596f92ec7595ae4155e32b5b927 Mon Sep 17 00:00:00 2001 From: eigenvivek Date: Sat, 15 Nov 2025 21:12:03 -0500 Subject: [PATCH 6/6] Roll back Rigid.Transform__getitem__ return a RigidTransform --- diffdrr/pose.py | 2 +- notebooks/api/06_pose.ipynb | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/diffdrr/pose.py b/diffdrr/pose.py index 1ebf1bdb..f8096417 100644 --- a/diffdrr/pose.py +++ b/diffdrr/pose.py @@ -29,7 +29,7 @@ def __len__(self): return len(self.matrix) def __getitem__(self, idx): - return RigidTransform(self.matrix[idx]) + return self.matrix[idx] def forward(self, x): """Apply (a batch) of rigid transforms to a pointcloud.""" diff --git a/notebooks/api/06_pose.ipynb b/notebooks/api/06_pose.ipynb index c8d9027b..b565ea2e 100644 --- a/notebooks/api/06_pose.ipynb +++ b/notebooks/api/06_pose.ipynb @@ -113,7 +113,7 @@ " return len(self.matrix)\n", "\n", " def __getitem__(self, idx):\n", - " return RigidTransform(self.matrix[idx])\n", + " return self.matrix[idx]\n", "\n", " def forward(self, x):\n", " \"\"\"Apply (a batch) of rigid transforms to a pointcloud.\"\"\"\n",