Skip to main content

torsh_data/
core_framework.rs

1//! Core transform framework for data transformations
2//!
3//! This module provides the fundamental building blocks for data transformations,
4//! including core traits, combinators, and basic transform implementations.
5//!
6//! # Features
7//!
8//! - **Transform trait**: Core abstraction for data transformations
9//! - **Transform combinators**: Chain, conditional, and composition operations
10//! - **Builder pattern**: TransformBuilder trait for complex transform construction
11//! - **Extension traits**: Convenient chainable API via TransformExt
12//! - **Basic transforms**: Normalize, type conversion, and lambda transforms
13
14use torsh_core::{
15    dtype::TensorElement,
16    error::{Result, TorshError},
17};
18use torsh_tensor::Tensor;
19
20#[cfg(not(feature = "std"))]
21use alloc::{boxed::Box, string::String, vec::Vec};
22
23/// Trait for data transformations
24///
25/// This is the core abstraction for all data transformations in the ToRSh ecosystem.
26/// Implementations should be stateless where possible and thread-safe.
27pub trait Transform<T>: Send + Sync {
28    /// Output type after transformation
29    type Output;
30
31    /// Apply the transformation to a single input
32    fn transform(&self, input: T) -> Result<Self::Output>;
33
34    /// Transform multiple items in batch
35    ///
36    /// Default implementation applies transform individually, but implementations
37    /// can override this for more efficient batch processing.
38    fn transform_batch(&self, inputs: Vec<T>) -> Result<Vec<Self::Output>> {
39        inputs
40            .into_iter()
41            .map(|input| self.transform(input))
42            .collect()
43    }
44
45    /// Check if the transform is deterministic
46    ///
47    /// A deterministic transform always produces the same output for the same input.
48    /// Non-deterministic transforms include random augmentations.
49    fn is_deterministic(&self) -> bool {
50        true
51    }
52}
53
54/// Builder trait for transforms with configuration options
55pub trait TransformBuilder {
56    /// The transform type this builder creates
57    type Transform;
58
59    /// Build the configured transform
60    fn build(self) -> Self::Transform;
61}
62
63/// Macro to create simple stateless transforms
64///
65/// This macro generates a transform struct and implementation for simple cases
66/// where the transform logic can be expressed as a function.
67#[macro_export]
68macro_rules! simple_transform {
69    ($name:ident, $input:ty, $output:ty, $transform_fn:expr) => {
70        /// Auto-generated simple transform
71        #[derive(Clone, Debug, Default)]
72        pub struct $name;
73
74        impl $crate::core_framework::Transform<$input> for $name {
75            type Output = $output;
76
77            fn transform(&self, input: $input) -> $crate::core_framework::Result<Self::Output> {
78                Ok($transform_fn(input))
79            }
80        }
81    };
82
83    ($name:ident, $input:ty, $output:ty, $transform_fn:expr, deterministic = $det:literal) => {
84        /// Auto-generated simple transform with determinism setting
85        #[derive(Clone, Debug, Default)]
86        pub struct $name;
87
88        impl $crate::core_framework::Transform<$input> for $name {
89            type Output = $output;
90
91            fn transform(&self, input: $input) -> $crate::core_framework::Result<Self::Output> {
92                Ok($transform_fn(input))
93            }
94
95            fn is_deterministic(&self) -> bool {
96                $det
97            }
98        }
99    };
100}
101
102/// Extension trait for chainable transform operations
103pub trait TransformExt<T>: Transform<T> + Sized + 'static {
104    /// Chain this transform with another
105    ///
106    /// Creates a new transform that applies this transform first, then the next.
107    fn then<U>(self, next: U) -> Chain<Self, U>
108    where
109        U: Transform<Self::Output>,
110    {
111        Chain::new(self, next)
112    }
113
114    /// Apply this transform conditionally based on a predicate
115    ///
116    /// The transform is only applied if the predicate returns true for the input.
117    fn when<P>(self, predicate: P) -> Conditional<Self, P>
118    where
119        P: Fn(&T) -> bool + Send + Sync,
120    {
121        Conditional::new(self, predicate)
122    }
123
124    /// Convert to a boxed trait object for dynamic dispatch
125    fn boxed(self) -> Box<dyn Transform<T, Output = Self::Output> + Send + Sync> {
126        Box::new(self)
127    }
128}
129
130// Blanket implementation for all transforms
131impl<T, U: Transform<T> + 'static> TransformExt<T> for U {}
132
133/// Chain two transforms together sequentially
134#[derive(Debug, Clone)]
135pub struct Chain<T1, T2> {
136    first: T1,
137    second: T2,
138}
139
140impl<T1, T2> Chain<T1, T2> {
141    /// Create a new chain of transforms
142    pub fn new(first: T1, second: T2) -> Self {
143        Self { first, second }
144    }
145}
146
147impl<T, T1, T2> Transform<T> for Chain<T1, T2>
148where
149    T1: Transform<T>,
150    T2: Transform<T1::Output>,
151{
152    type Output = T2::Output;
153
154    fn transform(&self, input: T) -> Result<Self::Output> {
155        let intermediate = self.first.transform(input)?;
156        self.second.transform(intermediate)
157    }
158
159    fn is_deterministic(&self) -> bool {
160        self.first.is_deterministic() && self.second.is_deterministic()
161    }
162}
163
164/// Conditionally apply a transform based on a predicate
165#[derive(Debug, Clone)]
166pub struct Conditional<T, P> {
167    transform: T,
168    predicate: P,
169}
170
171impl<T, P> Conditional<T, P> {
172    /// Create a new conditional transform
173    pub fn new(transform: T, predicate: P) -> Self {
174        Self {
175            transform,
176            predicate,
177        }
178    }
179}
180
181impl<T, U, P> Transform<T> for Conditional<U, P>
182where
183    U: Transform<T, Output = T>,
184    P: Fn(&T) -> bool + Send + Sync,
185{
186    type Output = T;
187
188    fn transform(&self, input: T) -> Result<Self::Output> {
189        if (self.predicate)(&input) {
190            self.transform.transform(input)
191        } else {
192            Ok(input)
193        }
194    }
195
196    fn is_deterministic(&self) -> bool {
197        self.transform.is_deterministic()
198    }
199}
200
201/// Compose multiple transforms that operate on the same type
202pub struct Compose<T> {
203    transforms: Vec<Box<dyn Transform<T, Output = T> + Send + Sync>>,
204}
205
206impl<T> Compose<T> {
207    /// Create a new compose transform from a vector of transforms
208    pub fn new(transforms: Vec<Box<dyn Transform<T, Output = T> + Send + Sync>>) -> Self {
209        Self { transforms }
210    }
211
212    /// Add a transform to the composition
213    pub fn add<U>(&mut self, transform: U)
214    where
215        U: Transform<T, Output = T> + Send + Sync + 'static,
216    {
217        self.transforms.push(Box::new(transform));
218    }
219
220    /// Get the number of transforms in the composition
221    pub fn len(&self) -> usize {
222        self.transforms.len()
223    }
224
225    /// Check if the composition is empty
226    pub fn is_empty(&self) -> bool {
227        self.transforms.is_empty()
228    }
229}
230
231impl<T> Transform<T> for Compose<T> {
232    type Output = T;
233
234    fn transform(&self, mut input: T) -> Result<Self::Output> {
235        for transform in &self.transforms {
236            input = transform.transform(input)?;
237        }
238        Ok(input)
239    }
240
241    fn is_deterministic(&self) -> bool {
242        self.transforms.iter().all(|t| t.is_deterministic())
243    }
244}
245
246/// Normalize tensor values using mean and standard deviation
247#[derive(Debug, Clone)]
248pub struct Normalize<T: TensorElement> {
249    #[allow(dead_code)] // Used in future full implementation
250    mean: Vec<T>,
251    #[allow(dead_code)] // Used in future full implementation
252    std: Vec<T>,
253}
254
255impl<T: TensorElement> Normalize<T> {
256    /// Create a new normalize transform
257    pub fn new(mean: Vec<T>, std: Vec<T>) -> Result<Self> {
258        if mean.len() != std.len() {
259            return Err(TorshError::InvalidArgument(
260                "Mean and std vectors must have the same length".to_string(),
261            ));
262        }
263        Ok(Self { mean, std })
264    }
265}
266
267impl<T: TensorElement> Transform<Tensor<T>> for Normalize<T> {
268    type Output = Tensor<T>;
269
270    fn transform(&self, input: Tensor<T>) -> Result<Self::Output> {
271        // Placeholder implementation - real normalization would require tensor operations
272        // For now, just return the input tensor
273        // NOTE: tracing disabled (not a dependency)
274        // tracing::debug!(
275        //     "Normalize transform applied with {} channels",
276        //     self.mean.len()
277        // );
278        Ok(input)
279    }
280}
281
282/// Convert tensor from one type to another
283#[derive(Debug, Clone)]
284pub struct ToType<From, To> {
285    _phantom: core::marker::PhantomData<(From, To)>,
286}
287
288impl<From, To> Default for ToType<From, To> {
289    fn default() -> Self {
290        Self::new()
291    }
292}
293
294impl<From, To> ToType<From, To> {
295    /// Create a new type conversion transform
296    pub fn new() -> Self {
297        Self {
298            _phantom: core::marker::PhantomData,
299        }
300    }
301}
302
303impl<From: TensorElement, To: TensorElement> Transform<Tensor<From>> for ToType<From, To> {
304    type Output = Tensor<To>;
305
306    fn transform(&self, _input: Tensor<From>) -> Result<Self::Output> {
307        // Placeholder implementation - real type conversion would require tensor operations
308        // For now, create a new tensor with the target type (this is a simplification)
309        // NOTE: tracing disabled (not a dependency)
310        // tracing::debug!(
311        //     "Type conversion from {} to {} requested",
312        //     core::any::type_name::<From>(),
313        //     core::any::type_name::<To>()
314        // );
315
316        // In a real implementation, this would convert the tensor data
317        // For now, we return an error as this requires complex tensor operations
318        Err(TorshError::InvalidArgument(
319            "Type conversion not yet implemented".to_string(),
320        ))
321    }
322}
323
324/// Apply a custom function as a transform
325#[derive(Debug)]
326pub struct Lambda<F> {
327    func: F,
328}
329
330impl<F> Lambda<F> {
331    /// Create a new lambda transform
332    pub fn new(func: F) -> Self {
333        Self { func }
334    }
335}
336
337impl<T, O, F> Transform<T> for Lambda<F>
338where
339    F: Fn(T) -> Result<O> + Send + Sync,
340{
341    type Output = O;
342
343    fn transform(&self, input: T) -> Result<Self::Output> {
344        (self.func)(input)
345    }
346
347    fn is_deterministic(&self) -> bool {
348        // Lambda functions are assumed to be deterministic unless specified otherwise
349        true
350    }
351}
352
353/// Convenience function to create a normalize transform
354pub fn normalize<T: TensorElement>(mean: Vec<T>, std: Vec<T>) -> Result<Normalize<T>> {
355    Normalize::new(mean, std)
356}
357
358/// Convenience function to create a type conversion transform
359pub fn to_type<From: TensorElement, To: TensorElement>() -> ToType<From, To> {
360    ToType::new()
361}
362
363/// Convenience function to create a lambda transform
364pub fn lambda<F>(func: F) -> Lambda<F> {
365    Lambda::new(func)
366}
367
368/// Convenience function to create a composition transform
369pub fn compose<T>(transforms: Vec<Box<dyn Transform<T, Output = T> + Send + Sync>>) -> Compose<T> {
370    Compose::new(transforms)
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376
377    // Mock tensor for testing
378    #[allow(dead_code)]
379    fn mock_tensor() -> Tensor<f32> {
380        Tensor::from_data(
381            vec![1.0f32, 2.0, 3.0, 4.0],
382            vec![2, 2],
383            torsh_core::device::DeviceType::Cpu,
384        )
385        .unwrap()
386    }
387
388    #[test]
389    fn test_chain_transform() {
390        let lambda1 = lambda(|x: i32| Ok(x * 2));
391        let lambda2 = lambda(|x: i32| Ok(x + 1));
392
393        let chained = lambda1.then(lambda2);
394        let result = chained.transform(5).unwrap();
395        assert_eq!(result, 11); // (5 * 2) + 1 = 11
396    }
397
398    #[test]
399    fn test_conditional_transform() {
400        let double = lambda(|x: i32| Ok(x * 2));
401        let conditional = double.when(|&x| x > 5);
402
403        assert_eq!(conditional.transform(3).unwrap(), 3); // Not applied
404        assert_eq!(conditional.transform(7).unwrap(), 14); // Applied
405    }
406
407    #[test]
408    fn test_compose_transform() {
409        let lambda1 = lambda(|x: i32| Ok(x + 1));
410        let lambda2 = lambda(|x: i32| Ok(x * 2));
411
412        let mut composition = Compose::new(vec![]);
413        composition.add(lambda1);
414        composition.add(lambda2);
415
416        let result = composition.transform(5).unwrap();
417        assert_eq!(result, 12); // ((5 + 1) * 2) = 12
418    }
419
420    #[test]
421    fn test_normalize_creation() {
422        let mean = vec![0.485f32, 0.456, 0.406];
423        let std = vec![0.229f32, 0.224, 0.225];
424
425        let normalize_transform = normalize(mean, std);
426        assert!(normalize_transform.is_ok());
427    }
428
429    #[test]
430    fn test_normalize_invalid_dimensions() {
431        let mean = vec![0.485f32, 0.456];
432        let std = vec![0.229f32, 0.224, 0.225];
433
434        let normalize_transform = normalize(mean, std);
435        assert!(normalize_transform.is_err());
436    }
437
438    #[test]
439    fn test_determinism() {
440        let deterministic = lambda(|x: i32| Ok(x + 1));
441        assert!(deterministic.is_deterministic());
442
443        let chain = deterministic.then(lambda(|x: i32| Ok(x * 2)));
444        assert!(chain.is_deterministic());
445    }
446}