scirs2_neural/config/
mod.rs1mod schema;
8mod serialize;
9mod validation;
10pub use schema::*;
11pub use serialize::*;
12pub use validation::*;
13
14use crate::error::{Error, Result};
15use crate::models::architectures::{
16 BertConfig, CLIPConfig, ConvNeXtConfig, EfficientNetConfig, FeatureFusionConfig, GPTConfig,
17 MobileNetConfig, ResNetConfig, Seq2SeqConfig, ViTConfig,
18};
19use serde::{Deserialize, Serialize};
20use std::fs;
21use std::io::{Read, Write};
22use std::path::Path;
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26#[serde(tag = "model_type")]
27pub enum ModelConfig {
28 #[serde(rename = "resnet")]
30 ResNet(ResNetConfig),
31 #[serde(rename = "vit")]
33 ViT(ViTConfig),
34 #[serde(rename = "bert")]
36 Bert(BertConfig),
37 #[serde(rename = "gpt")]
39 GPT(GPTConfig),
40 #[serde(rename = "efficientnet")]
42 EfficientNet(EfficientNetConfig),
43 #[serde(rename = "mobilenet")]
45 MobileNet(MobileNetConfig),
46 #[serde(rename = "convnext")]
48 ConvNeXt(ConvNeXtConfig),
49 #[serde(rename = "clip")]
51 CLIP(CLIPConfig),
52 #[serde(rename = "feature_fusion")]
54 FeatureFusion(FeatureFusionConfig),
55 #[serde(rename = "seq2seq")]
57 Seq2Seq(Seq2SeqConfig),
58}
59
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
62pub enum ConfigFormat {
63 JSON,
65 YAML,
67}
68
69impl ModelConfig {
70 pub fn from_file<P: AsRef<Path>>(path: P, format: Option<ConfigFormat>) -> Result<Self> {
72 let path = path.as_ref();
73 let format = if let Some(fmt) = format {
75 fmt
76 } else if let Some(ext) = path.extension() {
77 if ext == "json" {
78 ConfigFormat::JSON
79 } else if ext == "yaml" || ext == "yml" {
80 ConfigFormat::YAML
81 } else {
82 return Err(Error::InvalidArgument(format!(
83 "Unsupported file extension: {:?}. Expected .json, .yaml, or .yml",
84 ext
85 )));
86 }
87 } else {
88 return Err(Error::InvalidArgument("File has no extension".to_string()));
89 };
90 let mut file = fs::File::open(path)
92 .map_err(|e| Error::IOError(format!("Failed to open config file: {}", e)))?;
93 let mut content = String::new();
94 file.read_to_string(&mut content)
95 .map_err(|e| Error::IOError(format!("Failed to read config file: {}", e)))?;
96 match format {
98 ConfigFormat::JSON => serde_json::from_str(&content)
99 .map_err(|e| Error::DeserializationError(format!("Failed to parse JSON: {}", e))),
100 ConfigFormat::YAML => serde_yaml::from_str(&content)
101 .map_err(|e| Error::DeserializationError(format!("Failed to parse YAML: {}", e))),
102 }
103 }
104
105 pub fn to_file<P: AsRef<Path>>(&self, path: P, format: Option<ConfigFormat>) -> Result<()> {
107 let path = path.as_ref();
108 if let Some(parent) = path.parent() {
110 fs::create_dir_all(parent)
111 .map_err(|e| Error::IOError(format!("Failed to create directory: {}", e)))?;
112 }
113 let format = if let Some(fmt) = format {
115 fmt
116 } else if let Some(ext) = path.extension() {
117 if ext == "json" {
118 ConfigFormat::JSON
119 } else if ext == "yaml" || ext == "yml" {
120 ConfigFormat::YAML
121 } else {
122 ConfigFormat::JSON
123 }
124 } else {
125 ConfigFormat::JSON
126 };
127 let mut file = fs::File::create(path)
129 .map_err(|e| Error::IOError(format!("Failed to create config file: {}", e)))?;
130 match format {
132 ConfigFormat::JSON => {
133 let content = serde_json::to_string_pretty(self).map_err(|e| {
134 Error::SerializationError(format!("Failed to serialize to JSON: {}", e))
135 })?;
136 file.write_all(content.as_bytes())
137 .map_err(|e| Error::IOError(format!("Failed to write config file: {}", e)))?;
138 }
139 ConfigFormat::YAML => {
140 let content = serde_yaml::to_string(self).map_err(|e| {
141 Error::SerializationError(format!("Failed to serialize to YAML: {}", e))
142 })?;
143 file.write_all(content.as_bytes())
144 .map_err(|e| Error::IOError(format!("Failed to write config file: {}", e)))?;
145 }
146 }
147 Ok(())
148 }
149
150 pub fn to_json(&self) -> Result<String> {
152 serde_json::to_string_pretty(self)
153 .map_err(|e| Error::SerializationError(format!("Failed to serialize to JSON: {}", e)))
154 }
155
156 pub fn to_yaml(&self) -> Result<String> {
158 serde_yaml::to_string(self)
159 .map_err(|e| Error::SerializationError(format!("Failed to serialize to YAML: {}", e)))
160 }
161
162 pub fn from_json(json: &str) -> Result<Self> {
164 serde_json::from_str(json)
165 .map_err(|e| Error::DeserializationError(format!("Failed to parse JSON: {}", e)))
166 }
167
168 pub fn from_yaml(yaml: &str) -> Result<Self> {
170 serde_yaml::from_str(yaml)
171 .map_err(|e| Error::DeserializationError(format!("Failed to parse YAML: {}", e)))
172 }
173
174 pub fn validate(&self) -> Result<()> {
176 validation::validate_model_config(self)
177 }
178
179 pub fn model_type(&self) -> &'static str {
181 match self {
182 ModelConfig::ResNet(_) => "resnet",
183 ModelConfig::ViT(_) => "vit",
184 ModelConfig::Bert(_) => "bert",
185 ModelConfig::GPT(_) => "gpt",
186 ModelConfig::EfficientNet(_) => "efficientnet",
187 ModelConfig::MobileNet(_) => "mobilenet",
188 ModelConfig::ConvNeXt(_) => "convnext",
189 ModelConfig::CLIP(_) => "clip",
190 ModelConfig::FeatureFusion(_) => "feature_fusion",
191 ModelConfig::Seq2Seq(_) => "seq2seq",
192 }
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199 use crate::models::architectures::{ResNetBlock, ResNetLayer};
200
201 fn make_resnet_config() -> ModelConfig {
202 ModelConfig::ResNet(ResNetConfig {
203 block: ResNetBlock::Basic,
204 layers: vec![ResNetLayer {
205 blocks: 2,
206 channels: 64,
207 stride: 1,
208 }],
209 input_channels: 3,
210 num_classes: 10,
211 dropout_rate: 0.0,
212 })
213 }
214
215 #[test]
216 fn test_config_json_roundtrip() {
217 let config = make_resnet_config();
218 let json = config.to_json().expect("serialization failed");
219 let restored: ModelConfig = ModelConfig::from_json(&json).expect("deserialization failed");
220 assert_eq!(restored.model_type(), "resnet");
221 }
222
223 #[test]
224 fn test_config_yaml_roundtrip() {
225 let config = make_resnet_config();
226 let yaml = config.to_yaml().expect("yaml serialization failed");
227 let restored: ModelConfig =
228 ModelConfig::from_yaml(&yaml).expect("yaml deserialization failed");
229 assert_eq!(restored.model_type(), "resnet");
230 }
231
232 #[test]
233 fn test_config_validation_resnet_valid() {
234 let config = make_resnet_config();
235 assert!(config.validate().is_ok());
236 }
237
238 #[test]
239 fn test_config_validation_resnet_invalid_channels() {
240 let config = ModelConfig::ResNet(ResNetConfig {
241 block: ResNetBlock::Basic,
242 layers: vec![],
243 input_channels: 0, num_classes: 10,
245 dropout_rate: 0.0,
246 });
247 assert!(config.validate().is_err());
248 }
249
250 #[test]
251 fn test_config_file_roundtrip() {
252 let config = make_resnet_config();
253 let tmp = std::env::temp_dir().join("scirs2_test_config.json");
254 config
255 .to_file(&tmp, Some(ConfigFormat::JSON))
256 .expect("write failed");
257 let loaded = ModelConfig::from_file(&tmp, Some(ConfigFormat::JSON)).expect("read failed");
258 assert_eq!(loaded.model_type(), "resnet");
259 let _ = std::fs::remove_file(&tmp);
260 }
261
262 #[test]
263 fn test_model_type_names() {
264 assert_eq!(make_resnet_config().model_type(), "resnet");
265 }
266}