Skip to main content

torsh_data/collate/
optimized.rs

1//! Optimized collation implementations
2
3// Used in both std and no_std feature branches
4#[allow(unused_imports)]
5use super::stacking::TensorStacker;
6use crate::collate::Collate;
7use torsh_core::{
8    dtype::TensorElement,
9    error::{Result, TorshError},
10};
11use torsh_tensor::Tensor;
12
13#[cfg(not(feature = "std"))]
14use alloc::vec::Vec;
15
16// ✅ SciRS2 POLICY: Use scirs2_core::parallel_ops instead of rayon::prelude
17#[cfg(feature = "std")]
18use scirs2_core::parallel_ops::*;
19
20/// Stack tensors along a new dimension (optimized version)
21pub fn stack_tensors<T: TensorElement + Copy>(
22    tensors: &[Tensor<T>],
23    dim: usize,
24) -> Result<Tensor<T>> {
25    if tensors.is_empty() {
26        return Err(TorshError::InvalidArgument(
27            "Cannot stack empty tensor list".to_string(),
28        ));
29    }
30
31    // Check that all tensors have the same shape
32    let first_shape = tensors[0].shape();
33    for tensor in &tensors[1..] {
34        if tensor.shape() != first_shape {
35            return Err(TorshError::ShapeMismatch {
36                expected: first_shape.dims().to_vec(),
37                got: tensor.shape().dims().to_vec(),
38            });
39        }
40    }
41
42    // Create new shape with additional dimension at the specified position
43    let original_dims = first_shape.dims();
44    let mut new_dims = Vec::with_capacity(original_dims.len() + 1);
45
46    // Insert batch dimension at the specified position
47    if dim == 0 {
48        new_dims.push(tensors.len());
49        new_dims.extend_from_slice(original_dims);
50    } else {
51        // Insert at position dim
52        new_dims.extend_from_slice(&original_dims[..dim.min(original_dims.len())]);
53        new_dims.push(tensors.len());
54        if dim < original_dims.len() {
55            new_dims.extend_from_slice(&original_dims[dim..]);
56        }
57    }
58
59    // Optimized stacking: pre-allocate without unnecessary initialization
60    // Use with_capacity + unsafe set_len for better performance when we know
61    // we'll immediately overwrite all values
62    let tensor_size = tensors[0].numel();
63    let total_elements = new_dims.iter().product::<usize>();
64    let mut new_data = Vec::with_capacity(total_elements);
65    // SAFETY: We immediately fill all elements below, so uninitialized memory is never read
66    unsafe { new_data.set_len(total_elements) };
67
68    // Use parallel processing for large batches when std feature is available
69    #[cfg(feature = "std")]
70    {
71        if tensors.len() > 4 && tensor_size > 1000 {
72            // Parallel data collection for large tensors
73            let parallel_data: std::result::Result<Vec<Vec<T>>, TorshError> =
74                tensors.par_iter().map(|tensor| tensor.to_vec()).collect();
75            let parallel_data = parallel_data?;
76            for (i, data) in parallel_data.into_iter().enumerate() {
77                let start_idx = i * tensor_size;
78                let end_idx = start_idx + tensor_size;
79                new_data[start_idx..end_idx].copy_from_slice(&data);
80            }
81        } else {
82            // Sequential copy for small tensors/batches
83            for (i, tensor) in tensors.iter().enumerate() {
84                let data = tensor.to_vec()?;
85                let start_idx = i * tensor_size;
86                let end_idx = start_idx + tensor_size;
87                new_data[start_idx..end_idx].copy_from_slice(&data);
88            }
89        }
90    }
91
92    #[cfg(not(feature = "std"))]
93    {
94        // Sequential copy for no_std
95        for (i, tensor) in tensors.iter().enumerate() {
96            let data = tensor.to_vec()?;
97            let start_idx = i * tensor_size;
98            let end_idx = start_idx + tensor_size;
99            new_data[start_idx..end_idx].copy_from_slice(&data);
100        }
101    }
102
103    let result = torsh_tensor::Tensor::from_data(new_data, new_dims, tensors[0].device())?;
104
105    Ok(result)
106}
107
108/// Fast stack tensors using memory mapping for very large batches
109#[cfg(feature = "std")]
110pub fn stack_tensors_fast<T: TensorElement + Copy>(
111    tensors: &[Tensor<T>],
112    dim: usize,
113) -> Result<Tensor<T>> {
114    if tensors.is_empty() {
115        return Err(TorshError::InvalidArgument(
116            "Cannot stack empty tensor list".to_string(),
117        ));
118    }
119
120    // For very large batches (>100 tensors), use memory mapped approach if available
121    #[cfg(feature = "mmap-support")]
122    {
123        if tensors.len() > 100 {
124            return stack_tensors_mmap(tensors, dim);
125        }
126    }
127
128    // Otherwise use regular optimized stacking
129    stack_tensors(tensors, dim)
130}
131
132/// Memory-mapped tensor stacking for very large batches
133#[cfg(all(feature = "std", feature = "mmap-support"))]
134pub fn stack_tensors_mmap<T: TensorElement + Copy>(
135    tensors: &[Tensor<T>],
136    dim: usize,
137) -> Result<Tensor<T>> {
138    // Check that all tensors have the same shape
139    let first_shape = tensors[0].shape();
140    for tensor in &tensors[1..] {
141        if tensor.shape() != first_shape {
142            return Err(TorshError::ShapeMismatch {
143                expected: first_shape.dims().to_vec(),
144                got: tensor.shape().dims().to_vec(),
145            });
146        }
147    }
148
149    // Create new shape
150    let original_dims = first_shape.dims();
151    let mut new_dims = Vec::with_capacity(original_dims.len() + 1);
152
153    if dim == 0 {
154        new_dims.push(tensors.len());
155        new_dims.extend_from_slice(original_dims);
156    } else {
157        new_dims.extend_from_slice(&original_dims[..dim.min(original_dims.len())]);
158        new_dims.push(tensors.len());
159        if dim < original_dims.len() {
160            new_dims.extend_from_slice(&original_dims[dim..]);
161        }
162    }
163
164    let tensor_size = tensors[0].numel();
165    let total_size = tensor_size * tensors.len() * std::mem::size_of::<T>();
166
167    // Create a temporary file for memory mapping
168    let mut temp_file =
169        tempfile::NamedTempFile::new().map_err(|e| TorshError::IoError(e.to_string()))?;
170
171    // Write tensor data to temp file in parallel
172    temp_file
173        .as_file_mut()
174        .set_len(total_size as u64)
175        .map_err(|e| TorshError::IoError(e.to_string()))?;
176
177    // Use memory mapping for efficient data transfer
178    let mmap = unsafe {
179        memmap2::MmapOptions::new()
180            .map_mut(temp_file.as_file())
181            .map_err(|e| TorshError::IoError(e.to_string()))?
182    };
183
184    // Parallel collection of tensor data, then sequential copy to memory mapped region
185    let all_data: std::result::Result<Vec<Vec<T>>, TorshError> =
186        tensors.par_iter().map(|tensor| tensor.to_vec()).collect();
187    let all_data = all_data?;
188
189    // Sequential copy to memory mapped region for thread safety
190    let mmap_ptr = mmap.as_ptr() as *mut T;
191    for (i, data) in all_data.iter().enumerate() {
192        unsafe {
193            let dst = mmap_ptr.add(i * tensor_size);
194            std::ptr::copy_nonoverlapping(data.as_ptr(), dst, tensor_size);
195        }
196    }
197
198    // Create tensor from memory mapped data
199    unsafe {
200        let data_slice =
201            std::slice::from_raw_parts(mmap_ptr as *const T, tensor_size * tensors.len());
202        let data_vec = data_slice.to_vec();
203        let result = torsh_tensor::Tensor::from_data(data_vec, new_dims, tensors[0].device())?;
204        Ok(result)
205    }
206}
207
208/// Optimized collation function for high-performance scenarios
209#[cfg(feature = "std")]
210#[derive(Debug, Clone, Copy)]
211pub struct OptimizedCollate;
212
213#[cfg(feature = "std")]
214impl<T: TensorElement + Copy> Collate<Tensor<T>> for OptimizedCollate {
215    type Output = Tensor<T>;
216
217    fn collate(&self, batch: Vec<Tensor<T>>) -> Result<Self::Output> {
218        if batch.is_empty() {
219            return Err(TorshError::InvalidArgument(
220                "Cannot collate empty batch".to_string(),
221            ));
222        }
223
224        // Use fast stacking with memory mapping for large batches
225        stack_tensors_fast(&batch, 0)
226    }
227}
228
229#[cfg(feature = "std")]
230impl<T: TensorElement + Copy> Collate<Vec<Tensor<T>>> for OptimizedCollate {
231    type Output = Vec<Tensor<T>>;
232
233    fn collate(&self, batch: Vec<Vec<Tensor<T>>>) -> Result<Self::Output> {
234        if batch.is_empty() {
235            return Err(TorshError::InvalidArgument(
236                "Cannot collate empty batch".to_string(),
237            ));
238        }
239
240        let num_tensors = batch[0].len();
241        let mut collated = Vec::with_capacity(num_tensors);
242
243        // Process each tensor position in parallel
244        (0..num_tensors)
245            .into_par_iter()
246            .map(|i| {
247                let tensors: Vec<Tensor<T>> =
248                    batch.iter().map(|sample| sample[i].clone()).collect();
249                stack_tensors_fast(&tensors, 0)
250            })
251            .collect::<Result<Vec<_>>>()?
252            .into_iter()
253            .for_each(|tensor| collated.push(tensor));
254
255        Ok(collated)
256    }
257}
258
259/// Optimized collation function factory
260#[cfg(feature = "std")]
261pub fn optimized_collate_fn<T>() -> OptimizedCollate {
262    OptimizedCollate
263}
264
265/// For no_std environments, provide a fallback OptimizedCollate that uses the TensorStacker
266#[cfg(not(feature = "std"))]
267#[derive(Debug, Clone, Copy)]
268pub struct OptimizedCollate;
269
270#[cfg(not(feature = "std"))]
271impl<T: TensorElement + Copy> Collate<Tensor<T>> for OptimizedCollate {
272    type Output = Tensor<T>;
273
274    fn collate(&self, batch: Vec<Tensor<T>>) -> Result<Self::Output> {
275        TensorStacker::new().stack(&batch, 0)
276    }
277}
278
279#[cfg(not(feature = "std"))]
280impl<T: TensorElement + Copy> Collate<Vec<Tensor<T>>> for OptimizedCollate {
281    type Output = Vec<Tensor<T>>;
282
283    fn collate(&self, batch: Vec<Vec<Tensor<T>>>) -> Result<Self::Output> {
284        if batch.is_empty() {
285            return Err(TorshError::InvalidArgument(
286                "Cannot collate empty batch".to_string(),
287            ));
288        }
289
290        let num_tensors = batch[0].len();
291        let mut collated = Vec::with_capacity(num_tensors);
292        let stacker = TensorStacker::new();
293
294        // Collate each tensor position across the batch
295        for i in 0..num_tensors {
296            let tensors: Vec<Tensor<T>> = batch.iter().map(|sample| sample[i].clone()).collect();
297            collated.push(stacker.stack(&tensors, 0)?);
298        }
299
300        Ok(collated)
301    }
302}
303
304/// No_std optimized collation function factory
305#[cfg(not(feature = "std"))]
306pub fn optimized_collate_fn<T>() -> OptimizedCollate {
307    OptimizedCollate
308}