Skip to main content

scirs2_neural/config/
mod.rs

1//! Neural network model configuration system
2//!
3//! This module provides utilities for loading, saving, and validating model
4//! configurations using JSON and YAML formats. It enables flexible model
5//! creation and reproducibility.
6
7mod schema;
8mod serialize;
9mod validation;
10pub use schema::*;
11pub use serialize::*;
12pub use validation::*;
13
14use crate::error::{Error, Result};
15use crate::models::architectures::{
16    BertConfig, CLIPConfig, ConvNeXtConfig, EfficientNetConfig, FeatureFusionConfig, GPTConfig,
17    MobileNetConfig, ResNetConfig, Seq2SeqConfig, ViTConfig,
18};
19use serde::{Deserialize, Serialize};
20use std::fs;
21use std::io::{Read, Write};
22use std::path::Path;
23
24/// Model configuration container
25#[derive(Debug, Clone, Serialize, Deserialize)]
26#[serde(tag = "model_type")]
27pub enum ModelConfig {
28    /// ResNet configuration
29    #[serde(rename = "resnet")]
30    ResNet(ResNetConfig),
31    /// Vision Transformer configuration
32    #[serde(rename = "vit")]
33    ViT(ViTConfig),
34    /// BERT configuration
35    #[serde(rename = "bert")]
36    Bert(BertConfig),
37    /// GPT configuration
38    #[serde(rename = "gpt")]
39    GPT(GPTConfig),
40    /// EfficientNet configuration
41    #[serde(rename = "efficientnet")]
42    EfficientNet(EfficientNetConfig),
43    /// MobileNet configuration
44    #[serde(rename = "mobilenet")]
45    MobileNet(MobileNetConfig),
46    /// ConvNeXt configuration
47    #[serde(rename = "convnext")]
48    ConvNeXt(ConvNeXtConfig),
49    /// CLIP configuration
50    #[serde(rename = "clip")]
51    CLIP(CLIPConfig),
52    /// Feature Fusion configuration
53    #[serde(rename = "feature_fusion")]
54    FeatureFusion(FeatureFusionConfig),
55    /// Seq2Seq configuration
56    #[serde(rename = "seq2seq")]
57    Seq2Seq(Seq2SeqConfig),
58}
59
60/// Format for configuration files
61#[derive(Debug, Clone, Copy, PartialEq, Eq)]
62pub enum ConfigFormat {
63    /// JSON format
64    JSON,
65    /// YAML format
66    YAML,
67}
68
69impl ModelConfig {
70    /// Load a model configuration from a file
71    pub fn from_file<P: AsRef<Path>>(path: P, format: Option<ConfigFormat>) -> Result<Self> {
72        let path = path.as_ref();
73        // Determine format from extension if not specified
74        let format = if let Some(fmt) = format {
75            fmt
76        } else if let Some(ext) = path.extension() {
77            if ext == "json" {
78                ConfigFormat::JSON
79            } else if ext == "yaml" || ext == "yml" {
80                ConfigFormat::YAML
81            } else {
82                return Err(Error::InvalidArgument(format!(
83                    "Unsupported file extension: {:?}. Expected .json, .yaml, or .yml",
84                    ext
85                )));
86            }
87        } else {
88            return Err(Error::InvalidArgument("File has no extension".to_string()));
89        };
90        // Read file content
91        let mut file = fs::File::open(path)
92            .map_err(|e| Error::IOError(format!("Failed to open config file: {}", e)))?;
93        let mut content = String::new();
94        file.read_to_string(&mut content)
95            .map_err(|e| Error::IOError(format!("Failed to read config file: {}", e)))?;
96        // Parse based on format
97        match format {
98            ConfigFormat::JSON => serde_json::from_str(&content)
99                .map_err(|e| Error::DeserializationError(format!("Failed to parse JSON: {}", e))),
100            ConfigFormat::YAML => serde_yaml::from_str(&content)
101                .map_err(|e| Error::DeserializationError(format!("Failed to parse YAML: {}", e))),
102        }
103    }
104
105    /// Save a model configuration to a file
106    pub fn to_file<P: AsRef<Path>>(&self, path: P, format: Option<ConfigFormat>) -> Result<()> {
107        let path = path.as_ref();
108        // Create directory if needed
109        if let Some(parent) = path.parent() {
110            fs::create_dir_all(parent)
111                .map_err(|e| Error::IOError(format!("Failed to create directory: {}", e)))?;
112        }
113        // Determine format from extension if not specified
114        let format = if let Some(fmt) = format {
115            fmt
116        } else if let Some(ext) = path.extension() {
117            if ext == "json" {
118                ConfigFormat::JSON
119            } else if ext == "yaml" || ext == "yml" {
120                ConfigFormat::YAML
121            } else {
122                ConfigFormat::JSON
123            }
124        } else {
125            ConfigFormat::JSON
126        };
127        // Create file
128        let mut file = fs::File::create(path)
129            .map_err(|e| Error::IOError(format!("Failed to create config file: {}", e)))?;
130        // Serialize based on format
131        match format {
132            ConfigFormat::JSON => {
133                let content = serde_json::to_string_pretty(self).map_err(|e| {
134                    Error::SerializationError(format!("Failed to serialize to JSON: {}", e))
135                })?;
136                file.write_all(content.as_bytes())
137                    .map_err(|e| Error::IOError(format!("Failed to write config file: {}", e)))?;
138            }
139            ConfigFormat::YAML => {
140                let content = serde_yaml::to_string(self).map_err(|e| {
141                    Error::SerializationError(format!("Failed to serialize to YAML: {}", e))
142                })?;
143                file.write_all(content.as_bytes())
144                    .map_err(|e| Error::IOError(format!("Failed to write config file: {}", e)))?;
145            }
146        }
147        Ok(())
148    }
149
150    /// Convert configuration to JSON string
151    pub fn to_json(&self) -> Result<String> {
152        serde_json::to_string_pretty(self)
153            .map_err(|e| Error::SerializationError(format!("Failed to serialize to JSON: {}", e)))
154    }
155
156    /// Convert configuration to YAML string
157    pub fn to_yaml(&self) -> Result<String> {
158        serde_yaml::to_string(self)
159            .map_err(|e| Error::SerializationError(format!("Failed to serialize to YAML: {}", e)))
160    }
161
162    /// Parse configuration from JSON string
163    pub fn from_json(json: &str) -> Result<Self> {
164        serde_json::from_str(json)
165            .map_err(|e| Error::DeserializationError(format!("Failed to parse JSON: {}", e)))
166    }
167
168    /// Parse configuration from YAML string
169    pub fn from_yaml(yaml: &str) -> Result<Self> {
170        serde_yaml::from_str(yaml)
171            .map_err(|e| Error::DeserializationError(format!("Failed to parse YAML: {}", e)))
172    }
173
174    /// Validate the configuration against schema and parameter constraints
175    pub fn validate(&self) -> Result<()> {
176        validation::validate_model_config(self)
177    }
178
179    /// Return the model type name as a string
180    pub fn model_type(&self) -> &'static str {
181        match self {
182            ModelConfig::ResNet(_) => "resnet",
183            ModelConfig::ViT(_) => "vit",
184            ModelConfig::Bert(_) => "bert",
185            ModelConfig::GPT(_) => "gpt",
186            ModelConfig::EfficientNet(_) => "efficientnet",
187            ModelConfig::MobileNet(_) => "mobilenet",
188            ModelConfig::ConvNeXt(_) => "convnext",
189            ModelConfig::CLIP(_) => "clip",
190            ModelConfig::FeatureFusion(_) => "feature_fusion",
191            ModelConfig::Seq2Seq(_) => "seq2seq",
192        }
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199    use crate::models::architectures::{ResNetBlock, ResNetLayer};
200
201    fn make_resnet_config() -> ModelConfig {
202        ModelConfig::ResNet(ResNetConfig {
203            block: ResNetBlock::Basic,
204            layers: vec![ResNetLayer {
205                blocks: 2,
206                channels: 64,
207                stride: 1,
208            }],
209            input_channels: 3,
210            num_classes: 10,
211            dropout_rate: 0.0,
212        })
213    }
214
215    #[test]
216    fn test_config_json_roundtrip() {
217        let config = make_resnet_config();
218        let json = config.to_json().expect("serialization failed");
219        let restored: ModelConfig = ModelConfig::from_json(&json).expect("deserialization failed");
220        assert_eq!(restored.model_type(), "resnet");
221    }
222
223    #[test]
224    fn test_config_yaml_roundtrip() {
225        let config = make_resnet_config();
226        let yaml = config.to_yaml().expect("yaml serialization failed");
227        let restored: ModelConfig =
228            ModelConfig::from_yaml(&yaml).expect("yaml deserialization failed");
229        assert_eq!(restored.model_type(), "resnet");
230    }
231
232    #[test]
233    fn test_config_validation_resnet_valid() {
234        let config = make_resnet_config();
235        assert!(config.validate().is_ok());
236    }
237
238    #[test]
239    fn test_config_validation_resnet_invalid_channels() {
240        let config = ModelConfig::ResNet(ResNetConfig {
241            block: ResNetBlock::Basic,
242            layers: vec![],
243            input_channels: 0, // invalid
244            num_classes: 10,
245            dropout_rate: 0.0,
246        });
247        assert!(config.validate().is_err());
248    }
249
250    #[test]
251    fn test_config_file_roundtrip() {
252        let config = make_resnet_config();
253        let tmp = std::env::temp_dir().join("scirs2_test_config.json");
254        config
255            .to_file(&tmp, Some(ConfigFormat::JSON))
256            .expect("write failed");
257        let loaded = ModelConfig::from_file(&tmp, Some(ConfigFormat::JSON)).expect("read failed");
258        assert_eq!(loaded.model_type(), "resnet");
259        let _ = std::fs::remove_file(&tmp);
260    }
261
262    #[test]
263    fn test_model_type_names() {
264        assert_eq!(make_resnet_config().model_type(), "resnet");
265    }
266}