1use serde_json::{Map, Value};
16
17pub struct AddDiscriminator {
34 property_name: String,
35}
36
37impl AddDiscriminator {
38 pub fn new(property_name: impl Into<String>) -> Self {
42 Self {
43 property_name: property_name.into(),
44 }
45 }
46
47 fn find_tag_const<'a>(variant: &'a Value, tag_property: &str) -> Option<&'a str> {
49 variant
50 .get("properties")
51 .and_then(|p| p.get(tag_property))
52 .and_then(|p| p.get("const"))
53 .and_then(|c| c.as_str())
54 }
55
56 fn find_ref(variant: &Value) -> Option<&str> {
58 variant.get("$ref").and_then(|r| r.as_str())
59 }
60}
61
62impl schemars::transform::Transform for AddDiscriminator {
63 fn transform(&mut self, schema: &mut schemars::Schema) {
64 let Some(obj) = schema.as_object_mut() else {
65 return;
66 };
67 let Some(one_of) = obj.get("oneOf").and_then(|v| v.as_array()) else {
68 return;
69 };
70
71 let mut mapping = Map::new();
73 for variant in one_of {
74 if let Some(tag_value) = Self::find_tag_const(variant, &self.property_name)
75 && let Some(ref_path) = Self::find_ref(variant)
76 {
77 mapping.insert(tag_value.to_string(), Value::String(ref_path.to_string()));
78 }
79 }
80
81 let mut discriminator = Map::new();
83 discriminator.insert(
84 "propertyName".to_string(),
85 Value::String(self.property_name.clone()),
86 );
87 if !mapping.is_empty() {
88 discriminator.insert("mapping".to_string(), Value::Object(mapping));
89 }
90
91 obj.insert("discriminator".to_string(), Value::Object(discriminator));
92 }
93}
94
95pub struct MergePropertiesIntoOneOf;
117
118impl schemars::transform::Transform for MergePropertiesIntoOneOf {
119 fn transform(&mut self, schema: &mut schemars::Schema) {
120 let Some(obj) = schema.as_object_mut() else {
121 return;
122 };
123
124 if !obj.contains_key("properties") || !obj.contains_key("oneOf") {
126 return;
127 }
128
129 let properties = obj.remove("properties").unwrap();
130 let required = obj.remove("required");
131
132 obj.remove("type");
134
135 let props_map = match properties.as_object() {
136 Some(m) => m.clone(),
137 None => return,
138 };
139 let req_items: Vec<Value> = required
140 .as_ref()
141 .and_then(|r| r.as_array())
142 .cloned()
143 .unwrap_or_default();
144
145 let Some(one_of) = obj.get_mut("oneOf").and_then(|v| v.as_array_mut()) else {
146 return;
147 };
148
149 for variant in one_of.iter_mut() {
150 if variant.get("$ref").is_some() {
152 continue;
153 }
154
155 let Some(variant_obj) = variant.as_object_mut() else {
156 continue;
157 };
158
159 let variant_props = variant_obj
161 .entry("properties")
162 .or_insert_with(|| Value::Object(Map::new()));
163 if let Some(vp) = variant_props.as_object_mut() {
164 for (key, value) in &props_map {
165 vp.insert(key.clone(), value.clone());
166 }
167 }
168
169 if !req_items.is_empty() {
171 let variant_req = variant_obj
172 .entry("required")
173 .or_insert_with(|| Value::Array(Vec::new()));
174 if let Some(vr) = variant_req.as_array_mut() {
175 for item in &req_items {
176 if !vr.contains(item) {
177 vr.push(item.clone());
178 }
179 }
180 }
181 }
182 }
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189 use schemars::JsonSchema;
190 use serde::{Deserialize, Serialize};
191
192 #[derive(Serialize, Deserialize, JsonSchema)]
195 #[serde(tag = "type", rename_all = "camelCase")]
196 enum SimpleTaggedEnum {
197 UnitVariant,
198 DataVariant { value: String, count: i32 },
199 }
200
201 #[test]
202 fn tagged_enum_has_one_of_with_const_tags() {
203 let schema = schemars::schema_for!(SimpleTaggedEnum);
204 let value = serde_json::to_value(&schema).unwrap();
205
206 let one_of = value.get("oneOf").expect("should have oneOf");
208 let variants = one_of.as_array().unwrap();
209 assert_eq!(variants.len(), 2);
210
211 for variant in variants {
213 let tag = variant
214 .get("properties")
215 .and_then(|p| p.get("type"))
216 .expect("variant should have 'type' property");
217 assert!(tag.get("const").is_some(), "tag should have const value");
218 }
219
220 insta::assert_yaml_snapshot!("tagged_enum_schema", value);
221 }
222
223 #[derive(Serialize, Deserialize, JsonSchema)]
226 #[serde(tag = "type", rename_all = "camelCase")]
227 #[schemars(transform = AddDiscriminator::new("type"))]
228 enum DiscriminatedEnum {
229 UnitVariant,
230 DataVariant { value: String },
231 }
232
233 #[test]
234 fn add_discriminator_via_attribute() {
235 let schema = schemars::schema_for!(DiscriminatedEnum);
236 let value = serde_json::to_value(&schema).unwrap();
237
238 let discriminator = value
240 .get("discriminator")
241 .expect("should have discriminator");
242 assert_eq!(
243 discriminator.get("propertyName"),
244 Some(&Value::String("type".to_string()))
245 );
246
247 insta::assert_yaml_snapshot!("discriminated_enum_schema", value);
248 }
249
250 #[derive(Serialize, Deserialize, JsonSchema)]
253 #[serde(tag = "action", rename_all = "camelCase")]
254 #[schemars(transform = AddDiscriminator::new("action"))]
255 enum TitledVariantsEnum {
256 #[schemars(title = "Fail")]
257 Fail,
258 #[schemars(title = "UseDefault")]
259 UseDefault { default_value: serde_json::Value },
260 #[schemars(title = "Retry")]
261 Retry,
262 }
263
264 #[test]
265 fn variant_titles_for_codegen() {
266 let schema = schemars::schema_for!(TitledVariantsEnum);
267 let value = serde_json::to_value(&schema).unwrap();
268
269 let variants = value
271 .get("oneOf")
272 .and_then(|v| v.as_array())
273 .expect("should have oneOf array");
274 for variant in variants {
275 assert!(
276 variant.get("title").is_some(),
277 "variant should have title: {variant:?}"
278 );
279 }
280
281 insta::assert_yaml_snapshot!("titled_variants_schema", value);
282 }
283
284 #[derive(Serialize, Deserialize, JsonSchema)]
287 #[serde(rename_all = "camelCase")]
288 struct NullableFields {
289 required_field: String,
290 optional_field: Option<String>,
291 optional_complex: Option<Vec<i32>>,
292 }
293
294 #[test]
295 fn nullable_types() {
296 let schema = schemars::schema_for!(NullableFields);
297 let value = serde_json::to_value(&schema).unwrap();
298
299 let required = value.get("required").and_then(|r| r.as_array()).unwrap();
301 assert!(required.contains(&Value::String("requiredField".to_string())));
302
303 assert!(!required.contains(&Value::String("optionalField".to_string())));
305
306 insta::assert_yaml_snapshot!("nullable_fields_schema", value);
307 }
308
309 #[derive(Serialize, Deserialize)]
313 #[serde(untagged)]
314 enum SimpleExpr {
315 Number(f64),
316 Text(String),
317 Bool(bool),
318 Null,
319 }
320
321 impl JsonSchema for SimpleExpr {
322 fn schema_name() -> std::borrow::Cow<'static, str> {
323 "SimpleExpr".into()
324 }
325
326 fn json_schema(_generator: &mut schemars::SchemaGenerator) -> schemars::Schema {
327 schemars::json_schema!({
328 "oneOf": [
329 { "type": "number", "title": "Number" },
330 { "type": "string", "title": "Text" },
331 { "type": "boolean", "title": "Bool" },
332 { "type": "null", "title": "Null" }
333 ]
334 })
335 }
336 }
337
338 #[test]
339 fn manual_json_schema_impl() {
340 let schema = schemars::schema_for!(SimpleExpr);
341 let value = serde_json::to_value(&schema).unwrap();
342
343 assert!(value.get("oneOf").is_some(), "should have oneOf");
344
345 insta::assert_yaml_snapshot!("manual_json_schema", value);
346 }
347
348 #[derive(Serialize, Deserialize)]
352 #[serde(tag = "outcome")]
353 enum CustomTaggedResult {
354 #[serde(rename = "success")]
355 Success { value: serde_json::Value },
356 #[serde(rename = "failure")]
357 Failure { error: String },
358 }
359
360 impl JsonSchema for CustomTaggedResult {
361 fn schema_name() -> std::borrow::Cow<'static, str> {
362 "CustomTaggedResult".into()
363 }
364
365 fn json_schema(_generator: &mut schemars::SchemaGenerator) -> schemars::Schema {
366 schemars::json_schema!({
367 "oneOf": [
368 {
369 "title": "Success",
370 "type": "object",
371 "properties": {
372 "outcome": { "type": "string", "const": "success" },
373 "value": {}
374 },
375 "required": ["outcome", "value"]
376 },
377 {
378 "title": "Failure",
379 "type": "object",
380 "properties": {
381 "outcome": { "type": "string", "const": "failure" },
382 "error": { "type": "string" }
383 },
384 "required": ["outcome", "error"]
385 }
386 ],
387 "discriminator": {
388 "propertyName": "outcome"
389 }
390 })
391 }
392 }
393
394 #[test]
395 fn manual_discriminator_in_json_schema() {
396 let schema = schemars::schema_for!(CustomTaggedResult);
397 let value = serde_json::to_value(&schema).unwrap();
398
399 let discriminator = value
400 .get("discriminator")
401 .expect("should have discriminator");
402 assert_eq!(
403 discriminator.get("propertyName"),
404 Some(&Value::String("outcome".to_string()))
405 );
406
407 insta::assert_yaml_snapshot!("custom_tagged_result_schema", value);
408 }
409
410 #[derive(Serialize, Deserialize, JsonSchema)]
415 #[serde(rename_all = "camelCase")]
416 struct InnerConfig {
417 endpoint: String,
418 timeout_ms: u64,
419 }
420
421 #[derive(Serialize, Deserialize, JsonSchema)]
422 #[serde(tag = "type", rename_all = "camelCase")]
423 #[schemars(transform = AddDiscriminator::new("type"))]
424 enum PluginConfig {
425 #[schemars(title = "NoOp")]
426 NoOp,
427 #[schemars(title = "Remote")]
428 Remote(InnerConfig),
429 }
430
431 #[test]
432 fn tagged_enum_with_newtype_variant() {
433 let schema = schemars::schema_for!(PluginConfig);
434 let value = serde_json::to_value(&schema).unwrap();
435
436 let discriminator = value
438 .get("discriminator")
439 .expect("should have discriminator");
440 assert_eq!(
441 discriminator.get("propertyName"),
442 Some(&Value::String("type".to_string()))
443 );
444
445 insta::assert_yaml_snapshot!("plugin_config_schema", value);
446 }
447
448 #[derive(Serialize, Deserialize, JsonSchema)]
451 #[serde(tag = "action", rename_all = "camelCase")]
452 #[schemars(transform = AddDiscriminator::new("action"))]
453 enum PipelineEnum {
454 #[schemars(title = "Fail")]
455 Fail,
456 #[schemars(title = "UseDefault")]
457 UseDefault { default_value: serde_json::Value },
458 #[schemars(title = "Retry")]
459 Retry,
460 }
461
462 #[test]
463 fn full_pipeline_extracts_and_maps() {
464 let value = crate::json_schema::generate_json_schema_with_defs::<PipelineEnum>();
467
468 let one_of = value
470 .get("oneOf")
471 .and_then(|v| v.as_array())
472 .expect("should have oneOf");
473 for variant in one_of {
474 assert!(
475 variant.get("$ref").is_some(),
476 "variant should be $ref after extraction, got: {variant}"
477 );
478 }
479
480 let defs = value
482 .get("$defs")
483 .and_then(|v| v.as_object())
484 .expect("should have $defs");
485 assert!(defs.contains_key("Fail"), "$defs should have 'Fail'");
486 assert!(
487 defs.contains_key("UseDefault"),
488 "$defs should have 'UseDefault'"
489 );
490 assert!(defs.contains_key("Retry"), "$defs should have 'Retry'");
491
492 let disc = value
494 .get("discriminator")
495 .and_then(|d| d.as_object())
496 .expect("should have discriminator");
497 let mapping = disc
498 .get("mapping")
499 .and_then(|m| m.as_object())
500 .expect("discriminator should have mapping");
501 assert_eq!(
502 mapping.get("fail"),
503 Some(&Value::String("#/$defs/Fail".to_string()))
504 );
505 assert_eq!(
506 mapping.get("useDefault"),
507 Some(&Value::String("#/$defs/UseDefault".to_string()))
508 );
509 assert_eq!(
510 mapping.get("retry"),
511 Some(&Value::String("#/$defs/Retry".to_string()))
512 );
513
514 insta::assert_yaml_snapshot!("full_pipeline_schema", value);
515 }
516}