1818
1919import static org .junit .jupiter .api .Assertions .assertEquals ;
2020import static org .junit .jupiter .api .Assertions .assertTrue ;
21+ import static org .junit .jupiter .api .Assertions .assertArrayEquals ;
2122
2223import org .junit .jupiter .api .Test ;
2324import org .tensorflow .ndarray .index .Indices ;
@@ -43,6 +44,122 @@ public void testNullConversions(){
4344 assertTrue (Indices .slice (null , null ).endMask (),
4445 "Passed null for slice end but didn't set end mask" );
4546 }
47+
48+ @ Test
49+ public void testIndices (){
50+
51+ String [][] indexData = new String [5 ][4 ];
52+ for (int i =0 ; i < 5 ; i ++)
53+ for (int j =0 ; j < 4 ; j ++)
54+ indexData [i ][j ] = "(" +j +", " +i +")" ;
55+
56+ NdArray <String > matrix2d = StdArrays .ndCopyOf (indexData );
57+ assertEquals (2 , matrix2d .rank ());
58+
59+ /*
60+ |(0, 0), (1, 0), (2, 0), (3, 0)|
61+ |(0, 1), (1, 1), (2, 1), (3, 1)|
62+ |(0, 2), (1, 2), (2, 2), (3, 2)|
63+ |(0, 3), (1, 3), (2, 3), (3, 3)|
64+ |(0, 4), (1, 4), (2, 4), (3, 4)|
65+ */
66+
67+ NdArray <String > same1 = matrix2d .slice (Indices .all ());
68+ String [][] same1j = StdArrays .array2dCopyOf (same1 , String .class );
69+ assertEquals (2 , same1 .rank ());
70+ assertEquals (same1 , matrix2d );
71+
72+ NdArray <String > same2 = matrix2d .slice (Indices .ellipsis ());
73+ String [][] same2j = StdArrays .array2dCopyOf (same2 , String .class );
74+ assertEquals (2 , same2 .rank ());
75+ assertEquals (matrix2d , same2 );
76+
77+ // All rows, column 1
78+ NdArray <String > same3 = matrix2d .slice (Indices .all (), Indices .at (1 ));
79+ assertEquals (1 , same3 .rank ());
80+ String [] same3j = StdArrays .array1dCopyOf (same3 , String .class );
81+ assertArrayEquals (new String [] { "(1, 0)" , "(1, 1)" , "(1, 2)" , "(1, 3)" , "(1, 4)" }, same3j );
82+
83+ // row 2, all columns
84+ NdArray <String > same4 = matrix2d .slice (Indices .at (2 ), Indices .all ());
85+ assertEquals (1 , same4 .rank ());
86+ String [] same4j = StdArrays .array1dCopyOf (same4 , String .class );
87+ assertArrayEquals (new String [] {"(0, 2)" , "(1, 2)" , "(2, 2)" , "(3, 2)" }, same4j );
88+ assertEquals (NdArrays .vectorOfObjects ("(0, 2)" , "(1, 2)" , "(2, 2)" , "(3, 2)" ), same4 );
89+
90+ // row 2, column 1
91+ NdArray <String > same5 = matrix2d .slice (Indices .at (2 ), Indices .at (1 ));
92+ assertEquals (0 , same5 .rank ());
93+ assertTrue (same5 .shape ().isScalar ());
94+ // Don't use an index
95+ String same5j = same5 .getObject ();
96+ assertEquals ("(1, 2)" , same5j );
97+
98+ // rows 1 to 2, all columns
99+ NdArray <String > same6 = matrix2d .slice (Indices .slice (1 ,3 ));
100+ assertEquals (2 , same6 .rank ());
101+ String [][] same6j = StdArrays .array2dCopyOf (same6 , String .class );
102+ assertArrayEquals (
103+ new String [][]
104+ {
105+ {"(0, 1)" , "(1, 1)" , "(2, 1)" , "(3, 1)" },
106+ {"(0, 2)" , "(1, 2)" , "(2, 2)" , "(3, 2)" }
107+ },
108+ same6j
109+ );
110+
111+ // Exception in thread "main" java.nio.BufferOverflowException
112+ // all rows, columns 1 to 2
113+ NdArray <String > same7 = matrix2d .slice (Indices .all (), Indices .slice (1 ,3 ));
114+ assertEquals (2 , same7 .rank ());
115+ assertEquals (Shape .of (5 ,2 ), same7 .shape ());
116+ assertEquals (10 , same7 .size ());
117+ NdArray <String > r7_0 = same7 .get (0 );
118+ NdArray <String > r7_1 = same7 .get (1 );
119+ NdArray <String > r7_2 = same7 .get (2 );
120+ NdArray <String > r7_3 = same7 .get (3 );
121+ NdArray <String > r7_4 = same7 .get (4 );
122+ assertEquals (1 , r7_0 .rank ());
123+ assertEquals (Shape .of (2 ), r7_0 .shape ());
124+ assertEquals (2 , r7_0 .size ());
125+ // TODO: I get a (0,0) which is not what I expected
126+ System .out .println (r7_0 .getObject ());
127+ //assertEquals("(1,0)", r7_0.getObject());
128+ assertEquals ( "(1, 0)" , r7_0 .getObject (0 ));
129+ assertEquals ( "(2, 0)" , r7_0 .getObject (1 ));
130+ assertEquals ( "(1, 1)" , r7_1 .getObject (0 ));
131+ assertEquals ( "(2, 1)" , r7_1 .getObject (1 ));
132+ assertEquals ( "(1, 2)" , r7_2 .getObject (0 ));
133+ assertEquals ( "(2, 2)" , r7_2 .getObject (1 ));
134+ assertEquals ( "(1, 3)" , r7_3 .getObject (0 ));
135+ assertEquals ( "(2, 3)" , r7_3 .getObject (1 ));
136+ assertEquals ( "(1, 4)" , r7_4 .getObject (0 ));
137+ assertEquals ( "(2, 4)" , r7_4 .getObject (1 ));
138+ String [][] expectedr7 = new String [][]
139+ {
140+ {"(1, 0)" , "(2, 0)" },
141+ {"(1, 1)" , "(2, 1)" },
142+ {"(1, 2)" , "(2, 2)" },
143+ {"(1, 3)" , "(2, 3)" },
144+ {"(1, 4)" , "(2, 4)" }
145+ };
146+ //String[][] lArray = new String[5][2];
147+ String [][] lArray = new String [5 ][];
148+ lArray [0 ] = new String [2 ];
149+ lArray [1 ] = new String [2 ];
150+ lArray [2 ] = new String [2 ];
151+ lArray [3 ] = new String [2 ];
152+ lArray [4 ] = new String [2 ];
153+ StdArrays .copyFrom (same7 , lArray );
154+ assertArrayEquals ( expectedr7 , lArray );
155+ String [][] same7j = StdArrays .array2dCopyOf (same7 , String .class );
156+ assertArrayEquals ( expectedr7 , same7j );
157+
158+ /*
159+ */
160+
161+ assertEquals (0 , 0 );
162+ }
46163
47164 @ Test
48165 public void testNewaxis (){
0 commit comments