1use std::collections::HashSet;
13
14use proc_macro::TokenStream;
15use quote::quote;
16use regex::Regex;
17use serde_json::Value;
18use syn::parse::{Parse, ParseStream};
19use syn::{Ident, LitStr, Token, braced, bracketed};
20
21struct OpenApiConfig {
22 url: Option<String>,
23 local_file: Option<String>,
24 root_types: Vec<String>,
25 extra_definitions: Option<String>,
26 debug_schema_path: Option<String>,
27}
28
29impl Parse for OpenApiConfig {
30 fn parse(input: ParseStream) -> syn::Result<Self> {
31 let mut url = None;
32 let mut local_file = None;
33 let mut root_types = Vec::new();
34 let mut extra_definitions = None;
35 let mut debug_schema_path = None;
36
37 while !input.is_empty() {
38 let key: Ident = input.parse()?;
39 input.parse::<Token![=]>()?;
40
41 match key.to_string().as_str() {
42 "url" => {
43 let lit: LitStr = input.parse()?;
44 url = Some(lit.value());
45 }
46 "local_file" => {
47 let lit: LitStr = input.parse()?;
48 local_file = Some(lit.value());
49 }
50 "root_types" => {
51 let content;
52 bracketed!(content in input);
53 while !content.is_empty() {
54 let lit: LitStr = content.parse()?;
55 root_types.push(lit.value());
56 if content.peek(Token![,]) {
57 content.parse::<Token![,]>()?;
58 }
59 }
60 }
61 "extra_definitions" => {
62 if input.peek(LitStr) {
63 let lit: LitStr = input.parse()?;
65 extra_definitions = Some(lit.value());
66 } else if input.peek(syn::token::Brace) {
67 let content;
69 braced!(content in input);
70 let tokens: proc_macro2::TokenStream = content.parse()?;
71 let json_str = format!("{{{}}}", tokens);
72 let _: serde_json::Value =
74 serde_json::from_str(&json_str).map_err(|e| {
75 syn::Error::new(key.span(), format!("invalid JSON: {}", e))
76 })?;
77 extra_definitions = Some(json_str);
78 } else {
79 return Err(syn::Error::new(
80 input.span(),
81 "expected string literal or JSON object",
82 ));
83 }
84 }
85 "debug_schema_path" => {
86 let lit: LitStr = input.parse()?;
87 debug_schema_path = Some(lit.value());
88 }
89 _ => {
90 return Err(syn::Error::new(key.span(), format!("unknown key: {}", key)));
91 }
92 }
93
94 if input.peek(Token![,]) {
95 input.parse::<Token![,]>()?;
96 }
97 }
98
99 if url.is_none() && local_file.is_none() {
101 return Err(syn::Error::new(
102 input.span(),
103 "missing `url` or `local_file`",
104 ));
105 }
106
107 Ok(OpenApiConfig {
108 url,
109 local_file,
110 root_types,
111 extra_definitions,
112 debug_schema_path,
113 })
114 }
115}
116
117#[proc_macro]
137pub fn include_openapi(input: TokenStream) -> TokenStream {
138 let config = syn::parse_macro_input!(input as OpenApiConfig);
139
140 let code = match generate_types(&config) {
141 Ok(code) => code,
142 Err(e) => {
143 return syn::Error::new(proc_macro2::Span::call_site(), e.to_string())
144 .to_compile_error()
145 .into();
146 }
147 };
148
149 let tokens: proc_macro2::TokenStream = match code.parse() {
150 Ok(t) => t,
151 Err(e) => {
152 return syn::Error::new(
153 proc_macro2::Span::call_site(),
154 format!("Failed to parse generated code: {}", e),
155 )
156 .to_compile_error()
157 .into();
158 }
159 };
160
161 quote! { #tokens }.into()
162}
163
164fn generate_types(config: &OpenApiConfig) -> Result<String, Box<dyn std::error::Error>> {
165 let spec_yaml = if let Some(ref local) = config.local_file {
167 let manifest_dir = std::env::var("CARGO_MANIFEST_DIR")?;
168 let local_path = std::path::Path::new(&manifest_dir).join(local);
169 if local_path.exists() {
170 std::fs::read_to_string(&local_path)?
171 } else if let Some(ref url) = config.url {
172 fetch_with_cache(url)?
173 } else {
174 return Err(format!("Local file not found: {}", local_path.display()).into());
175 }
176 } else if let Some(ref url) = config.url {
177 fetch_with_cache(url)?
178 } else {
179 return Err("No URL or local file specified".into());
180 };
181
182 let spec_yaml = preprocess_yaml(&spec_yaml);
183 let spec: Value = serde_yaml_ng::from_str(&spec_yaml)?;
184
185 let mut schemas = spec
186 .get("components")
187 .and_then(|c| c.get("schemas"))
188 .ok_or("No components/schemas in OpenAPI spec")?
189 .clone();
190
191 convert_openapi_to_json_schema(&mut schemas);
192
193 extract_inline_type_enums(&mut schemas);
195
196 if let Some(ref extra) = config.extra_definitions {
198 let extra_defs: serde_json::Map<String, Value> = serde_json::from_str(extra)?;
199 if let Value::Object(ref mut map) = schemas {
200 for (k, v) in extra_defs {
201 map.insert(k, v);
202 }
203 }
204 }
205
206 let root_refs: Vec<&str> = config.root_types.iter().map(|s| s.as_str()).collect();
207 let schemas = filter_schemas(schemas, &root_refs);
208
209 let mut json_schema = serde_json::json!({
210 "$schema": "http://json-schema.org/draft-07/schema#",
211 "definitions": schemas,
212 });
213 convert_openapi_to_json_schema(&mut json_schema);
214
215 if let Some(ref debug_path) = config.debug_schema_path {
217 let manifest_dir = std::env::var("CARGO_MANIFEST_DIR")?;
218 let debug_file = std::path::Path::new(&manifest_dir).join(debug_path);
219 let formatted = serde_json::to_string_pretty(&json_schema)?;
220 std::fs::write(&debug_file, formatted)?;
221 }
222
223 let mut type_space = typify::TypeSpace::new(
224 typify::TypeSpaceSettings::default().with_derive("PartialEq".to_string()),
225 );
226
227 let root_schema: schemars::schema::RootSchema = serde_json::from_value(json_schema.clone())
228 .map_err(|e| format!("Failed to parse JSON schema: {}", e,))?;
229 type_space
230 .add_root_schema(root_schema)
231 .map_err(|e| format!("Failed to add root schema to type space: {}", e))?;
232
233 Ok(type_space.to_stream().to_string())
234}
235
236fn fetch_with_cache(url: &str) -> Result<String, Box<dyn std::error::Error>> {
237 use http_cache_reqwest::{CACacheManager, Cache, CacheMode, HttpCache, HttpCacheOptions};
238 use reqwest_middleware::ClientBuilder;
239
240 let cache_dir = resolve_cache_dir()?;
241
242 let rt = tokio::runtime::Runtime::new()?;
243
244 rt.block_on(async {
245 let client = ClientBuilder::new(reqwest::Client::new())
246 .with(Cache(HttpCache {
247 mode: CacheMode::Default,
248 manager: CACacheManager { path: cache_dir },
249 options: HttpCacheOptions::default(),
250 }))
251 .build();
252
253 let response = client.get(url).send().await?;
254 let text = response.text().await?;
255 Ok(text)
256 })
257}
258
259fn resolve_cache_dir() -> Result<std::path::PathBuf, Box<dyn std::error::Error>> {
260 let mut candidates: Vec<std::path::PathBuf> = Vec::new();
261
262 if let Ok(dir) = std::env::var("YALLM_CACHE_DIR") {
263 candidates.push(std::path::PathBuf::from(dir));
264 }
265
266 if let Ok(dir) = std::env::var("CARGO_TARGET_DIR") {
267 candidates.push(std::path::PathBuf::from(dir).join("yallm-cache"));
268 }
269
270 if let Some(dir) = dirs::cache_dir() {
271 candidates.push(dir.join("yallm"));
272 }
273
274 candidates.push(std::env::temp_dir().join("yallm-cache"));
275
276 for candidate in candidates {
277 if ensure_writable_dir(&candidate).is_ok() {
278 return Ok(candidate);
279 }
280 }
281
282 Err("Failed to create cache directory for OpenAPI spec".into())
283}
284
285fn ensure_writable_dir(path: &std::path::Path) -> std::io::Result<()> {
286 use std::io::Write;
287
288 std::fs::create_dir_all(path)?;
289 let unique = format!(
290 ".yallm_cache_write_test_{}_{}",
291 std::process::id(),
292 std::time::SystemTime::now()
293 .duration_since(std::time::UNIX_EPOCH)
294 .unwrap_or_default()
295 .as_nanos()
296 );
297 let test_path = path.join(unique);
298 let mut file = std::fs::OpenOptions::new()
299 .write(true)
300 .create_new(true)
301 .open(&test_path)?;
302 file.write_all(b"ok")?;
303 std::fs::remove_file(test_path)?;
304 Ok(())
305}
306
307fn preprocess_yaml(yaml: &str) -> String {
313 let re = Regex::new(r"minimum:\s*-\d{15,}").unwrap();
314 let yaml = re.replace_all(yaml, "minimum: -2147483648").to_string();
315
316 let re = Regex::new(r"maximum:\s*\d{15,}").unwrap();
317 re.replace_all(&yaml, "maximum: 2147483647").to_string()
318}
319
320fn convert_openapi_to_json_schema(value: &mut Value) {
322 match value {
323 Value::Object(map) => {
324 let keys_to_remove: Vec<String> = map
326 .keys()
327 .filter(|k| k.starts_with("x-"))
328 .cloned()
329 .collect();
330 for key in keys_to_remove {
331 map.remove(&key);
332 }
333
334 if let Some(Value::String(ref_path)) = map.get_mut("$ref")
336 && ref_path.starts_with("#/components/schemas/")
337 {
338 *ref_path = ref_path.replace("#/components/schemas/", "#/definitions/");
339 }
340
341 let nullable_props: HashSet<String> = if map.get("type")
344 == Some(&Value::String("object".to_string()))
345 {
346 if let Some(Value::Object(props)) = map.get("properties") {
347 props
348 .iter()
349 .filter_map(|(name, prop_schema)| {
350 if let Value::Object(prop_obj) = prop_schema {
351 if let Some(Value::Array(any_of)) = prop_obj.get("anyOf") {
353 let has_null = any_of.iter().any(|v| {
354 matches!(v, Value::Object(m) if m.get("type") == Some(&Value::String("null".to_string())))
355 });
356 if has_null {
357 return Some(name.clone());
358 }
359 }
360 if prop_obj.get("default") == Some(&Value::Null) {
362 return Some(name.clone());
363 }
364 }
365 None
366 })
367 .collect()
368 } else {
369 HashSet::new()
370 }
371 } else {
372 HashSet::new()
373 };
374
375 if !nullable_props.is_empty()
377 && let Some(Value::Array(required)) = map.get_mut("required")
378 {
379 required.retain(|v| {
380 if let Value::String(s) = v {
381 !nullable_props.contains(s)
382 } else {
383 true
384 }
385 });
386 }
387
388 let replacement = if let Some(Value::Array(any_of)) = map.get("anyOf") {
390 let non_null: Vec<&Value> = any_of
391 .iter()
392 .filter(|v| {
393 !matches!(v, Value::Object(m) if m.get("type") == Some(&Value::String("null".to_string())))
394 })
395 .collect();
396
397 let has_null = any_of.len() != non_null.len();
398
399 if has_null && non_null.len() == 1 {
400 if let Value::Object(inner) = non_null[0] {
401 Some(inner.clone())
402 } else {
403 None
404 }
405 } else {
406 None
407 }
408 } else {
409 None
410 };
411
412 if let Some(inner) = replacement {
413 map.remove("anyOf");
414 for (k, v) in inner {
415 map.insert(k, v);
416 }
417 }
418
419 if let Some(Value::Bool(true)) = map.get("exclusiveMinimum") {
421 if let Some(min_val) = map.remove("minimum") {
422 map.insert("exclusiveMinimum".to_string(), min_val);
423 } else {
424 map.remove("exclusiveMinimum");
425 }
426 } else if let Some(Value::Bool(false)) = map.get("exclusiveMinimum") {
427 map.remove("exclusiveMinimum");
428 }
429
430 if let Some(Value::Bool(true)) = map.get("exclusiveMaximum") {
431 if let Some(max_val) = map.remove("maximum") {
432 map.insert("exclusiveMaximum".to_string(), max_val);
433 } else {
434 map.remove("exclusiveMaximum");
435 }
436 } else if let Some(Value::Bool(false)) = map.get("exclusiveMaximum") {
437 map.remove("exclusiveMaximum");
438 }
439
440 if let Some(Value::Bool(true)) = map.remove("nullable") {
442 if let Some(type_val) = map.get("type").cloned() {
443 match type_val {
446 Value::String(t) => {
447 map.insert(
448 "type".to_string(),
449 Value::Array(vec![
450 Value::String(t),
451 Value::String("null".to_string()),
452 ]),
453 );
454 }
455 Value::Array(mut arr) => {
456 if !arr.contains(&Value::String("null".to_string())) {
458 arr.push(Value::String("null".to_string()));
459 }
460 map.insert("type".to_string(), Value::Array(arr));
461 }
462 _ => {}
463 }
464 } else if let Some(Value::String(_)) = map.get("$ref") {
465 let ref_val = map.remove("$ref").unwrap();
467 let ref_schema = serde_json::json!({"$ref": ref_val});
468 let null_schema = serde_json::json!({"type": "null"});
469 map.insert(
470 "anyOf".to_string(),
471 Value::Array(vec![ref_schema, null_schema]),
472 );
473 } else {
474 map.insert("type".to_string(), Value::String("null".to_string()));
476 }
477 }
478
479 map.remove("discriminator");
481 map.remove("example");
482 map.remove("examples");
483 map.remove("externalDocs");
484 map.remove("xml");
485 map.remove("nullable");
486
487 if let Some(const_val) = map.remove("const") {
489 map.insert("enum".to_string(), Value::Array(vec![const_val]));
490 }
491
492 if let Some(Value::Null) = map.get("default") {
494 let type_val = map.get("type");
495 let is_nullable = match type_val {
496 Some(Value::Array(arr)) => arr.contains(&Value::String("null".to_string())),
497 Some(Value::String(s)) => s == "null",
498 _ => false,
499 };
500 if !is_nullable {
501 map.remove("default");
502 }
503 }
504
505 if let Some(Value::String(title)) = map.get("title") {
507 let is_string = match map.get("type") {
508 Some(Value::String(t)) => t == "string",
509 Some(Value::Array(arr)) => arr
510 .iter()
511 .any(|v| matches!(v, Value::String(s) if s == "string")),
512 _ => false,
513 };
514 if is_string {
515 if title == "Id" {
516 map.remove("pattern");
517 } else if title == "Name" {
518 map.remove("enum");
519 map.remove("minLength");
520 map.remove("maxLength");
521 }
522 }
523 }
524
525 for (_, v) in map.iter_mut() {
527 convert_openapi_to_json_schema(v);
528 }
529 }
530 Value::Array(arr) => {
531 for item in arr.iter_mut() {
532 convert_openapi_to_json_schema(item);
533 }
534 }
535 _ => {}
536 }
537}
538
539fn collect_refs(value: &Value, refs: &mut HashSet<String>) {
541 match value {
542 Value::Object(map) => {
543 if let Some(Value::String(ref_path)) = map.get("$ref") {
544 if let Some(name) = ref_path
546 .strip_prefix("#/definitions/")
547 .or_else(|| ref_path.strip_prefix("#/components/schemas/"))
548 {
549 refs.insert(name.to_string());
550 }
551 }
552 for v in map.values() {
553 collect_refs(v, refs);
554 }
555 }
556 Value::Array(arr) => {
557 for item in arr {
558 collect_refs(item, refs);
559 }
560 }
561 _ => {}
562 }
563}
564
565fn filter_schemas(schemas: Value, root_types: &[&str]) -> Value {
567 let schemas_map = match &schemas {
568 Value::Object(map) => map,
569 _ => return schemas,
570 };
571
572 let mut needed: HashSet<String> = root_types.iter().map(|s| s.to_string()).collect();
573 let mut to_process: Vec<String> = root_types.iter().map(|s| s.to_string()).collect();
574
575 while let Some(type_name) = to_process.pop() {
576 if let Some(schema) = schemas_map.get(&type_name) {
577 let mut refs = HashSet::new();
578 collect_refs(schema, &mut refs);
579 for r in refs {
580 if needed.insert(r.clone()) {
581 to_process.push(r);
582 }
583 }
584 }
585 }
586
587 let filtered: serde_json::Map<String, Value> = schemas_map
588 .iter()
589 .filter(|(k, _)| needed.contains(*k))
590 .map(|(k, v)| (k.clone(), v.clone()))
591 .collect();
592
593 Value::Object(filtered)
594}
595
596fn extract_inline_type_enums(schemas: &mut Value) {
601 let mut new_definitions: serde_json::Map<String, Value> = serde_json::Map::new();
602
603 if let Value::Object(schemas_map) = schemas {
604 let mut modifications: Vec<(String, String)> = Vec::new();
606
607 for (type_name, schema) in schemas_map.iter() {
608 if let Value::Object(obj) = schema
609 && let Some(Value::Object(props)) = obj.get("properties")
610 && let Some(Value::Object(type_prop)) = props.get("type")
611 {
612 if let Some(Value::Array(enum_vals)) = type_prop.get("enum")
614 && enum_vals.len() == 1
615 {
616 let unique_type_name = format!("{}Type", type_name);
618 modifications.push((type_name.clone(), unique_type_name.clone()));
619
620 let mut new_def = type_prop.clone();
622 new_def.insert("title".to_string(), Value::String(unique_type_name.clone()));
623 new_definitions.insert(unique_type_name, Value::Object(new_def));
624 }
625 }
626 }
627
628 for (type_name, unique_type_name) in modifications {
630 if let Some(Value::Object(obj)) = schemas_map.get_mut(&type_name)
631 && let Some(Value::Object(props)) = obj.get_mut("properties")
632 && props.contains_key("type")
633 {
634 props.insert(
636 "type".to_string(),
637 serde_json::json!({
638 "$ref": format!("#/definitions/{}", unique_type_name)
639 }),
640 );
641 }
642 }
643
644 for (name, def) in new_definitions {
646 schemas_map.insert(name, def);
647 }
648 }
649}