1use std::collections::HashMap;
6use std::path::Path;
7
8use serde::Deserialize;
9use thiserror::Error;
10
11use crate::config::{ApiConfig, AuthConfig, AuthType, ParamDef, ParamType, ToolDef};
12
13#[derive(Debug, Error)]
15pub enum OpenApiError {
16 #[error("Failed to read file: {0}")]
17 Io(#[from] std::io::Error),
18
19 #[error("Failed to parse YAML: {0}")]
20 Yaml(#[from] serde_yaml::Error),
21
22 #[error("Failed to parse JSON: {0}")]
23 Json(#[from] serde_json::Error),
24
25 #[error("Unsupported OpenAPI version: {0}. Only 3.0+ is supported.")]
26 UnsupportedVersion(String),
27
28 #[error("Missing required field: {0}")]
29 MissingField(String),
30}
31
32pub type Result<T> = std::result::Result<T, OpenApiError>;
34
35#[derive(Debug, Clone, Deserialize)]
37pub struct OpenApiSpec {
38 pub openapi: String,
40
41 pub info: OpenApiInfo,
43
44 #[serde(default)]
46 pub servers: Vec<OpenApiServer>,
47
48 #[serde(default)]
50 pub paths: HashMap<String, PathItem>,
51
52 #[serde(default)]
54 pub components: Option<Components>,
55
56 #[serde(default)]
58 pub security: Vec<SecurityRequirement>,
59}
60
61#[derive(Debug, Clone, Deserialize)]
63pub struct OpenApiInfo {
64 pub title: String,
65 #[serde(default)]
66 pub version: String,
67 #[serde(default)]
68 pub description: Option<String>,
69}
70
71#[derive(Debug, Clone, Deserialize)]
73pub struct OpenApiServer {
74 pub url: String,
75 #[serde(default)]
76 pub description: Option<String>,
77}
78
79#[derive(Debug, Clone, Deserialize)]
81pub struct PathItem {
82 #[serde(default)]
83 pub get: Option<Operation>,
84 #[serde(default)]
85 pub post: Option<Operation>,
86 #[serde(default)]
87 pub put: Option<Operation>,
88 #[serde(default)]
89 pub patch: Option<Operation>,
90 #[serde(default)]
91 pub delete: Option<Operation>,
92 #[serde(default)]
93 pub parameters: Vec<Parameter>,
94}
95
96#[derive(Debug, Clone, Deserialize)]
98#[serde(rename_all = "camelCase")]
99pub struct Operation {
100 #[serde(default)]
101 pub operation_id: Option<String>,
102 #[serde(default)]
103 pub summary: Option<String>,
104 #[serde(default)]
105 pub description: Option<String>,
106 #[serde(default)]
107 pub tags: Vec<String>,
108 #[serde(default)]
109 pub parameters: Vec<Parameter>,
110 #[serde(default)]
111 pub request_body: Option<RequestBody>,
112 #[serde(default)]
113 pub security: Vec<SecurityRequirement>,
114}
115
116#[derive(Debug, Clone, Deserialize)]
118pub struct Parameter {
119 pub name: String,
120 #[serde(rename = "in")]
121 pub location: String, #[serde(default)]
123 pub required: Option<bool>,
124 #[serde(default)]
125 pub description: Option<String>,
126 #[serde(default)]
127 pub schema: Option<SchemaRef>,
128}
129
130#[derive(Debug, Clone, Deserialize)]
132#[serde(untagged)]
133pub enum SchemaRef {
134 Ref {
135 #[serde(rename = "$ref")]
136 reference: String,
137 },
138 Inline(Schema),
139}
140
141#[derive(Debug, Clone, Deserialize)]
143pub struct Schema {
144 #[serde(rename = "type", default)]
145 pub schema_type: Option<String>,
146 #[serde(default)]
147 pub format: Option<String>,
148 #[serde(default)]
149 pub items: Option<Box<Schema>>,
150}
151
152#[derive(Debug, Clone, Deserialize)]
154pub struct RequestBody {
155 #[serde(default)]
156 pub required: Option<bool>,
157 #[serde(default)]
158 pub content: HashMap<String, MediaType>,
159}
160
161#[derive(Debug, Clone, Deserialize)]
163pub struct MediaType {
164 #[serde(default)]
165 pub schema: Option<SchemaRef>,
166}
167
168#[derive(Debug, Clone, Deserialize)]
170#[serde(rename_all = "camelCase")]
171pub struct Components {
172 #[serde(default)]
173 pub security_schemes: HashMap<String, SecurityScheme>,
174}
175
176#[derive(Debug, Clone, Deserialize)]
178pub struct SecurityScheme {
179 #[serde(rename = "type")]
180 pub scheme_type: String,
181 #[serde(default)]
182 pub scheme: Option<String>, #[serde(default)]
184 pub name: Option<String>, #[serde(rename = "in", default)]
186 pub location: Option<String>, }
188
189pub type SecurityRequirement = HashMap<String, Vec<String>>;
191
192pub fn parse_openapi(path: &Path) -> Result<OpenApiSpec> {
194 let content = std::fs::read_to_string(path)?;
195
196 let spec: OpenApiSpec =
198 if path.extension().is_some_and(|e| e == "json") || content.trim().starts_with('{') {
199 serde_json::from_str(&content)?
200 } else {
201 serde_yaml::from_str(&content)?
202 };
203
204 if !spec.openapi.starts_with("3.") {
206 return Err(OpenApiError::UnsupportedVersion(spec.openapi));
207 }
208
209 Ok(spec)
210}
211
212impl OpenApiSpec {
213 pub fn to_api_config(&self, api_name: Option<&str>) -> ApiConfig {
215 let name = api_name
216 .map(String::from)
217 .unwrap_or_else(|| slugify(&self.info.title));
218
219 let base_url = self
220 .servers
221 .first()
222 .map(|s| s.url.clone())
223 .unwrap_or_default();
224
225 let auth = self.detect_auth();
226 let tools = self.extract_tools(&name);
227
228 ApiConfig {
229 name: name.clone(),
230 version: "1.0".to_string(),
231 base_url,
232 description: self.info.description.clone(),
233 auth,
234 rate_limit: None,
235 headers: None,
236 tools,
237 }
238 }
239
240 fn detect_auth(&self) -> AuthConfig {
242 if let Some(components) = &self.components {
244 for (name, scheme) in &components.security_schemes {
245 match scheme.scheme_type.as_str() {
246 "http" => {
247 if scheme.scheme.as_deref() == Some("bearer") {
248 return AuthConfig {
249 auth_type: AuthType::Bearer,
250 credential: name.clone(),
251 location: None,
252 key_name: None,
253 };
254 } else if scheme.scheme.as_deref() == Some("basic") {
255 return AuthConfig {
256 auth_type: AuthType::Basic,
257 credential: name.clone(),
258 location: None,
259 key_name: None,
260 };
261 }
262 }
263 "apiKey" => {
264 let location = match scheme.location.as_deref() {
265 Some("query") => Some(crate::config::ApiKeyLocation::Query),
266 _ => Some(crate::config::ApiKeyLocation::Header),
267 };
268 return AuthConfig {
269 auth_type: AuthType::ApiKey,
270 credential: name.clone(),
271 location,
272 key_name: scheme.name.clone(),
273 };
274 }
275 _ => {}
276 }
277 }
278 }
279
280 AuthConfig {
282 auth_type: AuthType::Bearer,
283 credential: "api_key".to_string(),
284 location: None,
285 key_name: None,
286 }
287 }
288
289 fn extract_tools(&self, api_name: &str) -> Vec<ToolDef> {
291 let mut tools = Vec::new();
292
293 for (path, item) in &self.paths {
294 let path_params: Vec<_> = item.parameters.iter().collect();
296
297 if let Some(op) = &item.get {
299 tools.push(self.operation_to_tool(api_name, "GET", path, op, &path_params));
300 }
301 if let Some(op) = &item.post {
302 tools.push(self.operation_to_tool(api_name, "POST", path, op, &path_params));
303 }
304 if let Some(op) = &item.put {
305 tools.push(self.operation_to_tool(api_name, "PUT", path, op, &path_params));
306 }
307 if let Some(op) = &item.patch {
308 tools.push(self.operation_to_tool(api_name, "PATCH", path, op, &path_params));
309 }
310 if let Some(op) = &item.delete {
311 tools.push(self.operation_to_tool(api_name, "DELETE", path, op, &path_params));
312 }
313 }
314
315 tools.sort_by(|a, b| a.name.cmp(&b.name));
317 tools
318 }
319
320 fn operation_to_tool(
322 &self,
323 api_name: &str,
324 method: &str,
325 path: &str,
326 op: &Operation,
327 path_params: &[&Parameter],
328 ) -> ToolDef {
329 let name = op
331 .operation_id
332 .clone()
333 .unwrap_or_else(|| generate_tool_name(api_name, method, path));
334
335 let mut params = Vec::new();
337 for param in path_params.iter().copied() {
338 params.push(parameter_to_param_def(param));
339 }
340 for param in &op.parameters {
341 params.push(parameter_to_param_def(param));
342 }
343
344 let description = op.summary.clone().or_else(|| op.description.clone());
346
347 ToolDef {
348 name,
349 description,
350 method: method.to_string(),
351 path: path.to_string(),
352 body_template: None,
353 params,
354 response: None,
355 }
356 }
357
358 pub fn tools_by_tag(&self, tag: &str) -> Vec<(&str, &str, &Operation)> {
360 let mut results = Vec::new();
361
362 for (path, item) in &self.paths {
363 let ops = [
364 ("GET", &item.get),
365 ("POST", &item.post),
366 ("PUT", &item.put),
367 ("PATCH", &item.patch),
368 ("DELETE", &item.delete),
369 ];
370
371 for (method, op_opt) in ops {
372 if let Some(op) = op_opt {
373 if op.tags.iter().any(|t| t.eq_ignore_ascii_case(tag)) {
374 results.push((path.as_str(), method, op));
375 }
376 }
377 }
378 }
379
380 results
381 }
382
383 pub fn tags(&self) -> Vec<String> {
385 let mut tags = std::collections::HashSet::new();
386
387 for item in self.paths.values() {
388 let ops: [&Option<Operation>; 5] =
389 [&item.get, &item.post, &item.put, &item.patch, &item.delete];
390 for op_opt in ops.into_iter().flatten() {
391 for tag in &op_opt.tags {
392 tags.insert(tag.clone());
393 }
394 }
395 }
396
397 let mut sorted: Vec<_> = tags.into_iter().collect();
398 sorted.sort();
399 sorted
400 }
401
402 pub fn endpoint_count(&self) -> usize {
404 self.paths
405 .values()
406 .map(|item| {
407 [
408 item.get.is_some(),
409 item.post.is_some(),
410 item.put.is_some(),
411 item.patch.is_some(),
412 item.delete.is_some(),
413 ]
414 .iter()
415 .filter(|&&b| b)
416 .count()
417 })
418 .sum()
419 }
420}
421
422fn parameter_to_param_def(param: &Parameter) -> ParamDef {
424 let param_type = param
425 .schema
426 .as_ref()
427 .map(schema_to_param_type)
428 .unwrap_or(ParamType::String);
429
430 let required = param.location == "path" || param.required.unwrap_or(false);
431
432 ParamDef {
433 name: param.name.clone(),
434 param_type,
435 items: None,
436 required,
437 default: None,
438 description: param.description.clone(),
439 }
440}
441
442fn schema_to_param_type(schema: &SchemaRef) -> ParamType {
444 match schema {
445 SchemaRef::Ref { .. } => ParamType::Object,
446 SchemaRef::Inline(s) => match s.schema_type.as_deref() {
447 Some("integer") => ParamType::Integer,
448 Some("number") => ParamType::Number,
449 Some("boolean") => ParamType::Boolean,
450 Some("array") => ParamType::Array,
451 Some("object") => ParamType::Object,
452 _ => ParamType::String,
453 },
454 }
455}
456
457fn generate_tool_name(api_name: &str, method: &str, path: &str) -> String {
459 let path_part: String = path
460 .split('/')
461 .filter(|s| !s.is_empty() && !s.starts_with('{'))
462 .collect::<Vec<_>>()
463 .join("_");
464
465 let method_prefix = match method.to_uppercase().as_str() {
466 "GET" => "get",
467 "POST" => "create",
468 "PUT" | "PATCH" => "update",
469 "DELETE" => "delete",
470 _ => "call",
471 };
472
473 let name = format!("{}_{}", method_prefix, path_part);
474 let name = name.trim_matches('_');
475
476 if name.is_empty() {
478 format!("{}_{}", api_name, method.to_lowercase())
479 } else {
480 name.to_string()
481 }
482}
483
484fn slugify(s: &str) -> String {
486 s.to_lowercase()
487 .chars()
488 .map(|c| if c.is_alphanumeric() { c } else { '_' })
489 .collect::<String>()
490 .split('_')
491 .filter(|s| !s.is_empty())
492 .collect::<Vec<_>>()
493 .join("_")
494}
495
496#[cfg(test)]
497mod tests {
498 use super::*;
499
500 #[test]
501 fn test_slugify() {
502 assert_eq!(slugify("GitHub API"), "github_api");
503 assert_eq!(slugify("My-Cool_API v2"), "my_cool_api_v2");
504 assert_eq!(slugify(" spaces "), "spaces");
505 }
506
507 #[test]
508 fn test_generate_tool_name() {
509 assert_eq!(
510 generate_tool_name("github", "GET", "/repos/{owner}/{repo}"),
511 "get_repos"
512 );
513 assert_eq!(
514 generate_tool_name("github", "POST", "/repos/{owner}/{repo}/issues"),
515 "create_repos_issues"
516 );
517 assert_eq!(
518 generate_tool_name("github", "DELETE", "/repos/{owner}/{repo}"),
519 "delete_repos"
520 );
521 }
522
523 #[test]
524 fn test_parse_yaml() {
525 let yaml = r#"
526openapi: "3.0.0"
527info:
528 title: Test API
529 version: "1.0"
530servers:
531 - url: https://api.example.com
532paths:
533 /users:
534 get:
535 operationId: listUsers
536 summary: List all users
537 parameters:
538 - name: limit
539 in: query
540 schema:
541 type: integer
542"#;
543 let spec: OpenApiSpec = serde_yaml::from_str(yaml).unwrap();
544 assert_eq!(spec.info.title, "Test API");
545 assert_eq!(spec.paths.len(), 1);
546 assert!(spec.paths.get("/users").unwrap().get.is_some());
547 }
548
549 #[test]
550 fn test_to_api_config() {
551 let yaml = r#"
552openapi: "3.0.0"
553info:
554 title: Test API
555 version: "1.0"
556servers:
557 - url: https://api.example.com
558paths:
559 /users:
560 get:
561 operationId: listUsers
562 summary: List users
563 /users/{id}:
564 get:
565 operationId: getUser
566 parameters:
567 - name: id
568 in: path
569 required: true
570 schema:
571 type: string
572"#;
573 let spec: OpenApiSpec = serde_yaml::from_str(yaml).unwrap();
574 let config = spec.to_api_config(None);
575
576 assert_eq!(config.name, "test_api");
577 assert_eq!(config.base_url, "https://api.example.com");
578 assert_eq!(config.tools.len(), 2);
579 }
580
581 #[test]
582 fn test_detect_bearer_auth() {
583 let yaml = r#"
584openapi: "3.0.0"
585info:
586 title: Test
587 version: "1.0"
588paths: {}
589components:
590 securitySchemes:
591 bearerAuth:
592 type: http
593 scheme: bearer
594"#;
595 let spec: OpenApiSpec = serde_yaml::from_str(yaml).unwrap();
596 let config = spec.to_api_config(None);
597
598 assert_eq!(config.auth.auth_type, AuthType::Bearer);
599 assert_eq!(config.auth.credential, "bearerAuth");
600 }
601
602 #[test]
603 fn test_detect_api_key_auth() {
604 let yaml = r#"
605openapi: "3.0.0"
606info:
607 title: Test
608 version: "1.0"
609paths: {}
610components:
611 securitySchemes:
612 apiKey:
613 type: apiKey
614 name: X-API-Key
615 in: header
616"#;
617 let spec: OpenApiSpec = serde_yaml::from_str(yaml).unwrap();
618 let config = spec.to_api_config(None);
619
620 assert_eq!(config.auth.auth_type, AuthType::ApiKey);
621 assert_eq!(config.auth.key_name, Some("X-API-Key".to_string()));
622 }
623
624 #[test]
625 fn test_endpoint_count() {
626 let yaml = r#"
627openapi: "3.0.0"
628info:
629 title: Test
630 version: "1.0"
631paths:
632 /a:
633 get: {}
634 post: {}
635 /b:
636 delete: {}
637"#;
638 let spec: OpenApiSpec = serde_yaml::from_str(yaml).unwrap();
639 assert_eq!(spec.endpoint_count(), 3);
640 }
641
642 #[test]
643 fn test_tags() {
644 let yaml = r#"
645openapi: "3.0.0"
646info:
647 title: Test
648 version: "1.0"
649paths:
650 /users:
651 get:
652 tags: [users, admin]
653 /posts:
654 get:
655 tags: [posts]
656"#;
657 let spec: OpenApiSpec = serde_yaml::from_str(yaml).unwrap();
658 let tags = spec.tags();
659 assert_eq!(tags, vec!["admin", "posts", "users"]);
660 }
661}