torsh_data/collate/
core.rs1use torsh_core::{
4 dtype::TensorElement,
5 error::{Result, TorshError},
6};
7use torsh_tensor::Tensor;
8
9#[cfg(not(feature = "std"))]
10use alloc::{boxed::Box, vec::Vec};
11
12pub trait Collate<T> {
14 type Output;
16
17 fn collate(&self, batch: Vec<T>) -> Result<Self::Output>;
19
20 fn expected_batch_size(&self) -> Option<usize> {
22 None
23 }
24
25 fn supports_empty_batch(&self) -> bool {
27 false
28 }
29
30 fn validate_batch(&self, batch: &[T]) -> Result<()> {
32 if batch.is_empty() && !self.supports_empty_batch() {
33 return Err(TorshError::InvalidArgument(
34 "Cannot collate empty batch".to_string(),
35 ));
36 }
37 Ok(())
38 }
39}
40
41#[derive(Debug, Clone, Copy)]
43pub struct DefaultCollate;
44
45impl<T: TensorElement + Copy> Collate<Tensor<T>> for DefaultCollate {
46 type Output = Tensor<T>;
47
48 fn collate(&self, batch: Vec<Tensor<T>>) -> Result<Self::Output> {
49 self.validate_batch(&batch)?;
50 super::stacking::TensorStacker::new().stack(&batch, 0)
51 }
52}
53
54impl<T: TensorElement + Copy> Collate<(Tensor<T>, usize)> for DefaultCollate {
56 type Output = (Tensor<T>, Vec<usize>);
57
58 fn collate(&self, batch: Vec<(Tensor<T>, usize)>) -> Result<Self::Output> {
59 self.validate_batch(&batch)?;
60
61 let (tensors, labels): (Vec<Tensor<T>>, Vec<usize>) = batch.into_iter().unzip();
62 let stacked_tensors = super::stacking::TensorStacker::new().stack(&tensors, 0)?;
63
64 Ok((stacked_tensors, labels))
65 }
66}
67
68impl<T: TensorElement + Copy> Collate<(Tensor<T>, String)> for DefaultCollate {
69 type Output = (Tensor<T>, Vec<String>);
70
71 fn collate(&self, batch: Vec<(Tensor<T>, String)>) -> Result<Self::Output> {
72 self.validate_batch(&batch)?;
73
74 let (tensors, strings): (Vec<Tensor<T>>, Vec<String>) = batch.into_iter().unzip();
75 let stacked_tensors = super::stacking::TensorStacker::new().stack(&tensors, 0)?;
76
77 Ok((stacked_tensors, strings))
78 }
79}
80
81impl Collate<usize> for DefaultCollate {
83 type Output = Vec<usize>;
84
85 fn collate(&self, batch: Vec<usize>) -> Result<Self::Output> {
86 self.validate_batch(&batch)?;
87 Ok(batch)
88 }
89}
90
91impl Collate<String> for DefaultCollate {
92 type Output = Vec<String>;
93
94 fn collate(&self, batch: Vec<String>) -> Result<Self::Output> {
95 self.validate_batch(&batch)?;
96 Ok(batch)
97 }
98}
99
100impl Collate<f32> for DefaultCollate {
101 type Output = Vec<f32>;
102
103 fn collate(&self, batch: Vec<f32>) -> Result<Self::Output> {
104 self.validate_batch(&batch)?;
105 Ok(batch)
106 }
107}
108
109impl Collate<i32> for DefaultCollate {
110 type Output = Vec<i32>;
111
112 fn collate(&self, batch: Vec<i32>) -> Result<Self::Output> {
113 self.validate_batch(&batch)?;
114 Ok(batch)
115 }
116}
117
118impl<T: TensorElement + Copy> Collate<Vec<Tensor<T>>> for DefaultCollate {
119 type Output = Vec<Tensor<T>>;
120
121 fn collate(&self, batch: Vec<Vec<Tensor<T>>>) -> Result<Self::Output> {
122 self.validate_batch(&batch)?;
123
124 if batch.is_empty() {
125 return Ok(Vec::new());
126 }
127
128 let num_tensors = batch[0].len();
130 for sample in &batch {
131 if sample.len() != num_tensors {
132 return Err(TorshError::InvalidArgument(
133 "All samples must have the same number of tensors".to_string(),
134 ));
135 }
136 }
137
138 let mut result = Vec::with_capacity(num_tensors);
140 for tensor_idx in 0..num_tensors {
141 let tensors_to_stack: Vec<Tensor<T>> = batch
142 .iter()
143 .map(|sample| sample[tensor_idx].clone())
144 .collect();
145
146 let stacked = super::stacking::TensorStacker::new().stack(&tensors_to_stack, 0)?;
148 result.push(stacked);
149 }
150
151 Ok(result)
152 }
153}
154
155pub struct CollateFn<F> {
157 f: F,
158}
159
160impl<F> CollateFn<F> {
161 pub fn new(f: F) -> Self {
162 Self { f }
163 }
164}
165
166impl<T, O, F> Collate<T> for CollateFn<F>
167where
168 F: Fn(Vec<T>) -> Result<O>,
169{
170 type Output = O;
171
172 fn collate(&self, batch: Vec<T>) -> Result<Self::Output> {
173 (self.f)(batch)
174 }
175}
176
177pub fn collate_fn<T>() -> DefaultCollate {
179 DefaultCollate
180}