Skip to main content

vld/combinators/
discriminated_union.rs

1use serde_json::Value;
2
3use crate::error::{value_type_name, IssueCode, VldError};
4use crate::object::DynSchema;
5use crate::schema::VldSchema;
6
7/// Entry in a discriminated union: a discriminator value and its associated schema.
8struct Variant {
9    discriminator_value: Value,
10    schema: Box<dyn DynSchema>,
11}
12
13/// Discriminated union: chooses a schema based on a discriminator field value.
14///
15/// More efficient than a regular union because it looks up the correct variant
16/// by the discriminator field value instead of trying each schema in order.
17///
18/// Created via [`vld::discriminated_union()`](crate::discriminated_union).
19///
20/// # Example
21/// ```ignore
22/// let schema = vld::discriminated_union("type")
23///     .variant("dog", vld::object().field("type", vld::literal("dog")).field("bark", vld::boolean()))
24///     .variant("cat", vld::object().field("type", vld::literal("cat")).field("lives", vld::number().int()));
25/// ```
26pub struct ZDiscriminatedUnion {
27    discriminator: String,
28    variants: Vec<Variant>,
29}
30
31impl ZDiscriminatedUnion {
32    pub fn new(discriminator: impl Into<String>) -> Self {
33        Self {
34            discriminator: discriminator.into(),
35            variants: vec![],
36        }
37    }
38
39    /// Add a variant: when the discriminator field equals `value`, use `schema`.
40    pub fn variant<S: DynSchema + 'static>(mut self, value: impl Into<Value>, schema: S) -> Self {
41        self.variants.push(Variant {
42            discriminator_value: value.into(),
43            schema: Box::new(schema),
44        });
45        self
46    }
47
48    /// Add a string variant (convenience).
49    pub fn variant_str<S: DynSchema + 'static>(self, value: &str, schema: S) -> Self {
50        self.variant(Value::String(value.to_string()), schema)
51    }
52}
53
54impl VldSchema for ZDiscriminatedUnion {
55    type Output = Value;
56
57    fn parse_value(&self, value: &Value) -> Result<Value, VldError> {
58        let obj = value.as_object().ok_or_else(|| {
59            VldError::single(
60                IssueCode::InvalidType {
61                    expected: "object".to_string(),
62                    received: value_type_name(value),
63                },
64                format!("Expected object, received {}", value_type_name(value)),
65            )
66        })?;
67
68        let disc_value = obj.get(&self.discriminator).ok_or_else(|| {
69            VldError::single(
70                IssueCode::MissingField,
71                format!("Missing discriminator field \"{}\"", self.discriminator),
72            )
73        })?;
74
75        for variant in &self.variants {
76            if *disc_value == variant.discriminator_value {
77                return variant.schema.dyn_parse(value);
78            }
79        }
80
81        let known: Vec<String> = self
82            .variants
83            .iter()
84            .map(|v| format!("{}", v.discriminator_value))
85            .collect();
86
87        Err(VldError::single(
88            IssueCode::Custom {
89                code: "invalid_discriminator".to_string(),
90            },
91            format!(
92                "Invalid discriminator value {}. Expected one of: {}",
93                disc_value,
94                known.join(", ")
95            ),
96        ))
97    }
98}