1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
7pub enum SchemaType {
8 String,
10 Number,
12 Integer,
14 Bool,
16 Optional(Box<SchemaType>),
18 List(Box<SchemaType>),
20 Map(Box<SchemaType>),
22 Any,
24}
25
26#[derive(Debug, Clone, Default, Serialize, Deserialize)]
28pub struct WorkflowSchema {
29 #[serde(default)]
30 pub inputs: HashMap<String, SchemaType>,
32 #[serde(default)]
33 pub outputs: HashMap<String, SchemaType>,
35}
36
37pub fn parse_type(s: &str) -> crate::Result<SchemaType> {
45 let s = s.trim();
46
47 if let Some(inner) = s.strip_suffix('?') {
49 let inner_type = parse_type(inner)?;
50 return Ok(SchemaType::Optional(Box::new(inner_type)));
51 }
52
53 if let Some(rest) = s.strip_prefix("list<") {
55 let inner = rest
56 .strip_suffix('>')
57 .ok_or_else(|| crate::WfeError::StepExecution(format!("Invalid type syntax: {s}")))?;
58 let inner_type = parse_type(inner)?;
59 return Ok(SchemaType::List(Box::new(inner_type)));
60 }
61 if let Some(rest) = s.strip_prefix("map<") {
62 let inner = rest
63 .strip_suffix('>')
64 .ok_or_else(|| crate::WfeError::StepExecution(format!("Invalid type syntax: {s}")))?;
65 let inner_type = parse_type(inner)?;
66 return Ok(SchemaType::Map(Box::new(inner_type)));
67 }
68
69 match s {
71 "string" => Ok(SchemaType::String),
72 "number" => Ok(SchemaType::Number),
73 "integer" => Ok(SchemaType::Integer),
74 "bool" => Ok(SchemaType::Bool),
75 "any" => Ok(SchemaType::Any),
76 _ => Err(crate::WfeError::StepExecution(format!("Unknown type: {s}"))),
77 }
78}
79
80pub fn validate_value(value: &serde_json::Value, expected: &SchemaType) -> Result<(), String> {
82 match expected {
83 SchemaType::String => {
84 if value.is_string() {
85 Ok(())
86 } else {
87 Err(format!("expected string, got {}", value_type_name(value)))
88 }
89 }
90 SchemaType::Number => {
91 if value.is_number() {
92 Ok(())
93 } else {
94 Err(format!("expected number, got {}", value_type_name(value)))
95 }
96 }
97 SchemaType::Integer => {
98 if value.is_i64() || value.is_u64() {
99 Ok(())
100 } else {
101 Err(format!("expected integer, got {}", value_type_name(value)))
102 }
103 }
104 SchemaType::Bool => {
105 if value.is_boolean() {
106 Ok(())
107 } else {
108 Err(format!("expected bool, got {}", value_type_name(value)))
109 }
110 }
111 SchemaType::Optional(inner) => {
112 if value.is_null() {
113 Ok(())
114 } else {
115 validate_value(value, inner)
116 }
117 }
118 SchemaType::List(inner) => {
119 if let Some(arr) = value.as_array() {
120 for (i, item) in arr.iter().enumerate() {
121 validate_value(item, inner).map_err(|e| format!("list element [{i}]: {e}"))?;
122 }
123 Ok(())
124 } else {
125 Err(format!("expected list, got {}", value_type_name(value)))
126 }
127 }
128 SchemaType::Map(inner) => {
129 if let Some(obj) = value.as_object() {
130 for (key, val) in obj {
131 validate_value(val, inner).map_err(|e| format!("map key \"{key}\": {e}"))?;
132 }
133 Ok(())
134 } else {
135 Err(format!("expected map, got {}", value_type_name(value)))
136 }
137 }
138 SchemaType::Any => Ok(()),
139 }
140}
141
142fn value_type_name(value: &serde_json::Value) -> &'static str {
143 match value {
144 serde_json::Value::Null => "null",
145 serde_json::Value::Bool(_) => "bool",
146 serde_json::Value::Number(_) => "number",
147 serde_json::Value::String(_) => "string",
148 serde_json::Value::Array(_) => "array",
149 serde_json::Value::Object(_) => "object",
150 }
151}
152
153impl WorkflowSchema {
154 pub fn validate_inputs(&self, data: &serde_json::Value) -> Result<(), Vec<String>> {
156 self.validate_fields(&self.inputs, data)
157 }
158
159 pub fn validate_outputs(&self, data: &serde_json::Value) -> Result<(), Vec<String>> {
161 self.validate_fields(&self.outputs, data)
162 }
163
164 fn validate_fields(
165 &self,
166 fields: &HashMap<String, SchemaType>,
167 data: &serde_json::Value,
168 ) -> Result<(), Vec<String>> {
169 let obj = match data.as_object() {
170 Some(o) => o,
171 None => {
172 return Err(vec!["expected an object".to_string()]);
173 }
174 };
175
176 let mut errors = Vec::new();
177
178 for (name, schema_type) in fields {
179 match obj.get(name) {
180 Some(value) => {
181 if let Err(e) = validate_value(value, schema_type) {
182 errors.push(format!("field \"{name}\": {e}"));
183 }
184 }
185 None => {
186 if !matches!(schema_type, SchemaType::Optional(_)) {
188 errors.push(format!("missing required field: \"{name}\""));
189 }
190 }
191 }
192 }
193
194 if errors.is_empty() {
195 Ok(())
196 } else {
197 Err(errors)
198 }
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205 use serde_json::json;
206
207 #[test]
210 fn parse_type_string() {
211 assert_eq!(parse_type("string").unwrap(), SchemaType::String);
212 }
213
214 #[test]
215 fn parse_type_number() {
216 assert_eq!(parse_type("number").unwrap(), SchemaType::Number);
217 }
218
219 #[test]
220 fn parse_type_integer() {
221 assert_eq!(parse_type("integer").unwrap(), SchemaType::Integer);
222 }
223
224 #[test]
225 fn parse_type_bool() {
226 assert_eq!(parse_type("bool").unwrap(), SchemaType::Bool);
227 }
228
229 #[test]
230 fn parse_type_any() {
231 assert_eq!(parse_type("any").unwrap(), SchemaType::Any);
232 }
233
234 #[test]
235 fn parse_type_optional_string() {
236 assert_eq!(
237 parse_type("string?").unwrap(),
238 SchemaType::Optional(Box::new(SchemaType::String))
239 );
240 }
241
242 #[test]
243 fn parse_type_optional_number() {
244 assert_eq!(
245 parse_type("number?").unwrap(),
246 SchemaType::Optional(Box::new(SchemaType::Number))
247 );
248 }
249
250 #[test]
251 fn parse_type_list_string() {
252 assert_eq!(
253 parse_type("list<string>").unwrap(),
254 SchemaType::List(Box::new(SchemaType::String))
255 );
256 }
257
258 #[test]
259 fn parse_type_list_number() {
260 assert_eq!(
261 parse_type("list<number>").unwrap(),
262 SchemaType::List(Box::new(SchemaType::Number))
263 );
264 }
265
266 #[test]
267 fn parse_type_map_string() {
268 assert_eq!(
269 parse_type("map<string>").unwrap(),
270 SchemaType::Map(Box::new(SchemaType::String))
271 );
272 }
273
274 #[test]
275 fn parse_type_map_number() {
276 assert_eq!(
277 parse_type("map<number>").unwrap(),
278 SchemaType::Map(Box::new(SchemaType::Number))
279 );
280 }
281
282 #[test]
283 fn parse_type_nested_list() {
284 assert_eq!(
285 parse_type("list<list<string>>").unwrap(),
286 SchemaType::List(Box::new(SchemaType::List(Box::new(SchemaType::String))))
287 );
288 }
289
290 #[test]
291 fn parse_type_unknown_errors() {
292 assert!(parse_type("foobar").is_err());
293 }
294
295 #[test]
296 fn parse_type_trims_whitespace() {
297 assert_eq!(parse_type(" string ").unwrap(), SchemaType::String);
298 }
299
300 #[test]
303 fn validate_string_match() {
304 assert!(validate_value(&json!("hello"), &SchemaType::String).is_ok());
305 }
306
307 #[test]
308 fn validate_string_mismatch() {
309 assert!(validate_value(&json!(42), &SchemaType::String).is_err());
310 }
311
312 #[test]
313 fn validate_number_match() {
314 assert!(validate_value(&json!(2.78), &SchemaType::Number).is_ok());
315 }
316
317 #[test]
318 fn validate_number_mismatch() {
319 assert!(validate_value(&json!("not a number"), &SchemaType::Number).is_err());
320 }
321
322 #[test]
323 fn validate_integer_match() {
324 assert!(validate_value(&json!(42), &SchemaType::Integer).is_ok());
325 }
326
327 #[test]
328 fn validate_integer_mismatch_float() {
329 assert!(validate_value(&json!(2.78), &SchemaType::Integer).is_err());
330 }
331
332 #[test]
333 fn validate_bool_match() {
334 assert!(validate_value(&json!(true), &SchemaType::Bool).is_ok());
335 }
336
337 #[test]
338 fn validate_bool_mismatch() {
339 assert!(validate_value(&json!(1), &SchemaType::Bool).is_err());
340 }
341
342 #[test]
343 fn validate_optional_null_passes() {
344 let ty = SchemaType::Optional(Box::new(SchemaType::String));
345 assert!(validate_value(&json!(null), &ty).is_ok());
346 }
347
348 #[test]
349 fn validate_optional_correct_inner_passes() {
350 let ty = SchemaType::Optional(Box::new(SchemaType::String));
351 assert!(validate_value(&json!("hello"), &ty).is_ok());
352 }
353
354 #[test]
355 fn validate_optional_wrong_inner_fails() {
356 let ty = SchemaType::Optional(Box::new(SchemaType::String));
357 assert!(validate_value(&json!(42), &ty).is_err());
358 }
359
360 #[test]
361 fn validate_list_match() {
362 let ty = SchemaType::List(Box::new(SchemaType::Number));
363 assert!(validate_value(&json!([1, 2, 3]), &ty).is_ok());
364 }
365
366 #[test]
367 fn validate_list_mismatch_element() {
368 let ty = SchemaType::List(Box::new(SchemaType::Number));
369 assert!(validate_value(&json!([1, "two", 3]), &ty).is_err());
370 }
371
372 #[test]
373 fn validate_list_not_array() {
374 let ty = SchemaType::List(Box::new(SchemaType::Number));
375 assert!(validate_value(&json!("not a list"), &ty).is_err());
376 }
377
378 #[test]
379 fn validate_map_match() {
380 let ty = SchemaType::Map(Box::new(SchemaType::Number));
381 assert!(validate_value(&json!({"a": 1, "b": 2}), &ty).is_ok());
382 }
383
384 #[test]
385 fn validate_map_mismatch_value() {
386 let ty = SchemaType::Map(Box::new(SchemaType::Number));
387 assert!(validate_value(&json!({"a": 1, "b": "two"}), &ty).is_err());
388 }
389
390 #[test]
391 fn validate_map_not_object() {
392 let ty = SchemaType::Map(Box::new(SchemaType::Number));
393 assert!(validate_value(&json!([1, 2]), &ty).is_err());
394 }
395
396 #[test]
397 fn validate_any_always_passes() {
398 assert!(validate_value(&json!(null), &SchemaType::Any).is_ok());
399 assert!(validate_value(&json!("str"), &SchemaType::Any).is_ok());
400 assert!(validate_value(&json!(42), &SchemaType::Any).is_ok());
401 assert!(validate_value(&json!([1, 2]), &SchemaType::Any).is_ok());
402 }
403
404 #[test]
407 fn validate_inputs_all_present() {
408 let schema = WorkflowSchema {
409 inputs: HashMap::from([
410 ("name".into(), SchemaType::String),
411 ("age".into(), SchemaType::Integer),
412 ]),
413 outputs: HashMap::new(),
414 };
415 let data = json!({"name": "Alice", "age": 30});
416 assert!(schema.validate_inputs(&data).is_ok());
417 }
418
419 #[test]
420 fn validate_inputs_missing_required_field() {
421 let schema = WorkflowSchema {
422 inputs: HashMap::from([
423 ("name".into(), SchemaType::String),
424 ("age".into(), SchemaType::Integer),
425 ]),
426 outputs: HashMap::new(),
427 };
428 let data = json!({"name": "Alice"});
429 let errs = schema.validate_inputs(&data).unwrap_err();
430 assert!(errs.iter().any(|e| e.contains("age")));
431 }
432
433 #[test]
434 fn validate_inputs_wrong_type() {
435 let schema = WorkflowSchema {
436 inputs: HashMap::from([("count".into(), SchemaType::Integer)]),
437 outputs: HashMap::new(),
438 };
439 let data = json!({"count": "not-a-number"});
440 let errs = schema.validate_inputs(&data).unwrap_err();
441 assert!(!errs.is_empty());
442 }
443
444 #[test]
445 fn validate_outputs_missing_field() {
446 let schema = WorkflowSchema {
447 inputs: HashMap::new(),
448 outputs: HashMap::from([("result".into(), SchemaType::String)]),
449 };
450 let data = json!({});
451 let errs = schema.validate_outputs(&data).unwrap_err();
452 assert!(errs.iter().any(|e| e.contains("result")));
453 }
454
455 #[test]
456 fn validate_inputs_optional_field_missing_is_ok() {
457 let schema = WorkflowSchema {
458 inputs: HashMap::from([(
459 "nickname".into(),
460 SchemaType::Optional(Box::new(SchemaType::String)),
461 )]),
462 outputs: HashMap::new(),
463 };
464 let data = json!({});
465 assert!(schema.validate_inputs(&data).is_ok());
466 }
467
468 #[test]
469 fn validate_not_object_errors() {
470 let schema = WorkflowSchema {
471 inputs: HashMap::from([("x".into(), SchemaType::String)]),
472 outputs: HashMap::new(),
473 };
474 let errs = schema.validate_inputs(&json!("not an object")).unwrap_err();
475 assert!(errs[0].contains("expected an object"));
476 }
477
478 #[test]
479 fn schema_serde_round_trip() {
480 let schema = WorkflowSchema {
481 inputs: HashMap::from([("name".into(), SchemaType::String)]),
482 outputs: HashMap::from([("result".into(), SchemaType::Bool)]),
483 };
484 let json_str = serde_json::to_string(&schema).unwrap();
485 let deserialized: WorkflowSchema = serde_json::from_str(&json_str).unwrap();
486 assert_eq!(deserialized.inputs["name"], SchemaType::String);
487 assert_eq!(deserialized.outputs["result"], SchemaType::Bool);
488 }
489}