Skip to main content

torsh_nn/composition/
mod.rs

1//! Module composition system for neural network modules
2//!
3//! This module provides traits and implementations for composing modules
4//! in various patterns including sequential, parallel, residual, and conditional
5//! execution patterns.
6
7use torsh_core::device::DeviceType;
8use torsh_core::error::Result;
9use torsh_tensor::Tensor;
10
11// Conditional imports for std/no_std compatibility
12#[cfg(feature = "std")]
13use std::collections::HashMap;
14
15#[cfg(not(feature = "std"))]
16use hashbrown::HashMap;
17
18use crate::{Module, ModuleConfig, Parameter};
19
20/// Enhanced module building trait with fluent interface
21pub trait ModuleBuilder<T> {
22    /// Build the module
23    fn build(self) -> Result<T>;
24
25    /// Set training mode
26    fn training(self, training: bool) -> Self;
27
28    /// Set device
29    fn device(self, device: DeviceType) -> Self;
30
31    /// Add custom configuration
32    fn config<F>(self, config_fn: F) -> Self
33    where
34        F: FnOnce(&mut ModuleConfig);
35}
36
37/// Trait for modules that support functional composition
38pub trait ModuleComposition {
39    /// Compose this module with another module sequentially
40    ///
41    /// Creates a new module that applies self then other.
42    fn then<Other: Module + 'static>(self, other: Other) -> ComposedModule<Self, Other>
43    where
44        Self: Sized + 'static;
45
46    /// Compose this module with another module in parallel
47    ///
48    /// Creates a new module that applies both modules to the same input and combines results.
49    fn parallel<Other: Module + 'static>(self, other: Other) -> ParallelModule<Self, Other>
50    where
51        Self: Sized + 'static;
52
53    /// Add a residual connection
54    ///
55    /// Creates a new module that adds the input to the output of this module.
56    fn residual(self) -> ResidualModule<Self>
57    where
58        Self: Sized + 'static;
59
60    /// Add conditional execution
61    ///
62    /// Creates a new module that only applies this module when the condition is true.
63    fn conditional<F>(self, condition_fn: F) -> ConditionalModule<Self, F>
64    where
65        Self: Sized + 'static,
66        F: Fn() -> bool + Send + Sync;
67}
68
69/// Sequential composition of two modules
70pub struct ComposedModule<First, Second> {
71    first: First,
72    second: Second,
73}
74
75impl<First: Module, Second: Module> Module for ComposedModule<First, Second> {
76    fn forward(&self, input: &Tensor) -> Result<Tensor> {
77        let intermediate = self.first.forward(input)?;
78        self.second.forward(&intermediate)
79    }
80
81    fn parameters(&self) -> HashMap<String, Parameter> {
82        let mut params = self.first.parameters();
83        let second_params = self.second.parameters();
84        for (name, param) in second_params {
85            params.insert(format!("second.{}", name), param);
86        }
87        params
88    }
89
90    fn training(&self) -> bool {
91        self.first.training() && self.second.training()
92    }
93
94    fn set_training(&mut self, training: bool) {
95        self.first.set_training(training);
96        self.second.set_training(training);
97    }
98}
99
100/// Parallel composition of two modules
101pub struct ParallelModule<First, Second> {
102    first: First,
103    second: Second,
104}
105
106impl<First: Module, Second: Module> Module for ParallelModule<First, Second> {
107    fn forward(&self, input: &Tensor) -> Result<Tensor> {
108        let first_output = self.first.forward(input)?;
109        let _second_output = self.second.forward(input)?;
110        // In a full implementation, this would concatenate or combine the outputs
111        // For now, just return the first output
112        Ok(first_output)
113    }
114
115    fn parameters(&self) -> HashMap<String, Parameter> {
116        let mut params = HashMap::new();
117        for (name, param) in self.first.parameters() {
118            params.insert(format!("first.{name}"), param);
119        }
120        for (name, param) in self.second.parameters() {
121            params.insert(format!("second.{name}"), param);
122        }
123        params
124    }
125
126    fn training(&self) -> bool {
127        self.first.training() && self.second.training()
128    }
129
130    fn set_training(&mut self, training: bool) {
131        self.first.set_training(training);
132        self.second.set_training(training);
133    }
134}
135
136/// Module with residual connection
137pub struct ResidualModule<M> {
138    module: M,
139}
140
141impl<M: Module> Module for ResidualModule<M> {
142    fn forward(&self, input: &Tensor) -> Result<Tensor> {
143        let output = self.module.forward(input)?;
144        // In a full implementation, this would add input + output
145        // For now, just return the output
146        Ok(output)
147    }
148
149    fn parameters(&self) -> HashMap<String, Parameter> {
150        self.module.parameters()
151    }
152
153    fn training(&self) -> bool {
154        self.module.training()
155    }
156
157    fn set_training(&mut self, training: bool) {
158        self.module.set_training(training);
159    }
160}
161
162/// Module with conditional execution
163pub struct ConditionalModule<M, F> {
164    module: M,
165    condition_fn: F,
166}
167
168impl<M: Module, F: Fn() -> bool + Send + Sync> Module for ConditionalModule<M, F> {
169    fn forward(&self, input: &Tensor) -> Result<Tensor> {
170        if (self.condition_fn)() {
171            self.module.forward(input)
172        } else {
173            Ok(input.clone())
174        }
175    }
176
177    fn parameters(&self) -> HashMap<String, Parameter> {
178        self.module.parameters()
179    }
180
181    fn training(&self) -> bool {
182        self.module.training()
183    }
184
185    fn set_training(&mut self, training: bool) {
186        self.module.set_training(training);
187    }
188}
189
190/// Blanket implementation of ModuleComposition for all modules
191impl<T: Module + 'static> ModuleComposition for T {
192    fn then<Other: Module + 'static>(self, other: Other) -> ComposedModule<Self, Other> {
193        ComposedModule {
194            first: self,
195            second: other,
196        }
197    }
198
199    fn parallel<Other: Module + 'static>(self, other: Other) -> ParallelModule<Self, Other> {
200        ParallelModule {
201            first: self,
202            second: other,
203        }
204    }
205
206    fn residual(self) -> ResidualModule<Self> {
207        ResidualModule { module: self }
208    }
209
210    fn conditional<F>(self, condition_fn: F) -> ConditionalModule<Self, F>
211    where
212        F: Fn() -> bool + Send + Sync,
213    {
214        ConditionalModule {
215            module: self,
216            condition_fn,
217        }
218    }
219}