@@ -186,9 +186,46 @@ PYBIND11_MODULE(tensor2, m)
186186 pybind11::arg (" dtype" ) = S_INT_32
187187 );
188188
189+ m.def (
190+ " add" ,
191+ &tensor_array::value::add,
192+ pybind11::arg (" value_1" ),
193+ pybind11::arg (" value_2" )
194+ );
195+
196+ m.def (
197+ " multiply" ,
198+ &tensor_array::value::multiply,
199+ pybind11::arg (" value_1" ),
200+ pybind11::arg (" value_2" )
201+ );
202+
203+ m.def (
204+ " divide" ,
205+ &tensor_array::value::divide,
206+ pybind11::arg (" value_1" ),
207+ pybind11::arg (" value_2" )
208+ );
209+
210+ m.def (
211+ " matmul" ,
212+ &tensor_array::value::matmul,
213+ pybind11::arg (" value_1" ),
214+ pybind11::arg (" value_2" )
215+ );
216+
217+ m.def (
218+ " condition" ,
219+ &tensor_array::value::condition,
220+ pybind11::arg (" condition_value" ),
221+ pybind11::arg (" value_if_true" ),
222+ pybind11::arg (" value_if_false" )
223+ );
224+
189225 pybind11::class_<Tensor>(m, " Tensor" )
190226 .def (pybind11::init ())
191227 .def (pybind11::init (&tensor_copying))
228+ .def (pybind11::init (&convert_numpy_to_tensor_base<int >))
192229 .def (pybind11::init (&convert_numpy_to_tensor_base<float >))
193230 .def (pybind11::self + pybind11::self)
194231 .def (pybind11::self - pybind11::self)
@@ -207,33 +244,27 @@ PYBIND11_MODULE(tensor2, m)
207244 .def (+pybind11::self)
208245 .def (-pybind11::self)
209246 .def (hash (pybind11::self))
210- .def (" transpose" , &Tensor::transpose)
211- .def (" calc_grad" , &Tensor::calc_grad)
212- .def (" get_grad" , &Tensor::get_grad)
213- .def (" sin" , &Tensor::sin)
214- .def (" sin" , &Tensor::sin)
215- .def (" cos" , &Tensor::cos)
216- .def (" tan" , &Tensor::tan)
217- .def (" sinh" , &Tensor::sinh)
218- .def (" cosh" , &Tensor::cosh)
219- .def (" tanh" , &Tensor::tanh)
220- .def (" log" , &Tensor::log)
221- .def (" clone" , &Tensor::clone)
247+ .def (" transpose" , &tensor_array::value::Tensor::transpose)
248+ .def (" calc_grad" , &tensor_array::value::Tensor::calc_grad)
249+ .def (" get_grad" , &tensor_array::value::Tensor::get_grad)
250+ .def (" sin" , &tensor_array::value::Tensor::sin)
251+ .def (" cos" , &tensor_array::value::Tensor::cos)
252+ .def (" tan" , &tensor_array::value::Tensor::tan)
253+ .def (" sinh" , &tensor_array::value::Tensor::sinh)
254+ .def (" cosh" , &tensor_array::value::Tensor::cosh)
255+ .def (" tanh" , &tensor_array::value::Tensor::tanh)
256+ .def (" log" , &tensor_array::value::Tensor::log)
257+ .def (" clone" , &tensor_array::value::Tensor::clone)
222258 .def (" cast" , &tensor_cast_1)
223- .def (" add" , &add)
224- .def (" multiply" , &multiply)
225- .def (" divide" , ÷)
226- .def (" matmul" , &matmul)
227- .def (" condition" , &condition)
228259 .def (" numpy" , &convert_tensor_to_numpy)
229260 .def (" shape" , &tensor_shape)
230261 .def (" dtype" , &tensor_type)
231262 .def (" __getitem__" , &python_index)
232263 .def (" __getitem__" , &python_slice)
233264 .def (" __getitem__" , &python_tuple_slice)
234265 .def (" __len__" , &python_len)
235- .def (" __matmul__" , &matmul)
236- .def (" __rmatmul__" , &matmul)
266+ .def (" __matmul__" , &tensor_array::value:: matmul)
267+ .def (" __rmatmul__" , &tensor_array::value:: matmul)
237268 .def (" __repr__" , &tensor_to_string)
238269 .def (" __copy__" , &tensor_copying);
239270}
0 commit comments