tenflowers_dataset/transforms/
mod.rs1use crate::Dataset;
7use std::marker::PhantomData;
8use tenflowers_core::{Result, Tensor};
9
10pub mod augmentation;
12pub mod feature_engineering;
13pub mod noise;
14pub mod normalization;
15pub mod pipeline;
16pub mod profiling;
17pub mod vision;
18
19pub use noise::*;
21pub use normalization::*;
22
23pub trait Transform<T> {
25 fn apply(&self, sample: (Tensor<T>, Tensor<T>)) -> Result<(Tensor<T>, Tensor<T>)>;
27}
28
29pub struct TransformedDataset<T, D: Dataset<T>, Tr: Transform<T>> {
31 dataset: D,
32 transform: Tr,
33 _phantom: PhantomData<T>,
34}
35
36impl<T, D: Dataset<T>, Tr: Transform<T>> TransformedDataset<T, D, Tr> {
37 pub fn new(dataset: D, transform: Tr) -> Self {
38 Self {
39 dataset,
40 transform,
41 _phantom: PhantomData,
42 }
43 }
44}
45
46impl<T, D: Dataset<T>, Tr: Transform<T>> Dataset<T> for TransformedDataset<T, D, Tr> {
47 fn len(&self) -> usize {
48 self.dataset.len()
49 }
50
51 fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
52 let sample = self.dataset.get(index)?;
53 self.transform.apply(sample)
54 }
55}
56
57pub trait DatasetExt<T>: Dataset<T> + Sized {
59 fn transform<Tr: Transform<T>>(self, transform: Tr) -> TransformedDataset<T, Self, Tr> {
61 TransformedDataset::new(self, transform)
62 }
63}
64
65impl<T, D: Dataset<T>> DatasetExt<T> for D {}