Skip to content
This repository was archived by the owner on Jul 15, 2025. It is now read-only.

Commit 887fc19

Browse files
authored
Sparse tensor (#3)
1 parent 77aa450 commit 887fc19

32 files changed

+8376
-93
lines changed

ndarray/src/main/java/org/tensorflow/ndarray/NdArrays.java

Lines changed: 374 additions & 52 deletions
Large diffs are not rendered by default.
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
=======================================================================*/
15+
package org.tensorflow.ndarray;
16+
17+
/**
18+
* Interface for Sparse Arrays
19+
*
20+
* @param <T> the type that the array contains
21+
* @param <U> the type of dense NdArray
22+
*/
23+
public interface SparseNdArray<T, U extends NdArray<T>> extends NdArray<T> {
24+
/**
25+
* Gets the Indices
26+
*
27+
* <p>Indices are a A 2-D long array of shape {@code [N, ndims]}, that specifies the indices of
28+
* the elements in the sparse array that contain nonzero values (elements are zero-indexed).
29+
*
30+
* <p>For example, {@code indices=[[1,3], [2,4]]} specifies that the elements with indexes of
31+
* coordinates {@code [1,3]} and {@code [2,4]} have nonzero values.
32+
*
33+
* @return the Indices
34+
*/
35+
LongNdArray getIndices();
36+
37+
/**
38+
* Gets the values.
39+
*
40+
* <p>Values are a 1-D array of any type and shape {@code [N]}, that supplies the values for each
41+
* element in indices.
42+
*
43+
* <p>For example, given {@code indices=[[1,3], [2,4]]}, and {@code values=[18, 3.6]} specifies
44+
* that element {@code [1,3]} of the sparse array has a value of {@code 18}, and element {@code
45+
* [2,4]} of the sparse array has a value of {@code 3.6}.
46+
*
47+
* @return the values
48+
*/
49+
U getValues();
50+
}

ndarray/src/main/java/org/tensorflow/ndarray/impl/AbstractNdArray.java

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,38 @@
11
/*
2-
Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
Copyright 2019 The TensorFlow Authors. All Rights Reserved.
33
4-
Licensed under the Apache License, Version 2.0 (the "License");
5-
you may not use this file except in compliance with the License.
6-
You may obtain a copy of the License at
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
77
8-
http://www.apache.org/licenses/LICENSE-2.0
8+
http://www.apache.org/licenses/LICENSE-2.0
99
10-
Unless required by applicable law or agreed to in writing, software
11-
distributed under the License is distributed on an "AS IS" BASIS,
12-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13-
See the License for the specific language governing permissions and
14-
limitations under the License.
15-
=======================================================================
16-
*/
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
=======================================================================
16+
*/
1717
package org.tensorflow.ndarray.impl;
1818

19-
import java.util.Iterator;
2019
import org.tensorflow.ndarray.NdArray;
2120
import org.tensorflow.ndarray.NdArraySequence;
2221
import org.tensorflow.ndarray.Shape;
2322
import org.tensorflow.ndarray.impl.dimension.DimensionalSpace;
2423

24+
import java.util.Iterator;
25+
import java.util.Objects;
26+
2527
@SuppressWarnings("unchecked")
2628
public abstract class AbstractNdArray<T, U extends NdArray<T>> implements NdArray<T> {
2729

30+
protected final DimensionalSpace dimensions;
31+
32+
protected AbstractNdArray(DimensionalSpace dimensions) {
33+
this.dimensions = dimensions;
34+
}
35+
2836
public abstract U slice(long position, DimensionalSpace dimensions);
2937

3038
public DimensionalSpace dimensions() {
@@ -39,7 +47,7 @@ public Shape shape() {
3947
@Override
4048
public NdArraySequence<U> scalars() {
4149
// negative if this array is a scalar, should be handled in `elements(dimIdx)`
42-
return (NdArraySequence<U>)elements(shape().numDimensions() - 1);
50+
return (NdArraySequence<U>) elements(shape().numDimensions() - 1);
4351
}
4452

4553
@Override
@@ -55,11 +63,7 @@ public boolean equals(Object obj) {
5563
if (!(obj instanceof NdArray)) {
5664
return false;
5765
}
58-
return slowEquals((NdArray<?>)obj);
59-
}
60-
61-
protected AbstractNdArray(DimensionalSpace dimensions) {
62-
this.dimensions = dimensions;
66+
return slowEquals((NdArray<?>) obj);
6367
}
6468

6569
protected void slowCopyTo(NdArray<T> array) {
@@ -77,16 +81,19 @@ protected int slowHashCode() {
7781
}
7882

7983
protected boolean slowEquals(NdArray<?> array) {
80-
if (!shape().equals(array.shape())) { // this guarantees also that we have the same number of scalar values
84+
if (!shape()
85+
.equals(
86+
array.shape())) { // this guarantees also that we have the same number of scalar values
8187
return false;
8288
}
83-
for (Iterator<? extends NdArray<?>> thisIter = scalars().iterator(), otherIter = array.scalars().iterator(); thisIter.hasNext();) {
84-
if (!thisIter.next().getObject().equals(otherIter.next().getObject())) {
89+
for (Iterator<? extends NdArray<?>> thisIter = scalars().iterator(),
90+
otherIter = array.scalars().iterator();
91+
thisIter.hasNext(); ) {
92+
// Use Object.equals to handle nulls.
93+
if (!Objects.equals(thisIter.next().getObject(), otherIter.next().getObject())) {
8594
return false;
8695
}
8796
}
8897
return true;
8998
}
90-
91-
protected final DimensionalSpace dimensions;
9299
}

0 commit comments

Comments
 (0)