Skip to main content

torsh_data/collate/
advanced.rs

1//! Advanced collation implementations
2
3use super::{optimized::stack_tensors, Collate};
4use torsh_core::{
5    dtype::TensorElement,
6    error::{Result, TorshError},
7};
8use torsh_tensor::Tensor;
9
10#[cfg(feature = "sparse")]
11use torsh_sparse::{CooTensor, SparseTensor};
12
13#[cfg(not(feature = "std"))]
14use alloc::{boxed::Box, vec::Vec};
15
16#[cfg(feature = "std")]
17use std::sync::Arc;
18
19/// Cached collation function that reuses allocated memory
20pub struct CachedCollate<T: TensorElement> {
21    tensor_pool: Arc<parking_lot::Mutex<Vec<Vec<T>>>>,
22    max_pool_size: usize,
23}
24
25impl<T: TensorElement> CachedCollate<T> {
26    /// Create a new cached collation function
27    pub fn new(max_pool_size: usize) -> Self {
28        Self {
29            tensor_pool: Arc::new(parking_lot::Mutex::new(Vec::with_capacity(max_pool_size))),
30            max_pool_size,
31        }
32    }
33
34    /// Get a reusable buffer from the pool
35    fn get_buffer(&self, capacity: usize) -> Vec<T> {
36        let mut pool = self.tensor_pool.lock();
37        if let Some(mut buffer) = pool.pop() {
38            buffer.clear();
39            if buffer.capacity() >= capacity {
40                buffer.reserve(capacity - buffer.capacity());
41            }
42            buffer
43        } else {
44            Vec::with_capacity(capacity)
45        }
46    }
47
48    /// Return a buffer to the pool
49    fn return_buffer(&self, buffer: Vec<T>) {
50        let mut pool = self.tensor_pool.lock();
51        if pool.len() < self.max_pool_size {
52            pool.push(buffer);
53        }
54    }
55}
56
57impl<T: TensorElement + Copy> Collate<Tensor<T>> for CachedCollate<T> {
58    type Output = Tensor<T>;
59
60    fn collate(&self, batch: Vec<Tensor<T>>) -> Result<Self::Output> {
61        if batch.is_empty() {
62            return Err(TorshError::InvalidArgument(
63                "Cannot collate empty batch".to_string(),
64            ));
65        }
66
67        // Check that all tensors have the same shape
68        let first_shape = batch[0].shape();
69        for tensor in &batch[1..] {
70            if tensor.shape() != first_shape {
71                return Err(TorshError::ShapeMismatch {
72                    expected: first_shape.dims().to_vec(),
73                    got: tensor.shape().dims().to_vec(),
74                });
75            }
76        }
77
78        // Create new shape with batch dimension
79        let original_dims = first_shape.dims();
80        let mut new_dims = Vec::with_capacity(original_dims.len() + 1);
81        new_dims.push(batch.len());
82        new_dims.extend_from_slice(original_dims);
83
84        let tensor_size = batch[0].numel();
85        let total_elements = tensor_size * batch.len();
86
87        // Get a reusable buffer
88        let mut new_data = self.get_buffer(total_elements);
89        new_data.reserve_exact(total_elements);
90
91        // Copy tensor data efficiently
92        for tensor in batch.iter() {
93            let data = tensor.to_vec()?;
94            new_data.extend_from_slice(&data);
95        }
96
97        let result =
98            torsh_tensor::Tensor::from_data(new_data.clone(), new_dims, batch[0].device())?;
99
100        // Return buffer to pool (create a new empty vector to return)
101        self.return_buffer(Vec::with_capacity(new_data.capacity()));
102
103        Ok(result)
104    }
105}
106
107/// Dynamic batching collation for variable-size sequences
108pub struct DynamicBatchCollate<T: TensorElement> {
109    padding_value: T,
110    max_sequence_length: Option<usize>,
111    pack_sequences: bool,
112}
113
114impl<T: TensorElement> DynamicBatchCollate<T> {
115    /// Create a new dynamic batch collation function
116    pub fn new(padding_value: T) -> Self {
117        Self {
118            padding_value,
119            max_sequence_length: None,
120            pack_sequences: false,
121        }
122    }
123
124    /// Set maximum sequence length (sequences longer than this will be truncated)
125    pub fn with_max_length(mut self, max_length: usize) -> Self {
126        self.max_sequence_length = Some(max_length);
127        self
128    }
129
130    /// Enable sequence packing to minimize padding
131    pub fn with_packing(mut self, pack: bool) -> Self {
132        self.pack_sequences = pack;
133        self
134    }
135}
136
137impl<
138        T: TensorElement
139            + Copy
140            + std::ops::Add<Output = T>
141            + std::ops::Sub<Output = T>
142            + std::ops::Mul<Output = T>
143            + std::ops::Div<Output = T>
144            + Default,
145    > Collate<Tensor<T>> for DynamicBatchCollate<T>
146{
147    type Output = (Tensor<T>, Tensor<i64>); // (padded_sequences, lengths)
148
149    fn collate(&self, batch: Vec<Tensor<T>>) -> Result<Self::Output> {
150        if batch.is_empty() {
151            return Err(TorshError::InvalidArgument(
152                "Cannot collate empty batch".to_string(),
153            ));
154        }
155
156        // Collect sequence lengths
157        let mut lengths = Vec::with_capacity(batch.len());
158        let mut max_length = 0;
159
160        for tensor in &batch {
161            if tensor.ndim() == 0 {
162                return Err(TorshError::InvalidArgument(
163                    "Cannot dynamically batch scalar tensors".to_string(),
164                ));
165            }
166
167            let seq_len = tensor.size(0)?;
168            lengths.push(seq_len as i64);
169            max_length = max_length.max(seq_len);
170        }
171
172        // Apply max length constraint if specified
173        if let Some(max_len) = self.max_sequence_length {
174            max_length = max_length.min(max_len);
175        }
176
177        // If packing is enabled, sort by length to minimize padding
178        let mut batch_with_indices: Vec<_> = batch.into_iter().enumerate().collect();
179        if self.pack_sequences {
180            batch_with_indices.sort_by_key(|(_, tensor)| tensor.size(0).unwrap_or(0));
181        }
182
183        // Get the shape for creating padded tensors
184        let first_tensor = &batch_with_indices[0].1;
185        let mut padded_shape = first_tensor.shape().dims().to_vec();
186        padded_shape[0] = max_length; // Set sequence dimension to max length
187
188        // Create padded batch
189        let batch_size = batch_with_indices.len();
190        let mut padded_batch = Vec::with_capacity(batch_size);
191
192        for (original_idx, tensor) in batch_with_indices {
193            let seq_len = tensor.size(0)?;
194            let actual_len = seq_len.min(max_length);
195
196            if actual_len == max_length {
197                // No padding needed, just truncate if necessary
198                if seq_len > max_length {
199                    let truncated = tensor.narrow(0, 0, max_length)?;
200                    padded_batch.push((original_idx, truncated));
201                } else {
202                    padded_batch.push((original_idx, tensor));
203                }
204            } else {
205                // Need to pad
206                let mut padding_shape = padded_shape.clone();
207                let padding_elements =
208                    (max_length - actual_len) * padding_shape[1..].iter().product::<usize>();
209
210                // Create padding tensor
211                let padding_data = vec![self.padding_value; padding_elements];
212                padding_shape[0] = max_length - actual_len;
213
214                let padding_tensor =
215                    Tensor::from_data(padding_data, padding_shape.clone(), tensor.device())?;
216
217                // Truncate if necessary
218                let tensor_to_pad = if seq_len > max_length {
219                    tensor.narrow(0, 0, max_length)?
220                } else {
221                    tensor
222                };
223
224                // Manual concatenation since Tensor::cat is not working correctly
225                let tensor_data = tensor_to_pad.to_vec()?;
226                let padding_data = padding_tensor.to_vec()?;
227
228                // Combine the data
229                let mut combined_data = tensor_data;
230                combined_data.extend(padding_data);
231
232                // Create new tensor with correct shape
233                let mut final_shape = tensor_to_pad.shape().dims().to_vec();
234                final_shape[0] = max_length; // Set to max_length
235
236                let padded = Tensor::from_data(combined_data, final_shape, tensor_to_pad.device())?;
237                padded_batch.push((original_idx, padded));
238            }
239        }
240
241        // Restore original order if packing was used
242        if self.pack_sequences {
243            padded_batch.sort_by_key(|(idx, _)| *idx);
244        }
245
246        // Extract tensors and stack them
247        let tensors: Vec<_> = padded_batch.into_iter().map(|(_, tensor)| tensor).collect();
248
249        let stacked = stack_tensors(&tensors, 0)?;
250
251        // Create lengths tensor
252        let lengths_tensor = Tensor::from_data(lengths, vec![batch_size], tensors[0].device())?;
253
254        Ok((stacked, lengths_tensor))
255    }
256}
257
258/// Wrapper for DynamicBatchCollate that only returns padded sequences (not lengths)
259/// This allows compatibility with the CollateBuilder which expects `Tensor<T>` output
260pub struct DynamicBatchCollateWrapper<T: TensorElement> {
261    inner: DynamicBatchCollate<T>,
262}
263
264impl<T: TensorElement> DynamicBatchCollateWrapper<T> {
265    pub fn new(padding_value: T) -> Self {
266        Self {
267            inner: DynamicBatchCollate::new(padding_value),
268        }
269    }
270
271    pub fn with_max_length(mut self, max_length: usize) -> Self {
272        self.inner = self.inner.with_max_length(max_length);
273        self
274    }
275
276    pub fn with_packing(mut self, pack: bool) -> Self {
277        self.inner = self.inner.with_packing(pack);
278        self
279    }
280}
281
282impl<
283        T: TensorElement
284            + Copy
285            + std::ops::Add<Output = T>
286            + std::ops::Sub<Output = T>
287            + std::ops::Mul<Output = T>
288            + std::ops::Div<Output = T>
289            + Default,
290    > Collate<Tensor<T>> for DynamicBatchCollateWrapper<T>
291{
292    type Output = Tensor<T>;
293
294    fn collate(&self, batch: Vec<Tensor<T>>) -> Result<Self::Output> {
295        // Call the inner collate function and extract only the padded sequences
296        let (padded_sequences, _lengths) = self.inner.collate(batch)?;
297        Ok(padded_sequences)
298    }
299}
300
301/// Bucket sampler for dynamic batching
302/// Groups sequences of similar lengths to minimize padding
303pub struct BucketBatchSampler {
304    lengths: Vec<usize>,
305    batch_size: usize,
306    bucket_boundaries: Vec<usize>,
307    drop_last: bool,
308}
309
310impl BucketBatchSampler {
311    /// Create a new bucket batch sampler
312    pub fn new(lengths: Vec<usize>, batch_size: usize, drop_last: bool) -> Self {
313        // Create bucket boundaries based on length distribution
314        let mut sorted_lengths = lengths.clone();
315        sorted_lengths.sort_unstable();
316
317        let num_buckets = (lengths.len() / batch_size).clamp(1, 10);
318        let mut bucket_boundaries = Vec::with_capacity(num_buckets + 1);
319
320        for i in 0..=num_buckets {
321            let idx = (i * sorted_lengths.len()) / num_buckets;
322            let boundary = if idx >= sorted_lengths.len() {
323                sorted_lengths.last().copied().unwrap_or(0) + 1
324            } else {
325                sorted_lengths[idx]
326            };
327            bucket_boundaries.push(boundary);
328        }
329
330        Self {
331            lengths,
332            batch_size,
333            bucket_boundaries,
334            drop_last,
335        }
336    }
337
338    /// Generate batches grouped by sequence length buckets
339    pub fn generate_batches(&self) -> Vec<Vec<usize>> {
340        // Group indices by bucket
341        let mut buckets: Vec<Vec<usize>> = vec![Vec::new(); self.bucket_boundaries.len() - 1];
342
343        for (idx, &length) in self.lengths.iter().enumerate() {
344            for (bucket_idx, bucket) in buckets.iter_mut().enumerate() {
345                if length >= self.bucket_boundaries[bucket_idx]
346                    && length < self.bucket_boundaries[bucket_idx + 1]
347                {
348                    bucket.push(idx);
349                    break;
350                }
351            }
352        }
353
354        // Shuffle within each bucket and create batches
355        let mut batches = Vec::new();
356
357        for mut bucket in buckets {
358            // ✅ SciRS2 Policy Enhanced - Using scientific shuffle for optimal ML batching
359            use scirs2_core::random::prelude::*;
360            use scirs2_core::random::seq::ScientificSliceRandom;
361
362            let mut rng = thread_rng();
363            bucket.scientific_shuffle(&mut rng);
364
365            for chunk in bucket.chunks(self.batch_size) {
366                if chunk.len() == self.batch_size || !self.drop_last {
367                    batches.push(chunk.to_vec());
368                }
369            }
370        }
371
372        // Enhanced scientific shuffle to optimize ML training batch distribution
373        // ✅ SciRS2 Policy Enhanced - Using scientific shuffle for superior randomness
374        use scirs2_core::random::prelude::*;
375        use scirs2_core::random::seq::ScientificSliceRandom;
376        let mut rng = thread_rng();
377        batches.scientific_shuffle(&mut rng);
378
379        batches
380    }
381}
382
383/// Adaptive batch size sampler that adjusts batch size based on sequence lengths
384pub struct AdaptiveBatchSampler {
385    target_tokens: usize,
386    max_batch_size: usize,
387    min_batch_size: usize,
388    lengths: Vec<usize>,
389}
390
391impl AdaptiveBatchSampler {
392    /// Create a new adaptive batch sampler
393    pub fn new(
394        lengths: Vec<usize>,
395        target_tokens: usize,
396        max_batch_size: usize,
397        min_batch_size: usize,
398    ) -> Self {
399        Self {
400            target_tokens,
401            max_batch_size,
402            min_batch_size,
403            lengths,
404        }
405    }
406
407    /// Generate batches with adaptive batch sizes
408    pub fn generate_batches(&self) -> Vec<Vec<usize>> {
409        let mut indices: Vec<usize> = (0..self.lengths.len()).collect();
410
411        // Sort by length to process similar lengths together
412        indices.sort_by_key(|&i| self.lengths[i]);
413
414        let mut batches = Vec::new();
415        let mut current_batch = Vec::new();
416        let mut _current_tokens = 0;
417
418        for idx in indices {
419            let length = self.lengths[idx];
420            let batch_size = current_batch.len();
421            let tokens_if_added = (batch_size + 1)
422                * length.max(
423                    current_batch
424                        .iter()
425                        .map(|&i| self.lengths[i])
426                        .max()
427                        .unwrap_or(0),
428                );
429
430            // Check if adding this sequence would exceed limits
431            if tokens_if_added > self.target_tokens || batch_size >= self.max_batch_size {
432                // Finish current batch if it meets minimum size
433                if batch_size >= self.min_batch_size {
434                    batches.push(current_batch);
435                }
436
437                // Start new batch
438                current_batch = vec![idx];
439                _current_tokens = length;
440            } else {
441                // Add to current batch
442                current_batch.push(idx);
443                _current_tokens = tokens_if_added;
444            }
445        }
446
447        // Add final batch if it meets minimum size
448        if current_batch.len() >= self.min_batch_size {
449            batches.push(current_batch);
450        }
451
452        batches
453    }
454}
455
456/// Padding collation for variable-length sequences
457pub struct PadCollate<T: TensorElement> {
458    #[allow(dead_code)]
459    padding_value: T,
460}
461
462impl<T: TensorElement> PadCollate<T> {
463    /// Create a new padding collation function
464    pub fn new(padding_value: T) -> Self {
465        Self { padding_value }
466    }
467}
468
469impl<T: TensorElement + Copy> Collate<Tensor<T>> for PadCollate<T> {
470    type Output = Tensor<T>;
471
472    fn collate(&self, batch: Vec<Tensor<T>>) -> Result<Self::Output> {
473        if batch.is_empty() {
474            return Err(TorshError::InvalidArgument(
475                "Cannot collate empty batch".to_string(),
476            ));
477        }
478
479        // Find maximum dimensions
480        let ndim = batch[0].ndim();
481        let mut max_dims = vec![0; ndim];
482
483        for tensor in &batch {
484            if tensor.ndim() != ndim {
485                return Err(TorshError::InvalidArgument(
486                    "All tensors must have the same number of dimensions".to_string(),
487                ));
488            }
489
490            for (i, max_dim) in max_dims.iter_mut().enumerate().take(ndim) {
491                let size = tensor.size(i as i32)?;
492                if size > *max_dim {
493                    *max_dim = size;
494                }
495            }
496        }
497
498        // Create padded tensors
499        let batch_size = batch.len();
500        let mut padded_batch = Vec::with_capacity(batch_size);
501
502        for tensor in batch {
503            // For each tensor, pad to match max_dims
504            let shape_ref = tensor.shape();
505            let current_shape = shape_ref.dims();
506            let padded_tensor = tensor;
507
508            // Check if padding is needed
509            let needs_padding = current_shape
510                .iter()
511                .zip(max_dims.iter())
512                .any(|(&current, &max)| current < max);
513
514            if needs_padding {
515                // For now, just use the tensor as-is since we don't have full broadcasting yet
516                // In a full implementation, we'd properly pad with padding_value
517                // For this placeholder, we'll just use the original tensor
518            }
519
520            padded_batch.push(padded_tensor);
521        }
522
523        // Stack the padded tensors
524        stack_tensors(&padded_batch, 0)
525    }
526}
527
528/// Sparse tensor collation function
529#[cfg(feature = "sparse")]
530pub struct SparseCollate;
531
532#[cfg(feature = "sparse")]
533impl Collate<CooTensor> for SparseCollate {
534    type Output = CooTensor;
535
536    fn collate(&self, batch: Vec<CooTensor>) -> Result<Self::Output> {
537        if batch.is_empty() {
538            return Err(TorshError::InvalidArgument(
539                "Cannot collate empty batch".to_string(),
540            ));
541        }
542
543        // For sparse tensors, we concatenate them along the batch dimension
544        // This creates a larger sparse tensor with all the non-zero elements
545        collate_sparse_tensors(&batch)
546    }
547}
548
549/// Stack sparse tensors along a new batch dimension
550#[cfg(feature = "sparse")]
551pub fn collate_sparse_tensors(tensors: &[CooTensor]) -> Result<CooTensor> {
552    if tensors.is_empty() {
553        return Err(TorshError::InvalidArgument(
554            "Cannot collate empty sparse tensor batch".to_string(),
555        ));
556    }
557
558    // Check that all tensors have the same shape (except batch dimension)
559    let first_shape = tensors[0].shape();
560    for tensor in &tensors[1..] {
561        if tensor.shape() != first_shape {
562            return Err(TorshError::ShapeMismatch {
563                expected: first_shape.dims().to_vec(),
564                got: tensor.shape().dims().to_vec(),
565            });
566        }
567    }
568
569    // Calculate new shape with batch dimension
570    let original_dims = first_shape.dims();
571    let mut new_dims = Vec::with_capacity(original_dims.len() + 1);
572    new_dims.push(tensors.len());
573    new_dims.extend_from_slice(original_dims);
574
575    // For COO format, we need to:
576    // 1. Collect all indices and values
577    // 2. Adjust indices to account for batch dimension
578    // 3. Create new COO tensor
579
580    let mut all_row_indices = Vec::new();
581    let mut all_col_indices = Vec::new();
582    let mut all_values = Vec::new();
583    let mut _total_nnz = 0;
584
585    for (batch_idx, tensor) in tensors.iter().enumerate() {
586        let _row_indices = tensor.row_indices();
587        let col_indices = tensor.col_indices();
588        let values = tensor.values();
589
590        // Adjust indices to include batch dimension
591        for i in 0..tensor.nnz() {
592            all_row_indices.push(batch_idx);
593            all_col_indices.push(col_indices[i]);
594        }
595
596        all_values.extend_from_slice(values);
597        _total_nnz += tensor.nnz();
598    }
599
600    // Create new COO tensor
601    let shape = torsh_core::Shape::new(new_dims);
602    CooTensor::new(all_row_indices, all_col_indices, all_values, shape)
603}
604
605/// Collation function for mixed dense and sparse tensors
606#[cfg(feature = "sparse")]
607pub struct MixedCollate;
608
609#[cfg(feature = "sparse")]
610impl Collate<Box<dyn SparseTensor>> for MixedCollate {
611    type Output = Box<dyn SparseTensor>;
612
613    fn collate(&self, batch: Vec<Box<dyn SparseTensor>>) -> Result<Self::Output> {
614        if batch.is_empty() {
615            return Err(TorshError::InvalidArgument(
616                "Cannot collate empty batch".to_string(),
617            ));
618        }
619
620        // Convert all to COO format for consistency
621        let mut coo_tensors = Vec::with_capacity(batch.len());
622        for tensor in batch {
623            coo_tensors.push(tensor.to_coo()?);
624        }
625
626        // Use sparse collation
627        let collated = collate_sparse_tensors(&coo_tensors)?;
628        Ok(Box::new(collated))
629    }
630}