Skip to main content

stepflow_flow/
discriminator_schema.rs

1// Copyright 2025 DataStax Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
4// in compliance with the License. You may obtain a copy of the License at
5//
6//     http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software distributed under the License
9// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
10// or implied. See the License for the specific language governing permissions and limitations under
11// the License.
12
13//! Reusable schemars transforms for adding OpenAPI discriminators to tagged enum schemas.
14
15use serde_json::{Map, Value};
16
17/// A schemars [`Transform`](schemars::transform::Transform) that adds an OpenAPI
18/// `discriminator` object to `oneOf` schemas generated from `#[serde(tag = "...")]` enums.
19///
20/// # Usage
21///
22/// ```ignore
23/// #[derive(schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
24/// #[serde(tag = "type", rename_all = "camelCase")]
25/// #[schemars(transform = AddDiscriminator::new("type"))]
26/// enum MyEnum {
27///     #[schemars(title = "VariantA")]
28///     VariantA,
29///     #[schemars(title = "VariantB")]
30///     VariantB { value: String },
31/// }
32/// ```
33pub struct AddDiscriminator {
34    property_name: String,
35}
36
37impl AddDiscriminator {
38    /// Create a new `AddDiscriminator` transform for the given tag property name.
39    ///
40    /// The `property_name` should match the `tag` value in `#[serde(tag = "...")]`.
41    pub fn new(property_name: impl Into<String>) -> Self {
42        Self {
43            property_name: property_name.into(),
44        }
45    }
46
47    /// Extract the discriminator tag's `const` value from a oneOf variant schema.
48    fn find_tag_const<'a>(variant: &'a Value, tag_property: &str) -> Option<&'a str> {
49        variant
50            .get("properties")
51            .and_then(|p| p.get(tag_property))
52            .and_then(|p| p.get("const"))
53            .and_then(|c| c.as_str())
54    }
55
56    /// Extract a `$ref` path from a oneOf variant schema.
57    fn find_ref(variant: &Value) -> Option<&str> {
58        variant.get("$ref").and_then(|r| r.as_str())
59    }
60}
61
62impl schemars::transform::Transform for AddDiscriminator {
63    fn transform(&mut self, schema: &mut schemars::Schema) {
64        let Some(obj) = schema.as_object_mut() else {
65            return;
66        };
67        let Some(one_of) = obj.get("oneOf").and_then(|v| v.as_array()) else {
68            return;
69        };
70
71        // Build mapping from discriminator values to $ref paths
72        let mut mapping = Map::new();
73        for variant in one_of {
74            if let Some(tag_value) = Self::find_tag_const(variant, &self.property_name)
75                && let Some(ref_path) = Self::find_ref(variant)
76            {
77                mapping.insert(tag_value.to_string(), Value::String(ref_path.to_string()));
78            }
79        }
80
81        // Build the discriminator object
82        let mut discriminator = Map::new();
83        discriminator.insert(
84            "propertyName".to_string(),
85            Value::String(self.property_name.clone()),
86        );
87        if !mapping.is_empty() {
88            discriminator.insert("mapping".to_string(), Value::Object(mapping));
89        }
90
91        obj.insert("discriminator".to_string(), Value::Object(discriminator));
92    }
93}
94
95/// A schemars [`Transform`](schemars::transform::Transform) that merges top-level
96/// `properties` and `required` fields from a wrapper struct into each `oneOf` variant.
97///
98/// This is used when a struct wraps a `#[serde(flatten)]` tagged enum: schemars places
99/// the struct's own fields as top-level `properties` alongside the enum's `oneOf`.
100/// OpenAPI code generators don't handle this combination, so this transform pushes
101/// the shared properties into each variant so they appear in the generated models.
102///
103/// # Usage
104///
105/// ```ignore
106/// #[derive(schemars::JsonSchema, serde::Serialize, serde::Deserialize)]
107/// #[serde(rename_all = "camelCase")]
108/// #[schemars(transform = MergePropertiesIntoOneOf)]
109/// struct Wrapper {
110///     sequence_number: u64,
111///     timestamp: DateTime<Utc>,
112///     #[serde(flatten)]
113///     kind: MyEnum,
114/// }
115/// ```
116pub struct MergePropertiesIntoOneOf;
117
118impl schemars::transform::Transform for MergePropertiesIntoOneOf {
119    fn transform(&mut self, schema: &mut schemars::Schema) {
120        let Some(obj) = schema.as_object_mut() else {
121            return;
122        };
123
124        // Only act when both properties and oneOf are present (flatten pattern)
125        if !obj.contains_key("properties") || !obj.contains_key("oneOf") {
126            return;
127        }
128
129        let properties = obj.remove("properties").unwrap();
130        let required = obj.remove("required");
131
132        // Remove "type": "object" — the oneOf variants define their own type
133        obj.remove("type");
134
135        let props_map = match properties.as_object() {
136            Some(m) => m.clone(),
137            None => return,
138        };
139        let req_items: Vec<Value> = required
140            .as_ref()
141            .and_then(|r| r.as_array())
142            .cloned()
143            .unwrap_or_default();
144
145        let Some(one_of) = obj.get_mut("oneOf").and_then(|v| v.as_array_mut()) else {
146            return;
147        };
148
149        for variant in one_of.iter_mut() {
150            // Skip $ref variants — they'll be resolved later by the pipeline
151            if variant.get("$ref").is_some() {
152                continue;
153            }
154
155            let Some(variant_obj) = variant.as_object_mut() else {
156                continue;
157            };
158
159            // Merge properties
160            let variant_props = variant_obj
161                .entry("properties")
162                .or_insert_with(|| Value::Object(Map::new()));
163            if let Some(vp) = variant_props.as_object_mut() {
164                for (key, value) in &props_map {
165                    vp.insert(key.clone(), value.clone());
166                }
167            }
168
169            // Merge required
170            if !req_items.is_empty() {
171                let variant_req = variant_obj
172                    .entry("required")
173                    .or_insert_with(|| Value::Array(Vec::new()));
174                if let Some(vr) = variant_req.as_array_mut() {
175                    for item in &req_items {
176                        if !vr.contains(item) {
177                            vr.push(item.clone());
178                        }
179                    }
180                }
181            }
182        }
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189    use schemars::JsonSchema;
190    use serde::{Deserialize, Serialize};
191
192    // ---- Test 1: Tagged enum generates oneOf with const tag values ----
193
194    #[derive(Serialize, Deserialize, JsonSchema)]
195    #[serde(tag = "type", rename_all = "camelCase")]
196    enum SimpleTaggedEnum {
197        UnitVariant,
198        DataVariant { value: String, count: i32 },
199    }
200
201    #[test]
202    fn tagged_enum_has_one_of_with_const_tags() {
203        let schema = schemars::schema_for!(SimpleTaggedEnum);
204        let value = serde_json::to_value(&schema).unwrap();
205
206        // Should have oneOf
207        let one_of = value.get("oneOf").expect("should have oneOf");
208        let variants = one_of.as_array().unwrap();
209        assert_eq!(variants.len(), 2);
210
211        // Each variant should have the tag property with a const value
212        for variant in variants {
213            let tag = variant
214                .get("properties")
215                .and_then(|p| p.get("type"))
216                .expect("variant should have 'type' property");
217            assert!(tag.get("const").is_some(), "tag should have const value");
218        }
219
220        insta::assert_yaml_snapshot!("tagged_enum_schema", value);
221    }
222
223    // ---- Test 2: AddDiscriminator transform adds discriminator ----
224
225    #[derive(Serialize, Deserialize, JsonSchema)]
226    #[serde(tag = "type", rename_all = "camelCase")]
227    #[schemars(transform = AddDiscriminator::new("type"))]
228    enum DiscriminatedEnum {
229        UnitVariant,
230        DataVariant { value: String },
231    }
232
233    #[test]
234    fn add_discriminator_via_attribute() {
235        let schema = schemars::schema_for!(DiscriminatedEnum);
236        let value = serde_json::to_value(&schema).unwrap();
237
238        // Should have discriminator with propertyName
239        let discriminator = value
240            .get("discriminator")
241            .expect("should have discriminator");
242        assert_eq!(
243            discriminator.get("propertyName"),
244            Some(&Value::String("type".to_string()))
245        );
246
247        insta::assert_yaml_snapshot!("discriminated_enum_schema", value);
248    }
249
250    // ---- Test 3: Variant titles for Python codegen ----
251
252    #[derive(Serialize, Deserialize, JsonSchema)]
253    #[serde(tag = "action", rename_all = "camelCase")]
254    #[schemars(transform = AddDiscriminator::new("action"))]
255    enum TitledVariantsEnum {
256        #[schemars(title = "Fail")]
257        Fail,
258        #[schemars(title = "UseDefault")]
259        UseDefault { default_value: serde_json::Value },
260        #[schemars(title = "Retry")]
261        Retry,
262    }
263
264    #[test]
265    fn variant_titles_for_codegen() {
266        let schema = schemars::schema_for!(TitledVariantsEnum);
267        let value = serde_json::to_value(&schema).unwrap();
268
269        // Each oneOf variant should have a title
270        let variants = value
271            .get("oneOf")
272            .and_then(|v| v.as_array())
273            .expect("should have oneOf array");
274        for variant in variants {
275            assert!(
276                variant.get("title").is_some(),
277                "variant should have title: {variant:?}"
278            );
279        }
280
281        insta::assert_yaml_snapshot!("titled_variants_schema", value);
282    }
283
284    // ---- Test 4: Option/nullable types ----
285
286    #[derive(Serialize, Deserialize, JsonSchema)]
287    #[serde(rename_all = "camelCase")]
288    struct NullableFields {
289        required_field: String,
290        optional_field: Option<String>,
291        optional_complex: Option<Vec<i32>>,
292    }
293
294    #[test]
295    fn nullable_types() {
296        let schema = schemars::schema_for!(NullableFields);
297        let value = serde_json::to_value(&schema).unwrap();
298
299        // required_field should be in required list
300        let required = value.get("required").and_then(|r| r.as_array()).unwrap();
301        assert!(required.contains(&Value::String("requiredField".to_string())));
302
303        // optional fields should NOT be in required
304        assert!(!required.contains(&Value::String("optionalField".to_string())));
305
306        insta::assert_yaml_snapshot!("nullable_fields_schema", value);
307    }
308
309    // ---- Test 5: Manual JsonSchema impl (ValueExpr-like pattern) ----
310    // Validates we can write manual impls for types with custom serde.
311
312    #[derive(Serialize, Deserialize)]
313    #[serde(untagged)]
314    enum SimpleExpr {
315        Number(f64),
316        Text(String),
317        Bool(bool),
318        Null,
319    }
320
321    impl JsonSchema for SimpleExpr {
322        fn schema_name() -> std::borrow::Cow<'static, str> {
323            "SimpleExpr".into()
324        }
325
326        fn json_schema(_generator: &mut schemars::SchemaGenerator) -> schemars::Schema {
327            schemars::json_schema!({
328                "oneOf": [
329                    { "type": "number", "title": "Number" },
330                    { "type": "string", "title": "Text" },
331                    { "type": "boolean", "title": "Bool" },
332                    { "type": "null", "title": "Null" }
333                ]
334            })
335        }
336    }
337
338    #[test]
339    fn manual_json_schema_impl() {
340        let schema = schemars::schema_for!(SimpleExpr);
341        let value = serde_json::to_value(&schema).unwrap();
342
343        assert!(value.get("oneOf").is_some(), "should have oneOf");
344
345        insta::assert_yaml_snapshot!("manual_json_schema", value);
346    }
347
348    // ---- Test 6: Manual JsonSchema with discriminator (FlowResult pattern) ----
349    // Validates manual impls with embedded discriminator for custom-serde types.
350
351    #[derive(Serialize, Deserialize)]
352    #[serde(tag = "outcome")]
353    enum CustomTaggedResult {
354        #[serde(rename = "success")]
355        Success { value: serde_json::Value },
356        #[serde(rename = "failure")]
357        Failure { error: String },
358    }
359
360    impl JsonSchema for CustomTaggedResult {
361        fn schema_name() -> std::borrow::Cow<'static, str> {
362            "CustomTaggedResult".into()
363        }
364
365        fn json_schema(_generator: &mut schemars::SchemaGenerator) -> schemars::Schema {
366            schemars::json_schema!({
367                "oneOf": [
368                    {
369                        "title": "Success",
370                        "type": "object",
371                        "properties": {
372                            "outcome": { "type": "string", "const": "success" },
373                            "value": {}
374                        },
375                        "required": ["outcome", "value"]
376                    },
377                    {
378                        "title": "Failure",
379                        "type": "object",
380                        "properties": {
381                            "outcome": { "type": "string", "const": "failure" },
382                            "error": { "type": "string" }
383                        },
384                        "required": ["outcome", "error"]
385                    }
386                ],
387                "discriminator": {
388                    "propertyName": "outcome"
389                }
390            })
391        }
392    }
393
394    #[test]
395    fn manual_discriminator_in_json_schema() {
396        let schema = schemars::schema_for!(CustomTaggedResult);
397        let value = serde_json::to_value(&schema).unwrap();
398
399        let discriminator = value
400            .get("discriminator")
401            .expect("should have discriminator");
402        assert_eq!(
403            discriminator.get("propertyName"),
404            Some(&Value::String("outcome".to_string()))
405        );
406
407        insta::assert_yaml_snapshot!("custom_tagged_result_schema", value);
408    }
409
410    // ---- Test 7: Tagged enum with newtype variant (references another type) ----
411    // Validates schema output for enums like SupportedPlugin/LeaseManagerConfig
412    // that wrap structs in their variants.
413
414    #[derive(Serialize, Deserialize, JsonSchema)]
415    #[serde(rename_all = "camelCase")]
416    struct InnerConfig {
417        endpoint: String,
418        timeout_ms: u64,
419    }
420
421    #[derive(Serialize, Deserialize, JsonSchema)]
422    #[serde(tag = "type", rename_all = "camelCase")]
423    #[schemars(transform = AddDiscriminator::new("type"))]
424    enum PluginConfig {
425        #[schemars(title = "NoOp")]
426        NoOp,
427        #[schemars(title = "Remote")]
428        Remote(InnerConfig),
429    }
430
431    #[test]
432    fn tagged_enum_with_newtype_variant() {
433        let schema = schemars::schema_for!(PluginConfig);
434        let value = serde_json::to_value(&schema).unwrap();
435
436        // Should have discriminator
437        let discriminator = value
438            .get("discriminator")
439            .expect("should have discriminator");
440        assert_eq!(
441            discriminator.get("propertyName"),
442            Some(&Value::String("type".to_string()))
443        );
444
445        insta::assert_yaml_snapshot!("plugin_config_schema", value);
446    }
447
448    // ---- Test 8: Full pipeline (inline → extract → discriminator mapping) ----
449
450    #[derive(Serialize, Deserialize, JsonSchema)]
451    #[serde(tag = "action", rename_all = "camelCase")]
452    #[schemars(transform = AddDiscriminator::new("action"))]
453    enum PipelineEnum {
454        #[schemars(title = "Fail")]
455        Fail,
456        #[schemars(title = "UseDefault")]
457        UseDefault { default_value: serde_json::Value },
458        #[schemars(title = "Retry")]
459        Retry,
460    }
461
462    #[test]
463    fn full_pipeline_extracts_and_maps() {
464        // generate_json_schema_with_defs runs finalize_discriminators which
465        // extracts inline variants to $defs and builds discriminator mappings.
466        let value = crate::json_schema::generate_json_schema_with_defs::<PipelineEnum>();
467
468        // All oneOf entries should be $ref after extraction
469        let one_of = value
470            .get("oneOf")
471            .and_then(|v| v.as_array())
472            .expect("should have oneOf");
473        for variant in one_of {
474            assert!(
475                variant.get("$ref").is_some(),
476                "variant should be $ref after extraction, got: {variant}"
477            );
478        }
479
480        // $defs keys should match variant titles
481        let defs = value
482            .get("$defs")
483            .and_then(|v| v.as_object())
484            .expect("should have $defs");
485        assert!(defs.contains_key("Fail"), "$defs should have 'Fail'");
486        assert!(
487            defs.contains_key("UseDefault"),
488            "$defs should have 'UseDefault'"
489        );
490        assert!(defs.contains_key("Retry"), "$defs should have 'Retry'");
491
492        // Discriminator should have complete mapping
493        let disc = value
494            .get("discriminator")
495            .and_then(|d| d.as_object())
496            .expect("should have discriminator");
497        let mapping = disc
498            .get("mapping")
499            .and_then(|m| m.as_object())
500            .expect("discriminator should have mapping");
501        assert_eq!(
502            mapping.get("fail"),
503            Some(&Value::String("#/$defs/Fail".to_string()))
504        );
505        assert_eq!(
506            mapping.get("useDefault"),
507            Some(&Value::String("#/$defs/UseDefault".to_string()))
508        );
509        assert_eq!(
510            mapping.get("retry"),
511            Some(&Value::String("#/$defs/Retry".to_string()))
512        );
513
514        insta::assert_yaml_snapshot!("full_pipeline_schema", value);
515    }
516}