diff --git a/diffdrr/pose.py b/diffdrr/pose.py index 482005c9..f8096417 100644 --- a/diffdrr/pose.py +++ b/diffdrr/pose.py @@ -8,6 +8,7 @@ import torch from einops import rearrange +from roma import is_orthonormal_matrix class RigidTransform(torch.nn.Module): @@ -17,11 +18,12 @@ class RigidTransform(torch.nn.Module): inversion, and conversions to various representations of SE(3). """ - def __init__(self, matrix): + def __init__(self, matrix, eps=1e-6): super().__init__() if matrix.dim() == 2: matrix = matrix.unsqueeze(0) self.register_buffer("matrix", matrix) + self.eps = eps def __len__(self): return len(self.matrix) @@ -43,11 +45,14 @@ def translation(self): return self.matrix[..., :3, 3] def inverse(self): - R = self.matrix[..., :3, :3] - t = self.matrix[..., :3, 3] - Rinv = R.mT - tinv = -torch.einsum("bij, bj -> bi", Rinv, t) - matrix = make_matrix(Rinv, tinv) + if is_orthonormal_matrix(self.matrix[..., :3, :3], self.eps): + R = self.matrix[..., :3, :3] + t = self.matrix[..., :3, 3] + Rinv = R.mT + tinv = -torch.einsum("bij, bj -> bi", Rinv, t) + matrix = make_matrix(Rinv, tinv) + else: + matrix = self.matrix.inverse() return RigidTransform(matrix) def compose(self, T): diff --git a/environment.yml b/environment.yml index 89948b86..65fdecac 100644 --- a/environment.yml +++ b/environment.yml @@ -20,3 +20,5 @@ dependencies: - tqdm - pyvista>=0.45.0 - vtk>9.4.0 + - pip: + - roma diff --git a/notebooks/api/06_pose.ipynb b/notebooks/api/06_pose.ipynb index 4db67290..b565ea2e 100644 --- a/notebooks/api/06_pose.ipynb +++ b/notebooks/api/06_pose.ipynb @@ -92,6 +92,7 @@ "import torch\n", "\n", "from einops import rearrange\n", + "from roma import is_orthonormal_matrix\n", "\n", "\n", "class RigidTransform(torch.nn.Module):\n", @@ -101,11 +102,12 @@ " inversion, and conversions to various representations of SE(3).\n", " \"\"\"\n", "\n", - " def __init__(self, matrix):\n", + " def __init__(self, matrix, eps=1e-6):\n", " super().__init__()\n", " if matrix.dim() == 2:\n", " matrix = matrix.unsqueeze(0)\n", " self.register_buffer(\"matrix\", matrix)\n", + " self.eps = eps\n", "\n", " def __len__(self):\n", " return len(self.matrix)\n", @@ -127,11 +129,14 @@ " return self.matrix[..., :3, 3]\n", "\n", " def inverse(self):\n", - " R = self.matrix[..., :3, :3]\n", - " t = self.matrix[..., :3, 3]\n", - " Rinv = R.mT\n", - " tinv = -torch.einsum(\"bij, bj -> bi\", Rinv, t)\n", - " matrix = make_matrix(Rinv, tinv)\n", + " if is_orthonormal_matrix(self.matrix[..., :3, :3], self.eps):\n", + " R = self.matrix[..., :3, :3]\n", + " t = self.matrix[..., :3, 3]\n", + " Rinv = R.mT\n", + " tinv = -torch.einsum(\"bij, bj -> bi\", Rinv, t)\n", + " matrix = make_matrix(Rinv, tinv)\n", + " else:\n", + " matrix = self.matrix.inverse()\n", " return RigidTransform(matrix)\n", "\n", " def compose(self, T):\n", diff --git a/settings.ini b/settings.ini index 66823fff..89fe1898 100644 --- a/settings.ini +++ b/settings.ini @@ -26,8 +26,8 @@ keywords = nbdev jupyter notebook python language = English status = 3 user = eigenvivek -requirements = matplotlib seaborn tqdm imageio fastcore 'pyvista[all]' einops torchvision scipy torchio timm numpy kornia -pip_requirements = torch +requirements = matplotlib seaborn tqdm imageio fastcore 'pyvista[all]' einops torchvision scipy torchio timm numpy kornia roma +pip_requirements = torch roma conda_requirements = pytorch dev_requirements = nbdev black flake8 ipykernel ipywidgets jupyterlab jupyterlab_execute_time jupyterlab-code-formatter isort readme_nb = index.ipynb