1- #include < tensor-array/core/tensor.hh>
2- #include < pybind11/pybind11.h>
3- #include < pybind11/numpy.h>
4- #include < pybind11/operators.h>
5-
6- using namespace tensor_array ::value;
7-
8- template <typename T>
9- TensorBase convert_numpy_to_tensor_base (pybind11::array_t <T> py_buf)
10- {
11- pybind11::buffer_info info = py_buf.request ();
12- std::vector<unsigned int > shape_vec (info.ndim );
13- std::transform
14- (
15- info.shape .cbegin (),
16- info.shape .cend (),
17- shape_vec.begin (),
18- [](pybind11::size_t dim)
19- {
20- return static_cast <unsigned int >(dim);
21- }
22- );
23- return TensorBase (typeid (T), shape_vec, info.ptr );
24- }
25-
26- std::string tensor_to_string (const Tensor t)
27- {
28- std::ostringstream osstream;
29- osstream << t;
30- return osstream.str ();
31- }
32-
33- PYBIND11_MODULE (tensor2, m)
34- {
35- pybind11::class_<Tensor>(m, " Tensor" )
36- .def (pybind11::init ())
37- .def (pybind11::init (&convert_numpy_to_tensor_base<float >))
38- .def (pybind11::self + pybind11::self)
39- .def (pybind11::self - pybind11::self)
40- .def (pybind11::self * pybind11::self)
41- .def (pybind11::self / pybind11::self)
42- .def (pybind11::self += pybind11::self)
43- .def (pybind11::self -= pybind11::self)
44- .def (pybind11::self *= pybind11::self)
45- .def (pybind11::self /= pybind11::self)
46- .def (pybind11::self == pybind11::self)
47- .def (pybind11::self != pybind11::self)
48- .def (pybind11::self >= pybind11::self)
49- .def (pybind11::self <= pybind11::self)
50- .def (pybind11::self > pybind11::self)
51- .def (pybind11::self < pybind11::self)
52- .def (+pybind11::self)
53- .def (-pybind11::self)
54- .def (" __matmul__" , &matmul)
55- .def (" __repr__" , &tensor_to_string);
1+ #include < tensor-array/core/tensor.hh>
2+ #include < pybind11/pybind11.h>
3+ #include < pybind11/numpy.h>
4+ #include < pybind11/operators.h>
5+
6+ using namespace tensor_array ::value;
7+
8+ template <typename T>
9+ TensorBase convert_numpy_to_tensor_base (pybind11::array_t <T> py_buf)
10+ {
11+ pybind11::buffer_info info = py_buf.request ();
12+ std::vector<unsigned int > shape_vec (info.ndim );
13+ std::transform
14+ (
15+ info.shape .cbegin (),
16+ info.shape .cend (),
17+ shape_vec.begin (),
18+ [](pybind11::size_t dim)
19+ {
20+ return static_cast <unsigned int >(dim);
21+ }
22+ );
23+ return TensorBase (typeid (T), shape_vec, info.ptr );
24+ }
25+
26+ pybind11::array convert_tensor_to_numpy (const Tensor& tensor)
27+ {
28+ const TensorBase& base_tensor = tensor.get_buffer ();
29+ std::vector<pybind11::size_t > shape_vec (base_tensor.shape ().size ());
30+ std::transform
31+ (
32+ base_tensor.shape ().begin (),
33+ base_tensor.shape ().end (),
34+ shape_vec.begin (),
35+ [](unsigned int dim)
36+ {
37+ return static_cast <pybind11::size_t >(dim);
38+ }
39+ );
40+ pybind11::array arr = pybind11::array ();
41+ return arr;
42+ }
43+
44+ Tensor python_tuple_slice (const Tensor& t, pybind11::tuple tuple_slice)
45+ {
46+ std::vector<Tensor::Slice> t_slices;
47+ for (size_t i = 0 ; i < tuple_slice.size (); i++)
48+ {
49+ ssize_t start, stop, step;
50+ ssize_t length;
51+ pybind11::slice py_slice = tuple_slice[i].cast <pybind11::slice>();
52+ if (!py_slice.compute (t.get_buffer ().shape ().begin ()[i], &start, &stop, &step, &length))
53+ throw std::runtime_error (" Invalid slice" );
54+ t_slices.insert
55+ (
56+ t_slices.begin () + i,
57+ Tensor::Slice
58+ {
59+ static_cast <int >(start),
60+ static_cast <int >(stop),
61+ static_cast <int >(step)
62+ }
63+ );
64+ }
65+
66+ #ifdef __GNUC__
67+ struct
68+ {
69+ const Tensor::Slice* it;
70+ std::size_t sz;
71+ } test;
72+ test.it = t_slices.data ();
73+ test.sz = t_slices.size ();
74+ std::initializer_list<Tensor::Slice>& t_slice_list = reinterpret_cast <std::initializer_list<Tensor::Slice>&>(test);
75+ #endif
76+ return t[t_slice_list];
77+ }
78+
79+ Tensor python_slice (const Tensor& t, pybind11::slice py_slice)
80+ {
81+ std::vector<Tensor::Slice> t_slices;
82+ ssize_t start, stop, step;
83+ ssize_t length;
84+ if (!py_slice.compute (t.get_buffer ().shape ().begin ()[0 ], &start, &stop, &step, &length))
85+ throw std::runtime_error (" Invalid slice" );
86+ return t
87+ [
88+ {
89+ Tensor::Slice
90+ {
91+ static_cast <int >(start),
92+ static_cast <int >(stop),
93+ static_cast <int >(step)
94+ }
95+ }
96+ ];
97+ }
98+
99+ Tensor python_index (const Tensor& t, unsigned int i)
100+ {
101+ return t[i];
102+ }
103+
104+ std::size_t python_len (const Tensor& t)
105+ {
106+ std::initializer_list<unsigned int > shape_list = t.get_buffer ().shape ();
107+ return shape_list.size () != 0 ? shape_list.begin ()[0 ]: 1U ;
108+ }
109+
110+ std::string tensor_to_string (const Tensor& t)
111+ {
112+ std::ostringstream osstream;
113+ osstream << t;
114+ return osstream.str ();
115+ }
116+
117+ PYBIND11_MODULE (tensor2, m)
118+ {
119+ pybind11::class_<Tensor>(m, " Tensor" )
120+ .def (pybind11::init ())
121+ .def (pybind11::init (&convert_numpy_to_tensor_base<float >))
122+ .def (pybind11::self + pybind11::self)
123+ .def (pybind11::self - pybind11::self)
124+ .def (pybind11::self * pybind11::self)
125+ .def (pybind11::self / pybind11::self)
126+ .def (pybind11::self += pybind11::self)
127+ .def (pybind11::self -= pybind11::self)
128+ .def (pybind11::self *= pybind11::self)
129+ .def (pybind11::self /= pybind11::self)
130+ .def (pybind11::self == pybind11::self)
131+ .def (pybind11::self != pybind11::self)
132+ .def (pybind11::self >= pybind11::self)
133+ .def (pybind11::self <= pybind11::self)
134+ .def (pybind11::self > pybind11::self)
135+ .def (pybind11::self < pybind11::self)
136+ .def (+pybind11::self)
137+ .def (-pybind11::self)
138+ .def (hash (pybind11::self))
139+ .def (" transpose" , &Tensor::transpose)
140+ .def (" calc_grad" , &Tensor::calc_grad)
141+ .def (" add" , &add)
142+ .def (" multiply" , &multiply)
143+ .def (" divide" , ÷)
144+ .def (" matmul" , &matmul)
145+ .def (" condition" , &condition)
146+ .def (" __getitem__" , &python_index)
147+ .def (" __getitem__" , &python_slice)
148+ .def (" __getitem__" , &python_tuple_slice)
149+ .def (" __len__" , &python_len)
150+ .def (" __matmul__" , &matmul)
151+ .def (" __repr__" , &tensor_to_string);
56152}
0 commit comments