Skip to main content

tenflowers_dataset/transforms/
mod.rs

1//! Data transformation utilities for datasets
2//!
3//! This module provides common data preprocessing and augmentation transformations
4//! that can be applied to datasets during training and inference.
5
6use crate::Dataset;
7use std::marker::PhantomData;
8use tenflowers_core::{Result, Tensor};
9
10// Re-export all transform modules
11pub mod augmentation;
12pub mod feature_engineering;
13pub mod noise;
14pub mod normalization;
15pub mod pipeline;
16pub mod profiling;
17pub mod vision;
18
19// Re-export commonly used types
20pub use noise::*;
21pub use normalization::*;
22
23/// Trait for data transformations
24pub trait Transform<T> {
25    /// Apply the transformation to a sample
26    fn apply(&self, sample: (Tensor<T>, Tensor<T>)) -> Result<(Tensor<T>, Tensor<T>)>;
27}
28
29/// Dataset wrapper that applies transformations to samples
30pub struct TransformedDataset<T, D: Dataset<T>, Tr: Transform<T>> {
31    dataset: D,
32    transform: Tr,
33    _phantom: PhantomData<T>,
34}
35
36impl<T, D: Dataset<T>, Tr: Transform<T>> TransformedDataset<T, D, Tr> {
37    pub fn new(dataset: D, transform: Tr) -> Self {
38        Self {
39            dataset,
40            transform,
41            _phantom: PhantomData,
42        }
43    }
44}
45
46impl<T, D: Dataset<T>, Tr: Transform<T>> Dataset<T> for TransformedDataset<T, D, Tr> {
47    fn len(&self) -> usize {
48        self.dataset.len()
49    }
50
51    fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
52        let sample = self.dataset.get(index)?;
53        self.transform.apply(sample)
54    }
55}
56
57/// Dataset extension trait for convenience methods
58pub trait DatasetExt<T>: Dataset<T> + Sized {
59    /// Apply a transform to this dataset
60    fn transform<Tr: Transform<T>>(self, transform: Tr) -> TransformedDataset<T, Self, Tr> {
61        TransformedDataset::new(self, transform)
62    }
63}
64
65impl<T, D: Dataset<T>> DatasetExt<T> for D {}