Skip to content
This repository was archived by the owner on Jul 15, 2025. It is now read-only.

Commit 77aa450

Browse files
authored
Fix bug when slicing on a segmented dimension (#2)
1 parent b284d2b commit 77aa450

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/DimensionalSpace.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ public DimensionalSpace from(int dimensionStart) {
144144
throw new IndexOutOfBoundsException();
145145
}
146146
Dimension[] newDimensions = Arrays.copyOfRange(dimensions, dimensionStart, dimensions.length);
147-
if (segmentationIdx > dimensionStart) {
147+
if (segmentationIdx >= dimensionStart) {
148148
return new DimensionalSpace(newDimensions, segmentationIdx - dimensionStart);
149149
}
150150
return new DimensionalSpace(newDimensions);

ndarray/src/test/java/org/tensorflow/ndarray/NdArrayTestBase.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import java.nio.BufferUnderflowException;
3535
import org.junit.jupiter.api.Test;
3636
import org.tensorflow.ndarray.buffer.DataBuffer;
37+
import org.tensorflow.ndarray.index.Indices;
3738

3839
public abstract class NdArrayTestBase<T> {
3940

@@ -335,4 +336,26 @@ public void equalsAndHashCode() {
335336
assertNotEquals(array1, array4);
336337
assertNotEquals(array1.hashCode(), array4.hashCode());
337338
}
339+
340+
@Test
341+
public void iterateScalarsOnSegmentedElements() {
342+
NdArray<T> originalTensor = allocate(Shape.of(2, 3));
343+
344+
originalTensor
345+
.setObject(valueOf(0L), 0, 0)
346+
.setObject(valueOf(1L), 0, 1)
347+
.setObject(valueOf(2L), 0, 2)
348+
.setObject(valueOf(3L), 1, 0)
349+
.setObject(valueOf(4L), 1, 1)
350+
.setObject(valueOf(5L), 1, 2);
351+
352+
NdArray<T> slice = originalTensor.slice(Indices.all(), Indices.sliceFrom(1));
353+
assertEquals(Shape.of(2, 2), slice.shape());
354+
355+
slice.elements(0).forEachIndexed((eCoord, e) -> {
356+
e.scalars().forEachIndexed((sCoord, s) -> {
357+
assertEquals(valueOf((eCoord[0] * originalTensor.shape().get(1)) + sCoord[0] + 1), s.getObject());
358+
});
359+
});
360+
}
338361
}

0 commit comments

Comments
 (0)