1use super::{optimized::stack_tensors, Collate};
4use torsh_core::{
5 dtype::TensorElement,
6 error::{Result, TorshError},
7};
8use torsh_tensor::Tensor;
9
10#[cfg(feature = "sparse")]
11use torsh_sparse::{CooTensor, SparseTensor};
12
13#[cfg(not(feature = "std"))]
14use alloc::{boxed::Box, vec::Vec};
15
16#[cfg(feature = "std")]
17use std::sync::Arc;
18
19pub struct CachedCollate<T: TensorElement> {
21 tensor_pool: Arc<parking_lot::Mutex<Vec<Vec<T>>>>,
22 max_pool_size: usize,
23}
24
25impl<T: TensorElement> CachedCollate<T> {
26 pub fn new(max_pool_size: usize) -> Self {
28 Self {
29 tensor_pool: Arc::new(parking_lot::Mutex::new(Vec::with_capacity(max_pool_size))),
30 max_pool_size,
31 }
32 }
33
34 fn get_buffer(&self, capacity: usize) -> Vec<T> {
36 let mut pool = self.tensor_pool.lock();
37 if let Some(mut buffer) = pool.pop() {
38 buffer.clear();
39 if buffer.capacity() >= capacity {
40 buffer.reserve(capacity - buffer.capacity());
41 }
42 buffer
43 } else {
44 Vec::with_capacity(capacity)
45 }
46 }
47
48 fn return_buffer(&self, buffer: Vec<T>) {
50 let mut pool = self.tensor_pool.lock();
51 if pool.len() < self.max_pool_size {
52 pool.push(buffer);
53 }
54 }
55}
56
57impl<T: TensorElement + Copy> Collate<Tensor<T>> for CachedCollate<T> {
58 type Output = Tensor<T>;
59
60 fn collate(&self, batch: Vec<Tensor<T>>) -> Result<Self::Output> {
61 if batch.is_empty() {
62 return Err(TorshError::InvalidArgument(
63 "Cannot collate empty batch".to_string(),
64 ));
65 }
66
67 let first_shape = batch[0].shape();
69 for tensor in &batch[1..] {
70 if tensor.shape() != first_shape {
71 return Err(TorshError::ShapeMismatch {
72 expected: first_shape.dims().to_vec(),
73 got: tensor.shape().dims().to_vec(),
74 });
75 }
76 }
77
78 let original_dims = first_shape.dims();
80 let mut new_dims = Vec::with_capacity(original_dims.len() + 1);
81 new_dims.push(batch.len());
82 new_dims.extend_from_slice(original_dims);
83
84 let tensor_size = batch[0].numel();
85 let total_elements = tensor_size * batch.len();
86
87 let mut new_data = self.get_buffer(total_elements);
89 new_data.reserve_exact(total_elements);
90
91 for tensor in batch.iter() {
93 let data = tensor.to_vec()?;
94 new_data.extend_from_slice(&data);
95 }
96
97 let result =
98 torsh_tensor::Tensor::from_data(new_data.clone(), new_dims, batch[0].device())?;
99
100 self.return_buffer(Vec::with_capacity(new_data.capacity()));
102
103 Ok(result)
104 }
105}
106
107pub struct DynamicBatchCollate<T: TensorElement> {
109 padding_value: T,
110 max_sequence_length: Option<usize>,
111 pack_sequences: bool,
112}
113
114impl<T: TensorElement> DynamicBatchCollate<T> {
115 pub fn new(padding_value: T) -> Self {
117 Self {
118 padding_value,
119 max_sequence_length: None,
120 pack_sequences: false,
121 }
122 }
123
124 pub fn with_max_length(mut self, max_length: usize) -> Self {
126 self.max_sequence_length = Some(max_length);
127 self
128 }
129
130 pub fn with_packing(mut self, pack: bool) -> Self {
132 self.pack_sequences = pack;
133 self
134 }
135}
136
137impl<
138 T: TensorElement
139 + Copy
140 + std::ops::Add<Output = T>
141 + std::ops::Sub<Output = T>
142 + std::ops::Mul<Output = T>
143 + std::ops::Div<Output = T>
144 + Default,
145 > Collate<Tensor<T>> for DynamicBatchCollate<T>
146{
147 type Output = (Tensor<T>, Tensor<i64>); fn collate(&self, batch: Vec<Tensor<T>>) -> Result<Self::Output> {
150 if batch.is_empty() {
151 return Err(TorshError::InvalidArgument(
152 "Cannot collate empty batch".to_string(),
153 ));
154 }
155
156 let mut lengths = Vec::with_capacity(batch.len());
158 let mut max_length = 0;
159
160 for tensor in &batch {
161 if tensor.ndim() == 0 {
162 return Err(TorshError::InvalidArgument(
163 "Cannot dynamically batch scalar tensors".to_string(),
164 ));
165 }
166
167 let seq_len = tensor.size(0)?;
168 lengths.push(seq_len as i64);
169 max_length = max_length.max(seq_len);
170 }
171
172 if let Some(max_len) = self.max_sequence_length {
174 max_length = max_length.min(max_len);
175 }
176
177 let mut batch_with_indices: Vec<_> = batch.into_iter().enumerate().collect();
179 if self.pack_sequences {
180 batch_with_indices.sort_by_key(|(_, tensor)| tensor.size(0).unwrap_or(0));
181 }
182
183 let first_tensor = &batch_with_indices[0].1;
185 let mut padded_shape = first_tensor.shape().dims().to_vec();
186 padded_shape[0] = max_length; let batch_size = batch_with_indices.len();
190 let mut padded_batch = Vec::with_capacity(batch_size);
191
192 for (original_idx, tensor) in batch_with_indices {
193 let seq_len = tensor.size(0)?;
194 let actual_len = seq_len.min(max_length);
195
196 if actual_len == max_length {
197 if seq_len > max_length {
199 let truncated = tensor.narrow(0, 0, max_length)?;
200 padded_batch.push((original_idx, truncated));
201 } else {
202 padded_batch.push((original_idx, tensor));
203 }
204 } else {
205 let mut padding_shape = padded_shape.clone();
207 let padding_elements =
208 (max_length - actual_len) * padding_shape[1..].iter().product::<usize>();
209
210 let padding_data = vec![self.padding_value; padding_elements];
212 padding_shape[0] = max_length - actual_len;
213
214 let padding_tensor =
215 Tensor::from_data(padding_data, padding_shape.clone(), tensor.device())?;
216
217 let tensor_to_pad = if seq_len > max_length {
219 tensor.narrow(0, 0, max_length)?
220 } else {
221 tensor
222 };
223
224 let tensor_data = tensor_to_pad.to_vec()?;
226 let padding_data = padding_tensor.to_vec()?;
227
228 let mut combined_data = tensor_data;
230 combined_data.extend(padding_data);
231
232 let mut final_shape = tensor_to_pad.shape().dims().to_vec();
234 final_shape[0] = max_length; let padded = Tensor::from_data(combined_data, final_shape, tensor_to_pad.device())?;
237 padded_batch.push((original_idx, padded));
238 }
239 }
240
241 if self.pack_sequences {
243 padded_batch.sort_by_key(|(idx, _)| *idx);
244 }
245
246 let tensors: Vec<_> = padded_batch.into_iter().map(|(_, tensor)| tensor).collect();
248
249 let stacked = stack_tensors(&tensors, 0)?;
250
251 let lengths_tensor = Tensor::from_data(lengths, vec![batch_size], tensors[0].device())?;
253
254 Ok((stacked, lengths_tensor))
255 }
256}
257
258pub struct DynamicBatchCollateWrapper<T: TensorElement> {
261 inner: DynamicBatchCollate<T>,
262}
263
264impl<T: TensorElement> DynamicBatchCollateWrapper<T> {
265 pub fn new(padding_value: T) -> Self {
266 Self {
267 inner: DynamicBatchCollate::new(padding_value),
268 }
269 }
270
271 pub fn with_max_length(mut self, max_length: usize) -> Self {
272 self.inner = self.inner.with_max_length(max_length);
273 self
274 }
275
276 pub fn with_packing(mut self, pack: bool) -> Self {
277 self.inner = self.inner.with_packing(pack);
278 self
279 }
280}
281
282impl<
283 T: TensorElement
284 + Copy
285 + std::ops::Add<Output = T>
286 + std::ops::Sub<Output = T>
287 + std::ops::Mul<Output = T>
288 + std::ops::Div<Output = T>
289 + Default,
290 > Collate<Tensor<T>> for DynamicBatchCollateWrapper<T>
291{
292 type Output = Tensor<T>;
293
294 fn collate(&self, batch: Vec<Tensor<T>>) -> Result<Self::Output> {
295 let (padded_sequences, _lengths) = self.inner.collate(batch)?;
297 Ok(padded_sequences)
298 }
299}
300
301pub struct BucketBatchSampler {
304 lengths: Vec<usize>,
305 batch_size: usize,
306 bucket_boundaries: Vec<usize>,
307 drop_last: bool,
308}
309
310impl BucketBatchSampler {
311 pub fn new(lengths: Vec<usize>, batch_size: usize, drop_last: bool) -> Self {
313 let mut sorted_lengths = lengths.clone();
315 sorted_lengths.sort_unstable();
316
317 let num_buckets = (lengths.len() / batch_size).clamp(1, 10);
318 let mut bucket_boundaries = Vec::with_capacity(num_buckets + 1);
319
320 for i in 0..=num_buckets {
321 let idx = (i * sorted_lengths.len()) / num_buckets;
322 let boundary = if idx >= sorted_lengths.len() {
323 sorted_lengths.last().copied().unwrap_or(0) + 1
324 } else {
325 sorted_lengths[idx]
326 };
327 bucket_boundaries.push(boundary);
328 }
329
330 Self {
331 lengths,
332 batch_size,
333 bucket_boundaries,
334 drop_last,
335 }
336 }
337
338 pub fn generate_batches(&self) -> Vec<Vec<usize>> {
340 let mut buckets: Vec<Vec<usize>> = vec![Vec::new(); self.bucket_boundaries.len() - 1];
342
343 for (idx, &length) in self.lengths.iter().enumerate() {
344 for (bucket_idx, bucket) in buckets.iter_mut().enumerate() {
345 if length >= self.bucket_boundaries[bucket_idx]
346 && length < self.bucket_boundaries[bucket_idx + 1]
347 {
348 bucket.push(idx);
349 break;
350 }
351 }
352 }
353
354 let mut batches = Vec::new();
356
357 for mut bucket in buckets {
358 use scirs2_core::random::prelude::*;
360 use scirs2_core::random::seq::ScientificSliceRandom;
361
362 let mut rng = thread_rng();
363 bucket.scientific_shuffle(&mut rng);
364
365 for chunk in bucket.chunks(self.batch_size) {
366 if chunk.len() == self.batch_size || !self.drop_last {
367 batches.push(chunk.to_vec());
368 }
369 }
370 }
371
372 use scirs2_core::random::prelude::*;
375 use scirs2_core::random::seq::ScientificSliceRandom;
376 let mut rng = thread_rng();
377 batches.scientific_shuffle(&mut rng);
378
379 batches
380 }
381}
382
383pub struct AdaptiveBatchSampler {
385 target_tokens: usize,
386 max_batch_size: usize,
387 min_batch_size: usize,
388 lengths: Vec<usize>,
389}
390
391impl AdaptiveBatchSampler {
392 pub fn new(
394 lengths: Vec<usize>,
395 target_tokens: usize,
396 max_batch_size: usize,
397 min_batch_size: usize,
398 ) -> Self {
399 Self {
400 target_tokens,
401 max_batch_size,
402 min_batch_size,
403 lengths,
404 }
405 }
406
407 pub fn generate_batches(&self) -> Vec<Vec<usize>> {
409 let mut indices: Vec<usize> = (0..self.lengths.len()).collect();
410
411 indices.sort_by_key(|&i| self.lengths[i]);
413
414 let mut batches = Vec::new();
415 let mut current_batch = Vec::new();
416 let mut _current_tokens = 0;
417
418 for idx in indices {
419 let length = self.lengths[idx];
420 let batch_size = current_batch.len();
421 let tokens_if_added = (batch_size + 1)
422 * length.max(
423 current_batch
424 .iter()
425 .map(|&i| self.lengths[i])
426 .max()
427 .unwrap_or(0),
428 );
429
430 if tokens_if_added > self.target_tokens || batch_size >= self.max_batch_size {
432 if batch_size >= self.min_batch_size {
434 batches.push(current_batch);
435 }
436
437 current_batch = vec![idx];
439 _current_tokens = length;
440 } else {
441 current_batch.push(idx);
443 _current_tokens = tokens_if_added;
444 }
445 }
446
447 if current_batch.len() >= self.min_batch_size {
449 batches.push(current_batch);
450 }
451
452 batches
453 }
454}
455
456pub struct PadCollate<T: TensorElement> {
458 #[allow(dead_code)]
459 padding_value: T,
460}
461
462impl<T: TensorElement> PadCollate<T> {
463 pub fn new(padding_value: T) -> Self {
465 Self { padding_value }
466 }
467}
468
469impl<T: TensorElement + Copy> Collate<Tensor<T>> for PadCollate<T> {
470 type Output = Tensor<T>;
471
472 fn collate(&self, batch: Vec<Tensor<T>>) -> Result<Self::Output> {
473 if batch.is_empty() {
474 return Err(TorshError::InvalidArgument(
475 "Cannot collate empty batch".to_string(),
476 ));
477 }
478
479 let ndim = batch[0].ndim();
481 let mut max_dims = vec![0; ndim];
482
483 for tensor in &batch {
484 if tensor.ndim() != ndim {
485 return Err(TorshError::InvalidArgument(
486 "All tensors must have the same number of dimensions".to_string(),
487 ));
488 }
489
490 for (i, max_dim) in max_dims.iter_mut().enumerate().take(ndim) {
491 let size = tensor.size(i as i32)?;
492 if size > *max_dim {
493 *max_dim = size;
494 }
495 }
496 }
497
498 let batch_size = batch.len();
500 let mut padded_batch = Vec::with_capacity(batch_size);
501
502 for tensor in batch {
503 let shape_ref = tensor.shape();
505 let current_shape = shape_ref.dims();
506 let padded_tensor = tensor;
507
508 let needs_padding = current_shape
510 .iter()
511 .zip(max_dims.iter())
512 .any(|(¤t, &max)| current < max);
513
514 if needs_padding {
515 }
519
520 padded_batch.push(padded_tensor);
521 }
522
523 stack_tensors(&padded_batch, 0)
525 }
526}
527
528#[cfg(feature = "sparse")]
530pub struct SparseCollate;
531
532#[cfg(feature = "sparse")]
533impl Collate<CooTensor> for SparseCollate {
534 type Output = CooTensor;
535
536 fn collate(&self, batch: Vec<CooTensor>) -> Result<Self::Output> {
537 if batch.is_empty() {
538 return Err(TorshError::InvalidArgument(
539 "Cannot collate empty batch".to_string(),
540 ));
541 }
542
543 collate_sparse_tensors(&batch)
546 }
547}
548
549#[cfg(feature = "sparse")]
551pub fn collate_sparse_tensors(tensors: &[CooTensor]) -> Result<CooTensor> {
552 if tensors.is_empty() {
553 return Err(TorshError::InvalidArgument(
554 "Cannot collate empty sparse tensor batch".to_string(),
555 ));
556 }
557
558 let first_shape = tensors[0].shape();
560 for tensor in &tensors[1..] {
561 if tensor.shape() != first_shape {
562 return Err(TorshError::ShapeMismatch {
563 expected: first_shape.dims().to_vec(),
564 got: tensor.shape().dims().to_vec(),
565 });
566 }
567 }
568
569 let original_dims = first_shape.dims();
571 let mut new_dims = Vec::with_capacity(original_dims.len() + 1);
572 new_dims.push(tensors.len());
573 new_dims.extend_from_slice(original_dims);
574
575 let mut all_row_indices = Vec::new();
581 let mut all_col_indices = Vec::new();
582 let mut all_values = Vec::new();
583 let mut _total_nnz = 0;
584
585 for (batch_idx, tensor) in tensors.iter().enumerate() {
586 let _row_indices = tensor.row_indices();
587 let col_indices = tensor.col_indices();
588 let values = tensor.values();
589
590 for i in 0..tensor.nnz() {
592 all_row_indices.push(batch_idx);
593 all_col_indices.push(col_indices[i]);
594 }
595
596 all_values.extend_from_slice(values);
597 _total_nnz += tensor.nnz();
598 }
599
600 let shape = torsh_core::Shape::new(new_dims);
602 CooTensor::new(all_row_indices, all_col_indices, all_values, shape)
603}
604
605#[cfg(feature = "sparse")]
607pub struct MixedCollate;
608
609#[cfg(feature = "sparse")]
610impl Collate<Box<dyn SparseTensor>> for MixedCollate {
611 type Output = Box<dyn SparseTensor>;
612
613 fn collate(&self, batch: Vec<Box<dyn SparseTensor>>) -> Result<Self::Output> {
614 if batch.is_empty() {
615 return Err(TorshError::InvalidArgument(
616 "Cannot collate empty batch".to_string(),
617 ));
618 }
619
620 let mut coo_tensors = Vec::with_capacity(batch.len());
622 for tensor in batch {
623 coo_tensors.push(tensor.to_coo()?);
624 }
625
626 let collated = collate_sparse_tensors(&coo_tensors)?;
628 Ok(Box::new(collated))
629 }
630}