1#![allow(dead_code)]
5use crate::error::FfiError;
6use crate::python::tensor::PyTensor;
7use numpy::{PyReadonlyArrayDyn, PyUntypedArrayMethods};
8use pyo3::prelude::*;
9use pyo3::types::{PyAny, PyList};
10use pyo3::Py;
11
12#[pyfunction]
14#[pyo3(signature = (data, dtype=None, requires_grad=false))]
15pub fn tensor(
16 data: &Bound<'_, PyAny>,
17 dtype: Option<&str>,
18 requires_grad: bool,
19) -> PyResult<PyTensor> {
20 PyTensor::new(data, None, dtype, requires_grad)
21}
22
23#[pyfunction]
25#[pyo3(signature = (shape, dtype=None, requires_grad=false))]
26pub fn zeros(shape: Vec<usize>, dtype: Option<&str>, requires_grad: bool) -> PyResult<PyTensor> {
27 let total_elements: usize = shape.iter().product();
28 let data = vec![0.0; total_elements];
29
30 Python::attach(|py| {
31 let py_data = PyList::new(py, &data)?;
32 PyTensor::new(py_data.as_ref(), Some(shape), dtype, requires_grad)
33 })
34}
35
36#[pyfunction]
38#[pyo3(signature = (shape, dtype=None, requires_grad=false))]
39pub fn ones(shape: Vec<usize>, dtype: Option<&str>, requires_grad: bool) -> PyResult<PyTensor> {
40 let total_elements: usize = shape.iter().product();
41 let data = vec![1.0; total_elements];
42
43 Python::attach(|py| {
44 let py_data = PyList::new(py, &data)?;
45 PyTensor::new(py_data.as_ref(), Some(shape), dtype, requires_grad)
46 })
47}
48
49#[pyfunction]
51#[pyo3(signature = (shape, mean=0.0, std=1.0, dtype=None, requires_grad=false))]
52pub fn randn(
53 shape: Vec<usize>,
54 mean: f32,
55 std: f32,
56 dtype: Option<&str>,
57 requires_grad: bool,
58) -> PyResult<PyTensor> {
59 let total_elements: usize = shape.iter().product();
60
61 let mut data = Vec::with_capacity(total_elements);
63 let mut rng_state = 12345u64; for i in 0..total_elements {
66 if i % 2 == 0 {
67 let u1 = lcg_random(&mut rng_state);
69 let u2 = lcg_random(&mut rng_state);
70
71 let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos();
72 let z1 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).sin();
73
74 data.push(mean + std * z0);
75 if i + 1 < total_elements {
76 data.push(mean + std * z1);
77 }
78 }
79 }
80
81 data.truncate(total_elements);
83
84 Python::attach(|py| {
85 let py_data = PyList::new(py, &data)?;
86 PyTensor::new(py_data.as_ref(), Some(shape), dtype, requires_grad)
87 })
88}
89
90#[pyfunction]
92#[pyo3(signature = (shape, low=0.0, high=1.0, dtype=None, requires_grad=false))]
93pub fn rand(
94 shape: Vec<usize>,
95 low: f32,
96 high: f32,
97 dtype: Option<&str>,
98 requires_grad: bool,
99) -> PyResult<PyTensor> {
100 let total_elements: usize = shape.iter().product();
101 let mut data = Vec::with_capacity(total_elements);
102 let mut rng_state = 54321u64; for _ in 0..total_elements {
105 let random_val = lcg_random(&mut rng_state);
106 data.push(low + (high - low) * random_val);
107 }
108
109 Python::attach(|py| {
110 let py_data = PyList::new(py, &data)?;
111 PyTensor::new(py_data.as_ref(), Some(shape), dtype, requires_grad)
112 })
113}
114
115#[pyfunction]
117#[pyo3(signature = (n, dtype=None, requires_grad=false))]
118pub fn eye(n: usize, dtype: Option<&str>, requires_grad: bool) -> PyResult<PyTensor> {
119 let mut data = vec![0.0; n * n];
120
121 for i in 0..n {
122 data[i * n + i] = 1.0;
123 }
124
125 Python::attach(|py| {
126 let py_data = PyList::new(py, &data)?;
127 PyTensor::new(py_data.as_ref(), Some(vec![n, n]), dtype, requires_grad)
128 })
129}
130
131#[pyfunction]
133#[pyo3(signature = (shape, fill_value, dtype=None, requires_grad=false))]
134pub fn full(
135 shape: Vec<usize>,
136 fill_value: f32,
137 dtype: Option<&str>,
138 requires_grad: bool,
139) -> PyResult<PyTensor> {
140 let total_elements: usize = shape.iter().product();
141 let data = vec![fill_value; total_elements];
142
143 Python::attach(|py| {
144 let py_data = PyList::new(py, &data)?;
145 PyTensor::new(py_data.as_ref(), Some(shape), dtype, requires_grad)
146 })
147}
148
149#[pyfunction]
151pub fn from_numpy(array: PyReadonlyArrayDyn<f32>) -> PyResult<PyTensor> {
152 let data: Vec<f32> = array.as_array().iter().cloned().collect();
153 let shape: Vec<usize> = array.shape().to_vec();
154
155 Python::attach(|py| {
156 let py_data = PyList::new(py, &data)?;
157 PyTensor::new(py_data.as_ref(), Some(shape), Some("f32"), false)
158 })
159}
160
161#[pyfunction]
163pub fn to_numpy(tensor: &PyTensor, py: Python) -> PyResult<Py<PyAny>> {
164 tensor.to_numpy_internal(py)
165}
166
167#[pyfunction]
169#[pyo3(signature = (start, end, steps, dtype=None, requires_grad=false))]
170pub fn linspace(
171 start: f32,
172 end: f32,
173 steps: usize,
174 dtype: Option<&str>,
175 requires_grad: bool,
176) -> PyResult<PyTensor> {
177 if steps == 0 {
178 return Err(FfiError::InvalidParameter {
179 parameter: "steps".to_string(),
180 value: "0".to_string(),
181 }
182 .into());
183 }
184
185 let mut data = Vec::with_capacity(steps);
186
187 if steps == 1 {
188 data.push(start);
189 } else {
190 let step_size = (end - start) / (steps - 1) as f32;
191 for i in 0..steps {
192 data.push(start + i as f32 * step_size);
193 }
194 }
195
196 Python::attach(|py| {
197 let py_data = PyList::new(py, &data)?;
198 PyTensor::new(py_data.as_ref(), Some(vec![steps]), dtype, requires_grad)
199 })
200}
201
202#[pyfunction]
204#[pyo3(signature = (start, end, step=1.0, dtype=None, requires_grad=false))]
205pub fn arange(
206 start: f32,
207 end: f32,
208 step: f32,
209 dtype: Option<&str>,
210 requires_grad: bool,
211) -> PyResult<PyTensor> {
212 if step == 0.0 {
213 return Err(FfiError::InvalidParameter {
214 parameter: "step".to_string(),
215 value: "0.0".to_string(),
216 }
217 .into());
218 }
219
220 let mut data = Vec::new();
221 let mut current = start;
222
223 if step > 0.0 {
224 while current < end {
225 data.push(current);
226 current += step;
227 }
228 } else {
229 while current > end {
230 data.push(current);
231 current += step;
232 }
233 }
234
235 Python::attach(|py| {
236 let py_data = PyList::new(py, &data)?;
237 PyTensor::new(
238 py_data.as_ref(),
239 Some(vec![data.len()]),
240 dtype,
241 requires_grad,
242 )
243 })
244}
245
246#[pyfunction]
248#[pyo3(signature = (tensors, dim=0))]
249pub fn stack(tensors: Vec<PyTensor>, dim: i32) -> PyResult<PyTensor> {
250 if tensors.is_empty() {
251 return Err(FfiError::InvalidParameter {
252 parameter: "tensors".to_string(),
253 value: "empty list".to_string(),
254 }
255 .into());
256 }
257
258 let first_shape = &tensors[0].shape();
260 for tensor in &tensors[1..] {
261 if tensor.shape() != *first_shape {
262 return Err(FfiError::ShapeMismatch {
263 expected: first_shape.clone(),
264 actual: tensor.shape(),
265 }
266 .into());
267 }
268 }
269
270 if dim != 0 {
272 return Err(FfiError::UnsupportedOperation {
273 operation: format!("stack with dim={} not yet implemented", dim),
274 }
275 .into());
276 }
277
278 let mut stacked_data = Vec::new();
279 for tensor in &tensors {
280 stacked_data.extend_from_slice(&tensor.data);
281 }
282
283 let mut new_shape = vec![tensors.len()];
284 new_shape.extend_from_slice(first_shape);
285
286 Python::attach(|py| {
287 let py_data = PyList::new(py, &stacked_data)?;
288 PyTensor::new(py_data.as_ref(), Some(new_shape), Some("f32"), false)
289 })
290}
291
292#[pyfunction]
294#[pyo3(signature = (tensors, dim=0))]
295pub fn cat(tensors: Vec<PyTensor>, dim: i32) -> PyResult<PyTensor> {
296 if tensors.is_empty() {
297 return Err(FfiError::InvalidParameter {
298 parameter: "tensors".to_string(),
299 value: "empty list".to_string(),
300 }
301 .into());
302 }
303
304 if dim != 0 {
306 return Err(FfiError::UnsupportedOperation {
307 operation: format!("cat with dim={} not yet implemented", dim),
308 }
309 .into());
310 }
311
312 let mut concatenated_data = Vec::new();
313 let mut total_size = 0;
314
315 for tensor in &tensors {
316 concatenated_data.extend_from_slice(&tensor.data);
317 total_size += tensor.data.len();
318 }
319
320 Python::attach(|py| {
321 let py_data = PyList::new(py, &concatenated_data)?;
322 PyTensor::new(py_data.as_ref(), Some(vec![total_size]), Some("f32"), false)
323 })
324}
325
326#[pyfunction]
328pub fn cuda_is_available() -> bool {
329 cfg!(feature = "cuda")
331}
332
333#[pyfunction]
335pub fn cuda_device_count() -> usize {
336 if cuda_is_available() {
338 1
339 } else {
340 0
341 }
342}
343
344#[pyfunction]
346pub fn manual_seed(_seed: u64) {
347 }
350
351fn lcg_random(state: &mut u64) -> f32 {
353 *state = state.wrapping_mul(1103515245).wrapping_add(12345);
354 (*state as f32) / (u64::MAX as f32)
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360 use pyo3::Python;
361
362 #[test]
363 fn test_zeros() {
364 Python::initialize();
365 let result = zeros(vec![2, 3], None, false).unwrap();
366 assert_eq!(result.shape(), vec![2, 3]);
367 assert_eq!(result.data, vec![0.0; 6]);
368 }
369
370 #[test]
371 fn test_ones() {
372 Python::initialize();
373 let result = ones(vec![2, 2], None, false).unwrap();
374 assert_eq!(result.shape(), vec![2, 2]);
375 assert_eq!(result.data, vec![1.0; 4]);
376 }
377
378 #[test]
379 fn test_eye() {
380 Python::initialize();
381 let result = eye(3, None, false).unwrap();
382 assert_eq!(result.shape(), vec![3, 3]);
383
384 let expected = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
385 assert_eq!(result.data, expected);
386 }
387
388 #[test]
389 fn test_linspace() {
390 Python::initialize();
391 let result = linspace(0.0, 1.0, 5, None, false).unwrap();
392 assert_eq!(result.shape(), vec![5]);
393
394 let expected = vec![0.0, 0.25, 0.5, 0.75, 1.0];
395 for (a, b) in result.data.iter().zip(expected.iter()) {
396 assert!((a - b).abs() < 1e-6);
397 }
398 }
399
400 #[test]
401 fn test_arange() {
402 Python::initialize();
403 let result = arange(0.0, 5.0, 1.0, None, false).unwrap();
404 assert_eq!(result.shape(), vec![5]);
405 assert_eq!(result.data, vec![0.0, 1.0, 2.0, 3.0, 4.0]);
406 }
407
408 #[test]
409 fn test_randn() {
410 Python::initialize();
411 let result = randn(vec![100], 0.0, 1.0, None, false).unwrap();
412 assert_eq!(result.shape(), vec![100]);
413
414 let mean = result.data.iter().sum::<f32>() / result.data.len() as f32;
416 assert!(mean.abs() < 0.5); let variance: f32 =
419 result.data.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / result.data.len() as f32;
420 let std_dev = variance.sqrt();
421 assert!((std_dev - 1.0).abs() < 0.5); }
423}