torsh_ffi/python/
dataloader.rs

1//! Python bindings for ToRSh data loaders
2
3// Framework infrastructure - components designed for future use
4#![allow(dead_code)]
5use crate::python::tensor::PyTensor;
6use pyo3::prelude::*;
7use pyo3::types::{PyAny, PyList};
8use pyo3::{Bound, Py};
9use torsh_core::DType;
10use torsh_data::{
11    collate::DefaultCollate,
12    dataloader::{simple_dataloader, simple_random_dataloader, DataLoader},
13    dataset::TensorDataset,
14    sampler::{BatchingSampler, RandomSampler, SequentialSampler},
15};
16use torsh_tensor::Tensor;
17
18/// Python wrapper for ToRSh DataLoader
19#[pyclass(name = "DataLoader")]
20pub struct PyDataLoader {
21    inner: DataLoader<TensorDataset<f32>, BatchingSampler<SequentialSampler>, DefaultCollate>,
22}
23
24#[pymethods]
25impl PyDataLoader {
26    /// Create a new DataLoader from a tensor dataset
27    #[new]
28    fn new(
29        dataset: PyTensor,
30        batch_size: Option<usize>,
31        shuffle: Option<bool>,
32        num_workers: Option<usize>,
33        drop_last: Option<bool>,
34    ) -> PyResult<Self> {
35        let batch_size = batch_size.unwrap_or(1);
36        let shuffle = shuffle.unwrap_or(false);
37        let _num_workers = num_workers.unwrap_or(0);
38        let _drop_last = drop_last.unwrap_or(false);
39
40        // Extract the tensor from PyTensor
41        let tensor_data = dataset.data.clone();
42        let tensor_shape = dataset.shape.clone();
43
44        // Create tensor dataset
45        let tensor = Tensor::from_vec(tensor_data, &tensor_shape).map_err(|e| {
46            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
47                "Failed to create tensor: {}",
48                e
49            ))
50        })?;
51
52        let tensor_dataset = TensorDataset::from_tensor(tensor);
53
54        // Create the dataloader (use sequential for now, shuffle will be handled differently)
55        let dataloader = simple_dataloader(tensor_dataset, batch_size, shuffle);
56
57        match dataloader {
58            Ok(dl) => Ok(PyDataLoader { inner: dl }),
59            Err(e) => Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
60                "Failed to create dataloader: {}",
61                e
62            ))),
63        }
64    }
65
66    /// Get the number of batches in the dataloader
67    fn __len__(&self) -> usize {
68        self.inner.len()
69    }
70
71    /// Check if the dataloader is empty
72    fn is_empty(&self) -> bool {
73        self.inner.is_empty()
74    }
75
76    /// Create an iterator over the dataloader
77    fn __iter__(_slf: PyRef<'_, Self>) -> PyDataLoaderIterator {
78        // Create a simplified iterator that doesn't rely on private fields
79        PyDataLoaderIterator {
80            batch_size: 32, // Default batch size
81            current_batch: 0,
82            total_batches: 10, // Default, should be calculated properly
83        }
84    }
85}
86
87/// Python iterator for DataLoader
88#[pyclass(name = "DataLoaderIterator")]
89pub struct PyDataLoaderIterator {
90    batch_size: usize,
91    current_batch: usize,
92    total_batches: usize,
93}
94
95#[pymethods]
96impl PyDataLoaderIterator {
97    fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
98        slf
99    }
100
101    fn __next__(&mut self) -> PyResult<Option<PyTensor>> {
102        if self.current_batch < self.total_batches {
103            self.current_batch += 1;
104
105            // Create a simple dummy batch for now
106            let batch_data = vec![0.0f32; self.batch_size];
107            let batch_shape = vec![self.batch_size];
108
109            Python::attach(|py| {
110                let py_list = PyList::new(py, &batch_data)?;
111                let py_tensor = PyTensor::new(&py_list, Some(batch_shape), Some("f32"), false)?;
112                Ok(Some(py_tensor))
113            })
114        } else {
115            Ok(None) // End of iteration
116        }
117    }
118}
119
120/// Python wrapper for random DataLoader
121#[pyclass(name = "RandomDataLoader")]
122pub struct PyRandomDataLoader {
123    inner: DataLoader<TensorDataset<f32>, BatchingSampler<RandomSampler>, DefaultCollate>,
124}
125
126#[pymethods]
127impl PyRandomDataLoader {
128    /// Create a new random DataLoader from a tensor dataset
129    #[new]
130    fn new(
131        dataset: PyTensor,
132        batch_size: Option<usize>,
133        generator_seed: Option<u64>,
134        num_workers: Option<usize>,
135        drop_last: Option<bool>,
136    ) -> PyResult<Self> {
137        let batch_size = batch_size.unwrap_or(1);
138        let _num_workers = num_workers.unwrap_or(0);
139        let _drop_last = drop_last.unwrap_or(false);
140
141        // Extract the tensor from PyTensor
142        let tensor_data = dataset.data.clone();
143        let tensor_shape = dataset.shape.clone();
144
145        // Create tensor dataset
146        let tensor = Tensor::from_vec(tensor_data, &tensor_shape).map_err(|e| {
147            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
148                "Failed to create tensor: {}",
149                e
150            ))
151        })?;
152
153        let tensor_dataset = TensorDataset::from_tensor(tensor);
154
155        // Create the random dataloader
156        let dataloader = simple_random_dataloader(tensor_dataset, batch_size, generator_seed);
157
158        match dataloader {
159            Ok(dl) => Ok(PyRandomDataLoader { inner: dl }),
160            Err(e) => Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
161                "Failed to create random dataloader: {}",
162                e
163            ))),
164        }
165    }
166
167    /// Get the number of batches in the dataloader
168    fn __len__(&self) -> usize {
169        self.inner.len()
170    }
171
172    /// Check if the dataloader is empty
173    fn is_empty(&self) -> bool {
174        self.inner.is_empty()
175    }
176
177    /// Create an iterator over the random dataloader
178    fn __iter__(_slf: PyRef<'_, Self>) -> PyRandomDataLoaderIterator {
179        // Create a simplified iterator that doesn't rely on private fields
180        PyRandomDataLoaderIterator {
181            batch_size: 32, // Default batch size
182            current_batch: 0,
183            total_batches: 10, // Default, should be calculated properly
184        }
185    }
186}
187
188/// Python iterator for Random DataLoader
189#[pyclass(name = "RandomDataLoaderIterator")]
190pub struct PyRandomDataLoaderIterator {
191    batch_size: usize,
192    current_batch: usize,
193    total_batches: usize,
194}
195
196#[pymethods]
197impl PyRandomDataLoaderIterator {
198    fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
199        slf
200    }
201
202    fn __next__(&mut self) -> PyResult<Option<PyTensor>> {
203        if self.current_batch < self.total_batches {
204            self.current_batch += 1;
205
206            // Create a simple dummy batch for now
207            let batch_data = vec![0.0f32; self.batch_size];
208            let batch_shape = vec![self.batch_size];
209
210            Python::attach(|py| {
211                let py_list = PyList::new(py, &batch_data)?;
212                let py_tensor = PyTensor::new(&py_list, Some(batch_shape), Some("f32"), false)?;
213                Ok(Some(py_tensor))
214            })
215        } else {
216            Ok(None) // End of iteration
217        }
218    }
219}
220
221/// Create a simple dataloader from a list of tensors
222#[pyfunction]
223pub fn create_dataloader(
224    tensors: &Bound<'_, PyList>,
225    batch_size: Option<usize>,
226    shuffle: Option<bool>,
227) -> PyResult<PyDataLoader> {
228    let batch_size = batch_size.unwrap_or(1);
229    let shuffle = shuffle.unwrap_or(false);
230
231    // Convert Python list to Vec<PyTensor>
232    let mut tensor_list = Vec::new();
233    for item in tensors.iter() {
234        let py_tensor: PyTensor = item.extract()?;
235        tensor_list.push(py_tensor);
236    }
237
238    if tensor_list.is_empty() {
239        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
240            "Cannot create dataloader from empty tensor list",
241        ));
242    }
243
244    // For simplicity, we'll use the first tensor as the dataset
245    // In a real implementation, you might want to concatenate tensors or handle differently
246    let first_tensor = &tensor_list[0];
247    PyDataLoader::new(
248        first_tensor.clone(),
249        Some(batch_size),
250        Some(shuffle),
251        None,
252        None,
253    )
254}
255
256/// Create a dataset from numpy-like arrays
257#[pyfunction]
258pub fn create_dataset_from_array(py: Python, array: Py<PyAny>) -> PyResult<PyTensor> {
259    // This is a simplified implementation - in practice you'd want to handle
260    // different array types (numpy, list, etc.) and convert them to tensors
261
262    // Try to extract as a list of lists (2D array)
263    if let Ok(outer_list) = array.bind(py).extract::<Vec<Vec<f32>>>() {
264        let rows = outer_list.len();
265        let cols = if rows > 0 { outer_list[0].len() } else { 0 };
266
267        let mut data = Vec::with_capacity(rows * cols);
268        for row in outer_list {
269            if row.len() != cols {
270                return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
271                    "All rows must have the same length",
272                ));
273            }
274            data.extend(row);
275        }
276
277        let shape_vec = vec![rows, cols];
278        Ok(PyTensor::from_raw(data, shape_vec, DType::F32, false))
279    } else if let Ok(flat_list) = array.bind(py).extract::<Vec<f32>>() {
280        // Handle 1D array
281        let len = flat_list.len();
282        let shape_vec = vec![len];
283        Ok(PyTensor::from_raw(flat_list, shape_vec, DType::F32, false))
284    } else {
285        Err(PyErr::new::<pyo3::exceptions::PyTypeError, _>(
286            "Array must be a list of numbers or list of lists",
287        ))
288    }
289}
290
291/// Helper function to create a dataloader builder with advanced options
292#[pyclass(name = "DataLoaderBuilder")]
293pub struct PyDataLoaderBuilder {
294    dataset: PyTensor,
295    batch_size: usize,
296    shuffle: bool,
297    num_workers: usize,
298    pin_memory: bool,
299    drop_last: bool,
300    generator_seed: Option<u64>,
301}
302
303#[pymethods]
304impl PyDataLoaderBuilder {
305    #[new]
306    fn new(dataset: PyTensor) -> Self {
307        Self {
308            dataset,
309            batch_size: 1,
310            shuffle: false,
311            num_workers: 0,
312            pin_memory: false,
313            drop_last: false,
314            generator_seed: None,
315        }
316    }
317
318    fn batch_size(mut slf: PyRefMut<Self>, batch_size: usize) -> PyRefMut<Self> {
319        slf.batch_size = batch_size;
320        slf
321    }
322
323    fn shuffle(mut slf: PyRefMut<Self>, shuffle: bool) -> PyRefMut<Self> {
324        slf.shuffle = shuffle;
325        slf
326    }
327
328    fn num_workers(mut slf: PyRefMut<Self>, num_workers: usize) -> PyRefMut<Self> {
329        slf.num_workers = num_workers;
330        slf
331    }
332
333    fn pin_memory(mut slf: PyRefMut<Self>, pin_memory: bool) -> PyRefMut<Self> {
334        slf.pin_memory = pin_memory;
335        slf
336    }
337
338    fn drop_last(mut slf: PyRefMut<Self>, drop_last: bool) -> PyRefMut<Self> {
339        slf.drop_last = drop_last;
340        slf
341    }
342
343    fn generator(mut slf: PyRefMut<Self>, seed: u64) -> PyRefMut<Self> {
344        slf.generator_seed = Some(seed);
345        slf
346    }
347
348    fn build(&self) -> PyResult<PyDataLoader> {
349        if self.shuffle {
350            PyRandomDataLoader::new(
351                self.dataset.clone(),
352                Some(self.batch_size),
353                self.generator_seed,
354                Some(self.num_workers),
355                Some(self.drop_last),
356            )
357            .map(|_random_dl| {
358                // Convert to regular PyDataLoader - this is simplified
359                // In practice you'd want better type handling
360                PyDataLoader::new(
361                    self.dataset.clone(),
362                    Some(self.batch_size),
363                    Some(false),
364                    Some(self.num_workers),
365                    Some(self.drop_last),
366                )
367                .unwrap()
368            })
369        } else {
370            PyDataLoader::new(
371                self.dataset.clone(),
372                Some(self.batch_size),
373                Some(self.shuffle),
374                Some(self.num_workers),
375                Some(self.drop_last),
376            )
377        }
378    }
379}
380
381/// Utility functions for data loading
382#[pyfunction]
383pub fn get_dataloader_info(dataloader: &PyDataLoader) -> PyResult<String> {
384    Ok(format!(
385        "DataLoader(batches={}, empty={})",
386        dataloader.__len__(),
387        dataloader.is_empty()
388    ))
389}
390
391#[pyfunction]
392pub fn benchmark_dataloader(dataloader: &PyDataLoader, num_epochs: Option<usize>) -> PyResult<f64> {
393    let num_epochs = num_epochs.unwrap_or(1);
394    let start = std::time::Instant::now();
395
396    for _epoch in 0..num_epochs {
397        for batch_result in dataloader.inner.iter() {
398            // Just consume the batches to measure iteration speed
399            if batch_result.is_err() {
400                return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
401                    "Error during benchmarking",
402                ));
403            }
404        }
405    }
406
407    let elapsed = start.elapsed();
408    Ok(elapsed.as_secs_f64())
409}
410
411// Types are already defined in this module and are pub, so no need to re-export