torsh_data/collate/
optimized.rs1#[allow(unused_imports)]
5use super::stacking::TensorStacker;
6use crate::collate::Collate;
7use torsh_core::{
8 dtype::TensorElement,
9 error::{Result, TorshError},
10};
11use torsh_tensor::Tensor;
12
13#[cfg(not(feature = "std"))]
14use alloc::vec::Vec;
15
16#[cfg(feature = "std")]
18use scirs2_core::parallel_ops::*;
19
20pub fn stack_tensors<T: TensorElement + Copy>(
22 tensors: &[Tensor<T>],
23 dim: usize,
24) -> Result<Tensor<T>> {
25 if tensors.is_empty() {
26 return Err(TorshError::InvalidArgument(
27 "Cannot stack empty tensor list".to_string(),
28 ));
29 }
30
31 let first_shape = tensors[0].shape();
33 for tensor in &tensors[1..] {
34 if tensor.shape() != first_shape {
35 return Err(TorshError::ShapeMismatch {
36 expected: first_shape.dims().to_vec(),
37 got: tensor.shape().dims().to_vec(),
38 });
39 }
40 }
41
42 let original_dims = first_shape.dims();
44 let mut new_dims = Vec::with_capacity(original_dims.len() + 1);
45
46 if dim == 0 {
48 new_dims.push(tensors.len());
49 new_dims.extend_from_slice(original_dims);
50 } else {
51 new_dims.extend_from_slice(&original_dims[..dim.min(original_dims.len())]);
53 new_dims.push(tensors.len());
54 if dim < original_dims.len() {
55 new_dims.extend_from_slice(&original_dims[dim..]);
56 }
57 }
58
59 let tensor_size = tensors[0].numel();
63 let total_elements = new_dims.iter().product::<usize>();
64 let mut new_data = Vec::with_capacity(total_elements);
65 unsafe { new_data.set_len(total_elements) };
67
68 #[cfg(feature = "std")]
70 {
71 if tensors.len() > 4 && tensor_size > 1000 {
72 let parallel_data: std::result::Result<Vec<Vec<T>>, TorshError> =
74 tensors.par_iter().map(|tensor| tensor.to_vec()).collect();
75 let parallel_data = parallel_data?;
76 for (i, data) in parallel_data.into_iter().enumerate() {
77 let start_idx = i * tensor_size;
78 let end_idx = start_idx + tensor_size;
79 new_data[start_idx..end_idx].copy_from_slice(&data);
80 }
81 } else {
82 for (i, tensor) in tensors.iter().enumerate() {
84 let data = tensor.to_vec()?;
85 let start_idx = i * tensor_size;
86 let end_idx = start_idx + tensor_size;
87 new_data[start_idx..end_idx].copy_from_slice(&data);
88 }
89 }
90 }
91
92 #[cfg(not(feature = "std"))]
93 {
94 for (i, tensor) in tensors.iter().enumerate() {
96 let data = tensor.to_vec()?;
97 let start_idx = i * tensor_size;
98 let end_idx = start_idx + tensor_size;
99 new_data[start_idx..end_idx].copy_from_slice(&data);
100 }
101 }
102
103 let result = torsh_tensor::Tensor::from_data(new_data, new_dims, tensors[0].device())?;
104
105 Ok(result)
106}
107
108#[cfg(feature = "std")]
110pub fn stack_tensors_fast<T: TensorElement + Copy>(
111 tensors: &[Tensor<T>],
112 dim: usize,
113) -> Result<Tensor<T>> {
114 if tensors.is_empty() {
115 return Err(TorshError::InvalidArgument(
116 "Cannot stack empty tensor list".to_string(),
117 ));
118 }
119
120 #[cfg(feature = "mmap-support")]
122 {
123 if tensors.len() > 100 {
124 return stack_tensors_mmap(tensors, dim);
125 }
126 }
127
128 stack_tensors(tensors, dim)
130}
131
132#[cfg(all(feature = "std", feature = "mmap-support"))]
134pub fn stack_tensors_mmap<T: TensorElement + Copy>(
135 tensors: &[Tensor<T>],
136 dim: usize,
137) -> Result<Tensor<T>> {
138 let first_shape = tensors[0].shape();
140 for tensor in &tensors[1..] {
141 if tensor.shape() != first_shape {
142 return Err(TorshError::ShapeMismatch {
143 expected: first_shape.dims().to_vec(),
144 got: tensor.shape().dims().to_vec(),
145 });
146 }
147 }
148
149 let original_dims = first_shape.dims();
151 let mut new_dims = Vec::with_capacity(original_dims.len() + 1);
152
153 if dim == 0 {
154 new_dims.push(tensors.len());
155 new_dims.extend_from_slice(original_dims);
156 } else {
157 new_dims.extend_from_slice(&original_dims[..dim.min(original_dims.len())]);
158 new_dims.push(tensors.len());
159 if dim < original_dims.len() {
160 new_dims.extend_from_slice(&original_dims[dim..]);
161 }
162 }
163
164 let tensor_size = tensors[0].numel();
165 let total_size = tensor_size * tensors.len() * std::mem::size_of::<T>();
166
167 let mut temp_file =
169 tempfile::NamedTempFile::new().map_err(|e| TorshError::IoError(e.to_string()))?;
170
171 temp_file
173 .as_file_mut()
174 .set_len(total_size as u64)
175 .map_err(|e| TorshError::IoError(e.to_string()))?;
176
177 let mmap = unsafe {
179 memmap2::MmapOptions::new()
180 .map_mut(temp_file.as_file())
181 .map_err(|e| TorshError::IoError(e.to_string()))?
182 };
183
184 let all_data: std::result::Result<Vec<Vec<T>>, TorshError> =
186 tensors.par_iter().map(|tensor| tensor.to_vec()).collect();
187 let all_data = all_data?;
188
189 let mmap_ptr = mmap.as_ptr() as *mut T;
191 for (i, data) in all_data.iter().enumerate() {
192 unsafe {
193 let dst = mmap_ptr.add(i * tensor_size);
194 std::ptr::copy_nonoverlapping(data.as_ptr(), dst, tensor_size);
195 }
196 }
197
198 unsafe {
200 let data_slice =
201 std::slice::from_raw_parts(mmap_ptr as *const T, tensor_size * tensors.len());
202 let data_vec = data_slice.to_vec();
203 let result = torsh_tensor::Tensor::from_data(data_vec, new_dims, tensors[0].device())?;
204 Ok(result)
205 }
206}
207
208#[cfg(feature = "std")]
210#[derive(Debug, Clone, Copy)]
211pub struct OptimizedCollate;
212
213#[cfg(feature = "std")]
214impl<T: TensorElement + Copy> Collate<Tensor<T>> for OptimizedCollate {
215 type Output = Tensor<T>;
216
217 fn collate(&self, batch: Vec<Tensor<T>>) -> Result<Self::Output> {
218 if batch.is_empty() {
219 return Err(TorshError::InvalidArgument(
220 "Cannot collate empty batch".to_string(),
221 ));
222 }
223
224 stack_tensors_fast(&batch, 0)
226 }
227}
228
229#[cfg(feature = "std")]
230impl<T: TensorElement + Copy> Collate<Vec<Tensor<T>>> for OptimizedCollate {
231 type Output = Vec<Tensor<T>>;
232
233 fn collate(&self, batch: Vec<Vec<Tensor<T>>>) -> Result<Self::Output> {
234 if batch.is_empty() {
235 return Err(TorshError::InvalidArgument(
236 "Cannot collate empty batch".to_string(),
237 ));
238 }
239
240 let num_tensors = batch[0].len();
241 let mut collated = Vec::with_capacity(num_tensors);
242
243 (0..num_tensors)
245 .into_par_iter()
246 .map(|i| {
247 let tensors: Vec<Tensor<T>> =
248 batch.iter().map(|sample| sample[i].clone()).collect();
249 stack_tensors_fast(&tensors, 0)
250 })
251 .collect::<Result<Vec<_>>>()?
252 .into_iter()
253 .for_each(|tensor| collated.push(tensor));
254
255 Ok(collated)
256 }
257}
258
259#[cfg(feature = "std")]
261pub fn optimized_collate_fn<T>() -> OptimizedCollate {
262 OptimizedCollate
263}
264
265#[cfg(not(feature = "std"))]
267#[derive(Debug, Clone, Copy)]
268pub struct OptimizedCollate;
269
270#[cfg(not(feature = "std"))]
271impl<T: TensorElement + Copy> Collate<Tensor<T>> for OptimizedCollate {
272 type Output = Tensor<T>;
273
274 fn collate(&self, batch: Vec<Tensor<T>>) -> Result<Self::Output> {
275 TensorStacker::new().stack(&batch, 0)
276 }
277}
278
279#[cfg(not(feature = "std"))]
280impl<T: TensorElement + Copy> Collate<Vec<Tensor<T>>> for OptimizedCollate {
281 type Output = Vec<Tensor<T>>;
282
283 fn collate(&self, batch: Vec<Vec<Tensor<T>>>) -> Result<Self::Output> {
284 if batch.is_empty() {
285 return Err(TorshError::InvalidArgument(
286 "Cannot collate empty batch".to_string(),
287 ));
288 }
289
290 let num_tensors = batch[0].len();
291 let mut collated = Vec::with_capacity(num_tensors);
292 let stacker = TensorStacker::new();
293
294 for i in 0..num_tensors {
296 let tensors: Vec<Tensor<T>> = batch.iter().map(|sample| sample[i].clone()).collect();
297 collated.push(stacker.stack(&tensors, 0)?);
298 }
299
300 Ok(collated)
301 }
302}
303
304#[cfg(not(feature = "std"))]
306pub fn optimized_collate_fn<T>() -> OptimizedCollate {
307 OptimizedCollate
308}