Skip to main content

scirs2_neural/config/
schema.rs

1//! JSON Schema definitions for model configurations
2//!
3//! This module provides JSON Schema definitions for validating model configurations.
4//! These schemas can be used for validation and for generating documentation.
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9/// JSON Schema for model configurations
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct JsonSchema {
12    /// Schema title
13    pub title: String,
14    /// Schema description
15    pub description: Option<String>,
16    /// Schema type
17    #[serde(rename = "type")]
18    pub schema_type: String,
19    /// Schema properties
20    pub properties: HashMap<String, JsonSchemaProperty>,
21    /// Required properties
22    pub required: Vec<String>,
23    /// Additional properties allowed
24    pub additional_properties: Option<bool>,
25}
26
27/// JSON Schema property
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct JsonSchemaProperty {
30    /// Property type
31    #[serde(rename = "type")]
32    pub property_type: Option<String>,
33    /// Property description
34    pub description: Option<String>,
35    /// Property enum values
36    #[serde(rename = "enum")]
37    pub enum_values: Option<Vec<String>>,
38    /// Property format
39    pub format: Option<String>,
40    /// Minimum value
41    pub minimum: Option<f64>,
42    /// Maximum value
43    pub maximum: Option<f64>,
44    /// Default value
45    pub default: Option<serde_json::Value>,
46    /// Items schema (for arrays)
47    pub items: Option<Box<JsonSchemaProperty>>,
48    /// Reference to another schema
49    #[serde(rename = "$ref")]
50    pub reference: Option<String>,
51    /// Properties for objects
52    pub properties: Option<HashMap<String, JsonSchemaProperty>>,
53    /// Required properties for objects
54    pub required: Option<Vec<String>>,
55    /// One of schemas
56    pub one_of: Option<Vec<JsonSchemaProperty>>,
57    /// All of schemas
58    pub all_of: Option<Vec<JsonSchemaProperty>>,
59}
60
61/// Schema registry for all model configurations
62pub struct SchemaRegistry;
63
64impl SchemaRegistry {
65    /// Get schema for ResNet configuration
66    pub fn resnet_schema() -> JsonSchema {
67        let mut properties = HashMap::new();
68        properties.insert(
69            "num_layers".to_string(),
70            JsonSchemaProperty {
71                property_type: Some("integer".to_string()),
72                description: Some("Number of layers in the ResNet model".to_string()),
73                enum_values: Some(vec![
74                    "18".to_string(),
75                    "34".to_string(),
76                    "50".to_string(),
77                    "101".to_string(),
78                    "152".to_string(),
79                ]),
80                format: None,
81                minimum: Some(18.0),
82                maximum: Some(152.0),
83                default: Some(serde_json::json!(50)),
84                items: None,
85                reference: None,
86                properties: None,
87                required: None,
88                one_of: None,
89                all_of: None,
90            },
91        );
92        properties.insert(
93            "in_channels".to_string(),
94            JsonSchemaProperty {
95                property_type: Some("integer".to_string()),
96                description: Some("Number of input channels".to_string()),
97                enum_values: None,
98                format: None,
99                minimum: Some(1.0),
100                maximum: None,
101                default: Some(serde_json::json!(3)),
102                items: None,
103                reference: None,
104                properties: None,
105                required: None,
106                one_of: None,
107                all_of: None,
108            },
109        );
110        properties.insert(
111            "num_classes".to_string(),
112            JsonSchemaProperty {
113                property_type: Some("integer".to_string()),
114                description: Some("Number of output classes".to_string()),
115                enum_values: None,
116                format: None,
117                minimum: Some(1.0),
118                maximum: None,
119                default: Some(serde_json::json!(1000)),
120                items: None,
121                reference: None,
122                properties: None,
123                required: None,
124                one_of: None,
125                all_of: None,
126            },
127        );
128        properties.insert(
129            "zero_init_residual".to_string(),
130            JsonSchemaProperty {
131                property_type: Some("boolean".to_string()),
132                description: Some(
133                    "Whether to initialize residual connections with zero".to_string(),
134                ),
135                enum_values: None,
136                format: None,
137                minimum: None,
138                maximum: None,
139                default: Some(serde_json::json!(false)),
140                items: None,
141                reference: None,
142                properties: None,
143                required: None,
144                one_of: None,
145                all_of: None,
146            },
147        );
148        JsonSchema {
149            title: "ResNet Configuration".to_string(),
150            description: Some("Configuration for ResNet models".to_string()),
151            schema_type: "object".to_string(),
152            properties,
153            required: vec![
154                "num_layers".to_string(),
155                "in_channels".to_string(),
156                "num_classes".to_string(),
157            ],
158            additional_properties: Some(false),
159        }
160    }
161
162    /// Get schema for Vision Transformer configuration
163    pub fn vit_schema() -> JsonSchema {
164        let mut properties = HashMap::new();
165        let make_prop = |desc: &str, default: serde_json::Value| JsonSchemaProperty {
166            property_type: Some("integer".to_string()),
167            description: Some(desc.to_string()),
168            enum_values: None,
169            format: None,
170            minimum: Some(1.0),
171            maximum: None,
172            default: Some(default),
173            items: None,
174            reference: None,
175            properties: None,
176            required: None,
177            one_of: None,
178            all_of: None,
179        };
180        properties.insert(
181            "image_size".to_string(),
182            make_prop("Size of the input image (square)", serde_json::json!(224)),
183        );
184        properties.insert(
185            "patch_size".to_string(),
186            make_prop(
187                "Size of the patches to divide the image into",
188                serde_json::json!(16),
189            ),
190        );
191        properties.insert(
192            "hidden_size".to_string(),
193            make_prop(
194                "Dimension of transformer hidden layers",
195                serde_json::json!(768),
196            ),
197        );
198        properties.insert(
199            "num_layers".to_string(),
200            make_prop("Number of transformer layers", serde_json::json!(12)),
201        );
202        properties.insert(
203            "num_heads".to_string(),
204            make_prop("Number of attention heads", serde_json::json!(12)),
205        );
206        properties.insert(
207            "mlp_dim".to_string(),
208            make_prop("Dimension of the MLP layers", serde_json::json!(3072)),
209        );
210        properties.insert(
211            "dropout_rate".to_string(),
212            JsonSchemaProperty {
213                property_type: Some("number".to_string()),
214                description: Some("Dropout rate".to_string()),
215                enum_values: None,
216                format: None,
217                minimum: Some(0.0),
218                maximum: Some(1.0),
219                default: Some(serde_json::json!(0.1)),
220                items: None,
221                reference: None,
222                properties: None,
223                required: None,
224                one_of: None,
225                all_of: None,
226            },
227        );
228        properties.insert(
229            "attention_dropout_rate".to_string(),
230            JsonSchemaProperty {
231                property_type: Some("number".to_string()),
232                description: Some("Attention dropout rate".to_string()),
233                enum_values: None,
234                format: None,
235                minimum: Some(0.0),
236                maximum: Some(1.0),
237                default: Some(serde_json::json!(0.0)),
238                items: None,
239                reference: None,
240                properties: None,
241                required: None,
242                one_of: None,
243                all_of: None,
244            },
245        );
246        properties.insert(
247            "classifier".to_string(),
248            JsonSchemaProperty {
249                property_type: Some("string".to_string()),
250                description: Some("Type of classifier ('token' or 'gap')".to_string()),
251                enum_values: Some(vec!["token".to_string(), "gap".to_string()]),
252                format: None,
253                minimum: None,
254                maximum: None,
255                default: Some(serde_json::json!("token")),
256                items: None,
257                reference: None,
258                properties: None,
259                required: None,
260                one_of: None,
261                all_of: None,
262            },
263        );
264        properties.insert(
265            "include_top".to_string(),
266            JsonSchemaProperty {
267                property_type: Some("boolean".to_string()),
268                description: Some("Whether to include the classification head".to_string()),
269                enum_values: None,
270                format: None,
271                minimum: None,
272                maximum: None,
273                default: Some(serde_json::json!(true)),
274                items: None,
275                reference: None,
276                properties: None,
277                required: None,
278                one_of: None,
279                all_of: None,
280            },
281        );
282        JsonSchema {
283            title: "Vision Transformer Configuration".to_string(),
284            description: Some("Configuration for Vision Transformer models".to_string()),
285            schema_type: "object".to_string(),
286            properties,
287            required: vec![
288                "image_size".to_string(),
289                "patch_size".to_string(),
290                "hidden_size".to_string(),
291                "num_heads".to_string(),
292            ],
293            additional_properties: Some(false),
294        }
295    }
296
297    /// Get all available schemas
298    pub fn get_all_schemas() -> HashMap<String, JsonSchema> {
299        let mut schemas = HashMap::new();
300        schemas.insert("resnet".to_string(), Self::resnet_schema());
301        schemas.insert("vit".to_string(), Self::vit_schema());
302        schemas
303    }
304}