Skip to main content

torsh_python/tensor/
creation.rs

1//! Tensor creation functions - zeros, ones, randn, etc.
2
3use super::core::PyTensor;
4use crate::{device::PyDevice, dtype::PyDType, error::PyResult, py_result};
5use pyo3::prelude::*;
6
7/// Register simplified tensor creation functions
8pub fn register_creation_functions(m: &Bound<'_, PyModule>) -> PyResult<()> {
9    use pyo3::wrap_pyfunction;
10
11    #[pyfunction]
12    fn tensor(
13        data: &Bound<'_, PyAny>,
14        dtype: Option<PyDType>,
15        device: Option<PyDevice>,
16        requires_grad: Option<bool>,
17    ) -> PyResult<PyTensor> {
18        PyTensor::new(data, dtype, device, requires_grad)
19    }
20
21    #[pyfunction]
22    fn zeros(
23        size: Vec<usize>,
24        _dtype: Option<PyDType>,
25        _device: Option<PyDevice>,
26        requires_grad: Option<bool>,
27    ) -> PyResult<PyTensor> {
28        let tensor_result = py_result!(torsh_tensor::creation::zeros(&size))?;
29        let tensor = tensor_result.requires_grad_(requires_grad.unwrap_or(false));
30        Ok(PyTensor { tensor })
31    }
32
33    #[pyfunction]
34    fn ones(
35        size: Vec<usize>,
36        _dtype: Option<PyDType>,
37        _device: Option<PyDevice>,
38        requires_grad: Option<bool>,
39    ) -> PyResult<PyTensor> {
40        let tensor_result = py_result!(torsh_tensor::creation::ones(&size))?;
41        let tensor = tensor_result.requires_grad_(requires_grad.unwrap_or(false));
42        Ok(PyTensor { tensor })
43    }
44
45    #[pyfunction]
46    fn randn(
47        size: Vec<usize>,
48        _dtype: Option<PyDType>,
49        _device: Option<PyDevice>,
50        requires_grad: Option<bool>,
51    ) -> PyResult<PyTensor> {
52        let tensor_result = py_result!(torsh_tensor::creation::randn(&size))?;
53        let tensor = tensor_result.requires_grad_(requires_grad.unwrap_or(false));
54        Ok(PyTensor { tensor })
55    }
56
57    #[pyfunction]
58    fn rand(
59        size: Vec<usize>,
60        _dtype: Option<PyDType>,
61        _device: Option<PyDevice>,
62        requires_grad: Option<bool>,
63    ) -> PyResult<PyTensor> {
64        let tensor_result = py_result!(torsh_tensor::creation::rand(&size))?;
65        let tensor = tensor_result.requires_grad_(requires_grad.unwrap_or(false));
66        Ok(PyTensor { tensor })
67    }
68
69    #[pyfunction]
70    fn empty(
71        size: Vec<usize>,
72        _dtype: Option<PyDType>,
73        _device: Option<PyDevice>,
74        requires_grad: Option<bool>,
75    ) -> PyResult<PyTensor> {
76        // Use zeros as a fallback since empty is not available
77        let tensor_result = py_result!(torsh_tensor::creation::zeros(&size))?;
78        let tensor = tensor_result.requires_grad_(requires_grad.unwrap_or(false));
79        Ok(PyTensor { tensor })
80    }
81
82    #[pyfunction]
83    fn full(
84        size: Vec<usize>,
85        fill_value: f32,
86        _dtype: Option<PyDType>,
87        _device: Option<PyDevice>,
88        requires_grad: Option<bool>,
89    ) -> PyResult<PyTensor> {
90        let tensor_result = py_result!(torsh_tensor::creation::full(&size, fill_value))?;
91        let tensor = tensor_result.requires_grad_(requires_grad.unwrap_or(false));
92        Ok(PyTensor { tensor })
93    }
94
95    #[pyfunction]
96    fn eye(
97        n: usize,
98        _m: Option<usize>,
99        _dtype: Option<PyDType>,
100        _device: Option<PyDevice>,
101        requires_grad: Option<bool>,
102    ) -> PyResult<PyTensor> {
103        // torsh_tensor::creation::eye only takes one parameter
104        let tensor_result = py_result!(torsh_tensor::creation::eye(n))?;
105        let tensor = tensor_result.requires_grad_(requires_grad.unwrap_or(false));
106        Ok(PyTensor { tensor })
107    }
108
109    #[pyfunction]
110    fn arange(
111        start: f32,
112        end: Option<f32>,
113        step: Option<f32>,
114        _dtype: Option<PyDType>,
115        _device: Option<PyDevice>,
116        requires_grad: Option<bool>,
117    ) -> PyResult<PyTensor> {
118        let (start, end) = if let Some(end) = end {
119            (start, end)
120        } else {
121            (0.0, start)
122        };
123        let step = step.unwrap_or(1.0);
124        let tensor_result = py_result!(torsh_tensor::creation::arange(start, end, step))?;
125        let tensor = tensor_result.requires_grad_(requires_grad.unwrap_or(false));
126        Ok(PyTensor { tensor })
127    }
128
129    #[pyfunction]
130    fn linspace(
131        start: f32,
132        end: f32,
133        steps: usize,
134        _dtype: Option<PyDType>,
135        _device: Option<PyDevice>,
136        requires_grad: Option<bool>,
137    ) -> PyResult<PyTensor> {
138        let tensor_result = py_result!(torsh_tensor::creation::linspace(start, end, steps))?;
139        let tensor = tensor_result.requires_grad_(requires_grad.unwrap_or(false));
140        Ok(PyTensor { tensor })
141    }
142
143    // "_like" functions - create tensors with same shape as input
144    #[pyfunction]
145    fn zeros_like(
146        input: &PyTensor,
147        dtype: Option<PyDType>,
148        device: Option<PyDevice>,
149        requires_grad: Option<bool>,
150    ) -> PyResult<PyTensor> {
151        let _dtype = dtype;
152        let _device = device;
153        let tensor_result = py_result!(torsh_tensor::creation::zeros_like(&input.tensor))?;
154        let tensor = tensor_result.requires_grad_(requires_grad.unwrap_or(false));
155        Ok(PyTensor { tensor })
156    }
157
158    #[pyfunction]
159    fn ones_like(
160        input: &PyTensor,
161        dtype: Option<PyDType>,
162        device: Option<PyDevice>,
163        requires_grad: Option<bool>,
164    ) -> PyResult<PyTensor> {
165        let _dtype = dtype;
166        let _device = device;
167        let tensor_result = py_result!(torsh_tensor::creation::ones_like(&input.tensor))?;
168        let tensor = tensor_result.requires_grad_(requires_grad.unwrap_or(false));
169        Ok(PyTensor { tensor })
170    }
171
172    #[pyfunction]
173    fn full_like(
174        input: &PyTensor,
175        fill_value: f32,
176        dtype: Option<PyDType>,
177        device: Option<PyDevice>,
178        requires_grad: Option<bool>,
179    ) -> PyResult<PyTensor> {
180        let _dtype = dtype;
181        let _device = device;
182        let shape = input.tensor.shape().dims().to_vec();
183        let tensor_result = py_result!(torsh_tensor::creation::full(&shape, fill_value))?;
184        let tensor = tensor_result.requires_grad_(requires_grad.unwrap_or(false));
185        Ok(PyTensor { tensor })
186    }
187
188    #[pyfunction]
189    fn empty_like(
190        input: &PyTensor,
191        dtype: Option<PyDType>,
192        device: Option<PyDevice>,
193        requires_grad: Option<bool>,
194    ) -> PyResult<PyTensor> {
195        let _dtype = dtype;
196        let _device = device;
197        // Use zeros_like as fallback since empty is not critical
198        let tensor_result = py_result!(torsh_tensor::creation::zeros_like(&input.tensor))?;
199        let tensor = tensor_result.requires_grad_(requires_grad.unwrap_or(false));
200        Ok(PyTensor { tensor })
201    }
202
203    #[pyfunction]
204    fn randn_like(
205        input: &PyTensor,
206        dtype: Option<PyDType>,
207        device: Option<PyDevice>,
208        requires_grad: Option<bool>,
209    ) -> PyResult<PyTensor> {
210        let _dtype = dtype;
211        let _device = device;
212        let shape = input.tensor.shape().dims().to_vec();
213        let tensor_result = py_result!(torsh_tensor::creation::randn(&shape))?;
214        let tensor = tensor_result.requires_grad_(requires_grad.unwrap_or(false));
215        Ok(PyTensor { tensor })
216    }
217
218    #[pyfunction]
219    fn rand_like(
220        input: &PyTensor,
221        dtype: Option<PyDType>,
222        device: Option<PyDevice>,
223        requires_grad: Option<bool>,
224    ) -> PyResult<PyTensor> {
225        let _dtype = dtype;
226        let _device = device;
227        let shape = input.tensor.shape().dims().to_vec();
228        let tensor_result = py_result!(torsh_tensor::creation::rand(&shape))?;
229        let tensor = tensor_result.requires_grad_(requires_grad.unwrap_or(false));
230        Ok(PyTensor { tensor })
231    }
232
233    // Register all functions
234    m.add_function(wrap_pyfunction!(tensor, m)?)?;
235    m.add_function(wrap_pyfunction!(zeros, m)?)?;
236    m.add_function(wrap_pyfunction!(ones, m)?)?;
237    m.add_function(wrap_pyfunction!(randn, m)?)?;
238    m.add_function(wrap_pyfunction!(rand, m)?)?;
239    m.add_function(wrap_pyfunction!(empty, m)?)?;
240    m.add_function(wrap_pyfunction!(full, m)?)?;
241    m.add_function(wrap_pyfunction!(eye, m)?)?;
242    m.add_function(wrap_pyfunction!(arange, m)?)?;
243    m.add_function(wrap_pyfunction!(linspace, m)?)?;
244
245    // Register "_like" functions
246    m.add_function(wrap_pyfunction!(zeros_like, m)?)?;
247    m.add_function(wrap_pyfunction!(ones_like, m)?)?;
248    m.add_function(wrap_pyfunction!(full_like, m)?)?;
249    m.add_function(wrap_pyfunction!(empty_like, m)?)?;
250    m.add_function(wrap_pyfunction!(randn_like, m)?)?;
251    m.add_function(wrap_pyfunction!(rand_like, m)?)?;
252
253    Ok(())
254}