1use super::core::PyTensor;
4use crate::{device::PyDevice, dtype::PyDType, error::PyResult, py_result};
5use pyo3::prelude::*;
6
7pub fn register_creation_functions(m: &Bound<'_, PyModule>) -> PyResult<()> {
9 use pyo3::wrap_pyfunction;
10
11 #[pyfunction]
12 fn tensor(
13 data: &Bound<'_, PyAny>,
14 dtype: Option<PyDType>,
15 device: Option<PyDevice>,
16 requires_grad: Option<bool>,
17 ) -> PyResult<PyTensor> {
18 PyTensor::new(data, dtype, device, requires_grad)
19 }
20
21 #[pyfunction]
22 fn zeros(
23 size: Vec<usize>,
24 _dtype: Option<PyDType>,
25 _device: Option<PyDevice>,
26 requires_grad: Option<bool>,
27 ) -> PyResult<PyTensor> {
28 let tensor_result = py_result!(torsh_tensor::creation::zeros(&size))?;
29 let tensor = tensor_result.requires_grad_(requires_grad.unwrap_or(false));
30 Ok(PyTensor { tensor })
31 }
32
33 #[pyfunction]
34 fn ones(
35 size: Vec<usize>,
36 _dtype: Option<PyDType>,
37 _device: Option<PyDevice>,
38 requires_grad: Option<bool>,
39 ) -> PyResult<PyTensor> {
40 let tensor_result = py_result!(torsh_tensor::creation::ones(&size))?;
41 let tensor = tensor_result.requires_grad_(requires_grad.unwrap_or(false));
42 Ok(PyTensor { tensor })
43 }
44
45 #[pyfunction]
46 fn randn(
47 size: Vec<usize>,
48 _dtype: Option<PyDType>,
49 _device: Option<PyDevice>,
50 requires_grad: Option<bool>,
51 ) -> PyResult<PyTensor> {
52 let tensor_result = py_result!(torsh_tensor::creation::randn(&size))?;
53 let tensor = tensor_result.requires_grad_(requires_grad.unwrap_or(false));
54 Ok(PyTensor { tensor })
55 }
56
57 #[pyfunction]
58 fn rand(
59 size: Vec<usize>,
60 _dtype: Option<PyDType>,
61 _device: Option<PyDevice>,
62 requires_grad: Option<bool>,
63 ) -> PyResult<PyTensor> {
64 let tensor_result = py_result!(torsh_tensor::creation::rand(&size))?;
65 let tensor = tensor_result.requires_grad_(requires_grad.unwrap_or(false));
66 Ok(PyTensor { tensor })
67 }
68
69 #[pyfunction]
70 fn empty(
71 size: Vec<usize>,
72 _dtype: Option<PyDType>,
73 _device: Option<PyDevice>,
74 requires_grad: Option<bool>,
75 ) -> PyResult<PyTensor> {
76 let tensor_result = py_result!(torsh_tensor::creation::zeros(&size))?;
78 let tensor = tensor_result.requires_grad_(requires_grad.unwrap_or(false));
79 Ok(PyTensor { tensor })
80 }
81
82 #[pyfunction]
83 fn full(
84 size: Vec<usize>,
85 fill_value: f32,
86 _dtype: Option<PyDType>,
87 _device: Option<PyDevice>,
88 requires_grad: Option<bool>,
89 ) -> PyResult<PyTensor> {
90 let tensor_result = py_result!(torsh_tensor::creation::full(&size, fill_value))?;
91 let tensor = tensor_result.requires_grad_(requires_grad.unwrap_or(false));
92 Ok(PyTensor { tensor })
93 }
94
95 #[pyfunction]
96 fn eye(
97 n: usize,
98 _m: Option<usize>,
99 _dtype: Option<PyDType>,
100 _device: Option<PyDevice>,
101 requires_grad: Option<bool>,
102 ) -> PyResult<PyTensor> {
103 let tensor_result = py_result!(torsh_tensor::creation::eye(n))?;
105 let tensor = tensor_result.requires_grad_(requires_grad.unwrap_or(false));
106 Ok(PyTensor { tensor })
107 }
108
109 #[pyfunction]
110 fn arange(
111 start: f32,
112 end: Option<f32>,
113 step: Option<f32>,
114 _dtype: Option<PyDType>,
115 _device: Option<PyDevice>,
116 requires_grad: Option<bool>,
117 ) -> PyResult<PyTensor> {
118 let (start, end) = if let Some(end) = end {
119 (start, end)
120 } else {
121 (0.0, start)
122 };
123 let step = step.unwrap_or(1.0);
124 let tensor_result = py_result!(torsh_tensor::creation::arange(start, end, step))?;
125 let tensor = tensor_result.requires_grad_(requires_grad.unwrap_or(false));
126 Ok(PyTensor { tensor })
127 }
128
129 #[pyfunction]
130 fn linspace(
131 start: f32,
132 end: f32,
133 steps: usize,
134 _dtype: Option<PyDType>,
135 _device: Option<PyDevice>,
136 requires_grad: Option<bool>,
137 ) -> PyResult<PyTensor> {
138 let tensor_result = py_result!(torsh_tensor::creation::linspace(start, end, steps))?;
139 let tensor = tensor_result.requires_grad_(requires_grad.unwrap_or(false));
140 Ok(PyTensor { tensor })
141 }
142
143 #[pyfunction]
145 fn zeros_like(
146 input: &PyTensor,
147 dtype: Option<PyDType>,
148 device: Option<PyDevice>,
149 requires_grad: Option<bool>,
150 ) -> PyResult<PyTensor> {
151 let _dtype = dtype;
152 let _device = device;
153 let tensor_result = py_result!(torsh_tensor::creation::zeros_like(&input.tensor))?;
154 let tensor = tensor_result.requires_grad_(requires_grad.unwrap_or(false));
155 Ok(PyTensor { tensor })
156 }
157
158 #[pyfunction]
159 fn ones_like(
160 input: &PyTensor,
161 dtype: Option<PyDType>,
162 device: Option<PyDevice>,
163 requires_grad: Option<bool>,
164 ) -> PyResult<PyTensor> {
165 let _dtype = dtype;
166 let _device = device;
167 let tensor_result = py_result!(torsh_tensor::creation::ones_like(&input.tensor))?;
168 let tensor = tensor_result.requires_grad_(requires_grad.unwrap_or(false));
169 Ok(PyTensor { tensor })
170 }
171
172 #[pyfunction]
173 fn full_like(
174 input: &PyTensor,
175 fill_value: f32,
176 dtype: Option<PyDType>,
177 device: Option<PyDevice>,
178 requires_grad: Option<bool>,
179 ) -> PyResult<PyTensor> {
180 let _dtype = dtype;
181 let _device = device;
182 let shape = input.tensor.shape().dims().to_vec();
183 let tensor_result = py_result!(torsh_tensor::creation::full(&shape, fill_value))?;
184 let tensor = tensor_result.requires_grad_(requires_grad.unwrap_or(false));
185 Ok(PyTensor { tensor })
186 }
187
188 #[pyfunction]
189 fn empty_like(
190 input: &PyTensor,
191 dtype: Option<PyDType>,
192 device: Option<PyDevice>,
193 requires_grad: Option<bool>,
194 ) -> PyResult<PyTensor> {
195 let _dtype = dtype;
196 let _device = device;
197 let tensor_result = py_result!(torsh_tensor::creation::zeros_like(&input.tensor))?;
199 let tensor = tensor_result.requires_grad_(requires_grad.unwrap_or(false));
200 Ok(PyTensor { tensor })
201 }
202
203 #[pyfunction]
204 fn randn_like(
205 input: &PyTensor,
206 dtype: Option<PyDType>,
207 device: Option<PyDevice>,
208 requires_grad: Option<bool>,
209 ) -> PyResult<PyTensor> {
210 let _dtype = dtype;
211 let _device = device;
212 let shape = input.tensor.shape().dims().to_vec();
213 let tensor_result = py_result!(torsh_tensor::creation::randn(&shape))?;
214 let tensor = tensor_result.requires_grad_(requires_grad.unwrap_or(false));
215 Ok(PyTensor { tensor })
216 }
217
218 #[pyfunction]
219 fn rand_like(
220 input: &PyTensor,
221 dtype: Option<PyDType>,
222 device: Option<PyDevice>,
223 requires_grad: Option<bool>,
224 ) -> PyResult<PyTensor> {
225 let _dtype = dtype;
226 let _device = device;
227 let shape = input.tensor.shape().dims().to_vec();
228 let tensor_result = py_result!(torsh_tensor::creation::rand(&shape))?;
229 let tensor = tensor_result.requires_grad_(requires_grad.unwrap_or(false));
230 Ok(PyTensor { tensor })
231 }
232
233 m.add_function(wrap_pyfunction!(tensor, m)?)?;
235 m.add_function(wrap_pyfunction!(zeros, m)?)?;
236 m.add_function(wrap_pyfunction!(ones, m)?)?;
237 m.add_function(wrap_pyfunction!(randn, m)?)?;
238 m.add_function(wrap_pyfunction!(rand, m)?)?;
239 m.add_function(wrap_pyfunction!(empty, m)?)?;
240 m.add_function(wrap_pyfunction!(full, m)?)?;
241 m.add_function(wrap_pyfunction!(eye, m)?)?;
242 m.add_function(wrap_pyfunction!(arange, m)?)?;
243 m.add_function(wrap_pyfunction!(linspace, m)?)?;
244
245 m.add_function(wrap_pyfunction!(zeros_like, m)?)?;
247 m.add_function(wrap_pyfunction!(ones_like, m)?)?;
248 m.add_function(wrap_pyfunction!(full_like, m)?)?;
249 m.add_function(wrap_pyfunction!(empty_like, m)?)?;
250 m.add_function(wrap_pyfunction!(randn_like, m)?)?;
251 m.add_function(wrap_pyfunction!(rand_like, m)?)?;
252
253 Ok(())
254}