diff --git a/kornia/geometry/quaternion.py b/kornia/geometry/quaternion.py index 3d4d190cbfe..30b6b5a6de7 100644 --- a/kornia/geometry/quaternion.py +++ b/kornia/geometry/quaternion.py @@ -12,6 +12,7 @@ euler_from_quaternion, normalize_quaternion, quaternion_from_euler, + quaternion_to_axis_angle, quaternion_to_rotation_matrix, rotation_matrix_to_quaternion, ) @@ -262,7 +263,7 @@ def from_matrix(cls, matrix: Tensor) -> "Quaternion": @classmethod def from_euler(cls, roll: Tensor, pitch: Tensor, yaw: Tensor) -> "Quaternion": - """Create a quaternion from euler angles. + """Create a quaternion from Euler angles. Args: roll: the roll euler angle. @@ -281,10 +282,7 @@ def from_euler(cls, roll: Tensor, pitch: Tensor, yaw: Tensor) -> "Quaternion": return cls(q) def to_euler(self) -> Tuple[Tensor, Tensor, Tensor]: - """Create a quaternion from euler angles. - - Args: - matrix: the rotation matrix to convert of shape :math:`(B, 3, 3)`. + """Convert the quaternion to a triple of Euler angles (roll, pitch, yaw). Example: >>> q = Quaternion(tensor([2., 0., 1., 1.])) @@ -314,6 +312,17 @@ def from_axis_angle(cls, axis_angle: Tensor) -> "Quaternion": """ return cls(axis_angle_to_quaternion(axis_angle)) + def to_axis_angle(self) -> Tensor: + """Converts the quaternion to an axis-angle representation. + + Example: + >>> q = Quaternion.identity() + >>> axis_angle = q.to_axis_angle() + >>> axis_angle + tensor([0., 0., 0.], grad_fn=) + """ + return quaternion_to_axis_angle(self.data) + @classmethod def identity( cls, batch_size: Optional[int] = None, device: Optional[Device] = None, dtype: Dtype = None diff --git a/tests/geometry/test_quaternion.py b/tests/geometry/test_quaternion.py index 1bb98bac9f8..46a2ce48b3f 100644 --- a/tests/geometry/test_quaternion.py +++ b/tests/geometry/test_quaternion.py @@ -208,6 +208,39 @@ def test_axis_angle(self, device, dtype, batch_size): q2 = q2.to(device, dtype) self.assert_close(q1, q2) + @pytest.mark.parametrize("batch_size", (None, 1, 2, 5)) + def test_to_axis_angle(self, device, dtype, batch_size): + # batch_s = 5 + # random_coefs = Quaternion.random(batch_s).data + random_coefs = torch.tensor( + [ + [2.5398e-04, -2.2677e-01, -8.3897e-01, 4.9467e-01], + [-1.7005e-01, -1.0974e-01, 3.7635e-01, -9.0410e-01], + [9.1273e-01, 4.8935e-02, -6.2994e-03, 4.0558e-01], + [-9.8316e-01, 5.4078e-03, 1.4471e-01, 1.1145e-01], + [4.5794e-02, -7.0831e-01, 6.7577e-01, 1.9883e-01], + ], + device=device, + dtype=dtype, + ) + + q = Quaternion(random_coefs) + axis_angle_actual = q.to_axis_angle() + + axis_angle_expected = torch.tensor( + [ + [-0.7123, -2.6353, 1.5538], + [0.3118, -1.0693, 2.5687], + [0.1008, -0.0130, 0.8356], + [-0.0109, -0.2911, -0.2242], + [-2.1626, 2.0632, 0.6071], + ], + device=device, + dtype=dtype, + ) + + self.assert_close(axis_angle_expected, axis_angle_actual, 1e-4, 1e-4) + @pytest.mark.parametrize("batch_size", (None, 1, 2, 5)) def test_slerp(self, device, dtype, batch_size): for axis in torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]):