stepflow_flow/json_schema.rs
1// Copyright 2025 DataStax Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
4// in compliance with the License. You may obtain a copy of the License at
5//
6// http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software distributed under the License
9// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
10// or implied. See the License for the specific language governing permissions and limitations under
11// the License.
12
13//! Utility for generating standalone JSON Schema documents from schemars::JsonSchema types.
14//!
15//! This module provides functionality to generate standalone JSON Schema draft 2020-12
16//! documents suitable for code generation tools like datamodel-code-generator.
17
18use serde_json::Value;
19
20/// Controls how external type references are handled in the generated schema.
21#[derive(Debug, Clone)]
22pub enum Refs {
23 /// Omit external schemas - just reference them by name without definitions.
24 /// Produces compact schemas suitable for documentation.
25 Omit,
26 /// Include external schemas in `$defs` with local references (`#/$defs/TypeName`).
27 /// Produces self-contained schemas for validation.
28 Local,
29 /// Reference external schemas from an external URL.
30 /// References become `{base_url}#/$defs/TypeName`.
31 External(String),
32}
33
34/// Generate a standalone JSON Schema document from a type implementing JsonSchema.
35///
36/// This function generates a compact schema without `$defs` - any referenced types
37/// will appear as `$ref` without definitions. This is suitable for component schemas
38/// used for documentation purposes.
39///
40/// For a complete schema with all `$defs` included, use [`generate_json_schema_with_defs`].
41pub fn generate_json_schema<T: schemars::JsonSchema>() -> Value {
42 generate_json_schema_with_refs::<T>(Refs::Omit)
43}
44
45/// Generate a standalone JSON Schema document with all `$defs` included.
46///
47/// This produces a fully self-contained schema suitable for validation
48/// without external references.
49pub fn generate_json_schema_with_defs<T: schemars::JsonSchema>() -> Value {
50 generate_json_schema_with_refs::<T>(Refs::Local)
51}
52
53/// Generate a JSON Schema document with configurable reference handling.
54///
55/// # Arguments
56/// * `refs` - Controls how external type references are handled:
57/// - `Refs::Omit` - Omit `$defs`, just reference by name
58/// - `Refs::Local` - Include schemas in `$defs` with local references
59/// - `Refs::External(url)` - Reference schemas from an external URL
60pub fn generate_json_schema_with_refs<T: schemars::JsonSchema>(refs: Refs) -> Value {
61 generate_json_schema_custom::<T>(refs, |_| {})
62}
63
64/// Generate a JSON Schema document with configurable reference handling and
65/// additional types seeded into `$defs`.
66///
67/// The `seed` callback receives a `&mut SchemaGenerator` before the root schema
68/// is finalised. Calling `generator.subschema_for::<ExtraType>()` inside the
69/// callback ensures the type (and all its transitive deps) appear in `$defs`
70/// even when they are not reachable from the root type `T`.
71pub fn generate_json_schema_custom<T: schemars::JsonSchema>(
72 refs: Refs,
73 seed: impl FnOnce(&mut schemars::SchemaGenerator),
74) -> Value {
75 let settings = schemars::generate::SchemaSettings::draft2020_12();
76 let mut generator = settings.into_generator();
77 seed(&mut generator);
78 let schema = generator.into_root_schema_for::<T>();
79 let mut json = serde_json::to_value(schema).expect("Failed to serialize schema");
80
81 match refs {
82 Refs::Omit => {
83 // Remove $defs entirely
84 if let Some(obj) = json.as_object_mut() {
85 obj.remove("$defs");
86 }
87 }
88 Refs::Local => {
89 finalize_discriminators(&mut json);
90 }
91 Refs::External(ref base_url) => {
92 finalize_discriminators(&mut json);
93 // Transform #/$defs/X references to {base_url}#/$defs/X
94 transform_refs_external(&mut json, base_url);
95 }
96 }
97
98 json
99}
100
101/// Post-process a generated schema to make discriminated unions work correctly
102/// with code generators like `datamodel-code-generator`.
103///
104/// This runs three steps in order:
105/// 1. **Extract inline `oneOf` variants** to the definitions section — variants
106/// are keyed by their `title` attribute, so code generators produce the
107/// expected type names.
108/// 2. **Build discriminator mappings** by resolving `$ref` → definitions to read
109/// tag `const` values and populate `discriminator.mapping`.
110/// 3. **Add `default` alongside `const`** for discriminator tag properties —
111/// `datamodel-code-generator` uses `default` (not `const`) to set tag values.
112///
113/// Schemas are resolved using `#/$defs/` references. For OpenAPI documents
114/// where schemas live under `#/components/schemas/`, use
115/// [`finalize_discriminators_with_prefix`].
116pub fn finalize_discriminators(root: &mut Value) {
117 finalize_discriminators_with_prefix(root, "#/$defs/");
118}
119
120/// Like [`finalize_discriminators`], but with a configurable `$ref` prefix.
121///
122/// The `ref_prefix` determines both where definitions are stored in the JSON
123/// tree and the `$ref` prefix used in references:
124/// - `"#/$defs/"` — JSON Schema (definitions at `root.$defs`)
125/// - `"#/components/schemas/"` — OpenAPI (definitions at `root.components.schemas`)
126pub fn finalize_discriminators_with_prefix(root: &mut Value, ref_prefix: &str) {
127 flatten_string_enum_oneofs(root);
128 convert_nullable_anyof_to_oneof(root);
129 extract_inline_oneof_to_defs(root, ref_prefix);
130 build_discriminator_mappings(root, ref_prefix);
131 add_defaults_to_discriminator_consts(root, ref_prefix);
132}
133
134/// Derive a JSON pointer path from a `$ref` prefix.
135///
136/// - `"#/$defs/"` → `"/$defs"`
137/// - `"#/components/schemas/"` → `"/components/schemas"`
138fn defs_pointer(ref_prefix: &str) -> &str {
139 ref_prefix
140 .strip_prefix('#')
141 .unwrap_or(ref_prefix)
142 .strip_suffix('/')
143 .unwrap_or(ref_prefix)
144}
145
146/// Navigate to (and create if needed) the definitions object at the path
147/// implied by `ref_prefix`.
148fn get_or_create_defs_mut<'a>(
149 root: &'a mut Value,
150 ref_prefix: &str,
151) -> &'a mut serde_json::Map<String, Value> {
152 let pointer = defs_pointer(ref_prefix);
153 let mut current = root;
154 for segment in pointer.split('/').filter(|s| !s.is_empty()) {
155 current = current
156 .as_object_mut()
157 .unwrap()
158 .entry(segment.to_string())
159 .or_insert_with(|| Value::Object(serde_json::Map::new()));
160 }
161 current.as_object_mut().unwrap()
162}
163
164/// Convert `oneOf` schemas of string-const variants into simple string enums.
165///
166/// schemars generates documented Rust enums as `oneOf` arrays with per-variant
167/// `const` + `description` entries. This is valid JSON Schema but code generators
168/// (openapi-generator, datamodel-code-generator) produce broken or overly complex
169/// types because every variant resolves to `str`.
170///
171/// This rewrites such schemas into `{ "type": "string", "enum": ["a", "b", ...] }`
172/// which all code generators handle correctly. Schemas that have a `discriminator`
173/// are left untouched — those are tagged unions, not simple enums.
174fn flatten_string_enum_oneofs(root: &mut Value) {
175 match root {
176 Value::Object(obj) => {
177 // Check if this object is a string-const oneOf (without a discriminator)
178 let should_flatten = !obj.contains_key("discriminator")
179 && obj
180 .get("oneOf")
181 .and_then(|v| v.as_array())
182 .is_some_and(|arr| {
183 !arr.is_empty()
184 && arr.iter().all(|v| {
185 v.get("type").and_then(|t| t.as_str()) == Some("string")
186 && v.get("const").is_some()
187 })
188 });
189
190 if should_flatten {
191 if let Some(Value::Array(one_of)) = obj.remove("oneOf") {
192 let enum_values: Vec<Value> = one_of
193 .iter()
194 .filter_map(|v| v.get("const").cloned())
195 .collect();
196
197 // Append per-variant descriptions to the enum's description
198 let case_docs: Vec<String> = one_of
199 .iter()
200 .filter_map(|v| {
201 let name = v.get("const")?.as_str()?;
202 let desc = v.get("description")?.as_str()?;
203 Some(format!("* `{name}`: {desc}"))
204 })
205 .collect();
206
207 if !case_docs.is_empty() {
208 let existing = obj
209 .get("description")
210 .and_then(|d| d.as_str())
211 .unwrap_or_default();
212 let full = format!("{existing}\n\nCases:\n{}", case_docs.join("\n"));
213 obj.insert("description".to_string(), Value::String(full));
214 }
215
216 obj.insert("type".to_string(), Value::String("string".to_string()));
217 obj.insert("enum".to_string(), Value::Array(enum_values));
218 }
219 } else {
220 // Recurse into all values
221 for v in obj.values_mut() {
222 flatten_string_enum_oneofs(v);
223 }
224 }
225 }
226 Value::Array(arr) => {
227 for v in arr.iter_mut() {
228 flatten_string_enum_oneofs(v);
229 }
230 }
231 _ => {}
232 }
233}
234
235/// Convert nullable `anyOf` patterns to `oneOf`.
236///
237/// schemars generates `Option<T>` as `anyOf: [T, {type: null}]`, but
238/// code generators like openapi-generator handle `oneOf` nullable patterns
239/// correctly (the existing `fix_any_type_from_dict` post-processing in the
240/// Python codegen handles `OneOf` references). This matches the schema
241/// output that utoipa previously produced.
242fn convert_nullable_anyof_to_oneof(root: &mut Value) {
243 match root {
244 Value::Object(obj) => {
245 // Check for anyOf with exactly one null variant (nullable pattern)
246 let is_nullable_anyof =
247 obj.get("anyOf")
248 .and_then(|v| v.as_array())
249 .is_some_and(|arr| {
250 arr.len() == 2
251 && arr
252 .iter()
253 .any(|v| v.get("type").and_then(|t| t.as_str()) == Some("null"))
254 });
255
256 if is_nullable_anyof && let Some(any_of) = obj.remove("anyOf") {
257 obj.insert("oneOf".to_string(), any_of);
258 }
259
260 for v in obj.values_mut() {
261 convert_nullable_anyof_to_oneof(v);
262 }
263 }
264 Value::Array(arr) => {
265 for v in arr.iter_mut() {
266 convert_nullable_anyof_to_oneof(v);
267 }
268 }
269 _ => {}
270 }
271}
272
273/// Extract inline oneOf variants to the definitions section in schemas with
274/// discriminators.
275///
276/// schemars inlines all variants in the `oneOf` array. Discriminator mappings
277/// require `$ref` paths, so this extracts inline variants to definitions (using
278/// their `title` as the key) and replaces them with `$ref` entries.
279fn extract_inline_oneof_to_defs(root: &mut Value, ref_prefix: &str) {
280 let mut extractions: Vec<(String, Value)> = Vec::new();
281 extract_inline_oneof_recursive(root, ref_prefix, &mut extractions);
282
283 if extractions.is_empty() {
284 return;
285 }
286
287 let defs = get_or_create_defs_mut(root, ref_prefix);
288
289 for (key, schema) in extractions {
290 if let Some(existing) = defs.get_mut(&key) {
291 // Collision: the variant's title matches an existing $defs key (the inner
292 // type). Merge the discriminator tag property into the existing entry so
293 // that code generators can read the tag const value.
294 merge_tag_properties(existing, &schema);
295 } else {
296 defs.insert(key, schema);
297 }
298 }
299}
300
301fn extract_inline_oneof_recursive(
302 value: &mut Value,
303 ref_prefix: &str,
304 extractions: &mut Vec<(String, Value)>,
305) {
306 match value {
307 Value::Object(obj) => {
308 if obj.contains_key("discriminator")
309 && let Some(Value::Array(one_of)) = obj.get_mut("oneOf")
310 {
311 for variant in one_of.iter_mut() {
312 // Skip variants that are already pure $ref entries
313 if variant
314 .as_object()
315 .is_some_and(|o| o.len() == 1 && o.contains_key("$ref"))
316 {
317 continue;
318 }
319 // Extract inline variants with titles to $defs
320 if let Some(title) = variant
321 .get("title")
322 .and_then(|t| t.as_str())
323 .map(|s| s.to_string())
324 {
325 extractions.push((title.clone(), variant.clone()));
326 *variant = serde_json::json!({ "$ref": format!("{ref_prefix}{title}") });
327 }
328 }
329 }
330
331 for v in obj.values_mut() {
332 extract_inline_oneof_recursive(v, ref_prefix, extractions);
333 }
334 }
335 Value::Array(arr) => {
336 for v in arr.iter_mut() {
337 extract_inline_oneof_recursive(v, ref_prefix, extractions);
338 }
339 }
340 _ => {}
341 }
342}
343
344/// Merge discriminator tag properties from an extracted variant into an existing `$defs` entry.
345///
346/// When a variant's title matches an existing `$defs` key (e.g., `StepflowPluginConfig`
347/// is both the inner type and the variant title), this adds the tag `const` property
348/// and updates `required` so code generators can resolve the discriminator tag value.
349fn merge_tag_properties(existing: &mut Value, variant: &Value) {
350 // Merge properties (adds tag property from variant)
351 if let Some(variant_props) = variant.get("properties").and_then(|p| p.as_object()) {
352 let def_props = existing
353 .as_object_mut()
354 .unwrap()
355 .entry("properties")
356 .or_insert_with(|| Value::Object(serde_json::Map::new()))
357 .as_object_mut()
358 .unwrap();
359 for (key, value) in variant_props {
360 def_props
361 .entry(key.clone())
362 .or_insert_with(|| value.clone());
363 }
364 }
365
366 // Merge required arrays
367 if let Some(variant_required) = variant.get("required").and_then(|r| r.as_array()) {
368 let def_required = existing
369 .as_object_mut()
370 .unwrap()
371 .entry("required")
372 .or_insert_with(|| Value::Array(Vec::new()))
373 .as_array_mut()
374 .unwrap();
375 for req in variant_required {
376 if !def_required.contains(req) {
377 def_required.push(req.clone());
378 }
379 }
380 }
381}
382
383/// Build discriminator mappings by resolving `$ref` → definition entries
384/// and reading tag `const` values.
385fn build_discriminator_mappings(root: &mut Value, ref_prefix: &str) {
386 let defs = root
387 .pointer(defs_pointer(ref_prefix))
388 .and_then(|v| v.as_object())
389 .cloned();
390
391 // Recursively process all schemas in the document
392 build_discriminator_mappings_recursive(root, ref_prefix, defs.as_ref());
393}
394
395fn build_discriminator_mappings_recursive(
396 value: &mut Value,
397 ref_prefix: &str,
398 defs: Option<&serde_json::Map<String, Value>>,
399) {
400 let Some(defs) = defs else { return };
401 match value {
402 Value::Object(obj) => {
403 // Check if this object has a discriminator that needs mapping completion
404 let needs_mapping = obj
405 .get("discriminator")
406 .is_some_and(|d| d.get("propertyName").is_some());
407
408 if needs_mapping
409 && let Some(property_name) = obj
410 .get("discriminator")
411 .and_then(|d| d.get("propertyName"))
412 .and_then(|p| p.as_str())
413 .map(|s| s.to_string())
414 && let Some(one_of) = obj.get("oneOf").and_then(|v| v.as_array())
415 {
416 let mut mapping = serde_json::Map::new();
417
418 for variant in one_of {
419 // Resolve $ref to the definition entry
420 if let Some(ref_path) = variant.get("$ref").and_then(|r| r.as_str())
421 && let Some(def_key) = ref_path.strip_prefix(ref_prefix)
422 && let Some(def_schema) = defs.get(def_key)
423 {
424 // Read the const value for the discriminator property
425 if let Some(const_val) = def_schema
426 .get("properties")
427 .and_then(|p| p.get(&property_name))
428 .and_then(|p| p.get("const"))
429 .and_then(|c| c.as_str())
430 {
431 mapping
432 .insert(const_val.to_string(), Value::String(ref_path.to_string()));
433 }
434 }
435 }
436
437 if !mapping.is_empty()
438 && let Some(disc) = obj.get_mut("discriminator").and_then(|d| d.as_object_mut())
439 {
440 // Replace the mapping entirely — the post-processing steps
441 // (extract_inline_oneof_to_defs) may have changed $ref paths
442 disc.insert("mapping".to_string(), Value::Object(mapping));
443 }
444 }
445
446 // Recurse into all values
447 for v in obj.values_mut() {
448 build_discriminator_mappings_recursive(v, ref_prefix, Some(defs));
449 }
450 }
451 Value::Array(arr) => {
452 for v in arr.iter_mut() {
453 build_discriminator_mappings_recursive(v, ref_prefix, Some(defs));
454 }
455 }
456 _ => {}
457 }
458}
459
460/// Add `default` alongside `const` for discriminator tag properties in definitions.
461///
462/// Code generators like `datamodel-code-generator` use `default` (not `const`)
463/// to determine tag values for generated tagged union types. This walks all
464/// definition entries referenced by discriminator mappings and adds `default`
465/// equal to `const` for the discriminator tag property.
466fn add_defaults_to_discriminator_consts(root: &mut Value, ref_prefix: &str) {
467 let Some(root_obj) = root.as_object() else {
468 return;
469 };
470
471 // Collect (def_key, property_name) pairs from all discriminator mappings
472 let mut targets: Vec<(String, String)> = Vec::new();
473 collect_discriminator_targets(root_obj, ref_prefix, &mut targets);
474
475 if targets.is_empty() {
476 return;
477 }
478
479 // Apply defaults to the collected targets
480 let pointer = defs_pointer(ref_prefix);
481 let Some(defs) = root.pointer_mut(pointer).and_then(|d| d.as_object_mut()) else {
482 return;
483 };
484
485 for (def_key, property_name) in targets {
486 if let Some(def_schema) = defs.get_mut(&def_key)
487 && let Some(prop) = def_schema
488 .get_mut("properties")
489 .and_then(|p| p.get_mut(&property_name))
490 .and_then(|p| p.as_object_mut())
491 && let Some(const_val) = prop.get("const").cloned()
492 {
493 prop.entry("default").or_insert(const_val);
494 }
495 }
496}
497
498fn collect_discriminator_targets(
499 value: &serde_json::Map<String, Value>,
500 ref_prefix: &str,
501 targets: &mut Vec<(String, String)>,
502) {
503 if let Some(disc) = value.get("discriminator").and_then(|d| d.as_object())
504 && let Some(property_name) = disc.get("propertyName").and_then(|p| p.as_str())
505 && let Some(mapping) = disc.get("mapping").and_then(|m| m.as_object())
506 {
507 for ref_path in mapping.values() {
508 if let Some(ref_str) = ref_path.as_str()
509 && let Some(def_key) = ref_str.strip_prefix(ref_prefix)
510 {
511 targets.push((def_key.to_string(), property_name.to_string()));
512 }
513 }
514 }
515
516 // Recurse into nested objects
517 for v in value.values() {
518 if let Some(obj) = v.as_object() {
519 collect_discriminator_targets(obj, ref_prefix, targets);
520 } else if let Some(arr) = v.as_array() {
521 for item in arr {
522 if let Some(obj) = item.as_object() {
523 collect_discriminator_targets(obj, ref_prefix, targets);
524 }
525 }
526 }
527 }
528}
529
530/// Recursively transform `#/$defs/X` references to `{base_url}#/$defs/X`.
531fn transform_refs_external(value: &mut Value, base_url: &str) {
532 match value {
533 Value::Object(map) => {
534 if let Some(Value::String(ref_str)) = map.get_mut("$ref")
535 && let Some(name) = ref_str.strip_prefix("#/$defs/")
536 {
537 *ref_str = format!("{base_url}#/$defs/{name}");
538 }
539 for v in map.values_mut() {
540 transform_refs_external(v, base_url);
541 }
542 }
543 Value::Array(arr) => {
544 for v in arr.iter_mut() {
545 transform_refs_external(v, base_url);
546 }
547 }
548 _ => {}
549 }
550}
551
552#[cfg(test)]
553mod tests {
554 use super::*;
555
556 #[test]
557 fn test_generate_json_schema_has_required_fields() {
558 use crate::schema::SchemaRef;
559
560 let schema = generate_json_schema::<SchemaRef>();
561
562 // Should have $schema declaration
563 assert_eq!(
564 schema.get("$schema"),
565 Some(&Value::String(
566 "https://json-schema.org/draft/2020-12/schema".to_string()
567 ))
568 );
569
570 // Should have title
571 assert!(schema.get("title").is_some());
572 }
573
574 #[test]
575 fn test_generate_json_schema_with_defs() {
576 use crate::workflow::Flow;
577
578 let schema = generate_json_schema_with_defs::<Flow>();
579
580 // Should have $schema
581 assert!(schema.get("$schema").is_some());
582 // Should have title
583 assert!(schema.get("title").is_some());
584 // Should have $defs
585 assert!(schema.get("$defs").is_some());
586 }
587
588 #[test]
589 fn test_transform_refs_external() {
590 let mut input = serde_json::json!({
591 "$ref": "#/$defs/MyType",
592 "nested": {
593 "$ref": "#/$defs/OtherType"
594 },
595 "array": [
596 { "$ref": "#/$defs/ArrayItem" }
597 ]
598 });
599
600 transform_refs_external(&mut input, "https://stepflow.org/schemas/v1/flow.json");
601
602 assert_eq!(
603 input,
604 serde_json::json!({
605 "$ref": "https://stepflow.org/schemas/v1/flow.json#/$defs/MyType",
606 "nested": {
607 "$ref": "https://stepflow.org/schemas/v1/flow.json#/$defs/OtherType"
608 },
609 "array": [
610 { "$ref": "https://stepflow.org/schemas/v1/flow.json#/$defs/ArrayItem" }
611 ]
612 })
613 );
614 }
615}