Skip to main content

torsh_data/collate/
core.rs

1//! Core collation trait and basic implementations
2
3use torsh_core::{
4    dtype::TensorElement,
5    error::{Result, TorshError},
6};
7use torsh_tensor::Tensor;
8
9#[cfg(not(feature = "std"))]
10use alloc::{boxed::Box, vec::Vec};
11
12/// Trait for collating a batch of samples
13pub trait Collate<T> {
14    /// Output type after collation
15    type Output;
16
17    /// Collate a batch of samples
18    fn collate(&self, batch: Vec<T>) -> Result<Self::Output>;
19
20    /// Get the expected batch size (returns None for variable batch sizes)
21    fn expected_batch_size(&self) -> Option<usize> {
22        None
23    }
24
25    /// Check if this collate function supports empty batches
26    fn supports_empty_batch(&self) -> bool {
27        false
28    }
29
30    /// Validate batch before collation (optional hook)
31    fn validate_batch(&self, batch: &[T]) -> Result<()> {
32        if batch.is_empty() && !self.supports_empty_batch() {
33            return Err(TorshError::InvalidArgument(
34                "Cannot collate empty batch".to_string(),
35            ));
36        }
37        Ok(())
38    }
39}
40
41/// Default collation function
42#[derive(Debug, Clone, Copy)]
43pub struct DefaultCollate;
44
45impl<T: TensorElement + Copy> Collate<Tensor<T>> for DefaultCollate {
46    type Output = Tensor<T>;
47
48    fn collate(&self, batch: Vec<Tensor<T>>) -> Result<Self::Output> {
49        self.validate_batch(&batch)?;
50        super::stacking::TensorStacker::new().stack(&batch, 0)
51    }
52}
53
54// Common implementations for tuple types used in datasets
55impl<T: TensorElement + Copy> Collate<(Tensor<T>, usize)> for DefaultCollate {
56    type Output = (Tensor<T>, Vec<usize>);
57
58    fn collate(&self, batch: Vec<(Tensor<T>, usize)>) -> Result<Self::Output> {
59        self.validate_batch(&batch)?;
60
61        let (tensors, labels): (Vec<Tensor<T>>, Vec<usize>) = batch.into_iter().unzip();
62        let stacked_tensors = super::stacking::TensorStacker::new().stack(&tensors, 0)?;
63
64        Ok((stacked_tensors, labels))
65    }
66}
67
68impl<T: TensorElement + Copy> Collate<(Tensor<T>, String)> for DefaultCollate {
69    type Output = (Tensor<T>, Vec<String>);
70
71    fn collate(&self, batch: Vec<(Tensor<T>, String)>) -> Result<Self::Output> {
72        self.validate_batch(&batch)?;
73
74        let (tensors, strings): (Vec<Tensor<T>>, Vec<String>) = batch.into_iter().unzip();
75        let stacked_tensors = super::stacking::TensorStacker::new().stack(&tensors, 0)?;
76
77        Ok((stacked_tensors, strings))
78    }
79}
80
81// Implementations for common non-tensor types
82impl Collate<usize> for DefaultCollate {
83    type Output = Vec<usize>;
84
85    fn collate(&self, batch: Vec<usize>) -> Result<Self::Output> {
86        self.validate_batch(&batch)?;
87        Ok(batch)
88    }
89}
90
91impl Collate<String> for DefaultCollate {
92    type Output = Vec<String>;
93
94    fn collate(&self, batch: Vec<String>) -> Result<Self::Output> {
95        self.validate_batch(&batch)?;
96        Ok(batch)
97    }
98}
99
100impl Collate<f32> for DefaultCollate {
101    type Output = Vec<f32>;
102
103    fn collate(&self, batch: Vec<f32>) -> Result<Self::Output> {
104        self.validate_batch(&batch)?;
105        Ok(batch)
106    }
107}
108
109impl Collate<i32> for DefaultCollate {
110    type Output = Vec<i32>;
111
112    fn collate(&self, batch: Vec<i32>) -> Result<Self::Output> {
113        self.validate_batch(&batch)?;
114        Ok(batch)
115    }
116}
117
118impl<T: TensorElement + Copy> Collate<Vec<Tensor<T>>> for DefaultCollate {
119    type Output = Vec<Tensor<T>>;
120
121    fn collate(&self, batch: Vec<Vec<Tensor<T>>>) -> Result<Self::Output> {
122        self.validate_batch(&batch)?;
123
124        if batch.is_empty() {
125            return Ok(Vec::new());
126        }
127
128        // Check that all samples have the same number of tensors
129        let num_tensors = batch[0].len();
130        for sample in &batch {
131            if sample.len() != num_tensors {
132                return Err(TorshError::InvalidArgument(
133                    "All samples must have the same number of tensors".to_string(),
134                ));
135            }
136        }
137
138        // Group tensors by position and stack them
139        let mut result = Vec::with_capacity(num_tensors);
140        for tensor_idx in 0..num_tensors {
141            let tensors_to_stack: Vec<Tensor<T>> = batch
142                .iter()
143                .map(|sample| sample[tensor_idx].clone())
144                .collect();
145
146            // Stack tensors at this position
147            let stacked = super::stacking::TensorStacker::new().stack(&tensors_to_stack, 0)?;
148            result.push(stacked);
149        }
150
151        Ok(result)
152    }
153}
154
155/// Generic collate function wrapper
156pub struct CollateFn<F> {
157    f: F,
158}
159
160impl<F> CollateFn<F> {
161    pub fn new(f: F) -> Self {
162        Self { f }
163    }
164}
165
166impl<T, O, F> Collate<T> for CollateFn<F>
167where
168    F: Fn(Vec<T>) -> Result<O>,
169{
170    type Output = O;
171
172    fn collate(&self, batch: Vec<T>) -> Result<Self::Output> {
173        (self.f)(batch)
174    }
175}
176
177/// Convenience function to create default collate function
178pub fn collate_fn<T>() -> DefaultCollate {
179    DefaultCollate
180}