Skip to main content

yscv_model/dataset/
iter.rs

1use yscv_tensor::{Tensor, TensorError};
2
3use crate::{ImageAugmentationOp, ImageAugmentationPipeline, ModelError};
4
5use super::helpers::{
6    LcgRng, class_balanced_sampling_weights, should_apply_probability, shuffle_indices,
7};
8use super::types::{BatchIterOptions, SamplingPolicy, SupervisedDataset};
9
10/// Controls per-batch sample/label interpolation for regularized training.
11#[derive(Debug, Clone, PartialEq)]
12pub struct MixUpConfig {
13    probability: f32,
14    lambda_min: f32,
15}
16
17impl Default for MixUpConfig {
18    fn default() -> Self {
19        Self {
20            probability: 1.0,
21            lambda_min: 0.0,
22        }
23    }
24}
25
26impl MixUpConfig {
27    pub fn new() -> Self {
28        Self::default()
29    }
30
31    pub fn with_probability(mut self, probability: f32) -> Result<Self, ModelError> {
32        validate_mixup_probability(probability)?;
33        self.probability = probability;
34        Ok(self)
35    }
36
37    pub fn with_lambda_min(mut self, lambda_min: f32) -> Result<Self, ModelError> {
38        validate_mixup_lambda_min(lambda_min)?;
39        self.lambda_min = lambda_min;
40        Ok(self)
41    }
42
43    pub fn probability(&self) -> f32 {
44        self.probability
45    }
46
47    pub fn lambda_min(&self) -> f32 {
48        self.lambda_min
49    }
50}
51
52/// Controls per-batch region replacement interpolation for image tensors.
53#[derive(Debug, Clone, PartialEq)]
54pub struct CutMixConfig {
55    probability: f32,
56    min_patch_fraction: f32,
57    max_patch_fraction: f32,
58}
59
60impl Default for CutMixConfig {
61    fn default() -> Self {
62        Self {
63            probability: 1.0,
64            min_patch_fraction: 0.1,
65            max_patch_fraction: 0.5,
66        }
67    }
68}
69
70impl CutMixConfig {
71    pub fn new() -> Self {
72        Self::default()
73    }
74
75    pub fn with_probability(mut self, probability: f32) -> Result<Self, ModelError> {
76        validate_cutmix_probability(probability)?;
77        self.probability = probability;
78        Ok(self)
79    }
80
81    pub fn with_min_patch_fraction(mut self, min_patch_fraction: f32) -> Result<Self, ModelError> {
82        validate_cutmix_patch_fraction("min_patch_fraction", min_patch_fraction)?;
83        self.min_patch_fraction = min_patch_fraction;
84        if self.min_patch_fraction > self.max_patch_fraction {
85            return Err(ModelError::InvalidCutMixArgument {
86                field: "min_patch_fraction",
87                value: self.min_patch_fraction,
88                message: format!(
89                    "min_patch_fraction must be <= max_patch_fraction ({})",
90                    self.max_patch_fraction
91                ),
92            });
93        }
94        Ok(self)
95    }
96
97    pub fn with_max_patch_fraction(mut self, max_patch_fraction: f32) -> Result<Self, ModelError> {
98        validate_cutmix_patch_fraction("max_patch_fraction", max_patch_fraction)?;
99        self.max_patch_fraction = max_patch_fraction;
100        if self.min_patch_fraction > self.max_patch_fraction {
101            return Err(ModelError::InvalidCutMixArgument {
102                field: "max_patch_fraction",
103                value: self.max_patch_fraction,
104                message: format!(
105                    "max_patch_fraction must be >= min_patch_fraction ({})",
106                    self.min_patch_fraction
107                ),
108            });
109        }
110        Ok(self)
111    }
112
113    pub fn probability(&self) -> f32 {
114        self.probability
115    }
116
117    pub fn min_patch_fraction(&self) -> f32 {
118        self.min_patch_fraction
119    }
120
121    pub fn max_patch_fraction(&self) -> f32 {
122        self.max_patch_fraction
123    }
124}
125
126pub(super) struct MixUpBatch {
127    pub(super) inputs: Tensor,
128    pub(super) targets: Tensor,
129}
130
131pub(super) fn validate_augmentation_compatibility(
132    inputs: &Tensor,
133    pipeline: &ImageAugmentationPipeline,
134) -> Result<(), ModelError> {
135    if inputs.rank() != 4 {
136        return Err(ModelError::InvalidAugmentationInputShape {
137            got: inputs.shape().to_vec(),
138        });
139    }
140    let channels = inputs.shape()[3];
141    for op in pipeline.ops() {
142        if let ImageAugmentationOp::ChannelNormalize { mean, std: _ } = op
143            && mean.len() != channels
144        {
145            return Err(ModelError::InvalidAugmentationArgument {
146                operation: "channel_normalize",
147                message: format!(
148                    "channel count mismatch: dataset_channels={channels}, mean/std_len={}",
149                    mean.len()
150                ),
151            });
152        }
153    }
154    Ok(())
155}
156
157pub(super) fn validate_mixup_config(config: &MixUpConfig) -> Result<(), ModelError> {
158    validate_mixup_probability(config.probability())?;
159    validate_mixup_lambda_min(config.lambda_min())?;
160    Ok(())
161}
162
163pub(super) fn validate_cutmix_config(config: &CutMixConfig) -> Result<(), ModelError> {
164    validate_cutmix_probability(config.probability())?;
165    validate_cutmix_patch_fraction("min_patch_fraction", config.min_patch_fraction())?;
166    validate_cutmix_patch_fraction("max_patch_fraction", config.max_patch_fraction())?;
167    if config.min_patch_fraction() > config.max_patch_fraction() {
168        return Err(ModelError::InvalidCutMixArgument {
169            field: "min_patch_fraction",
170            value: config.min_patch_fraction(),
171            message: format!(
172                "min_patch_fraction must be <= max_patch_fraction ({})",
173                config.max_patch_fraction()
174            ),
175        });
176    }
177    Ok(())
178}
179
180fn validate_mixup_probability(probability: f32) -> Result<(), ModelError> {
181    if !probability.is_finite() || !(0.0..=1.0).contains(&probability) {
182        return Err(ModelError::InvalidMixupArgument {
183            field: "probability",
184            value: probability,
185            message: "probability must be finite and in [0, 1]".to_string(),
186        });
187    }
188    Ok(())
189}
190
191fn validate_cutmix_probability(probability: f32) -> Result<(), ModelError> {
192    if !probability.is_finite() || !(0.0..=1.0).contains(&probability) {
193        return Err(ModelError::InvalidCutMixArgument {
194            field: "probability",
195            value: probability,
196            message: "probability must be finite and in [0, 1]".to_string(),
197        });
198    }
199    Ok(())
200}
201
202fn validate_cutmix_patch_fraction(field: &'static str, value: f32) -> Result<(), ModelError> {
203    if !value.is_finite() || !(0.0..=1.0).contains(&value) {
204        return Err(ModelError::InvalidCutMixArgument {
205            field,
206            value,
207            message: format!("{field} must be finite and in [0, 1]"),
208        });
209    }
210    Ok(())
211}
212
213pub(super) fn validate_cutmix_compatibility(inputs: &Tensor) -> Result<(), ModelError> {
214    if inputs.rank() != 4 {
215        return Err(ModelError::InvalidCutMixInputShape {
216            got: inputs.shape().to_vec(),
217        });
218    }
219    Ok(())
220}
221
222fn validate_mixup_lambda_min(lambda_min: f32) -> Result<(), ModelError> {
223    if !lambda_min.is_finite() || !(0.0..=0.5).contains(&lambda_min) {
224        return Err(ModelError::InvalidMixupArgument {
225            field: "lambda_min",
226            value: lambda_min,
227            message: "lambda_min must be finite and in [0, 0.5]".to_string(),
228        });
229    }
230    Ok(())
231}
232
233pub(super) fn apply_mixup_batch(
234    inputs: &Tensor,
235    targets: &Tensor,
236    config: &MixUpConfig,
237    seed: u64,
238) -> Result<MixUpBatch, ModelError> {
239    validate_mixup_config(config)?;
240    if inputs.rank() == 0 || targets.rank() == 0 {
241        return Err(ModelError::InvalidDatasetRank {
242            inputs_rank: inputs.rank(),
243            targets_rank: targets.rank(),
244        });
245    }
246    let batch_size = inputs.shape()[0];
247    if batch_size != targets.shape()[0] {
248        return Err(ModelError::DatasetShapeMismatch {
249            inputs: inputs.shape().to_vec(),
250            targets: targets.shape().to_vec(),
251        });
252    }
253    if batch_size < 2 {
254        return Ok(MixUpBatch {
255            inputs: inputs.clone(),
256            targets: targets.clone(),
257        });
258    }
259
260    let mut rng = LcgRng::new(seed);
261    if !should_apply_probability(config.probability(), &mut rng) {
262        return Ok(MixUpBatch {
263            inputs: inputs.clone(),
264            targets: targets.clone(),
265        });
266    }
267
268    let lambda =
269        config.lambda_min() + rng.next_unit_f64() as f32 * (1.0 - 2.0 * config.lambda_min());
270    let partner_indices = build_partner_indices(batch_size, seed ^ 0xA5A5_A5A5_5A5A_5A5A);
271
272    Ok(MixUpBatch {
273        inputs: blend_rows(inputs, &partner_indices, lambda)?,
274        targets: blend_rows(targets, &partner_indices, lambda)?,
275    })
276}
277
278pub(super) fn apply_cutmix_batch(
279    inputs: &Tensor,
280    targets: &Tensor,
281    config: &CutMixConfig,
282    seed: u64,
283) -> Result<MixUpBatch, ModelError> {
284    validate_cutmix_config(config)?;
285    validate_cutmix_compatibility(inputs)?;
286    if targets.rank() == 0 {
287        return Err(ModelError::InvalidDatasetRank {
288            inputs_rank: inputs.rank(),
289            targets_rank: targets.rank(),
290        });
291    }
292    let batch_size = inputs.shape()[0];
293    if batch_size != targets.shape()[0] {
294        return Err(ModelError::DatasetShapeMismatch {
295            inputs: inputs.shape().to_vec(),
296            targets: targets.shape().to_vec(),
297        });
298    }
299    if batch_size < 2 {
300        return Ok(MixUpBatch {
301            inputs: inputs.clone(),
302            targets: targets.clone(),
303        });
304    }
305
306    let mut rng = LcgRng::new(seed);
307    if !should_apply_probability(config.probability(), &mut rng) {
308        return Ok(MixUpBatch {
309            inputs: inputs.clone(),
310            targets: targets.clone(),
311        });
312    }
313
314    let height = inputs.shape()[1];
315    let width = inputs.shape()[2];
316    let channels = inputs.shape()[3];
317    if height == 0 || width == 0 || channels == 0 {
318        return Ok(MixUpBatch {
319            inputs: inputs.clone(),
320            targets: targets.clone(),
321        });
322    }
323
324    let input_row_width = height
325        .checked_mul(width)
326        .and_then(|value| value.checked_mul(channels))
327        .ok_or_else(|| {
328            ModelError::Tensor(TensorError::SizeOverflow {
329                shape: inputs.shape().to_vec(),
330            })
331        })?;
332    let target_row_width = targets.shape()[1..]
333        .iter()
334        .try_fold(1usize, |acc, dim| acc.checked_mul(*dim))
335        .ok_or_else(|| {
336            ModelError::Tensor(TensorError::SizeOverflow {
337                shape: targets.shape().to_vec(),
338            })
339        })?;
340
341    let mut mixed_inputs = inputs.data().to_vec();
342    let mut mixed_targets = targets.data().to_vec();
343    let partner_indices = build_partner_indices(batch_size, seed ^ 0x5A5A_A5A5_DEAD_BEEF);
344    let total_pixels = (height * width) as f32;
345
346    for (row_index, partner_index) in partner_indices.iter().enumerate() {
347        let patch_fraction = sample_cutmix_patch_fraction(config, &mut rng);
348        let patch_height = ((height as f32 * patch_fraction).floor() as usize)
349            .max(1)
350            .min(height);
351        let patch_width = ((width as f32 * patch_fraction).floor() as usize)
352            .max(1)
353            .min(width);
354        let top = rng.next_usize(height - patch_height + 1);
355        let left = rng.next_usize(width - patch_width + 1);
356
357        let row_start = row_index.checked_mul(input_row_width).ok_or_else(|| {
358            ModelError::Tensor(TensorError::SizeOverflow {
359                shape: inputs.shape().to_vec(),
360            })
361        })?;
362        let partner_start = partner_index.checked_mul(input_row_width).ok_or_else(|| {
363            ModelError::Tensor(TensorError::SizeOverflow {
364                shape: inputs.shape().to_vec(),
365            })
366        })?;
367
368        for y in 0..patch_height {
369            for x in 0..patch_width {
370                let pixel_offset = ((top + y) * width + (left + x)) * channels;
371                let dst = row_start + pixel_offset;
372                let src = partner_start + pixel_offset;
373                mixed_inputs[dst..(dst + channels)]
374                    .copy_from_slice(&inputs.data()[src..(src + channels)]);
375            }
376        }
377
378        let replaced_ratio = (patch_height * patch_width) as f32 / total_pixels;
379        let lambda = 1.0 - replaced_ratio;
380        let target_row_start = row_index.checked_mul(target_row_width).ok_or_else(|| {
381            ModelError::Tensor(TensorError::SizeOverflow {
382                shape: targets.shape().to_vec(),
383            })
384        })?;
385        let partner_target_start =
386            partner_index.checked_mul(target_row_width).ok_or_else(|| {
387                ModelError::Tensor(TensorError::SizeOverflow {
388                    shape: targets.shape().to_vec(),
389                })
390            })?;
391        for offset in 0..target_row_width {
392            mixed_targets[target_row_start + offset] = lambda
393                * targets.data()[target_row_start + offset]
394                + (1.0 - lambda) * targets.data()[partner_target_start + offset];
395        }
396    }
397
398    Ok(MixUpBatch {
399        inputs: Tensor::from_vec(inputs.shape().to_vec(), mixed_inputs)?,
400        targets: Tensor::from_vec(targets.shape().to_vec(), mixed_targets)?,
401    })
402}
403
404fn sample_cutmix_patch_fraction(config: &CutMixConfig, rng: &mut LcgRng) -> f32 {
405    if (config.max_patch_fraction() - config.min_patch_fraction()).abs() <= f32::EPSILON {
406        return config.min_patch_fraction();
407    }
408    config.min_patch_fraction()
409        + rng.next_unit_f64() as f32 * (config.max_patch_fraction() - config.min_patch_fraction())
410}
411
412fn blend_rows(
413    tensor: &Tensor,
414    partner_indices: &[usize],
415    lambda: f32,
416) -> Result<Tensor, ModelError> {
417    if tensor.rank() == 0 {
418        return Err(ModelError::InvalidDatasetRank {
419            inputs_rank: tensor.rank(),
420            targets_rank: tensor.rank(),
421        });
422    }
423    let batch_size = tensor.shape()[0];
424    if partner_indices.len() != batch_size {
425        return Err(ModelError::DatasetShapeMismatch {
426            inputs: tensor.shape().to_vec(),
427            targets: vec![partner_indices.len()],
428        });
429    }
430    let row_width = tensor.shape()[1..]
431        .iter()
432        .try_fold(1usize, |acc, dim| acc.checked_mul(*dim))
433        .ok_or_else(|| {
434            ModelError::Tensor(TensorError::SizeOverflow {
435                shape: tensor.shape().to_vec(),
436            })
437        })?;
438
439    let mut out = vec![0.0f32; tensor.len()];
440    let left_weight = lambda;
441    let right_weight = 1.0 - lambda;
442    for (row_index, partner_index) in partner_indices.iter().enumerate() {
443        if *partner_index >= batch_size {
444            return Err(ModelError::DatasetShapeMismatch {
445                inputs: tensor.shape().to_vec(),
446                targets: vec![*partner_index, batch_size],
447            });
448        }
449        let row_start = row_index.checked_mul(row_width).ok_or_else(|| {
450            ModelError::Tensor(TensorError::SizeOverflow {
451                shape: tensor.shape().to_vec(),
452            })
453        })?;
454        let partner_start = partner_index.checked_mul(row_width).ok_or_else(|| {
455            ModelError::Tensor(TensorError::SizeOverflow {
456                shape: tensor.shape().to_vec(),
457            })
458        })?;
459        for offset in 0..row_width {
460            let dst = row_start + offset;
461            out[dst] = left_weight * tensor.data()[row_start + offset]
462                + right_weight * tensor.data()[partner_start + offset];
463        }
464    }
465    Tensor::from_vec(tensor.shape().to_vec(), out).map_err(Into::into)
466}
467
468fn build_partner_indices(batch_size: usize, seed: u64) -> Vec<usize> {
469    let mut partner_indices = (0..batch_size).collect::<Vec<_>>();
470    shuffle_indices(&mut partner_indices, seed);
471    if partner_indices
472        .iter()
473        .enumerate()
474        .all(|(index, partner)| index == *partner)
475    {
476        partner_indices.rotate_left(1);
477    }
478    partner_indices
479}
480
481pub(super) fn build_sample_order(
482    dataset: &SupervisedDataset,
483    options: &BatchIterOptions,
484) -> Result<Vec<usize>, ModelError> {
485    if let Some(policy) = options.sampling.as_ref() {
486        return build_sample_order_from_policy(dataset, policy);
487    }
488
489    let mut order = (0..dataset.len()).collect::<Vec<_>>();
490    if options.shuffle {
491        shuffle_indices(&mut order, options.shuffle_seed);
492    }
493    Ok(order)
494}
495
496fn build_sample_order_from_policy(
497    dataset: &SupervisedDataset,
498    policy: &SamplingPolicy,
499) -> Result<Vec<usize>, ModelError> {
500    let dataset_len = dataset.len();
501    match policy {
502        SamplingPolicy::Sequential => Ok((0..dataset_len).collect()),
503        SamplingPolicy::Shuffled { seed } => {
504            let mut order = (0..dataset_len).collect::<Vec<_>>();
505            shuffle_indices(&mut order, *seed);
506            Ok(order)
507        }
508        SamplingPolicy::BalancedByClass {
509            seed,
510            with_replacement,
511        } => {
512            let weights = class_balanced_sampling_weights(dataset.targets())?;
513            if *with_replacement {
514                sample_weighted_with_replacement(&weights, dataset_len, *seed)
515            } else {
516                sample_weighted_without_replacement(&weights, *seed)
517            }
518        }
519        SamplingPolicy::Weighted {
520            weights,
521            seed,
522            with_replacement,
523        } => {
524            validate_sampling_weights(weights, dataset_len)?;
525            if *with_replacement {
526                sample_weighted_with_replacement(weights, dataset_len, *seed)
527            } else {
528                sample_weighted_without_replacement(weights, *seed)
529            }
530        }
531    }
532}
533
534fn validate_sampling_weights(weights: &[f32], dataset_len: usize) -> Result<(), ModelError> {
535    if weights.len() != dataset_len {
536        return Err(ModelError::InvalidSamplingWeightsLength {
537            expected: dataset_len,
538            got: weights.len(),
539        });
540    }
541    let mut positive = false;
542    for (index, weight) in weights.iter().enumerate() {
543        if !weight.is_finite() || *weight < 0.0 {
544            return Err(ModelError::InvalidSamplingWeight {
545                index,
546                value: *weight,
547            });
548        }
549        if *weight > 0.0 {
550            positive = true;
551        }
552    }
553    if !positive && dataset_len > 0 {
554        return Err(ModelError::InvalidSamplingDistribution);
555    }
556    Ok(())
557}
558
559fn sample_weighted_with_replacement(
560    weights: &[f32],
561    draw_count: usize,
562    seed: u64,
563) -> Result<Vec<usize>, ModelError> {
564    if draw_count == 0 {
565        return Ok(Vec::new());
566    }
567
568    let mut cumulative = Vec::with_capacity(weights.len());
569    let mut total = 0.0f64;
570    for weight in weights {
571        total += *weight as f64;
572        cumulative.push(total);
573    }
574    if total <= 0.0 {
575        return Err(ModelError::InvalidSamplingDistribution);
576    }
577
578    let mut rng = LcgRng::new(seed);
579    let mut out = Vec::with_capacity(draw_count);
580    for _ in 0..draw_count {
581        let draw = rng.next_unit_f64() * total;
582        let mut sampled = cumulative.partition_point(|prefix| *prefix <= draw);
583        if sampled >= weights.len() {
584            sampled = weights.len() - 1;
585        }
586        out.push(sampled);
587    }
588    Ok(out)
589}
590
591fn sample_weighted_without_replacement(
592    weights: &[f32],
593    seed: u64,
594) -> Result<Vec<usize>, ModelError> {
595    if weights.is_empty() {
596        return Ok(Vec::new());
597    }
598
599    let mut rng = LcgRng::new(seed);
600    let mut keyed = Vec::with_capacity(weights.len());
601    for (index, weight) in weights.iter().enumerate() {
602        let key = if *weight == 0.0 {
603            0.0
604        } else {
605            let u = rng.next_unit_open_f64();
606            u.powf(1.0 / *weight as f64)
607        };
608        keyed.push((index, key));
609    }
610
611    keyed.sort_by(|left, right| {
612        right
613            .1
614            .total_cmp(&left.1)
615            .then_with(|| left.0.cmp(&right.0))
616    });
617    Ok(keyed.into_iter().map(|(index, _)| index).collect())
618}