From 962d24a1521ef855c0b87ba28b78f4983cdd6c37 Mon Sep 17 00:00:00 2001 From: Angus Stewart Date: Mon, 1 Jul 2024 22:25:32 +0100 Subject: [PATCH 1/3] feat: add reshape methods --- src/matrix.rs | 50 ++++++++++++++++++++++++++++++++++++++++++++++++++ src/storage.rs | 2 +- src/tensor.rs | 42 +++++++++++++++++++++++++++++++++++++++++- 3 files changed, 92 insertions(+), 2 deletions(-) diff --git a/src/matrix.rs b/src/matrix.rs index 215be90..1a01880 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -47,6 +47,19 @@ impl DynamicMatrix { Self::fill(shape, T::one()) } + pub fn reshape(&self, shape: &Shape) -> Result, ShapeError> { + if shape.order() != 2 { + return Err(ShapeError::new("Shape must have order of 2")); + } + let result = self.tensor.reshape(shape)?; + Ok(DynamicMatrix { tensor: result }) + } + + pub fn flatten(&self) -> DynamicVector { + let data = self.tensor.raw().iter().cloned().collect::>(); + DynamicVector::new(&data).unwrap() + } + pub fn sum(&self, axes: Axes) -> DynamicVector { let result = self.tensor.sum(axes); DynamicVector::from_tensor(result).unwrap() @@ -328,6 +341,43 @@ mod tests { assert_eq!(matrix[coord![1, 1].unwrap()], 1.0); } + #[test] + fn test_reshape() { + let shape = shape![2, 2].unwrap(); + let data = vec![1.0, 2.0, 3.0, 4.0]; + let matrix = DynamicMatrix::new(&shape, &data).unwrap(); + let new_shape = shape![4, 1].unwrap(); + let reshaped_matrix = matrix.reshape(&new_shape).unwrap(); + assert_eq!(reshaped_matrix.shape(), &new_shape); + assert_eq!(reshaped_matrix[coord![0, 0].unwrap()], 1.0); + assert_eq!(reshaped_matrix[coord![1, 0].unwrap()], 2.0); + assert_eq!(reshaped_matrix[coord![2, 0].unwrap()], 3.0); + assert_eq!(reshaped_matrix[coord![3, 0].unwrap()], 4.0); + } + + #[test] + fn test_reshape_fail() { + let shape = shape![2, 2].unwrap(); + let data = vec![1.0, 2.0, 3.0, 4.0]; + let matrix = DynamicMatrix::new(&shape, &data).unwrap(); + let new_shape = shape![3, 2].unwrap(); + let result = matrix.reshape(&new_shape); + assert!(result.is_err()); + } + + #[test] + fn test_flatten() { + let shape = shape![2, 2].unwrap(); + let data = vec![1.0, 2.0, 3.0, 4.0]; + let matrix = DynamicMatrix::new(&shape, &data).unwrap(); + let flattened_vector = matrix.flatten(); + assert_eq!(flattened_vector.shape(), &shape![4].unwrap()); + assert_eq!(flattened_vector[0], 1.0); + assert_eq!(flattened_vector[1], 2.0); + assert_eq!(flattened_vector[2], 3.0); + assert_eq!(flattened_vector[3], 4.0); + } + #[test] fn test_size() { let shape = shape![2, 2].unwrap(); diff --git a/src/storage.rs b/src/storage.rs index b51fbdf..c5b58f6 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -4,7 +4,7 @@ use crate::coordinate::Coordinate; use crate::error::ShapeError; use crate::shape::Shape; -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Clone)] pub struct DynamicStorage { data: Vec, } diff --git a/src/tensor.rs b/src/tensor.rs index 595dc26..480c375 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -43,7 +43,21 @@ impl Tensor { Tensor::fill(shape, T::one()) } + pub fn reshape(&self, shape: &Shape) -> Result, ShapeError> { + if self.shape.size() != shape.size() { + return Err(ShapeError::new("Data length does not match shape size")); + } + Ok(Tensor { + data: self.data.clone(), + shape: shape.clone(), + }) + } + // Properties + pub fn raw(&self) -> &DynamicStorage { + &self.data + } + pub fn shape(&self) -> &Shape { &self.shape } @@ -51,6 +65,7 @@ impl Tensor { self.shape.size() } + // Access methods pub fn get(&self, coord: &Coordinate) -> Result<&T, ShapeError> { Ok(&self.data[self.data.flatten(coord, &self.shape)?]) } @@ -66,7 +81,7 @@ impl Tensor { Ok(()) } - // // Reduction operations + // Reduction operations pub fn sum(&self, axes: Axes) -> Tensor { let all_axes = (0..self.shape.order()).collect::>(); let remaining_axes = all_axes @@ -594,6 +609,31 @@ mod tests { assert_eq!(tensor.data, DynamicStorage::new(vec![1.0; shape.size()])); } + #[test] + fn test_reshape_tensor() { + let shape = shape![2, 3].unwrap(); + let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; + let tensor = Tensor::new(&shape, &data).unwrap(); + + let new_shape = shape![3, 2].unwrap(); + let reshaped_tensor = tensor.reshape(&new_shape).unwrap(); + + assert_eq!(reshaped_tensor.shape(), &new_shape); + assert_eq!(reshaped_tensor.data, DynamicStorage::new(data)); + } + + #[test] + fn test_reshape_tensor_shape_mismatch() { + let shape = shape![2, 3].unwrap(); + let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; + let tensor = Tensor::new(&shape, &data).unwrap(); + + let new_shape = shape![4, 2].unwrap(); + let result = tensor.reshape(&new_shape); + + assert!(result.is_err()); + } + #[test] fn test_fill_tensor() { let shape = shape![2, 3].unwrap(); From 59501f089ecc9e7291eeea5c186106e893208083 Mon Sep 17 00:00:00 2001 From: Angus Stewart Date: Sat, 13 Jul 2024 17:58:35 +0100 Subject: [PATCH 2/3] refactor: use reshape interface --- src/matrix.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/matrix.rs b/src/matrix.rs index 1a01880..bcf51c3 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -56,8 +56,9 @@ impl DynamicMatrix { } pub fn flatten(&self) -> DynamicVector { - let data = self.tensor.raw().iter().cloned().collect::>(); - DynamicVector::new(&data).unwrap() + let flattened_shape = Shape::new(vec![self.tensor.size()]).unwrap(); + let result = self.tensor.reshape(&flattened_shape).unwrap(); + DynamicVector::from_tensor(result).unwrap() } pub fn sum(&self, axes: Axes) -> DynamicVector { From b1bfe5c29974b084d910296b8e67d83e7fda5ec4 Mon Sep 17 00:00:00 2001 From: Angus Stewart Date: Sun, 14 Jul 2024 12:33:57 +0100 Subject: [PATCH 3/3] refactor!: make reshape interface consistent --- src/matrix.rs | 8 ++++++-- src/vector.rs | 4 ++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/matrix.rs b/src/matrix.rs index bcf51c3..6d9d6c6 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -47,7 +47,7 @@ impl DynamicMatrix { Self::fill(shape, T::one()) } - pub fn reshape(&self, shape: &Shape) -> Result, ShapeError> { + pub fn redimension(&self, shape: &Shape) -> Result, ShapeError> { if shape.order() != 2 { return Err(ShapeError::new("Shape must have order of 2")); } @@ -55,6 +55,10 @@ impl DynamicMatrix { Ok(DynamicMatrix { tensor: result }) } + pub fn reshape(&self, shape: &Shape) -> Result, ShapeError> { + self.tensor.reshape(shape) + } + pub fn flatten(&self) -> DynamicVector { let flattened_shape = Shape::new(vec![self.tensor.size()]).unwrap(); let result = self.tensor.reshape(&flattened_shape).unwrap(); @@ -348,7 +352,7 @@ mod tests { let data = vec![1.0, 2.0, 3.0, 4.0]; let matrix = DynamicMatrix::new(&shape, &data).unwrap(); let new_shape = shape![4, 1].unwrap(); - let reshaped_matrix = matrix.reshape(&new_shape).unwrap(); + let reshaped_matrix = matrix.redimension(&new_shape).unwrap(); assert_eq!(reshaped_matrix.shape(), &new_shape); assert_eq!(reshaped_matrix[coord![0, 0].unwrap()], 1.0); assert_eq!(reshaped_matrix[coord![1, 0].unwrap()], 2.0); diff --git a/src/vector.rs b/src/vector.rs index ff2b939..a6c8c69 100644 --- a/src/vector.rs +++ b/src/vector.rs @@ -42,6 +42,10 @@ impl DynamicVector { Self::fill(shape, T::one()) } + pub fn reshape(&self, shape: &Shape) -> Result, ShapeError> { + self.tensor.reshape(shape) + } + pub fn sum(&self) -> DynamicVector { let result = self.tensor.sum(vec![]); DynamicVector::from_tensor(result).unwrap()