torsh_nn/construction/
mod.rs1use torsh_core::device::DeviceType;
7use torsh_core::error::Result;
8
9#[cfg(feature = "std")]
11use std::collections::HashMap;
12
13#[cfg(not(feature = "std"))]
14use hashbrown::HashMap;
15
16#[cfg(feature = "serialize")]
17use serde_json;
18
19pub trait ModuleConstruct {
24 type Output;
26
27 fn try_new() -> Result<Self::Output>;
31
32 fn new() -> Self::Output
37 where
38 Self::Output: Sized,
39 {
40 Self::try_new().expect("Module construction failed")
41 }
42
43 fn default() -> Result<Self::Output> {
48 Self::try_new()
49 }
50
51 fn with_config(_config: &ModuleConfig) -> Result<Self::Output> {
56 Self::try_new()
57 }
58}
59
60#[derive(Debug, Clone)]
65pub struct ModuleConfig {
66 pub training: bool,
68 pub device: DeviceType,
70 pub bias: bool,
72 pub dropout: f32,
74 #[cfg(feature = "serialize")]
76 pub custom: HashMap<String, serde_json::Value>,
77 #[cfg(not(feature = "serialize"))]
79 pub custom: HashMap<String, String>,
80}
81
82impl Default for ModuleConfig {
83 fn default() -> Self {
84 Self {
85 training: true,
86 device: DeviceType::Cpu,
87 bias: true,
88 dropout: 0.0,
89 custom: HashMap::new(),
90 }
91 }
92}
93
94impl ModuleConfig {
95 pub fn new() -> Self {
97 Self::default()
98 }
99
100 pub fn training(mut self, training: bool) -> Self {
102 self.training = training;
103 self
104 }
105
106 pub fn device(mut self, device: DeviceType) -> Self {
108 self.device = device;
109 self
110 }
111
112 pub fn bias(mut self, bias: bool) -> Self {
114 self.bias = bias;
115 self
116 }
117
118 pub fn dropout(mut self, dropout: f32) -> Self {
120 self.dropout = dropout;
121 self
122 }
123
124 #[cfg(feature = "serialize")]
126 pub fn custom_param<T: serde::Serialize>(mut self, name: &str, value: T) -> Self {
127 if let Ok(json_value) = serde_json::to_value(value) {
128 self.custom.insert(name.to_string(), json_value);
129 }
130 self
131 }
132
133 #[cfg(not(feature = "serialize"))]
135 pub fn custom_param<T: std::fmt::Display>(mut self, name: &str, value: T) -> Self {
136 self.custom.insert(name.to_string(), value.to_string());
137 self
138 }
139
140 #[cfg(feature = "serialize")]
142 pub fn get_custom<T: serde::de::DeserializeOwned>(&self, name: &str) -> Option<T> {
143 self.custom
144 .get(name)
145 .and_then(|v| serde_json::from_value(v.clone()).ok())
146 }
147
148 #[cfg(not(feature = "serialize"))]
150 pub fn get_custom(&self, name: &str) -> Option<String> {
151 self.custom.get(name).cloned()
152 }
153}
154
155#[macro_export]
157macro_rules! impl_module_constructor {
158 ($module_type:ty, $constructor:expr) => {
159 impl ModuleConstruct for $module_type {
160 type Output = $module_type;
161
162 fn try_new() -> Result<Self::Output> {
163 $constructor
164 }
165 }
166 };
167}