From 736d06961203f89d7c5ba0720df5152c48c16a6e Mon Sep 17 00:00:00 2001 From: Angus Stewart Date: Sun, 30 Jun 2024 22:57:19 +0100 Subject: [PATCH 1/2] feat: add PartialEq trait --- src/matrix.rs | 145 +++++-------------------- src/tensor.rs | 286 ++++++++++++-------------------------------------- src/vector.rs | 149 +++++--------------------- 3 files changed, 123 insertions(+), 457 deletions(-) diff --git a/src/matrix.rs b/src/matrix.rs index 215be90..92fa556 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -10,6 +10,7 @@ use crate::tensor::DynamicTensor; use crate::vector::DynamicVector; use num::{Float, Num}; +#[derive(Debug, PartialEq)] pub struct DynamicMatrix { tensor: DynamicTensor, } @@ -263,11 +264,7 @@ mod tests { let data = vec![1.0, 2.0, 3.0, 4.0]; let tensor = DynamicTensor::new(&shape, &data).unwrap(); let matrix = DynamicMatrix::from_tensor(tensor).unwrap(); - assert_eq!(matrix.shape(), &shape); - assert_eq!(matrix[coord![0, 0].unwrap()], 1.0); - assert_eq!(matrix[coord![0, 1].unwrap()], 2.0); - assert_eq!(matrix[coord![1, 0].unwrap()], 3.0); - assert_eq!(matrix[coord![1, 1].unwrap()], 4.0); + assert_eq!(matrix, DynamicMatrix::new(&shape, &data).unwrap()); } #[test] @@ -283,49 +280,28 @@ mod tests { fn test_fill() { let shape = shape![2, 2].unwrap(); let matrix = DynamicMatrix::fill(&shape, 3.0).unwrap(); - assert_eq!(matrix.shape(), &shape); - assert_eq!(matrix[coord![0, 0].unwrap()], 3.0); - assert_eq!(matrix[coord![0, 1].unwrap()], 3.0); - assert_eq!(matrix[coord![1, 0].unwrap()], 3.0); - assert_eq!(matrix[coord![1, 1].unwrap()], 3.0); + assert_eq!(matrix, DynamicMatrix::new(&shape, &[3.0; 4]).unwrap()); } #[test] fn test_eye() { let shape = shape![3, 3].unwrap(); let matrix = DynamicMatrix::::eye(&shape).unwrap(); - assert_eq!(matrix.shape(), &shape); - assert_eq!(matrix[coord![0, 0].unwrap()], 1.0); - assert_eq!(matrix[coord![0, 1].unwrap()], 0.0); - assert_eq!(matrix[coord![0, 2].unwrap()], 0.0); - assert_eq!(matrix[coord![1, 0].unwrap()], 0.0); - assert_eq!(matrix[coord![1, 1].unwrap()], 1.0); - assert_eq!(matrix[coord![1, 2].unwrap()], 0.0); - assert_eq!(matrix[coord![2, 0].unwrap()], 0.0); - assert_eq!(matrix[coord![2, 1].unwrap()], 0.0); - assert_eq!(matrix[coord![2, 2].unwrap()], 1.0); + assert_eq!(matrix, DynamicMatrix::new(&shape, &[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]).unwrap()); } #[test] fn test_zeros() { let shape = shape![2, 2].unwrap(); let matrix = DynamicMatrix::::zeros(&shape).unwrap(); - assert_eq!(matrix.shape(), &shape); - assert_eq!(matrix[coord![0, 0].unwrap()], 0.0); - assert_eq!(matrix[coord![0, 1].unwrap()], 0.0); - assert_eq!(matrix[coord![1, 0].unwrap()], 0.0); - assert_eq!(matrix[coord![1, 1].unwrap()], 0.0); + assert_eq!(matrix, DynamicMatrix::new(&shape, &[0.0; 4]).unwrap()); } #[test] fn test_ones() { let shape = shape![2, 2].unwrap(); let matrix = DynamicMatrix::::ones(&shape).unwrap(); - assert_eq!(matrix.shape(), &shape); - assert_eq!(matrix[coord![0, 0].unwrap()], 1.0); - assert_eq!(matrix[coord![0, 1].unwrap()], 1.0); - assert_eq!(matrix[coord![1, 0].unwrap()], 1.0); - assert_eq!(matrix[coord![1, 1].unwrap()], 1.0); + assert_eq!(matrix, DynamicMatrix::new(&shape, &[1.0; 4]).unwrap()); } #[test] @@ -349,11 +325,7 @@ mod tests { let data = vec![1.0, 2.0, 3.0, 4.0]; let mut matrix = DynamicMatrix::new(&shape, &data).unwrap(); matrix[coord![1, 0].unwrap()] = 5.0; - assert_eq!(matrix.shape(), &shape); - assert_eq!(matrix[coord![0, 0].unwrap()], 1.0); - assert_eq!(matrix[coord![0, 1].unwrap()], 2.0); - assert_eq!(matrix[coord![1, 0].unwrap()], 5.0); - assert_eq!(matrix[coord![1, 1].unwrap()], 4.0); + assert_eq!(matrix, DynamicMatrix::new(&shape, &[1.0, 2.0, 5.0, 4.0]).unwrap()); } #[test] @@ -362,11 +334,7 @@ mod tests { let data = vec![1.0, 2.0, 3.0, 4.0]; let mut matrix = DynamicMatrix::new(&shape, &data).unwrap(); matrix.set(&coord![1, 0].unwrap(), 5.0).unwrap(); - assert_eq!(matrix.shape(), &shape); - assert_eq!(matrix[coord![0, 0].unwrap()], 1.0); - assert_eq!(matrix[coord![0, 1].unwrap()], 2.0); - assert_eq!(matrix[coord![1, 0].unwrap()], 5.0); - assert_eq!(matrix[coord![1, 1].unwrap()], 4.0); + assert_eq!(matrix, DynamicMatrix::new(&shape, &[1.0, 2.0, 5.0, 4.0]).unwrap()); } #[test] @@ -375,8 +343,7 @@ mod tests { let data = vec![1.0, 2.0, 3.0, 4.0]; let matrix = DynamicMatrix::new(&shape, &data).unwrap(); let result = matrix.sum(vec![0, 1]); - assert_eq!(result[0], 10.0); - assert_eq!(result.shape(), &shape![1].unwrap()); + assert_eq!(result, DynamicVector::new(&[10.0]).unwrap()); } #[test] @@ -385,8 +352,7 @@ mod tests { let data = vec![1.0, 2.0, 3.0, 4.0]; let matrix = DynamicMatrix::new(&shape, &data).unwrap(); let result = matrix.mean(vec![0, 1]); - assert_eq!(result[0], 2.5); - assert_eq!(result.shape(), &shape![1].unwrap()); + assert_eq!(result, DynamicVector::new(&[2.5]).unwrap()); } #[test] @@ -395,8 +361,7 @@ mod tests { let data = vec![1.0, 2.0, 3.0, 4.0]; let matrix = DynamicMatrix::new(&shape, &data).unwrap(); let result = matrix.var(vec![0, 1]); - assert_eq!(result[0], 1.25); - assert_eq!(result.shape(), &shape![1].unwrap()); + assert_eq!(result, DynamicVector::new(&[1.25]).unwrap()); } #[test] @@ -405,8 +370,7 @@ mod tests { let data = vec![1.0, 2.0, 3.0, 4.0]; let matrix = DynamicMatrix::new(&shape, &data).unwrap(); let result = matrix.min(vec![0, 1]); - assert_eq!(result[0], 1.0); - assert_eq!(result.shape(), &shape![1].unwrap()); + assert_eq!(result, DynamicVector::new(&[1.0]).unwrap()); } #[test] @@ -415,8 +379,7 @@ mod tests { let data = vec![-1.0, -2.0, -3.0, -4.0]; let matrix = DynamicMatrix::new(&shape, &data).unwrap(); let result = matrix.max(vec![0, 1]); - assert_eq!(result[0], -1.0); - assert_eq!(result.shape(), &shape![1].unwrap()); + assert_eq!(result, DynamicVector::new(&[-1.0]).unwrap()); } #[test] @@ -427,11 +390,7 @@ mod tests { let matrix1 = DynamicMatrix::new(&shape, &data1).unwrap(); let matrix2 = DynamicMatrix::new(&shape, &data2).unwrap(); let result = matrix1.matmul(&matrix2); - assert_eq!(result.shape(), &shape); - assert_eq!(result[coord![0, 0].unwrap()], 10.0); - assert_eq!(result[coord![0, 1].unwrap()], 13.0); - assert_eq!(result[coord![1, 0].unwrap()], 22.0); - assert_eq!(result[coord![1, 1].unwrap()], 29.0); + assert_eq!(result, DynamicMatrix::new(&shape, &[10.0, 13.0, 22.0, 29.0]).unwrap()); } #[test] @@ -442,9 +401,7 @@ mod tests { let vector_data = vec![1.0, 2.0]; let vector = DynamicVector::new(&vector_data).unwrap(); let result = matrix.vecmul(&vector); - assert_eq!(result.shape(), &shape![2].unwrap()); - assert_eq!(result[0], 5.0); - assert_eq!(result[1], 11.0); + assert_eq!(result, DynamicVector::new(&[5.0, 11.0]).unwrap()); } #[test] @@ -453,11 +410,7 @@ mod tests { let data = vec![1.0, 2.0, 3.0, 4.0]; let matrix = DynamicMatrix::new(&shape, &data).unwrap(); let result = matrix + 2.0; - assert_eq!(result[coord![0, 0].unwrap()], 3.0); - assert_eq!(result[coord![0, 1].unwrap()], 4.0); - assert_eq!(result[coord![1, 0].unwrap()], 5.0); - assert_eq!(result[coord![1, 1].unwrap()], 6.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicMatrix::new(&shape, &[3.0, 4.0, 5.0, 6.0]).unwrap()); } #[test] @@ -468,11 +421,7 @@ mod tests { let matrix1 = DynamicMatrix::new(&shape, &data1).unwrap(); let matrix2 = DynamicMatrix::new(&shape, &data2).unwrap(); let result = matrix1 + matrix2; - assert_eq!(result[coord![0, 0].unwrap()], 3.0); - assert_eq!(result[coord![0, 1].unwrap()], 5.0); - assert_eq!(result[coord![1, 0].unwrap()], 7.0); - assert_eq!(result[coord![1, 1].unwrap()], 9.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicMatrix::new(&shape, &[3.0, 5.0, 7.0, 9.0]).unwrap()); } #[test] @@ -483,11 +432,7 @@ mod tests { let matrix = DynamicMatrix::new(&shape, &data1).unwrap(); let tensor = DynamicTensor::new(&shape, &data2).unwrap(); let result = matrix + tensor; - assert_eq!(result[coord![0, 0].unwrap()], 3.0); - assert_eq!(result[coord![0, 1].unwrap()], 5.0); - assert_eq!(result[coord![1, 0].unwrap()], 7.0); - assert_eq!(result[coord![1, 1].unwrap()], 9.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicMatrix::new(&shape, &[3.0, 5.0, 7.0, 9.0]).unwrap()); } #[test] @@ -496,11 +441,7 @@ mod tests { let data = vec![1.0, 2.0, 3.0, 4.0]; let matrix = DynamicMatrix::new(&shape, &data).unwrap(); let result = matrix - 2.0; - assert_eq!(result[coord![0, 0].unwrap()], -1.0); - assert_eq!(result[coord![0, 1].unwrap()], 0.0); - assert_eq!(result[coord![1, 0].unwrap()], 1.0); - assert_eq!(result[coord![1, 1].unwrap()], 2.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicMatrix::new(&shape, &[-1.0, 0.0, 1.0, 2.0]).unwrap()); } #[test] @@ -511,11 +452,7 @@ mod tests { let matrix1 = DynamicMatrix::new(&shape, &data1).unwrap(); let matrix2 = DynamicMatrix::new(&shape, &data2).unwrap(); let result = matrix1 - matrix2; - assert_eq!(result[coord![0, 0].unwrap()], -1.0); - assert_eq!(result[coord![0, 1].unwrap()], -1.0); - assert_eq!(result[coord![1, 0].unwrap()], -1.0); - assert_eq!(result[coord![1, 1].unwrap()], -1.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicMatrix::new(&shape, &[-1.0; 4]).unwrap()); } #[test] @@ -526,11 +463,7 @@ mod tests { let matrix = DynamicMatrix::new(&shape, &data1).unwrap(); let tensor = DynamicTensor::new(&shape, &data2).unwrap(); let result = matrix - tensor; - assert_eq!(result[coord![0, 0].unwrap()], -1.0); - assert_eq!(result[coord![0, 1].unwrap()], -1.0); - assert_eq!(result[coord![1, 0].unwrap()], -1.0); - assert_eq!(result[coord![1, 1].unwrap()], -1.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicMatrix::new(&shape, &[-1.0; 4]).unwrap()); } #[test] @@ -539,11 +472,7 @@ mod tests { let data = vec![1.0, 2.0, 3.0, 4.0]; let matrix = DynamicMatrix::new(&shape, &data).unwrap(); let result = matrix * 2.0; - assert_eq!(result[coord![0, 0].unwrap()], 2.0); - assert_eq!(result[coord![0, 1].unwrap()], 4.0); - assert_eq!(result[coord![1, 0].unwrap()], 6.0); - assert_eq!(result[coord![1, 1].unwrap()], 8.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicMatrix::new(&shape, &[2.0, 4.0, 6.0, 8.0]).unwrap()); } #[test] @@ -554,11 +483,7 @@ mod tests { let matrix1 = DynamicMatrix::new(&shape, &data1).unwrap(); let matrix2 = DynamicMatrix::new(&shape, &data2).unwrap(); let result = matrix1 * matrix2; - assert_eq!(result[coord![0, 0].unwrap()], 2.0); - assert_eq!(result[coord![0, 1].unwrap()], 6.0); - assert_eq!(result[coord![1, 0].unwrap()], 12.0); - assert_eq!(result[coord![1, 1].unwrap()], 20.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicMatrix::new(&shape, &[2.0, 6.0, 12.0, 20.0]).unwrap()); } #[test] @@ -569,11 +494,7 @@ mod tests { let matrix = DynamicMatrix::new(&shape, &data1).unwrap(); let tensor = DynamicTensor::new(&shape, &data2).unwrap(); let result = matrix * tensor; - assert_eq!(result[coord![0, 0].unwrap()], 2.0); - assert_eq!(result[coord![0, 1].unwrap()], 6.0); - assert_eq!(result[coord![1, 0].unwrap()], 12.0); - assert_eq!(result[coord![1, 1].unwrap()], 20.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicMatrix::new(&shape, &[2.0, 6.0, 12.0, 20.0]).unwrap()); } #[test] @@ -582,11 +503,7 @@ mod tests { let data = vec![4.0, 6.0, 8.0, 10.0]; let matrix = DynamicMatrix::new(&shape, &data).unwrap(); let result = matrix / 2.0; - assert_eq!(result[coord![0, 0].unwrap()], 2.0); - assert_eq!(result[coord![0, 1].unwrap()], 3.0); - assert_eq!(result[coord![1, 0].unwrap()], 4.0); - assert_eq!(result[coord![1, 1].unwrap()], 5.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicMatrix::new(&shape, &[2.0, 3.0, 4.0, 5.0]).unwrap()); } #[test] @@ -597,11 +514,7 @@ mod tests { let matrix1 = DynamicMatrix::new(&shape, &data1).unwrap(); let matrix2 = DynamicMatrix::new(&shape, &data2).unwrap(); let result = matrix1 / matrix2; - assert_eq!(result[coord![0, 0].unwrap()], 2.0); - assert_eq!(result[coord![0, 1].unwrap()], 2.0); - assert_eq!(result[coord![1, 0].unwrap()], 2.0); - assert_eq!(result[coord![1, 1].unwrap()], 2.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicMatrix::new(&shape, &[2.0; 4]).unwrap()); } #[test] @@ -612,11 +525,7 @@ mod tests { let matrix = DynamicMatrix::new(&shape, &data1).unwrap(); let tensor = DynamicTensor::new(&shape, &data2).unwrap(); let result = matrix / tensor; - assert_eq!(result[coord![0, 0].unwrap()], 2.0); - assert_eq!(result[coord![0, 1].unwrap()], 2.0); - assert_eq!(result[coord![1, 0].unwrap()], 2.0); - assert_eq!(result[coord![1, 1].unwrap()], 2.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicMatrix::new(&shape, &[2.0; 4]).unwrap()); } #[test] diff --git a/src/tensor.rs b/src/tensor.rs index 595dc26..3de9f73 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -11,7 +11,7 @@ use crate::shape::Shape; use crate::storage::DynamicStorage; use crate::vector::DynamicVector; -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub struct DynamicTensor { data: DynamicStorage, shape: Shape, @@ -572,7 +572,6 @@ mod tests { let data = vec![1.0, 2.0, 3.0]; // Mismatched data length let result = Tensor::new(&shape, &data); - assert!(result.is_err()); } @@ -580,34 +579,27 @@ mod tests { fn test_zeros_tensor() { let shape = shape![2, 3].unwrap(); let tensor: Tensor = Tensor::zeros(&shape); - - assert_eq!(tensor.shape(), &shape); - assert_eq!(tensor.data, DynamicStorage::new(vec![0.0; shape.size()])); + assert_eq!(tensor, DynamicTensor::new(&shape, &[0.0; 6]).unwrap()); } #[test] fn test_ones_tensor() { let shape = shape![2, 3].unwrap(); let tensor: Tensor = Tensor::ones(&shape); - - assert_eq!(tensor.shape(), &shape); - assert_eq!(tensor.data, DynamicStorage::new(vec![1.0; shape.size()])); + assert_eq!(tensor, DynamicTensor::new(&shape, &[1.0; 6]).unwrap()); } #[test] fn test_fill_tensor() { let shape = shape![2, 3].unwrap(); let tensor: Tensor = Tensor::fill(&shape, 7.0); - - assert_eq!(tensor.shape(), &shape); - assert_eq!(tensor.data, DynamicStorage::new(vec![7.0; shape.size()])); + assert_eq!(tensor, DynamicTensor::new(&shape, &[7.0; 6]).unwrap()); } #[test] fn test_tensor_shape() { let shape = shape![2, 3].unwrap(); let tensor: Tensor = Tensor::zeros(&shape); - assert_eq!(tensor.shape(), &shape); } @@ -615,7 +607,6 @@ mod tests { fn test_tensor_size() { let shape = shape![2, 3].unwrap(); let tensor: Tensor = Tensor::zeros(&shape); - assert_eq!(tensor.size(), 6); } @@ -641,11 +632,7 @@ mod tests { tensor.set(&coord![0, 1].unwrap(), 6.0).unwrap(); tensor.set(&coord![1, 0].unwrap(), 7.0).unwrap(); tensor.set(&coord![1, 1].unwrap(), 8.0).unwrap(); - - assert_eq!(*tensor.get(&coord![0, 0].unwrap()).unwrap(), 5.0); - assert_eq!(*tensor.get(&coord![0, 1].unwrap()).unwrap(), 6.0); - assert_eq!(*tensor.get(&coord![1, 0].unwrap()).unwrap(), 7.0); - assert_eq!(*tensor.get(&coord![1, 1].unwrap()).unwrap(), 8.0); + assert_eq!(tensor, DynamicTensor::new(&shape, &[5.0, 6.0, 7.0, 8.0]).unwrap()); } #[test] @@ -677,9 +664,7 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.sum(vec![]); - - assert_eq!(result.shape(), &shape![1].unwrap()); - assert_eq!(result.data, DynamicStorage::new(vec![15.0])); + assert_eq!(result, DynamicTensor::new(&shape![1].unwrap(), &[15.0]).unwrap()); } #[test] @@ -689,9 +674,7 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.sum(vec![]); - - assert_eq!(result.shape(), &shape![1].unwrap()); - assert_eq!(result.data, DynamicStorage::new(vec![21.0])); + assert_eq!(result, DynamicTensor::new(&shape![1].unwrap(), &[21.0]).unwrap()); } #[test] @@ -703,9 +686,7 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.sum(vec![]); - - assert_eq!(result.shape(), &shape![1].unwrap()); - assert_eq!(result.data, DynamicStorage::new(vec![78.0])); + assert_eq!(result, DynamicTensor::new(&shape![1].unwrap(), &[78.0]).unwrap()); } #[test] @@ -715,9 +696,7 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.sum(vec![0]); - - assert_eq!(result.shape(), &shape![1].unwrap()); - assert_eq!(result.data, DynamicStorage::new(vec![15.0])); + assert_eq!(result, DynamicTensor::new(&shape![1].unwrap(), &[15.0]).unwrap()); } #[test] @@ -727,9 +706,7 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.sum(vec![0]); - - assert_eq!(result.shape(), &shape![3].unwrap()); - assert_eq!(result.data, DynamicStorage::new(vec![5.0, 7.0, 9.0])); + assert_eq!(result, DynamicTensor::new(&shape![3].unwrap(), &[5.0, 7.0, 9.0]).unwrap()); } #[test] @@ -741,12 +718,7 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.sum(vec![0]); - - assert_eq!(result.shape(), &shape![2, 3].unwrap()); - assert_eq!( - result.data, - DynamicStorage::new(vec![8.0, 10.0, 12.0, 14.0, 16.0, 18.0]) - ); + assert_eq!(result, DynamicTensor::new(&shape![2, 3].unwrap(), &[8.0, 10.0, 12.0, 14.0, 16.0, 18.0]).unwrap()); } #[test] @@ -756,9 +728,7 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.sum(vec![0, 1]); - - assert_eq!(result.shape(), &shape![1].unwrap()); - assert_eq!(result.data, DynamicStorage::new(vec![21.0])); + assert_eq!(result, DynamicTensor::new(&shape![1].unwrap(), &[21.0]).unwrap()); } #[test] @@ -770,9 +740,7 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.sum(vec![0, 1]); - - assert_eq!(result.shape(), &shape![3].unwrap()); - assert_eq!(result.data, DynamicStorage::new(vec![22.0, 26.0, 30.0])); + assert_eq!(result, DynamicTensor::new(&shape![3].unwrap(), &[22.0, 26.0, 30.0]).unwrap()); } #[test] @@ -782,9 +750,7 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.mean(vec![]); - - assert_eq!(result.shape(), &shape![1].unwrap()); - assert_eq!(result.data, DynamicStorage::new(vec![3.5])); + assert_eq!(result, DynamicTensor::new(&shape![1].unwrap(), &[3.5]).unwrap()); } #[test] @@ -794,9 +760,7 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.mean(vec![0]); - - assert_eq!(result.shape(), &shape![3].unwrap()); - assert_eq!(result.data, DynamicStorage::new(vec![2.5, 3.5, 4.5])); + assert_eq!(result, DynamicTensor::new(&shape![3].unwrap(), &[2.5, 3.5, 4.5]).unwrap()); } #[test] @@ -808,12 +772,7 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.mean(vec![0]); - - assert_eq!(result.shape(), &shape![2, 3].unwrap()); - assert_eq!( - result.data, - DynamicStorage::new(vec![4.0, 5.0, 6.0, 7.0, 8.0, 9.0]) - ); + assert_eq!(result, DynamicTensor::new(&shape![2, 3].unwrap(), &[4.0, 5.0, 6.0, 7.0, 8.0, 9.0]).unwrap()); } #[test] @@ -823,9 +782,7 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.mean(vec![0, 1]); - - assert_eq!(result.shape(), &shape![1].unwrap()); - assert_eq!(result.data, DynamicStorage::new(vec![3.5])); + assert_eq!(result, DynamicTensor::new(&shape![1].unwrap(), &[3.5]).unwrap()); } #[test] @@ -837,9 +794,7 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.mean(vec![0, 1]); - - assert_eq!(result.shape(), &shape![3].unwrap()); - assert_eq!(result.data, DynamicStorage::new(vec![5.5, 6.5, 7.5])); + assert_eq!(result, DynamicTensor::new(&shape![3].unwrap(), &[5.5, 6.5, 7.5]).unwrap()); } #[test] @@ -849,9 +804,7 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.var(vec![]); - - assert_eq!(result.shape(), &shape![1].unwrap()); - assert_eq!(result.data, DynamicStorage::new(vec![9.0])); + assert_eq!(result, DynamicTensor::new(&shape![1].unwrap(), &[9.0]).unwrap()); } #[test] @@ -861,9 +814,7 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.var(vec![0]); - - assert_eq!(result.shape(), &shape![3].unwrap()); - assert_eq!(result.data, DynamicStorage::new(vec![2.25, 2.25, 2.25])); + assert_eq!(result, DynamicTensor::new(&shape![3].unwrap(), &[2.25; 3]).unwrap()); } #[test] @@ -875,12 +826,7 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.var(vec![0]); - - assert_eq!(result.shape(), &shape![2, 3].unwrap()); - assert_eq!( - result.data, - DynamicStorage::new(vec![9.0, 9.0, 9.0, 9.0, 9.0, 9.0]) - ); + assert_eq!(result, DynamicTensor::new(&shape![2, 3].unwrap(), &[9.0; 6]).unwrap()); } #[test] @@ -890,9 +836,7 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.var(vec![0, 1]); - - assert_eq!(result.shape(), &shape![1].unwrap()); - assert_eq!(result.data, DynamicStorage::new(vec![9.0])); + assert_eq!(result, DynamicTensor::new(&shape![1].unwrap(), &[9.0]).unwrap()); } #[test] @@ -904,9 +848,7 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.var(vec![0, 1]); - - assert_eq!(result.shape(), &shape![3].unwrap()); - assert_eq!(result.data, DynamicStorage::new(vec![45.0, 45.0, 45.0])); + assert_eq!(result, DynamicTensor::new(&shape![3].unwrap(), &[45.0; 3]).unwrap()); } #[test] @@ -916,9 +858,7 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.max(vec![]); - - assert_eq!(result.shape(), &shape![1].unwrap()); - assert_eq!(result.data, DynamicStorage::new(vec![5.0])); + assert_eq!(result, DynamicTensor::new(&shape![1].unwrap(), &[5.0]).unwrap()); } #[test] @@ -928,9 +868,7 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.max(vec![0]); - - assert_eq!(result.shape(), &shape![3].unwrap()); - assert_eq!(result.data, DynamicStorage::new(vec![1.0, 5.0, 3.0])); + assert_eq!(result, DynamicTensor::new(&shape![3].unwrap(), &[1.0, 5.0, 3.0]).unwrap()); } #[test] @@ -942,9 +880,7 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.max(vec![0, 1]); - - assert_eq!(result.shape(), &shape![3].unwrap()); - assert_eq!(result.data, DynamicStorage::new(vec![7.0, 11.0, 9.0])); + assert_eq!(result, DynamicTensor::new(&shape![3].unwrap(), &[7.0, 11.0, 9.0]).unwrap()); } #[test] @@ -954,9 +890,7 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.min(vec![]); - - assert_eq!(result.shape(), &shape![1].unwrap()); - assert_eq!(result.data, DynamicStorage::new(vec![-4.0])); + assert_eq!(result, DynamicTensor::new(&shape![1].unwrap(), &[-4.0]).unwrap()); } #[test] @@ -966,9 +900,7 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.min(vec![0]); - - assert_eq!(result.shape(), &shape![3].unwrap()); - assert_eq!(result.data, DynamicStorage::new(vec![-4.0, -2.0, -6.0])); + assert_eq!(result, DynamicTensor::new(&shape![3].unwrap(), &[-4.0, -2.0, -6.0]).unwrap()); } #[test] @@ -980,9 +912,7 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.min(vec![0, 1]); - - assert_eq!(result.shape(), &shape![3].unwrap()); - assert_eq!(result.data, DynamicStorage::new(vec![-10.0, -8.0, -12.0])); + assert_eq!(result, DynamicTensor::new(&shape![3].unwrap(), &[-10.0, -8.0, -12.0]).unwrap()); } #[test] @@ -996,12 +926,7 @@ mod tests { let tensor2 = Tensor::new(&shape2, &data2).unwrap(); let result = tensor1.prod(&tensor2); - - assert_eq!(result.shape(), &shape![3, 2].unwrap()); - assert_eq!( - result.data, - DynamicStorage::new(vec![4.0, 5.0, 8.0, 10.0, 12.0, 15.0]) - ); + assert_eq!(result, DynamicTensor::new(&shape![3, 2].unwrap(), &[4.0, 5.0, 8.0, 10.0, 12.0, 15.0]).unwrap()); } #[test] @@ -1015,34 +940,19 @@ mod tests { let tensor2 = Tensor::new(&shape2, &data2).unwrap(); let result = tensor1.prod(&tensor2); - - assert_eq!(result.shape(), &shape![2, 2, 2].unwrap()); - assert_eq!( - result.data, - DynamicStorage::new(vec![5.0, 6.0, 10.0, 12.0, 15.0, 18.0, 20.0, 24.0]) - ); + assert_eq!(result, DynamicTensor::new(&shape![2, 2, 2].unwrap(), &[5.0, 6.0, 10.0, 12.0, 15.0, 18.0, 20.0, 24.0]).unwrap()); } #[test] fn test_tensor_prod_2d_2d() { - let shape1 = shape![2, 2].unwrap(); + let shape = shape![2, 2].unwrap(); let data1 = vec![1.0, 2.0, 3.0, 4.0]; - let tensor1 = Tensor::new(&shape1, &data1).unwrap(); - - let shape2 = shape![2, 2].unwrap(); + let tensor1 = Tensor::new(&shape, &data1).unwrap(); let data2 = vec![5.0, 6.0, 7.0, 8.0]; - let tensor2 = Tensor::new(&shape2, &data2).unwrap(); + let tensor2 = Tensor::new(&shape, &data2).unwrap(); let result = tensor1.prod(&tensor2); - - assert_eq!(result.shape(), &shape![2, 2, 2, 2].unwrap()); - assert_eq!( - result.data, - DynamicStorage::new(vec![ - 5.0, 6.0, 7.0, 8.0, 10.0, 12.0, 14.0, 16.0, 15.0, 18.0, 21.0, 24.0, 20.0, 24.0, - 28.0, 32.0 - ]) - ); + assert_eq!(result, DynamicTensor::new(&shape![2, 2, 2, 2].unwrap(), &[5.0, 6.0, 7.0, 8.0, 10.0, 12.0, 14.0, 16.0, 15.0, 18.0, 21.0, 24.0, 20.0, 24.0, 28.0, 32.0]).unwrap()); } #[test] @@ -1052,9 +962,7 @@ mod tests { let tensor1 = Tensor::new(&shape, &data1).unwrap(); let result = tensor1 + 3.0; - - assert_eq!(result.shape(), &shape); - assert_eq!(result.data, DynamicStorage::new(vec![4.0, 5.0, 6.0, 7.0])); + assert_eq!(result, DynamicTensor::new(&shape, &[4.0, 5.0, 6.0, 7.0]).unwrap()); } #[test] @@ -1066,22 +974,17 @@ mod tests { let tensor2 = Tensor::new(&shape, &data2).unwrap(); let result = tensor1 + tensor2; - - assert_eq!(result.shape(), &shape); - assert_eq!(result.data, DynamicStorage::new(vec![6.0, 8.0, 10.0, 12.0])); + assert_eq!(result, DynamicTensor::new(&shape, &[6.0, 8.0, 10.0, 12.0]).unwrap()); } #[test] fn test_sub_tensor() { let shape = shape![4].unwrap(); let data1 = vec![5.0, 6.0, 7.0, 8.0]; - let tensor1 = Tensor::new(&shape, &data1).unwrap(); let result = tensor1 - 3.0; - - assert_eq!(result.shape(), &shape); - assert_eq!(result.data, DynamicStorage::new(vec![2.0, 3.0, 4.0, 5.0])); + assert_eq!(result, DynamicTensor::new(&shape, &[2.0, 3.0, 4.0, 5.0]).unwrap()); } #[test] @@ -1093,50 +996,37 @@ mod tests { let tensor2 = Tensor::new(&shape, &data2).unwrap(); let result = tensor1 - tensor2; - - assert_eq!(result.shape(), &shape); - assert_eq!(result.data, DynamicStorage::new(vec![4.0, 4.0, 4.0, 4.0])); + assert_eq!(result, DynamicTensor::new(&shape, &[4.0; 4]).unwrap()); } #[test] fn test_mul_tensor() { let shape = shape![4].unwrap(); let data1 = vec![1.0, 2.0, 3.0, 4.0]; - let tensor1 = Tensor::new(&shape, &data1).unwrap(); let result = tensor1 * 2.0; - - assert_eq!(result.shape(), &shape); - assert_eq!(result.data, DynamicStorage::new(vec![2.0, 4.0, 6.0, 8.0])); + assert_eq!(result, DynamicTensor::new(&shape, &[2.0, 4.0, 6.0, 8.0]).unwrap()); } #[test] fn test_div_tensor() { let shape = shape![4].unwrap(); let data1 = vec![4.0, 6.0, 8.0, 10.0]; - let tensor1 = Tensor::new(&shape, &data1).unwrap(); let result = tensor1 / 2.0; - - assert_eq!(result.shape(), &shape); - assert_eq!(result.data, DynamicStorage::new(vec![2.0, 3.0, 4.0, 5.0])); + assert_eq!(result, DynamicTensor::new(&shape, &[2.0, 3.0, 4.0, 5.0]).unwrap()); } #[test] fn test_vec_vec_mul_single() { let shape = shape![1].unwrap(); - let data1 = vec![2.0]; - let data2 = vec![5.0]; - - let tensor1 = Tensor::new(&shape, &data1).unwrap(); - let tensor2 = Tensor::new(&shape, &data2).unwrap(); + let tensor1 = Tensor::new(&shape, &[2.0]).unwrap(); + let tensor2 = Tensor::new(&shape, &[5.0]).unwrap(); let result = tensor1 * tensor2; - - assert_eq!(result.shape(), &shape![1].unwrap()); - assert_eq!(result.data, DynamicStorage::new(vec![10.0])); + assert_eq!(result, DynamicTensor::new(&shape, &[10.0]).unwrap()); } #[test] @@ -1144,33 +1034,23 @@ mod tests { let shape = shape![4].unwrap(); let data1 = vec![1.0, 2.0, 3.0, 4.0]; let data2 = vec![2.0, 3.0, 4.0, 5.0]; - let tensor1 = Tensor::new(&shape, &data1).unwrap(); let tensor2 = Tensor::new(&shape, &data2).unwrap(); let result = tensor1 * tensor2; - - assert_eq!(result.shape(), &shape); - assert_eq!(result.data, DynamicStorage::new(vec![2.0, 6.0, 12.0, 20.0])); + assert_eq!(result, DynamicTensor::new(&shape, &[2.0, 6.0, 12.0, 20.0]).unwrap()); } #[test] fn test_matrix_matrix_mul() { - let shape1 = shape![2, 3].unwrap(); - let shape2 = shape![2, 3].unwrap(); + let shape = shape![2, 3].unwrap(); let data1 = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; let data2 = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0]; - - let tensor1 = Tensor::new(&shape1, &data1).unwrap(); - let tensor2 = Tensor::new(&shape2, &data2).unwrap(); + let tensor1 = Tensor::new(&shape, &data1).unwrap(); + let tensor2 = Tensor::new(&shape, &data2).unwrap(); let result = tensor1 * tensor2; - - assert_eq!(result.shape(), &shape![2, 3].unwrap()); - assert_eq!( - result.data, - DynamicStorage::new(vec![7.0, 16.0, 27.0, 40.0, 55.0, 72.0]) - ); + assert_eq!(result, DynamicTensor::new(&shape, &[7.0, 16.0, 27.0, 40.0, 55.0, 72.0]).unwrap()); } #[test] @@ -1180,12 +1060,9 @@ mod tests { let data2 = vec![2.0, 3.0, 4.0, 5.0]; let tensor = DynamicTensor::new(&shape, &data1).unwrap(); let vector = DynamicVector::new(&data2).unwrap(); + let result = tensor + vector; - assert_eq!(result[0], 3.0); - assert_eq!(result[1], 5.0); - assert_eq!(result[2], 7.0); - assert_eq!(result[3], 9.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicVector::new(&[3.0, 5.0, 7.0, 9.0]).unwrap()); } #[test] @@ -1195,12 +1072,9 @@ mod tests { let data2 = vec![1.0, 2.0, 3.0, 4.0]; let tensor = DynamicTensor::new(&shape, &data1).unwrap(); let vector = DynamicVector::new(&data2).unwrap(); + let result = tensor - vector; - assert_eq!(result[0], 1.0); - assert_eq!(result[1], 1.0); - assert_eq!(result[2], 1.0); - assert_eq!(result[3], 1.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicVector::new(&[1.0; 4]).unwrap()); } #[test] @@ -1210,12 +1084,9 @@ mod tests { let data2 = vec![1.0, 2.0, 3.0, 4.0]; let tensor = DynamicTensor::new(&shape, &data1).unwrap(); let vector = DynamicVector::new(&data2).unwrap(); + let result = tensor * vector; - assert_eq!(result[0], 2.0); - assert_eq!(result[1], 6.0); - assert_eq!(result[2], 12.0); - assert_eq!(result[3], 20.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicVector::new(&[2.0, 6.0, 12.0, 20.0]).unwrap()); } #[test] @@ -1225,12 +1096,9 @@ mod tests { let data2 = vec![1.0, 2.0, 3.0, 4.0]; let tensor = DynamicTensor::new(&shape, &data1).unwrap(); let vector = DynamicVector::new(&data2).unwrap(); + let result = tensor / vector; - assert_eq!(result[0], 2.0); - assert_eq!(result[1], 2.0); - assert_eq!(result[2], 2.0); - assert_eq!(result[3], 2.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicVector::new(&[2.0; 4]).unwrap()); } #[test] @@ -1240,12 +1108,9 @@ mod tests { let data2 = vec![2.0, 3.0, 4.0, 5.0]; let tensor = DynamicTensor::new(&shape, &data1).unwrap(); let matrix = DynamicMatrix::new(&shape, &data2).unwrap(); + let result = tensor + matrix; - assert_eq!(result[coord![0, 0].unwrap()], 3.0); - assert_eq!(result[coord![0, 1].unwrap()], 5.0); - assert_eq!(result[coord![1, 0].unwrap()], 7.0); - assert_eq!(result[coord![1, 1].unwrap()], 9.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicMatrix::new(&shape, &[3.0, 5.0, 7.0, 9.0]).unwrap()); } #[test] @@ -1255,12 +1120,9 @@ mod tests { let data2 = vec![1.0, 2.0, 3.0, 4.0]; let tensor = DynamicTensor::new(&shape, &data1).unwrap(); let matrix = DynamicMatrix::new(&shape, &data2).unwrap(); + let result = tensor - matrix; - assert_eq!(result[coord![0, 0].unwrap()], 1.0); - assert_eq!(result[coord![0, 1].unwrap()], 1.0); - assert_eq!(result[coord![1, 0].unwrap()], 1.0); - assert_eq!(result[coord![1, 1].unwrap()], 1.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicMatrix::new(&shape, &[1.0; 4]).unwrap()); } #[test] @@ -1270,12 +1132,9 @@ mod tests { let data2 = vec![1.0, 2.0, 3.0, 4.0]; let tensor = DynamicTensor::new(&shape, &data1).unwrap(); let matrix = DynamicMatrix::new(&shape, &data2).unwrap(); + let result = tensor * matrix; - assert_eq!(result[coord![0, 0].unwrap()], 2.0); - assert_eq!(result[coord![0, 1].unwrap()], 6.0); - assert_eq!(result[coord![1, 0].unwrap()], 12.0); - assert_eq!(result[coord![1, 1].unwrap()], 20.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicMatrix::new(&shape, &[2.0, 6.0, 12.0, 20.0]).unwrap()); } #[test] @@ -1285,12 +1144,9 @@ mod tests { let data2 = vec![1.0, 2.0, 3.0, 4.0]; let tensor = DynamicTensor::new(&shape, &data1).unwrap(); let matrix = DynamicMatrix::new(&shape, &data2).unwrap(); + let result = tensor / matrix; - assert_eq!(result[coord![0, 0].unwrap()], 2.0); - assert_eq!(result[coord![0, 1].unwrap()], 2.0); - assert_eq!(result[coord![1, 0].unwrap()], 2.0); - assert_eq!(result[coord![1, 1].unwrap()], 2.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicMatrix::new(&shape, &[2.0; 4]).unwrap()); } #[test] @@ -1337,8 +1193,7 @@ mod tests { let data = vec![1.0, 2.0, 3.0, 4.0]; let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.pow(2.0); - assert_eq!(result.shape(), &shape); - assert_eq!(result.data, DynamicStorage::new(vec![1.0, 4.0, 9.0, 16.0])); + assert_eq!(result, DynamicTensor::new(&shape, &[1.0, 4.0, 9.0, 16.0]).unwrap()); } #[test] @@ -1347,8 +1202,7 @@ mod tests { let data = vec![1.0, 4.0, 9.0, 16.0]; let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.pow(0.5); - assert_eq!(result.shape(), &shape); - assert_eq!(result.data, DynamicStorage::new(vec![1.0, 2.0, 3.0, 4.0])); + assert_eq!(result, DynamicTensor::new(&shape, &[1.0, 2.0, 3.0, 4.0]).unwrap()); } #[test] @@ -1357,10 +1211,6 @@ mod tests { let data = vec![1.0, 2.0, 4.0, 8.0]; let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.pow(-1.0); - assert_eq!(result.shape(), &shape); - assert_eq!( - result.data, - DynamicStorage::new(vec![1.0, 0.5, 0.25, 0.125]) - ); + assert_eq!(result, DynamicTensor::new(&shape, &[1.0, 0.5, 0.25, 0.125]).unwrap()); } } diff --git a/src/vector.rs b/src/vector.rs index ff2b939..27eb526 100644 --- a/src/vector.rs +++ b/src/vector.rs @@ -9,6 +9,7 @@ use crate::tensor::DynamicTensor; use num::Float; use num::Num; +#[derive(Debug, PartialEq)] pub struct DynamicVector { tensor: DynamicTensor, } @@ -251,11 +252,7 @@ mod tests { let data = vec![1.0, 2.0, 3.0, 4.0]; let tensor = DynamicTensor::new(&shape, &data).unwrap(); let vector = DynamicVector::from_tensor(tensor).unwrap(); - assert_eq!(vector.shape(), &shape); - assert_eq!(vector[0], 1.0); - assert_eq!(vector[1], 2.0); - assert_eq!(vector[2], 3.0); - assert_eq!(vector[3], 4.0); + assert_eq!(vector, DynamicVector::new(&data).unwrap()); } #[test] @@ -271,33 +268,21 @@ mod tests { fn test_fill() { let shape = shape![4].unwrap(); let vector = DynamicVector::fill(&shape, 3.0).unwrap(); - assert_eq!(vector.shape(), &shape); - assert_eq!(vector[0], 3.0); - assert_eq!(vector[1], 3.0); - assert_eq!(vector[2], 3.0); - assert_eq!(vector[3], 3.0); + assert_eq!(vector, DynamicVector::new(&[3.0; 4]).unwrap()); } #[test] fn test_zeros() { let shape = shape![4].unwrap(); let vector = DynamicVector::::zeros(&shape).unwrap(); - assert_eq!(vector.shape(), &shape); - assert_eq!(vector[0], 0.0); - assert_eq!(vector[1], 0.0); - assert_eq!(vector[2], 0.0); - assert_eq!(vector[3], 0.0); + assert_eq!(vector, DynamicVector::new(&[0.0; 4]).unwrap()); } #[test] fn test_ones() { let shape = shape![4].unwrap(); let vector = DynamicVector::::ones(&shape).unwrap(); - assert_eq!(vector.shape(), &shape); - assert_eq!(vector[0], 1.0); - assert_eq!(vector[1], 1.0); - assert_eq!(vector[2], 1.0); - assert_eq!(vector[3], 1.0); + assert_eq!(vector, DynamicVector::new(&[1.0; 4]).unwrap()); } #[test] @@ -335,8 +320,7 @@ mod tests { let data = vec![1.0, 2.0, 3.0, 4.0]; let vector = DynamicVector::new(&data).unwrap(); let result = vector.sum(); - assert_eq!(result[0], 10.0); - assert_eq!(result.shape(), &shape![1].unwrap()); + assert_eq!(result, DynamicVector::new(&[10.0]).unwrap()); } #[test] @@ -344,8 +328,7 @@ mod tests { let data = vec![1.0, 2.0, 3.0, 4.0]; let vector = DynamicVector::new(&data).unwrap(); let result = vector.mean(); - assert_eq!(result[0], 2.5); - assert_eq!(result.shape(), &shape![1].unwrap()); + assert_eq!(result, DynamicVector::new(&[2.5]).unwrap()); } #[test] @@ -353,8 +336,7 @@ mod tests { let data = vec![1.0, 2.0, 3.0, 4.0]; let vector = DynamicVector::new(&data).unwrap(); let result = vector.var(); - assert_eq!(result[0], 1.25); - assert_eq!(result.shape(), &shape![1].unwrap()); + assert_eq!(result, DynamicVector::new(&[1.25]).unwrap()); } #[test] @@ -362,8 +344,7 @@ mod tests { let data = vec![1.0, 2.0, 3.0, 4.0]; let vector = DynamicVector::new(&data).unwrap(); let result = vector.min(); - assert_eq!(result[0], 1.0); - assert_eq!(result.shape(), &shape![1].unwrap()); + assert_eq!(result, DynamicVector::new(&[1.0]).unwrap()); } #[test] @@ -371,8 +352,7 @@ mod tests { let data = vec![-1.0, -2.0, -3.0, -4.0]; let vector = DynamicVector::new(&data).unwrap(); let result = vector.max(); - assert_eq!(result[0], -1.0); - assert_eq!(result.shape(), &shape![1].unwrap()); + assert_eq!(result, DynamicVector::new(&[-1.0]).unwrap()); } #[test] @@ -382,8 +362,7 @@ mod tests { let vector1 = DynamicVector::new(&data1).unwrap(); let vector2 = DynamicVector::new(&data2).unwrap(); let result = vector1.vecmul(&vector2); - assert_eq!(result[0], 40.0); - assert_eq!(result.shape(), &shape![1].unwrap()); + assert_eq!(result, DynamicVector::new(&[40.0]).unwrap()); } #[test] @@ -393,9 +372,7 @@ mod tests { let vector = DynamicVector::new(&data_vector).unwrap(); let matrix = DynamicMatrix::new(&shape![2, 2].unwrap(), &data_matrix).unwrap(); let result = vector.matmul(&matrix); - assert_eq!(result.shape(), &shape![2].unwrap()); - assert_eq!(result[0], 7.0); - assert_eq!(result[1], 10.0); + assert_eq!(result, DynamicVector::new(&[7.0, 10.0]).unwrap()); } #[test] @@ -406,20 +383,10 @@ mod tests { let vector2 = DynamicVector::new(&data2).unwrap(); let result = vector1.prod(&vector2); - let expected_data = vec![ + let expected_tensor = DynamicTensor::new(&shape![4, 4].unwrap(), &[ 2.0, 3.0, 4.0, 5.0, 4.0, 6.0, 8.0, 10.0, 6.0, 9.0, 12.0, 15.0, 8.0, 12.0, 16.0, 20.0, - ]; - let expected_shape = shape![4, 4].unwrap(); - let expected_tensor = DynamicTensor::new(&expected_shape, &expected_data).unwrap(); - - assert_eq!(result.shape(), &expected_shape); - for i in 0..result.shape()[0] { - for j in 0..result.shape()[1] { - let x = result.get(&coord![i, j].unwrap()).unwrap(); - let y = expected_tensor.get(&coord![i, j].unwrap()).unwrap(); - assert_eq!(*x, *y); - } - } + ]).unwrap(); + assert_eq!(result, expected_tensor); } #[test] @@ -427,26 +394,17 @@ mod tests { let data = vec![1.0, 2.0, 3.0, 4.0]; let vector = DynamicVector::new(&data).unwrap(); let result = vector + 2.0; - assert_eq!(result[0], 3.0); - assert_eq!(result[1], 4.0); - assert_eq!(result[2], 5.0); - assert_eq!(result[3], 6.0); - assert_eq!(result.shape(), &shape![4].unwrap()); + assert_eq!(result, DynamicVector::new(&[3.0, 4.0, 5.0, 6.0]).unwrap()); } #[test] fn test_add_vector() { - let shape = shape![4].unwrap(); let data1 = vec![1.0, 2.0, 3.0, 4.0]; let data2 = vec![2.0, 3.0, 4.0, 5.0]; let vector1 = DynamicVector::new(&data1).unwrap(); let vector2 = DynamicVector::new(&data2).unwrap(); let result = vector1 + vector2; - assert_eq!(result[0], 3.0); - assert_eq!(result[1], 5.0); - assert_eq!(result[2], 7.0); - assert_eq!(result[3], 9.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicVector::new(&[3.0, 5.0, 7.0, 9.0]).unwrap()); } #[test] @@ -457,39 +415,25 @@ mod tests { let vector = DynamicVector::new(&data1).unwrap(); let tensor = DynamicTensor::new(&shape, &data2).unwrap(); let result = vector + tensor; - assert_eq!(result[0], 3.0); - assert_eq!(result[1], 5.0); - assert_eq!(result[2], 7.0); - assert_eq!(result[3], 9.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicVector::new(&[3.0, 5.0, 7.0, 9.0]).unwrap()); } #[test] fn test_sub_scalar() { - let shape = shape![4].unwrap(); let data = vec![1.0, 2.0, 3.0, 4.0]; let vector = DynamicVector::new(&data).unwrap(); let result = vector - 2.0; - assert_eq!(result[0], -1.0); - assert_eq!(result[1], 0.0); - assert_eq!(result[2], 1.0); - assert_eq!(result[3], 2.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicVector::new(&[-1.0, 0.0, 1.0, 2.0]).unwrap()); } #[test] fn test_sub_vector() { - let shape = shape![4].unwrap(); let data1 = vec![1.0, 2.0, 3.0, 4.0]; let data2 = vec![2.0, 3.0, 4.0, 5.0]; let vector1 = DynamicVector::new(&data1).unwrap(); let vector2 = DynamicVector::new(&data2).unwrap(); let result = vector1 - vector2; - assert_eq!(result[0], -1.0); - assert_eq!(result[1], -1.0); - assert_eq!(result[2], -1.0); - assert_eq!(result[3], -1.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicVector::new(&[-1.0; 4]).unwrap()); } #[test] @@ -500,39 +444,25 @@ mod tests { let vector = DynamicVector::new(&data1).unwrap(); let tensor = DynamicTensor::new(&shape, &data2).unwrap(); let result = vector - tensor; - assert_eq!(result[0], -1.0); - assert_eq!(result[1], -1.0); - assert_eq!(result[2], -1.0); - assert_eq!(result[3], -1.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicVector::new(&[-1.0; 4]).unwrap()); } #[test] fn test_mul_scalar() { - let shape = shape![4].unwrap(); let data = vec![1.0, 2.0, 3.0, 4.0]; let vector = DynamicVector::new(&data).unwrap(); let result = vector * 2.0; - assert_eq!(result[0], 2.0); - assert_eq!(result[1], 4.0); - assert_eq!(result[2], 6.0); - assert_eq!(result[3], 8.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicVector::new(&[2.0, 4.0, 6.0, 8.0]).unwrap()); } #[test] fn test_mul_vector() { - let shape = shape![4].unwrap(); let data1 = vec![1.0, 2.0, 3.0, 4.0]; let data2 = vec![2.0, 3.0, 4.0, 5.0]; let vector1 = DynamicVector::new(&data1).unwrap(); let vector2 = DynamicVector::new(&data2).unwrap(); let result = vector1 * vector2; - assert_eq!(result[0], 2.0); - assert_eq!(result[1], 6.0); - assert_eq!(result[2], 12.0); - assert_eq!(result[3], 20.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicVector::new(&[2.0, 6.0, 12.0, 20.0]).unwrap()); } #[test] @@ -543,39 +473,25 @@ mod tests { let vector = DynamicVector::new(&data1).unwrap(); let tensor = DynamicTensor::new(&shape, &data2).unwrap(); let result = vector * tensor; - assert_eq!(result[0], 2.0); - assert_eq!(result[1], 6.0); - assert_eq!(result[2], 12.0); - assert_eq!(result[3], 20.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicVector::new(&[2.0, 6.0, 12.0, 20.0]).unwrap()); } #[test] fn test_div_scalar() { - let shape = shape![4].unwrap(); let data = vec![4.0, 6.0, 8.0, 10.0]; let vector = DynamicVector::new(&data).unwrap(); let result = vector / 2.0; - assert_eq!(result[0], 2.0); - assert_eq!(result[1], 3.0); - assert_eq!(result[2], 4.0); - assert_eq!(result[3], 5.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicVector::new(&[2.0, 3.0, 4.0, 5.0]).unwrap()); } #[test] fn test_div_vector() { - let shape = shape![4].unwrap(); let data1 = vec![4.0, 6.0, 8.0, 10.0]; let data2 = vec![2.0, 3.0, 4.0, 5.0]; let vector1 = DynamicVector::new(&data1).unwrap(); let vector2 = DynamicVector::new(&data2).unwrap(); let result = vector1 / vector2; - assert_eq!(result[0], 2.0); - assert_eq!(result[1], 2.0); - assert_eq!(result[2], 2.0); - assert_eq!(result[3], 2.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicVector::new(&[2.0; 4]).unwrap()); } #[test] @@ -586,23 +502,14 @@ mod tests { let vector = DynamicVector::new(&data1).unwrap(); let tensor = DynamicTensor::new(&shape, &data2).unwrap(); let result = vector / tensor; - assert_eq!(result[0], 2.0); - assert_eq!(result[1], 2.0); - assert_eq!(result[2], 2.0); - assert_eq!(result[3], 2.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicVector::new(&[2.0; 4]).unwrap()); } #[test] fn test_pow_vector() { - let shape = shape![4].unwrap(); let data = vec![2.0, 3.0, 4.0, 5.0]; let vector = DynamicVector::new(&data).unwrap(); let result = vector.pow(2.0); - assert_eq!(result[0], 4.0); - assert_eq!(result[1], 9.0); - assert_eq!(result[2], 16.0); - assert_eq!(result[3], 25.0); - assert_eq!(result.shape(), &shape); + assert_eq!(result, DynamicVector::new(&[4.0, 9.0, 16.0, 25.0]).unwrap()); } } From 18327e49512ef41c252f17d76561f0eb5c3fdeff Mon Sep 17 00:00:00 2001 From: Angus Stewart Date: Sun, 30 Jun 2024 22:59:08 +0100 Subject: [PATCH 2/2] style: fmt fixes --- src/matrix.rs | 60 +++++++++++--- src/tensor.rs | 212 ++++++++++++++++++++++++++++++++++++++++---------- src/vector.rs | 11 ++- 3 files changed, 228 insertions(+), 55 deletions(-) diff --git a/src/matrix.rs b/src/matrix.rs index 92fa556..47dd097 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -287,7 +287,10 @@ mod tests { fn test_eye() { let shape = shape![3, 3].unwrap(); let matrix = DynamicMatrix::::eye(&shape).unwrap(); - assert_eq!(matrix, DynamicMatrix::new(&shape, &[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]).unwrap()); + assert_eq!( + matrix, + DynamicMatrix::new(&shape, &[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]).unwrap() + ); } #[test] @@ -325,7 +328,10 @@ mod tests { let data = vec![1.0, 2.0, 3.0, 4.0]; let mut matrix = DynamicMatrix::new(&shape, &data).unwrap(); matrix[coord![1, 0].unwrap()] = 5.0; - assert_eq!(matrix, DynamicMatrix::new(&shape, &[1.0, 2.0, 5.0, 4.0]).unwrap()); + assert_eq!( + matrix, + DynamicMatrix::new(&shape, &[1.0, 2.0, 5.0, 4.0]).unwrap() + ); } #[test] @@ -334,7 +340,10 @@ mod tests { let data = vec![1.0, 2.0, 3.0, 4.0]; let mut matrix = DynamicMatrix::new(&shape, &data).unwrap(); matrix.set(&coord![1, 0].unwrap(), 5.0).unwrap(); - assert_eq!(matrix, DynamicMatrix::new(&shape, &[1.0, 2.0, 5.0, 4.0]).unwrap()); + assert_eq!( + matrix, + DynamicMatrix::new(&shape, &[1.0, 2.0, 5.0, 4.0]).unwrap() + ); } #[test] @@ -390,7 +399,10 @@ mod tests { let matrix1 = DynamicMatrix::new(&shape, &data1).unwrap(); let matrix2 = DynamicMatrix::new(&shape, &data2).unwrap(); let result = matrix1.matmul(&matrix2); - assert_eq!(result, DynamicMatrix::new(&shape, &[10.0, 13.0, 22.0, 29.0]).unwrap()); + assert_eq!( + result, + DynamicMatrix::new(&shape, &[10.0, 13.0, 22.0, 29.0]).unwrap() + ); } #[test] @@ -410,7 +422,10 @@ mod tests { let data = vec![1.0, 2.0, 3.0, 4.0]; let matrix = DynamicMatrix::new(&shape, &data).unwrap(); let result = matrix + 2.0; - assert_eq!(result, DynamicMatrix::new(&shape, &[3.0, 4.0, 5.0, 6.0]).unwrap()); + assert_eq!( + result, + DynamicMatrix::new(&shape, &[3.0, 4.0, 5.0, 6.0]).unwrap() + ); } #[test] @@ -421,7 +436,10 @@ mod tests { let matrix1 = DynamicMatrix::new(&shape, &data1).unwrap(); let matrix2 = DynamicMatrix::new(&shape, &data2).unwrap(); let result = matrix1 + matrix2; - assert_eq!(result, DynamicMatrix::new(&shape, &[3.0, 5.0, 7.0, 9.0]).unwrap()); + assert_eq!( + result, + DynamicMatrix::new(&shape, &[3.0, 5.0, 7.0, 9.0]).unwrap() + ); } #[test] @@ -432,7 +450,10 @@ mod tests { let matrix = DynamicMatrix::new(&shape, &data1).unwrap(); let tensor = DynamicTensor::new(&shape, &data2).unwrap(); let result = matrix + tensor; - assert_eq!(result, DynamicMatrix::new(&shape, &[3.0, 5.0, 7.0, 9.0]).unwrap()); + assert_eq!( + result, + DynamicMatrix::new(&shape, &[3.0, 5.0, 7.0, 9.0]).unwrap() + ); } #[test] @@ -441,7 +462,10 @@ mod tests { let data = vec![1.0, 2.0, 3.0, 4.0]; let matrix = DynamicMatrix::new(&shape, &data).unwrap(); let result = matrix - 2.0; - assert_eq!(result, DynamicMatrix::new(&shape, &[-1.0, 0.0, 1.0, 2.0]).unwrap()); + assert_eq!( + result, + DynamicMatrix::new(&shape, &[-1.0, 0.0, 1.0, 2.0]).unwrap() + ); } #[test] @@ -472,7 +496,10 @@ mod tests { let data = vec![1.0, 2.0, 3.0, 4.0]; let matrix = DynamicMatrix::new(&shape, &data).unwrap(); let result = matrix * 2.0; - assert_eq!(result, DynamicMatrix::new(&shape, &[2.0, 4.0, 6.0, 8.0]).unwrap()); + assert_eq!( + result, + DynamicMatrix::new(&shape, &[2.0, 4.0, 6.0, 8.0]).unwrap() + ); } #[test] @@ -483,7 +510,10 @@ mod tests { let matrix1 = DynamicMatrix::new(&shape, &data1).unwrap(); let matrix2 = DynamicMatrix::new(&shape, &data2).unwrap(); let result = matrix1 * matrix2; - assert_eq!(result, DynamicMatrix::new(&shape, &[2.0, 6.0, 12.0, 20.0]).unwrap()); + assert_eq!( + result, + DynamicMatrix::new(&shape, &[2.0, 6.0, 12.0, 20.0]).unwrap() + ); } #[test] @@ -494,7 +524,10 @@ mod tests { let matrix = DynamicMatrix::new(&shape, &data1).unwrap(); let tensor = DynamicTensor::new(&shape, &data2).unwrap(); let result = matrix * tensor; - assert_eq!(result, DynamicMatrix::new(&shape, &[2.0, 6.0, 12.0, 20.0]).unwrap()); + assert_eq!( + result, + DynamicMatrix::new(&shape, &[2.0, 6.0, 12.0, 20.0]).unwrap() + ); } #[test] @@ -503,7 +536,10 @@ mod tests { let data = vec![4.0, 6.0, 8.0, 10.0]; let matrix = DynamicMatrix::new(&shape, &data).unwrap(); let result = matrix / 2.0; - assert_eq!(result, DynamicMatrix::new(&shape, &[2.0, 3.0, 4.0, 5.0]).unwrap()); + assert_eq!( + result, + DynamicMatrix::new(&shape, &[2.0, 3.0, 4.0, 5.0]).unwrap() + ); } #[test] diff --git a/src/tensor.rs b/src/tensor.rs index 3de9f73..5944e9d 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -632,7 +632,10 @@ mod tests { tensor.set(&coord![0, 1].unwrap(), 6.0).unwrap(); tensor.set(&coord![1, 0].unwrap(), 7.0).unwrap(); tensor.set(&coord![1, 1].unwrap(), 8.0).unwrap(); - assert_eq!(tensor, DynamicTensor::new(&shape, &[5.0, 6.0, 7.0, 8.0]).unwrap()); + assert_eq!( + tensor, + DynamicTensor::new(&shape, &[5.0, 6.0, 7.0, 8.0]).unwrap() + ); } #[test] @@ -664,7 +667,10 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.sum(vec![]); - assert_eq!(result, DynamicTensor::new(&shape![1].unwrap(), &[15.0]).unwrap()); + assert_eq!( + result, + DynamicTensor::new(&shape![1].unwrap(), &[15.0]).unwrap() + ); } #[test] @@ -674,7 +680,10 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.sum(vec![]); - assert_eq!(result, DynamicTensor::new(&shape![1].unwrap(), &[21.0]).unwrap()); + assert_eq!( + result, + DynamicTensor::new(&shape![1].unwrap(), &[21.0]).unwrap() + ); } #[test] @@ -686,7 +695,10 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.sum(vec![]); - assert_eq!(result, DynamicTensor::new(&shape![1].unwrap(), &[78.0]).unwrap()); + assert_eq!( + result, + DynamicTensor::new(&shape![1].unwrap(), &[78.0]).unwrap() + ); } #[test] @@ -696,7 +708,10 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.sum(vec![0]); - assert_eq!(result, DynamicTensor::new(&shape![1].unwrap(), &[15.0]).unwrap()); + assert_eq!( + result, + DynamicTensor::new(&shape![1].unwrap(), &[15.0]).unwrap() + ); } #[test] @@ -706,7 +721,10 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.sum(vec![0]); - assert_eq!(result, DynamicTensor::new(&shape![3].unwrap(), &[5.0, 7.0, 9.0]).unwrap()); + assert_eq!( + result, + DynamicTensor::new(&shape![3].unwrap(), &[5.0, 7.0, 9.0]).unwrap() + ); } #[test] @@ -718,7 +736,11 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.sum(vec![0]); - assert_eq!(result, DynamicTensor::new(&shape![2, 3].unwrap(), &[8.0, 10.0, 12.0, 14.0, 16.0, 18.0]).unwrap()); + assert_eq!( + result, + DynamicTensor::new(&shape![2, 3].unwrap(), &[8.0, 10.0, 12.0, 14.0, 16.0, 18.0]) + .unwrap() + ); } #[test] @@ -728,7 +750,10 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.sum(vec![0, 1]); - assert_eq!(result, DynamicTensor::new(&shape![1].unwrap(), &[21.0]).unwrap()); + assert_eq!( + result, + DynamicTensor::new(&shape![1].unwrap(), &[21.0]).unwrap() + ); } #[test] @@ -740,7 +765,10 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.sum(vec![0, 1]); - assert_eq!(result, DynamicTensor::new(&shape![3].unwrap(), &[22.0, 26.0, 30.0]).unwrap()); + assert_eq!( + result, + DynamicTensor::new(&shape![3].unwrap(), &[22.0, 26.0, 30.0]).unwrap() + ); } #[test] @@ -750,7 +778,10 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.mean(vec![]); - assert_eq!(result, DynamicTensor::new(&shape![1].unwrap(), &[3.5]).unwrap()); + assert_eq!( + result, + DynamicTensor::new(&shape![1].unwrap(), &[3.5]).unwrap() + ); } #[test] @@ -760,7 +791,10 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.mean(vec![0]); - assert_eq!(result, DynamicTensor::new(&shape![3].unwrap(), &[2.5, 3.5, 4.5]).unwrap()); + assert_eq!( + result, + DynamicTensor::new(&shape![3].unwrap(), &[2.5, 3.5, 4.5]).unwrap() + ); } #[test] @@ -772,7 +806,10 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.mean(vec![0]); - assert_eq!(result, DynamicTensor::new(&shape![2, 3].unwrap(), &[4.0, 5.0, 6.0, 7.0, 8.0, 9.0]).unwrap()); + assert_eq!( + result, + DynamicTensor::new(&shape![2, 3].unwrap(), &[4.0, 5.0, 6.0, 7.0, 8.0, 9.0]).unwrap() + ); } #[test] @@ -782,7 +819,10 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.mean(vec![0, 1]); - assert_eq!(result, DynamicTensor::new(&shape![1].unwrap(), &[3.5]).unwrap()); + assert_eq!( + result, + DynamicTensor::new(&shape![1].unwrap(), &[3.5]).unwrap() + ); } #[test] @@ -794,7 +834,10 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.mean(vec![0, 1]); - assert_eq!(result, DynamicTensor::new(&shape![3].unwrap(), &[5.5, 6.5, 7.5]).unwrap()); + assert_eq!( + result, + DynamicTensor::new(&shape![3].unwrap(), &[5.5, 6.5, 7.5]).unwrap() + ); } #[test] @@ -804,7 +847,10 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.var(vec![]); - assert_eq!(result, DynamicTensor::new(&shape![1].unwrap(), &[9.0]).unwrap()); + assert_eq!( + result, + DynamicTensor::new(&shape![1].unwrap(), &[9.0]).unwrap() + ); } #[test] @@ -814,7 +860,10 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.var(vec![0]); - assert_eq!(result, DynamicTensor::new(&shape![3].unwrap(), &[2.25; 3]).unwrap()); + assert_eq!( + result, + DynamicTensor::new(&shape![3].unwrap(), &[2.25; 3]).unwrap() + ); } #[test] @@ -826,7 +875,10 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.var(vec![0]); - assert_eq!(result, DynamicTensor::new(&shape![2, 3].unwrap(), &[9.0; 6]).unwrap()); + assert_eq!( + result, + DynamicTensor::new(&shape![2, 3].unwrap(), &[9.0; 6]).unwrap() + ); } #[test] @@ -836,7 +888,10 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.var(vec![0, 1]); - assert_eq!(result, DynamicTensor::new(&shape![1].unwrap(), &[9.0]).unwrap()); + assert_eq!( + result, + DynamicTensor::new(&shape![1].unwrap(), &[9.0]).unwrap() + ); } #[test] @@ -848,7 +903,10 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.var(vec![0, 1]); - assert_eq!(result, DynamicTensor::new(&shape![3].unwrap(), &[45.0; 3]).unwrap()); + assert_eq!( + result, + DynamicTensor::new(&shape![3].unwrap(), &[45.0; 3]).unwrap() + ); } #[test] @@ -858,7 +916,10 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.max(vec![]); - assert_eq!(result, DynamicTensor::new(&shape![1].unwrap(), &[5.0]).unwrap()); + assert_eq!( + result, + DynamicTensor::new(&shape![1].unwrap(), &[5.0]).unwrap() + ); } #[test] @@ -868,7 +929,10 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.max(vec![0]); - assert_eq!(result, DynamicTensor::new(&shape![3].unwrap(), &[1.0, 5.0, 3.0]).unwrap()); + assert_eq!( + result, + DynamicTensor::new(&shape![3].unwrap(), &[1.0, 5.0, 3.0]).unwrap() + ); } #[test] @@ -880,7 +944,10 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.max(vec![0, 1]); - assert_eq!(result, DynamicTensor::new(&shape![3].unwrap(), &[7.0, 11.0, 9.0]).unwrap()); + assert_eq!( + result, + DynamicTensor::new(&shape![3].unwrap(), &[7.0, 11.0, 9.0]).unwrap() + ); } #[test] @@ -890,7 +957,10 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.min(vec![]); - assert_eq!(result, DynamicTensor::new(&shape![1].unwrap(), &[-4.0]).unwrap()); + assert_eq!( + result, + DynamicTensor::new(&shape![1].unwrap(), &[-4.0]).unwrap() + ); } #[test] @@ -900,7 +970,10 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.min(vec![0]); - assert_eq!(result, DynamicTensor::new(&shape![3].unwrap(), &[-4.0, -2.0, -6.0]).unwrap()); + assert_eq!( + result, + DynamicTensor::new(&shape![3].unwrap(), &[-4.0, -2.0, -6.0]).unwrap() + ); } #[test] @@ -912,7 +985,10 @@ mod tests { let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.min(vec![0, 1]); - assert_eq!(result, DynamicTensor::new(&shape![3].unwrap(), &[-10.0, -8.0, -12.0]).unwrap()); + assert_eq!( + result, + DynamicTensor::new(&shape![3].unwrap(), &[-10.0, -8.0, -12.0]).unwrap() + ); } #[test] @@ -926,7 +1002,10 @@ mod tests { let tensor2 = Tensor::new(&shape2, &data2).unwrap(); let result = tensor1.prod(&tensor2); - assert_eq!(result, DynamicTensor::new(&shape![3, 2].unwrap(), &[4.0, 5.0, 8.0, 10.0, 12.0, 15.0]).unwrap()); + assert_eq!( + result, + DynamicTensor::new(&shape![3, 2].unwrap(), &[4.0, 5.0, 8.0, 10.0, 12.0, 15.0]).unwrap() + ); } #[test] @@ -940,7 +1019,14 @@ mod tests { let tensor2 = Tensor::new(&shape2, &data2).unwrap(); let result = tensor1.prod(&tensor2); - assert_eq!(result, DynamicTensor::new(&shape![2, 2, 2].unwrap(), &[5.0, 6.0, 10.0, 12.0, 15.0, 18.0, 20.0, 24.0]).unwrap()); + assert_eq!( + result, + DynamicTensor::new( + &shape![2, 2, 2].unwrap(), + &[5.0, 6.0, 10.0, 12.0, 15.0, 18.0, 20.0, 24.0] + ) + .unwrap() + ); } #[test] @@ -952,7 +1038,17 @@ mod tests { let tensor2 = Tensor::new(&shape, &data2).unwrap(); let result = tensor1.prod(&tensor2); - assert_eq!(result, DynamicTensor::new(&shape![2, 2, 2, 2].unwrap(), &[5.0, 6.0, 7.0, 8.0, 10.0, 12.0, 14.0, 16.0, 15.0, 18.0, 21.0, 24.0, 20.0, 24.0, 28.0, 32.0]).unwrap()); + assert_eq!( + result, + DynamicTensor::new( + &shape![2, 2, 2, 2].unwrap(), + &[ + 5.0, 6.0, 7.0, 8.0, 10.0, 12.0, 14.0, 16.0, 15.0, 18.0, 21.0, 24.0, 20.0, 24.0, + 28.0, 32.0 + ] + ) + .unwrap() + ); } #[test] @@ -962,7 +1058,10 @@ mod tests { let tensor1 = Tensor::new(&shape, &data1).unwrap(); let result = tensor1 + 3.0; - assert_eq!(result, DynamicTensor::new(&shape, &[4.0, 5.0, 6.0, 7.0]).unwrap()); + assert_eq!( + result, + DynamicTensor::new(&shape, &[4.0, 5.0, 6.0, 7.0]).unwrap() + ); } #[test] @@ -974,7 +1073,10 @@ mod tests { let tensor2 = Tensor::new(&shape, &data2).unwrap(); let result = tensor1 + tensor2; - assert_eq!(result, DynamicTensor::new(&shape, &[6.0, 8.0, 10.0, 12.0]).unwrap()); + assert_eq!( + result, + DynamicTensor::new(&shape, &[6.0, 8.0, 10.0, 12.0]).unwrap() + ); } #[test] @@ -984,7 +1086,10 @@ mod tests { let tensor1 = Tensor::new(&shape, &data1).unwrap(); let result = tensor1 - 3.0; - assert_eq!(result, DynamicTensor::new(&shape, &[2.0, 3.0, 4.0, 5.0]).unwrap()); + assert_eq!( + result, + DynamicTensor::new(&shape, &[2.0, 3.0, 4.0, 5.0]).unwrap() + ); } #[test] @@ -1006,7 +1111,10 @@ mod tests { let tensor1 = Tensor::new(&shape, &data1).unwrap(); let result = tensor1 * 2.0; - assert_eq!(result, DynamicTensor::new(&shape, &[2.0, 4.0, 6.0, 8.0]).unwrap()); + assert_eq!( + result, + DynamicTensor::new(&shape, &[2.0, 4.0, 6.0, 8.0]).unwrap() + ); } #[test] @@ -1016,7 +1124,10 @@ mod tests { let tensor1 = Tensor::new(&shape, &data1).unwrap(); let result = tensor1 / 2.0; - assert_eq!(result, DynamicTensor::new(&shape, &[2.0, 3.0, 4.0, 5.0]).unwrap()); + assert_eq!( + result, + DynamicTensor::new(&shape, &[2.0, 3.0, 4.0, 5.0]).unwrap() + ); } #[test] @@ -1038,7 +1149,10 @@ mod tests { let tensor2 = Tensor::new(&shape, &data2).unwrap(); let result = tensor1 * tensor2; - assert_eq!(result, DynamicTensor::new(&shape, &[2.0, 6.0, 12.0, 20.0]).unwrap()); + assert_eq!( + result, + DynamicTensor::new(&shape, &[2.0, 6.0, 12.0, 20.0]).unwrap() + ); } #[test] @@ -1050,7 +1164,10 @@ mod tests { let tensor2 = Tensor::new(&shape, &data2).unwrap(); let result = tensor1 * tensor2; - assert_eq!(result, DynamicTensor::new(&shape, &[7.0, 16.0, 27.0, 40.0, 55.0, 72.0]).unwrap()); + assert_eq!( + result, + DynamicTensor::new(&shape, &[7.0, 16.0, 27.0, 40.0, 55.0, 72.0]).unwrap() + ); } #[test] @@ -1110,7 +1227,10 @@ mod tests { let matrix = DynamicMatrix::new(&shape, &data2).unwrap(); let result = tensor + matrix; - assert_eq!(result, DynamicMatrix::new(&shape, &[3.0, 5.0, 7.0, 9.0]).unwrap()); + assert_eq!( + result, + DynamicMatrix::new(&shape, &[3.0, 5.0, 7.0, 9.0]).unwrap() + ); } #[test] @@ -1134,7 +1254,10 @@ mod tests { let matrix = DynamicMatrix::new(&shape, &data2).unwrap(); let result = tensor * matrix; - assert_eq!(result, DynamicMatrix::new(&shape, &[2.0, 6.0, 12.0, 20.0]).unwrap()); + assert_eq!( + result, + DynamicMatrix::new(&shape, &[2.0, 6.0, 12.0, 20.0]).unwrap() + ); } #[test] @@ -1193,7 +1316,10 @@ mod tests { let data = vec![1.0, 2.0, 3.0, 4.0]; let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.pow(2.0); - assert_eq!(result, DynamicTensor::new(&shape, &[1.0, 4.0, 9.0, 16.0]).unwrap()); + assert_eq!( + result, + DynamicTensor::new(&shape, &[1.0, 4.0, 9.0, 16.0]).unwrap() + ); } #[test] @@ -1202,7 +1328,10 @@ mod tests { let data = vec![1.0, 4.0, 9.0, 16.0]; let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.pow(0.5); - assert_eq!(result, DynamicTensor::new(&shape, &[1.0, 2.0, 3.0, 4.0]).unwrap()); + assert_eq!( + result, + DynamicTensor::new(&shape, &[1.0, 2.0, 3.0, 4.0]).unwrap() + ); } #[test] @@ -1211,6 +1340,9 @@ mod tests { let data = vec![1.0, 2.0, 4.0, 8.0]; let tensor = Tensor::new(&shape, &data).unwrap(); let result = tensor.pow(-1.0); - assert_eq!(result, DynamicTensor::new(&shape, &[1.0, 0.5, 0.25, 0.125]).unwrap()); + assert_eq!( + result, + DynamicTensor::new(&shape, &[1.0, 0.5, 0.25, 0.125]).unwrap() + ); } } diff --git a/src/vector.rs b/src/vector.rs index 27eb526..200f97c 100644 --- a/src/vector.rs +++ b/src/vector.rs @@ -383,9 +383,14 @@ mod tests { let vector2 = DynamicVector::new(&data2).unwrap(); let result = vector1.prod(&vector2); - let expected_tensor = DynamicTensor::new(&shape![4, 4].unwrap(), &[ - 2.0, 3.0, 4.0, 5.0, 4.0, 6.0, 8.0, 10.0, 6.0, 9.0, 12.0, 15.0, 8.0, 12.0, 16.0, 20.0, - ]).unwrap(); + let expected_tensor = DynamicTensor::new( + &shape![4, 4].unwrap(), + &[ + 2.0, 3.0, 4.0, 5.0, 4.0, 6.0, 8.0, 10.0, 6.0, 9.0, 12.0, 15.0, 8.0, 12.0, 16.0, + 20.0, + ], + ) + .unwrap(); assert_eq!(result, expected_tensor); }