1use torsh_core::{
15 dtype::TensorElement,
16 error::{Result, TorshError},
17};
18use torsh_tensor::Tensor;
19
20#[cfg(not(feature = "std"))]
21use alloc::{boxed::Box, string::String, vec::Vec};
22
23pub trait Transform<T>: Send + Sync {
28 type Output;
30
31 fn transform(&self, input: T) -> Result<Self::Output>;
33
34 fn transform_batch(&self, inputs: Vec<T>) -> Result<Vec<Self::Output>> {
39 inputs
40 .into_iter()
41 .map(|input| self.transform(input))
42 .collect()
43 }
44
45 fn is_deterministic(&self) -> bool {
50 true
51 }
52}
53
54pub trait TransformBuilder {
56 type Transform;
58
59 fn build(self) -> Self::Transform;
61}
62
63#[macro_export]
68macro_rules! simple_transform {
69 ($name:ident, $input:ty, $output:ty, $transform_fn:expr) => {
70 #[derive(Clone, Debug, Default)]
72 pub struct $name;
73
74 impl $crate::core_framework::Transform<$input> for $name {
75 type Output = $output;
76
77 fn transform(&self, input: $input) -> $crate::core_framework::Result<Self::Output> {
78 Ok($transform_fn(input))
79 }
80 }
81 };
82
83 ($name:ident, $input:ty, $output:ty, $transform_fn:expr, deterministic = $det:literal) => {
84 #[derive(Clone, Debug, Default)]
86 pub struct $name;
87
88 impl $crate::core_framework::Transform<$input> for $name {
89 type Output = $output;
90
91 fn transform(&self, input: $input) -> $crate::core_framework::Result<Self::Output> {
92 Ok($transform_fn(input))
93 }
94
95 fn is_deterministic(&self) -> bool {
96 $det
97 }
98 }
99 };
100}
101
102pub trait TransformExt<T>: Transform<T> + Sized + 'static {
104 fn then<U>(self, next: U) -> Chain<Self, U>
108 where
109 U: Transform<Self::Output>,
110 {
111 Chain::new(self, next)
112 }
113
114 fn when<P>(self, predicate: P) -> Conditional<Self, P>
118 where
119 P: Fn(&T) -> bool + Send + Sync,
120 {
121 Conditional::new(self, predicate)
122 }
123
124 fn boxed(self) -> Box<dyn Transform<T, Output = Self::Output> + Send + Sync> {
126 Box::new(self)
127 }
128}
129
130impl<T, U: Transform<T> + 'static> TransformExt<T> for U {}
132
133#[derive(Debug, Clone)]
135pub struct Chain<T1, T2> {
136 first: T1,
137 second: T2,
138}
139
140impl<T1, T2> Chain<T1, T2> {
141 pub fn new(first: T1, second: T2) -> Self {
143 Self { first, second }
144 }
145}
146
147impl<T, T1, T2> Transform<T> for Chain<T1, T2>
148where
149 T1: Transform<T>,
150 T2: Transform<T1::Output>,
151{
152 type Output = T2::Output;
153
154 fn transform(&self, input: T) -> Result<Self::Output> {
155 let intermediate = self.first.transform(input)?;
156 self.second.transform(intermediate)
157 }
158
159 fn is_deterministic(&self) -> bool {
160 self.first.is_deterministic() && self.second.is_deterministic()
161 }
162}
163
164#[derive(Debug, Clone)]
166pub struct Conditional<T, P> {
167 transform: T,
168 predicate: P,
169}
170
171impl<T, P> Conditional<T, P> {
172 pub fn new(transform: T, predicate: P) -> Self {
174 Self {
175 transform,
176 predicate,
177 }
178 }
179}
180
181impl<T, U, P> Transform<T> for Conditional<U, P>
182where
183 U: Transform<T, Output = T>,
184 P: Fn(&T) -> bool + Send + Sync,
185{
186 type Output = T;
187
188 fn transform(&self, input: T) -> Result<Self::Output> {
189 if (self.predicate)(&input) {
190 self.transform.transform(input)
191 } else {
192 Ok(input)
193 }
194 }
195
196 fn is_deterministic(&self) -> bool {
197 self.transform.is_deterministic()
198 }
199}
200
201pub struct Compose<T> {
203 transforms: Vec<Box<dyn Transform<T, Output = T> + Send + Sync>>,
204}
205
206impl<T> Compose<T> {
207 pub fn new(transforms: Vec<Box<dyn Transform<T, Output = T> + Send + Sync>>) -> Self {
209 Self { transforms }
210 }
211
212 pub fn add<U>(&mut self, transform: U)
214 where
215 U: Transform<T, Output = T> + Send + Sync + 'static,
216 {
217 self.transforms.push(Box::new(transform));
218 }
219
220 pub fn len(&self) -> usize {
222 self.transforms.len()
223 }
224
225 pub fn is_empty(&self) -> bool {
227 self.transforms.is_empty()
228 }
229}
230
231impl<T> Transform<T> for Compose<T> {
232 type Output = T;
233
234 fn transform(&self, mut input: T) -> Result<Self::Output> {
235 for transform in &self.transforms {
236 input = transform.transform(input)?;
237 }
238 Ok(input)
239 }
240
241 fn is_deterministic(&self) -> bool {
242 self.transforms.iter().all(|t| t.is_deterministic())
243 }
244}
245
246#[derive(Debug, Clone)]
248pub struct Normalize<T: TensorElement> {
249 #[allow(dead_code)] mean: Vec<T>,
251 #[allow(dead_code)] std: Vec<T>,
253}
254
255impl<T: TensorElement> Normalize<T> {
256 pub fn new(mean: Vec<T>, std: Vec<T>) -> Result<Self> {
258 if mean.len() != std.len() {
259 return Err(TorshError::InvalidArgument(
260 "Mean and std vectors must have the same length".to_string(),
261 ));
262 }
263 Ok(Self { mean, std })
264 }
265}
266
267impl<T: TensorElement> Transform<Tensor<T>> for Normalize<T> {
268 type Output = Tensor<T>;
269
270 fn transform(&self, input: Tensor<T>) -> Result<Self::Output> {
271 Ok(input)
279 }
280}
281
282#[derive(Debug, Clone)]
284pub struct ToType<From, To> {
285 _phantom: core::marker::PhantomData<(From, To)>,
286}
287
288impl<From, To> Default for ToType<From, To> {
289 fn default() -> Self {
290 Self::new()
291 }
292}
293
294impl<From, To> ToType<From, To> {
295 pub fn new() -> Self {
297 Self {
298 _phantom: core::marker::PhantomData,
299 }
300 }
301}
302
303impl<From: TensorElement, To: TensorElement> Transform<Tensor<From>> for ToType<From, To> {
304 type Output = Tensor<To>;
305
306 fn transform(&self, _input: Tensor<From>) -> Result<Self::Output> {
307 Err(TorshError::InvalidArgument(
319 "Type conversion not yet implemented".to_string(),
320 ))
321 }
322}
323
324#[derive(Debug)]
326pub struct Lambda<F> {
327 func: F,
328}
329
330impl<F> Lambda<F> {
331 pub fn new(func: F) -> Self {
333 Self { func }
334 }
335}
336
337impl<T, O, F> Transform<T> for Lambda<F>
338where
339 F: Fn(T) -> Result<O> + Send + Sync,
340{
341 type Output = O;
342
343 fn transform(&self, input: T) -> Result<Self::Output> {
344 (self.func)(input)
345 }
346
347 fn is_deterministic(&self) -> bool {
348 true
350 }
351}
352
353pub fn normalize<T: TensorElement>(mean: Vec<T>, std: Vec<T>) -> Result<Normalize<T>> {
355 Normalize::new(mean, std)
356}
357
358pub fn to_type<From: TensorElement, To: TensorElement>() -> ToType<From, To> {
360 ToType::new()
361}
362
363pub fn lambda<F>(func: F) -> Lambda<F> {
365 Lambda::new(func)
366}
367
368pub fn compose<T>(transforms: Vec<Box<dyn Transform<T, Output = T> + Send + Sync>>) -> Compose<T> {
370 Compose::new(transforms)
371}
372
373#[cfg(test)]
374mod tests {
375 use super::*;
376
377 #[allow(dead_code)]
379 fn mock_tensor() -> Tensor<f32> {
380 Tensor::from_data(
381 vec![1.0f32, 2.0, 3.0, 4.0],
382 vec![2, 2],
383 torsh_core::device::DeviceType::Cpu,
384 )
385 .unwrap()
386 }
387
388 #[test]
389 fn test_chain_transform() {
390 let lambda1 = lambda(|x: i32| Ok(x * 2));
391 let lambda2 = lambda(|x: i32| Ok(x + 1));
392
393 let chained = lambda1.then(lambda2);
394 let result = chained.transform(5).unwrap();
395 assert_eq!(result, 11); }
397
398 #[test]
399 fn test_conditional_transform() {
400 let double = lambda(|x: i32| Ok(x * 2));
401 let conditional = double.when(|&x| x > 5);
402
403 assert_eq!(conditional.transform(3).unwrap(), 3); assert_eq!(conditional.transform(7).unwrap(), 14); }
406
407 #[test]
408 fn test_compose_transform() {
409 let lambda1 = lambda(|x: i32| Ok(x + 1));
410 let lambda2 = lambda(|x: i32| Ok(x * 2));
411
412 let mut composition = Compose::new(vec![]);
413 composition.add(lambda1);
414 composition.add(lambda2);
415
416 let result = composition.transform(5).unwrap();
417 assert_eq!(result, 12); }
419
420 #[test]
421 fn test_normalize_creation() {
422 let mean = vec![0.485f32, 0.456, 0.406];
423 let std = vec![0.229f32, 0.224, 0.225];
424
425 let normalize_transform = normalize(mean, std);
426 assert!(normalize_transform.is_ok());
427 }
428
429 #[test]
430 fn test_normalize_invalid_dimensions() {
431 let mean = vec![0.485f32, 0.456];
432 let std = vec![0.229f32, 0.224, 0.225];
433
434 let normalize_transform = normalize(mean, std);
435 assert!(normalize_transform.is_err());
436 }
437
438 #[test]
439 fn test_determinism() {
440 let deterministic = lambda(|x: i32| Ok(x + 1));
441 assert!(deterministic.is_deterministic());
442
443 let chain = deterministic.then(lambda(|x: i32| Ok(x * 2)));
444 assert!(chain.is_deterministic());
445 }
446}