1#![allow(dead_code)]
5use crate::python::tensor::PyTensor;
6use pyo3::prelude::*;
7use pyo3::types::{PyAny, PyList};
8use pyo3::{Bound, Py};
9use torsh_core::DType;
10use torsh_data::{
11 collate::DefaultCollate,
12 dataloader::{simple_dataloader, simple_random_dataloader, DataLoader},
13 dataset::TensorDataset,
14 sampler::{BatchingSampler, RandomSampler, SequentialSampler},
15};
16use torsh_tensor::Tensor;
17
18#[pyclass(name = "DataLoader")]
20pub struct PyDataLoader {
21 inner: DataLoader<TensorDataset<f32>, BatchingSampler<SequentialSampler>, DefaultCollate>,
22}
23
24#[pymethods]
25impl PyDataLoader {
26 #[new]
28 fn new(
29 dataset: PyTensor,
30 batch_size: Option<usize>,
31 shuffle: Option<bool>,
32 num_workers: Option<usize>,
33 drop_last: Option<bool>,
34 ) -> PyResult<Self> {
35 let batch_size = batch_size.unwrap_or(1);
36 let shuffle = shuffle.unwrap_or(false);
37 let _num_workers = num_workers.unwrap_or(0);
38 let _drop_last = drop_last.unwrap_or(false);
39
40 let tensor_data = dataset.data.clone();
42 let tensor_shape = dataset.shape.clone();
43
44 let tensor = Tensor::from_vec(tensor_data, &tensor_shape).map_err(|e| {
46 PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
47 "Failed to create tensor: {}",
48 e
49 ))
50 })?;
51
52 let tensor_dataset = TensorDataset::from_tensor(tensor);
53
54 let dataloader = simple_dataloader(tensor_dataset, batch_size, shuffle);
56
57 match dataloader {
58 Ok(dl) => Ok(PyDataLoader { inner: dl }),
59 Err(e) => Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
60 "Failed to create dataloader: {}",
61 e
62 ))),
63 }
64 }
65
66 fn __len__(&self) -> usize {
68 self.inner.len()
69 }
70
71 fn is_empty(&self) -> bool {
73 self.inner.is_empty()
74 }
75
76 fn __iter__(_slf: PyRef<'_, Self>) -> PyDataLoaderIterator {
78 PyDataLoaderIterator {
80 batch_size: 32, current_batch: 0,
82 total_batches: 10, }
84 }
85}
86
87#[pyclass(name = "DataLoaderIterator")]
89pub struct PyDataLoaderIterator {
90 batch_size: usize,
91 current_batch: usize,
92 total_batches: usize,
93}
94
95#[pymethods]
96impl PyDataLoaderIterator {
97 fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
98 slf
99 }
100
101 fn __next__(&mut self) -> PyResult<Option<PyTensor>> {
102 if self.current_batch < self.total_batches {
103 self.current_batch += 1;
104
105 let batch_data = vec![0.0f32; self.batch_size];
107 let batch_shape = vec![self.batch_size];
108
109 Python::attach(|py| {
110 let py_list = PyList::new(py, &batch_data)?;
111 let py_tensor = PyTensor::new(&py_list, Some(batch_shape), Some("f32"), false)?;
112 Ok(Some(py_tensor))
113 })
114 } else {
115 Ok(None) }
117 }
118}
119
120#[pyclass(name = "RandomDataLoader")]
122pub struct PyRandomDataLoader {
123 inner: DataLoader<TensorDataset<f32>, BatchingSampler<RandomSampler>, DefaultCollate>,
124}
125
126#[pymethods]
127impl PyRandomDataLoader {
128 #[new]
130 fn new(
131 dataset: PyTensor,
132 batch_size: Option<usize>,
133 generator_seed: Option<u64>,
134 num_workers: Option<usize>,
135 drop_last: Option<bool>,
136 ) -> PyResult<Self> {
137 let batch_size = batch_size.unwrap_or(1);
138 let _num_workers = num_workers.unwrap_or(0);
139 let _drop_last = drop_last.unwrap_or(false);
140
141 let tensor_data = dataset.data.clone();
143 let tensor_shape = dataset.shape.clone();
144
145 let tensor = Tensor::from_vec(tensor_data, &tensor_shape).map_err(|e| {
147 PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
148 "Failed to create tensor: {}",
149 e
150 ))
151 })?;
152
153 let tensor_dataset = TensorDataset::from_tensor(tensor);
154
155 let dataloader = simple_random_dataloader(tensor_dataset, batch_size, generator_seed);
157
158 match dataloader {
159 Ok(dl) => Ok(PyRandomDataLoader { inner: dl }),
160 Err(e) => Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
161 "Failed to create random dataloader: {}",
162 e
163 ))),
164 }
165 }
166
167 fn __len__(&self) -> usize {
169 self.inner.len()
170 }
171
172 fn is_empty(&self) -> bool {
174 self.inner.is_empty()
175 }
176
177 fn __iter__(_slf: PyRef<'_, Self>) -> PyRandomDataLoaderIterator {
179 PyRandomDataLoaderIterator {
181 batch_size: 32, current_batch: 0,
183 total_batches: 10, }
185 }
186}
187
188#[pyclass(name = "RandomDataLoaderIterator")]
190pub struct PyRandomDataLoaderIterator {
191 batch_size: usize,
192 current_batch: usize,
193 total_batches: usize,
194}
195
196#[pymethods]
197impl PyRandomDataLoaderIterator {
198 fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
199 slf
200 }
201
202 fn __next__(&mut self) -> PyResult<Option<PyTensor>> {
203 if self.current_batch < self.total_batches {
204 self.current_batch += 1;
205
206 let batch_data = vec![0.0f32; self.batch_size];
208 let batch_shape = vec![self.batch_size];
209
210 Python::attach(|py| {
211 let py_list = PyList::new(py, &batch_data)?;
212 let py_tensor = PyTensor::new(&py_list, Some(batch_shape), Some("f32"), false)?;
213 Ok(Some(py_tensor))
214 })
215 } else {
216 Ok(None) }
218 }
219}
220
221#[pyfunction]
223pub fn create_dataloader(
224 tensors: &Bound<'_, PyList>,
225 batch_size: Option<usize>,
226 shuffle: Option<bool>,
227) -> PyResult<PyDataLoader> {
228 let batch_size = batch_size.unwrap_or(1);
229 let shuffle = shuffle.unwrap_or(false);
230
231 let mut tensor_list = Vec::new();
233 for item in tensors.iter() {
234 let py_tensor: PyTensor = item.extract()?;
235 tensor_list.push(py_tensor);
236 }
237
238 if tensor_list.is_empty() {
239 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
240 "Cannot create dataloader from empty tensor list",
241 ));
242 }
243
244 let first_tensor = &tensor_list[0];
247 PyDataLoader::new(
248 first_tensor.clone(),
249 Some(batch_size),
250 Some(shuffle),
251 None,
252 None,
253 )
254}
255
256#[pyfunction]
258pub fn create_dataset_from_array(py: Python, array: Py<PyAny>) -> PyResult<PyTensor> {
259 if let Ok(outer_list) = array.bind(py).extract::<Vec<Vec<f32>>>() {
264 let rows = outer_list.len();
265 let cols = if rows > 0 { outer_list[0].len() } else { 0 };
266
267 let mut data = Vec::with_capacity(rows * cols);
268 for row in outer_list {
269 if row.len() != cols {
270 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
271 "All rows must have the same length",
272 ));
273 }
274 data.extend(row);
275 }
276
277 let shape_vec = vec![rows, cols];
278 Ok(PyTensor::from_raw(data, shape_vec, DType::F32, false))
279 } else if let Ok(flat_list) = array.bind(py).extract::<Vec<f32>>() {
280 let len = flat_list.len();
282 let shape_vec = vec![len];
283 Ok(PyTensor::from_raw(flat_list, shape_vec, DType::F32, false))
284 } else {
285 Err(PyErr::new::<pyo3::exceptions::PyTypeError, _>(
286 "Array must be a list of numbers or list of lists",
287 ))
288 }
289}
290
291#[pyclass(name = "DataLoaderBuilder")]
293pub struct PyDataLoaderBuilder {
294 dataset: PyTensor,
295 batch_size: usize,
296 shuffle: bool,
297 num_workers: usize,
298 pin_memory: bool,
299 drop_last: bool,
300 generator_seed: Option<u64>,
301}
302
303#[pymethods]
304impl PyDataLoaderBuilder {
305 #[new]
306 fn new(dataset: PyTensor) -> Self {
307 Self {
308 dataset,
309 batch_size: 1,
310 shuffle: false,
311 num_workers: 0,
312 pin_memory: false,
313 drop_last: false,
314 generator_seed: None,
315 }
316 }
317
318 fn batch_size(mut slf: PyRefMut<Self>, batch_size: usize) -> PyRefMut<Self> {
319 slf.batch_size = batch_size;
320 slf
321 }
322
323 fn shuffle(mut slf: PyRefMut<Self>, shuffle: bool) -> PyRefMut<Self> {
324 slf.shuffle = shuffle;
325 slf
326 }
327
328 fn num_workers(mut slf: PyRefMut<Self>, num_workers: usize) -> PyRefMut<Self> {
329 slf.num_workers = num_workers;
330 slf
331 }
332
333 fn pin_memory(mut slf: PyRefMut<Self>, pin_memory: bool) -> PyRefMut<Self> {
334 slf.pin_memory = pin_memory;
335 slf
336 }
337
338 fn drop_last(mut slf: PyRefMut<Self>, drop_last: bool) -> PyRefMut<Self> {
339 slf.drop_last = drop_last;
340 slf
341 }
342
343 fn generator(mut slf: PyRefMut<Self>, seed: u64) -> PyRefMut<Self> {
344 slf.generator_seed = Some(seed);
345 slf
346 }
347
348 fn build(&self) -> PyResult<PyDataLoader> {
349 if self.shuffle {
350 PyRandomDataLoader::new(
351 self.dataset.clone(),
352 Some(self.batch_size),
353 self.generator_seed,
354 Some(self.num_workers),
355 Some(self.drop_last),
356 )
357 .map(|_random_dl| {
358 PyDataLoader::new(
361 self.dataset.clone(),
362 Some(self.batch_size),
363 Some(false),
364 Some(self.num_workers),
365 Some(self.drop_last),
366 )
367 .unwrap()
368 })
369 } else {
370 PyDataLoader::new(
371 self.dataset.clone(),
372 Some(self.batch_size),
373 Some(self.shuffle),
374 Some(self.num_workers),
375 Some(self.drop_last),
376 )
377 }
378 }
379}
380
381#[pyfunction]
383pub fn get_dataloader_info(dataloader: &PyDataLoader) -> PyResult<String> {
384 Ok(format!(
385 "DataLoader(batches={}, empty={})",
386 dataloader.__len__(),
387 dataloader.is_empty()
388 ))
389}
390
391#[pyfunction]
392pub fn benchmark_dataloader(dataloader: &PyDataLoader, num_epochs: Option<usize>) -> PyResult<f64> {
393 let num_epochs = num_epochs.unwrap_or(1);
394 let start = std::time::Instant::now();
395
396 for _epoch in 0..num_epochs {
397 for batch_result in dataloader.inner.iter() {
398 if batch_result.is_err() {
400 return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
401 "Error during benchmarking",
402 ));
403 }
404 }
405 }
406
407 let elapsed = start.elapsed();
408 Ok(elapsed.as_secs_f64())
409}
410
411