1use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct JsonSchema {
12 pub title: String,
14 pub description: Option<String>,
16 #[serde(rename = "type")]
18 pub schema_type: String,
19 pub properties: HashMap<String, JsonSchemaProperty>,
21 pub required: Vec<String>,
23 pub additional_properties: Option<bool>,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct JsonSchemaProperty {
30 #[serde(rename = "type")]
32 pub property_type: Option<String>,
33 pub description: Option<String>,
35 #[serde(rename = "enum")]
37 pub enum_values: Option<Vec<String>>,
38 pub format: Option<String>,
40 pub minimum: Option<f64>,
42 pub maximum: Option<f64>,
44 pub default: Option<serde_json::Value>,
46 pub items: Option<Box<JsonSchemaProperty>>,
48 #[serde(rename = "$ref")]
50 pub reference: Option<String>,
51 pub properties: Option<HashMap<String, JsonSchemaProperty>>,
53 pub required: Option<Vec<String>>,
55 pub one_of: Option<Vec<JsonSchemaProperty>>,
57 pub all_of: Option<Vec<JsonSchemaProperty>>,
59}
60
61pub struct SchemaRegistry;
63
64impl SchemaRegistry {
65 pub fn resnet_schema() -> JsonSchema {
67 let mut properties = HashMap::new();
68 properties.insert(
69 "num_layers".to_string(),
70 JsonSchemaProperty {
71 property_type: Some("integer".to_string()),
72 description: Some("Number of layers in the ResNet model".to_string()),
73 enum_values: Some(vec![
74 "18".to_string(),
75 "34".to_string(),
76 "50".to_string(),
77 "101".to_string(),
78 "152".to_string(),
79 ]),
80 format: None,
81 minimum: Some(18.0),
82 maximum: Some(152.0),
83 default: Some(serde_json::json!(50)),
84 items: None,
85 reference: None,
86 properties: None,
87 required: None,
88 one_of: None,
89 all_of: None,
90 },
91 );
92 properties.insert(
93 "in_channels".to_string(),
94 JsonSchemaProperty {
95 property_type: Some("integer".to_string()),
96 description: Some("Number of input channels".to_string()),
97 enum_values: None,
98 format: None,
99 minimum: Some(1.0),
100 maximum: None,
101 default: Some(serde_json::json!(3)),
102 items: None,
103 reference: None,
104 properties: None,
105 required: None,
106 one_of: None,
107 all_of: None,
108 },
109 );
110 properties.insert(
111 "num_classes".to_string(),
112 JsonSchemaProperty {
113 property_type: Some("integer".to_string()),
114 description: Some("Number of output classes".to_string()),
115 enum_values: None,
116 format: None,
117 minimum: Some(1.0),
118 maximum: None,
119 default: Some(serde_json::json!(1000)),
120 items: None,
121 reference: None,
122 properties: None,
123 required: None,
124 one_of: None,
125 all_of: None,
126 },
127 );
128 properties.insert(
129 "zero_init_residual".to_string(),
130 JsonSchemaProperty {
131 property_type: Some("boolean".to_string()),
132 description: Some(
133 "Whether to initialize residual connections with zero".to_string(),
134 ),
135 enum_values: None,
136 format: None,
137 minimum: None,
138 maximum: None,
139 default: Some(serde_json::json!(false)),
140 items: None,
141 reference: None,
142 properties: None,
143 required: None,
144 one_of: None,
145 all_of: None,
146 },
147 );
148 JsonSchema {
149 title: "ResNet Configuration".to_string(),
150 description: Some("Configuration for ResNet models".to_string()),
151 schema_type: "object".to_string(),
152 properties,
153 required: vec![
154 "num_layers".to_string(),
155 "in_channels".to_string(),
156 "num_classes".to_string(),
157 ],
158 additional_properties: Some(false),
159 }
160 }
161
162 pub fn vit_schema() -> JsonSchema {
164 let mut properties = HashMap::new();
165 let make_prop = |desc: &str, default: serde_json::Value| JsonSchemaProperty {
166 property_type: Some("integer".to_string()),
167 description: Some(desc.to_string()),
168 enum_values: None,
169 format: None,
170 minimum: Some(1.0),
171 maximum: None,
172 default: Some(default),
173 items: None,
174 reference: None,
175 properties: None,
176 required: None,
177 one_of: None,
178 all_of: None,
179 };
180 properties.insert(
181 "image_size".to_string(),
182 make_prop("Size of the input image (square)", serde_json::json!(224)),
183 );
184 properties.insert(
185 "patch_size".to_string(),
186 make_prop(
187 "Size of the patches to divide the image into",
188 serde_json::json!(16),
189 ),
190 );
191 properties.insert(
192 "hidden_size".to_string(),
193 make_prop(
194 "Dimension of transformer hidden layers",
195 serde_json::json!(768),
196 ),
197 );
198 properties.insert(
199 "num_layers".to_string(),
200 make_prop("Number of transformer layers", serde_json::json!(12)),
201 );
202 properties.insert(
203 "num_heads".to_string(),
204 make_prop("Number of attention heads", serde_json::json!(12)),
205 );
206 properties.insert(
207 "mlp_dim".to_string(),
208 make_prop("Dimension of the MLP layers", serde_json::json!(3072)),
209 );
210 properties.insert(
211 "dropout_rate".to_string(),
212 JsonSchemaProperty {
213 property_type: Some("number".to_string()),
214 description: Some("Dropout rate".to_string()),
215 enum_values: None,
216 format: None,
217 minimum: Some(0.0),
218 maximum: Some(1.0),
219 default: Some(serde_json::json!(0.1)),
220 items: None,
221 reference: None,
222 properties: None,
223 required: None,
224 one_of: None,
225 all_of: None,
226 },
227 );
228 properties.insert(
229 "attention_dropout_rate".to_string(),
230 JsonSchemaProperty {
231 property_type: Some("number".to_string()),
232 description: Some("Attention dropout rate".to_string()),
233 enum_values: None,
234 format: None,
235 minimum: Some(0.0),
236 maximum: Some(1.0),
237 default: Some(serde_json::json!(0.0)),
238 items: None,
239 reference: None,
240 properties: None,
241 required: None,
242 one_of: None,
243 all_of: None,
244 },
245 );
246 properties.insert(
247 "classifier".to_string(),
248 JsonSchemaProperty {
249 property_type: Some("string".to_string()),
250 description: Some("Type of classifier ('token' or 'gap')".to_string()),
251 enum_values: Some(vec!["token".to_string(), "gap".to_string()]),
252 format: None,
253 minimum: None,
254 maximum: None,
255 default: Some(serde_json::json!("token")),
256 items: None,
257 reference: None,
258 properties: None,
259 required: None,
260 one_of: None,
261 all_of: None,
262 },
263 );
264 properties.insert(
265 "include_top".to_string(),
266 JsonSchemaProperty {
267 property_type: Some("boolean".to_string()),
268 description: Some("Whether to include the classification head".to_string()),
269 enum_values: None,
270 format: None,
271 minimum: None,
272 maximum: None,
273 default: Some(serde_json::json!(true)),
274 items: None,
275 reference: None,
276 properties: None,
277 required: None,
278 one_of: None,
279 all_of: None,
280 },
281 );
282 JsonSchema {
283 title: "Vision Transformer Configuration".to_string(),
284 description: Some("Configuration for Vision Transformer models".to_string()),
285 schema_type: "object".to_string(),
286 properties,
287 required: vec![
288 "image_size".to_string(),
289 "patch_size".to_string(),
290 "hidden_size".to_string(),
291 "num_heads".to_string(),
292 ],
293 additional_properties: Some(false),
294 }
295 }
296
297 pub fn get_all_schemas() -> HashMap<String, JsonSchema> {
299 let mut schemas = HashMap::new();
300 schemas.insert("resnet".to_string(), Self::resnet_schema());
301 schemas.insert("vit".to_string(), Self::vit_schema());
302 schemas
303 }
304}