Skip to main content

torsh_nn/construction/
mod.rs

1//! Module construction and configuration system
2//!
3//! This module provides standardized construction patterns and configuration
4//! management for neural network modules.
5
6use torsh_core::device::DeviceType;
7use torsh_core::error::Result;
8
9// Conditional imports for std/no_std compatibility
10#[cfg(feature = "std")]
11use std::collections::HashMap;
12
13#[cfg(not(feature = "std"))]
14use hashbrown::HashMap;
15
16#[cfg(feature = "serialize")]
17use serde_json;
18
19/// Helper trait for module construction patterns
20///
21/// This trait provides standardized construction patterns for modules,
22/// ensuring consistent error handling and ergonomics across all implementations.
23pub trait ModuleConstruct {
24    /// Type returned by the constructor
25    type Output;
26
27    /// Attempt to create a module, returning Result for error handling
28    ///
29    /// This is the primary constructor that should be implemented.
30    fn try_new() -> Result<Self::Output>;
31
32    /// Create a module with panic on error (for convenience)
33    ///
34    /// This method provides a convenient interface for cases where
35    /// construction failure is not expected.
36    fn new() -> Self::Output
37    where
38        Self::Output: Sized,
39    {
40        Self::try_new().expect("Module construction failed")
41    }
42
43    /// Create a module with default parameters
44    ///
45    /// Default implementation delegates to `try_new()`. Override if your
46    /// module supports different default configurations.
47    fn default() -> Result<Self::Output> {
48        Self::try_new()
49    }
50
51    /// Create a module with a specific configuration
52    ///
53    /// Default implementation delegates to `try_new()`. Override if your
54    /// module supports configuration-based construction.
55    fn with_config(_config: &ModuleConfig) -> Result<Self::Output> {
56        Self::try_new()
57    }
58}
59
60/// Generic module configuration
61///
62/// This provides a standard configuration interface that can be extended
63/// by specific module types.
64#[derive(Debug, Clone)]
65pub struct ModuleConfig {
66    /// Training mode
67    pub training: bool,
68    /// Target device
69    pub device: DeviceType,
70    /// Whether to use bias terms
71    pub bias: bool,
72    /// Dropout probability
73    pub dropout: f32,
74    /// Custom parameters
75    #[cfg(feature = "serialize")]
76    pub custom: HashMap<String, serde_json::Value>,
77    /// Custom parameters (placeholder when serialize feature is disabled)
78    #[cfg(not(feature = "serialize"))]
79    pub custom: HashMap<String, String>,
80}
81
82impl Default for ModuleConfig {
83    fn default() -> Self {
84        Self {
85            training: true,
86            device: DeviceType::Cpu,
87            bias: true,
88            dropout: 0.0,
89            custom: HashMap::new(),
90        }
91    }
92}
93
94impl ModuleConfig {
95    /// Create a new configuration with default values
96    pub fn new() -> Self {
97        Self::default()
98    }
99
100    /// Set training mode
101    pub fn training(mut self, training: bool) -> Self {
102        self.training = training;
103        self
104    }
105
106    /// Set device
107    pub fn device(mut self, device: DeviceType) -> Self {
108        self.device = device;
109        self
110    }
111
112    /// Set bias usage
113    pub fn bias(mut self, bias: bool) -> Self {
114        self.bias = bias;
115        self
116    }
117
118    /// Set dropout probability
119    pub fn dropout(mut self, dropout: f32) -> Self {
120        self.dropout = dropout;
121        self
122    }
123
124    /// Add a custom parameter
125    #[cfg(feature = "serialize")]
126    pub fn custom_param<T: serde::Serialize>(mut self, name: &str, value: T) -> Self {
127        if let Ok(json_value) = serde_json::to_value(value) {
128            self.custom.insert(name.to_string(), json_value);
129        }
130        self
131    }
132
133    /// Add a custom parameter (simplified version when serialize feature is disabled)
134    #[cfg(not(feature = "serialize"))]
135    pub fn custom_param<T: std::fmt::Display>(mut self, name: &str, value: T) -> Self {
136        self.custom.insert(name.to_string(), value.to_string());
137        self
138    }
139
140    /// Get a custom parameter
141    #[cfg(feature = "serialize")]
142    pub fn get_custom<T: serde::de::DeserializeOwned>(&self, name: &str) -> Option<T> {
143        self.custom
144            .get(name)
145            .and_then(|v| serde_json::from_value(v.clone()).ok())
146    }
147
148    /// Get a custom parameter (simplified version when serialize feature is disabled)
149    #[cfg(not(feature = "serialize"))]
150    pub fn get_custom(&self, name: &str) -> Option<String> {
151        self.custom.get(name).cloned()
152    }
153}
154
155/// Macro to implement standardized constructors
156#[macro_export]
157macro_rules! impl_module_constructor {
158    ($module_type:ty, $constructor:expr) => {
159        impl ModuleConstruct for $module_type {
160            type Output = $module_type;
161
162            fn try_new() -> Result<Self::Output> {
163                $constructor
164            }
165        }
166    };
167}