Skip to main content

spn_mcp/openapi/
parser.rs

1//! OpenAPI 3.0+ parser.
2//!
3//! Parses OpenAPI specifications and converts them to [`ApiConfig`].
4
5use std::collections::HashMap;
6use std::path::Path;
7
8use serde::Deserialize;
9use thiserror::Error;
10
11use crate::config::{ApiConfig, AuthConfig, AuthType, ParamDef, ParamType, ToolDef};
12
13/// Errors that can occur when parsing OpenAPI specs.
14#[derive(Debug, Error)]
15pub enum OpenApiError {
16    #[error("Failed to read file: {0}")]
17    Io(#[from] std::io::Error),
18
19    #[error("Failed to parse YAML: {0}")]
20    Yaml(#[from] serde_yaml::Error),
21
22    #[error("Failed to parse JSON: {0}")]
23    Json(#[from] serde_json::Error),
24
25    #[error("Unsupported OpenAPI version: {0}. Only 3.0+ is supported.")]
26    UnsupportedVersion(String),
27
28    #[error("Missing required field: {0}")]
29    MissingField(String),
30}
31
32/// Result type for OpenAPI operations.
33pub type Result<T> = std::result::Result<T, OpenApiError>;
34
35/// OpenAPI 3.0+ specification (partial - only what we need).
36#[derive(Debug, Clone, Deserialize)]
37pub struct OpenApiSpec {
38    /// OpenAPI version (must be 3.0+)
39    pub openapi: String,
40
41    /// API info
42    pub info: OpenApiInfo,
43
44    /// Servers (for base URL)
45    #[serde(default)]
46    pub servers: Vec<OpenApiServer>,
47
48    /// Path definitions
49    #[serde(default)]
50    pub paths: HashMap<String, PathItem>,
51
52    /// Security definitions
53    #[serde(default)]
54    pub components: Option<Components>,
55
56    /// Top-level security requirements
57    #[serde(default)]
58    pub security: Vec<SecurityRequirement>,
59}
60
61/// API information.
62#[derive(Debug, Clone, Deserialize)]
63pub struct OpenApiInfo {
64    pub title: String,
65    #[serde(default)]
66    pub version: String,
67    #[serde(default)]
68    pub description: Option<String>,
69}
70
71/// Server definition.
72#[derive(Debug, Clone, Deserialize)]
73pub struct OpenApiServer {
74    pub url: String,
75    #[serde(default)]
76    pub description: Option<String>,
77}
78
79/// Path item containing operations.
80#[derive(Debug, Clone, Deserialize)]
81pub struct PathItem {
82    #[serde(default)]
83    pub get: Option<Operation>,
84    #[serde(default)]
85    pub post: Option<Operation>,
86    #[serde(default)]
87    pub put: Option<Operation>,
88    #[serde(default)]
89    pub patch: Option<Operation>,
90    #[serde(default)]
91    pub delete: Option<Operation>,
92    #[serde(default)]
93    pub parameters: Vec<Parameter>,
94}
95
96/// Operation definition.
97#[derive(Debug, Clone, Deserialize)]
98#[serde(rename_all = "camelCase")]
99pub struct Operation {
100    #[serde(default)]
101    pub operation_id: Option<String>,
102    #[serde(default)]
103    pub summary: Option<String>,
104    #[serde(default)]
105    pub description: Option<String>,
106    #[serde(default)]
107    pub tags: Vec<String>,
108    #[serde(default)]
109    pub parameters: Vec<Parameter>,
110    #[serde(default)]
111    pub request_body: Option<RequestBody>,
112    #[serde(default)]
113    pub security: Vec<SecurityRequirement>,
114}
115
116/// Parameter definition.
117#[derive(Debug, Clone, Deserialize)]
118pub struct Parameter {
119    pub name: String,
120    #[serde(rename = "in")]
121    pub location: String, // path, query, header, cookie
122    #[serde(default)]
123    pub required: Option<bool>,
124    #[serde(default)]
125    pub description: Option<String>,
126    #[serde(default)]
127    pub schema: Option<SchemaRef>,
128}
129
130/// Schema reference or inline schema.
131#[derive(Debug, Clone, Deserialize)]
132#[serde(untagged)]
133pub enum SchemaRef {
134    Ref {
135        #[serde(rename = "$ref")]
136        reference: String,
137    },
138    Inline(Schema),
139}
140
141/// Inline schema definition.
142#[derive(Debug, Clone, Deserialize)]
143pub struct Schema {
144    #[serde(rename = "type", default)]
145    pub schema_type: Option<String>,
146    #[serde(default)]
147    pub format: Option<String>,
148    #[serde(default)]
149    pub items: Option<Box<Schema>>,
150}
151
152/// Request body definition.
153#[derive(Debug, Clone, Deserialize)]
154pub struct RequestBody {
155    #[serde(default)]
156    pub required: Option<bool>,
157    #[serde(default)]
158    pub content: HashMap<String, MediaType>,
159}
160
161/// Media type definition.
162#[derive(Debug, Clone, Deserialize)]
163pub struct MediaType {
164    #[serde(default)]
165    pub schema: Option<SchemaRef>,
166}
167
168/// Components section.
169#[derive(Debug, Clone, Deserialize)]
170#[serde(rename_all = "camelCase")]
171pub struct Components {
172    #[serde(default)]
173    pub security_schemes: HashMap<String, SecurityScheme>,
174}
175
176/// Security scheme definition.
177#[derive(Debug, Clone, Deserialize)]
178pub struct SecurityScheme {
179    #[serde(rename = "type")]
180    pub scheme_type: String,
181    #[serde(default)]
182    pub scheme: Option<String>, // For http type: basic, bearer
183    #[serde(default)]
184    pub name: Option<String>, // For apiKey: header/query param name
185    #[serde(rename = "in", default)]
186    pub location: Option<String>, // For apiKey: header, query
187}
188
189/// Security requirement.
190pub type SecurityRequirement = HashMap<String, Vec<String>>;
191
192/// Parse an OpenAPI spec from a file.
193pub fn parse_openapi(path: &Path) -> Result<OpenApiSpec> {
194    let content = std::fs::read_to_string(path)?;
195
196    // Detect format by extension or content
197    let spec: OpenApiSpec =
198        if path.extension().is_some_and(|e| e == "json") || content.trim().starts_with('{') {
199            serde_json::from_str(&content)?
200        } else {
201            serde_yaml::from_str(&content)?
202        };
203
204    // Validate version
205    if !spec.openapi.starts_with("3.") {
206        return Err(OpenApiError::UnsupportedVersion(spec.openapi));
207    }
208
209    Ok(spec)
210}
211
212impl OpenApiSpec {
213    /// Convert the OpenAPI spec to an ApiConfig.
214    pub fn to_api_config(&self, api_name: Option<&str>) -> ApiConfig {
215        let name = api_name
216            .map(String::from)
217            .unwrap_or_else(|| slugify(&self.info.title));
218
219        let base_url = self
220            .servers
221            .first()
222            .map(|s| s.url.clone())
223            .unwrap_or_default();
224
225        let auth = self.detect_auth();
226        let tools = self.extract_tools(&name);
227
228        ApiConfig {
229            name: name.clone(),
230            version: "1.0".to_string(),
231            base_url,
232            description: self.info.description.clone(),
233            auth,
234            rate_limit: None,
235            headers: None,
236            tools,
237        }
238    }
239
240    /// Detect authentication from security schemes.
241    fn detect_auth(&self) -> AuthConfig {
242        // Check components.securitySchemes
243        if let Some(components) = &self.components {
244            for (name, scheme) in &components.security_schemes {
245                match scheme.scheme_type.as_str() {
246                    "http" => {
247                        if scheme.scheme.as_deref() == Some("bearer") {
248                            return AuthConfig {
249                                auth_type: AuthType::Bearer,
250                                credential: name.clone(),
251                                location: None,
252                                key_name: None,
253                            };
254                        } else if scheme.scheme.as_deref() == Some("basic") {
255                            return AuthConfig {
256                                auth_type: AuthType::Basic,
257                                credential: name.clone(),
258                                location: None,
259                                key_name: None,
260                            };
261                        }
262                    }
263                    "apiKey" => {
264                        let location = match scheme.location.as_deref() {
265                            Some("query") => Some(crate::config::ApiKeyLocation::Query),
266                            _ => Some(crate::config::ApiKeyLocation::Header),
267                        };
268                        return AuthConfig {
269                            auth_type: AuthType::ApiKey,
270                            credential: name.clone(),
271                            location,
272                            key_name: scheme.name.clone(),
273                        };
274                    }
275                    _ => {}
276                }
277            }
278        }
279
280        // Default to bearer token
281        AuthConfig {
282            auth_type: AuthType::Bearer,
283            credential: "api_key".to_string(),
284            location: None,
285            key_name: None,
286        }
287    }
288
289    /// Extract tools from paths.
290    fn extract_tools(&self, api_name: &str) -> Vec<ToolDef> {
291        let mut tools = Vec::new();
292
293        for (path, item) in &self.paths {
294            // Collect path-level parameters
295            let path_params: Vec<_> = item.parameters.iter().collect();
296
297            // Process each HTTP method
298            if let Some(op) = &item.get {
299                tools.push(self.operation_to_tool(api_name, "GET", path, op, &path_params));
300            }
301            if let Some(op) = &item.post {
302                tools.push(self.operation_to_tool(api_name, "POST", path, op, &path_params));
303            }
304            if let Some(op) = &item.put {
305                tools.push(self.operation_to_tool(api_name, "PUT", path, op, &path_params));
306            }
307            if let Some(op) = &item.patch {
308                tools.push(self.operation_to_tool(api_name, "PATCH", path, op, &path_params));
309            }
310            if let Some(op) = &item.delete {
311                tools.push(self.operation_to_tool(api_name, "DELETE", path, op, &path_params));
312            }
313        }
314
315        // Sort by name for consistent ordering
316        tools.sort_by(|a, b| a.name.cmp(&b.name));
317        tools
318    }
319
320    /// Convert an operation to a ToolDef.
321    fn operation_to_tool(
322        &self,
323        api_name: &str,
324        method: &str,
325        path: &str,
326        op: &Operation,
327        path_params: &[&Parameter],
328    ) -> ToolDef {
329        // Generate tool name
330        let name = op
331            .operation_id
332            .clone()
333            .unwrap_or_else(|| generate_tool_name(api_name, method, path));
334
335        // Merge path-level and operation-level parameters
336        let mut params = Vec::new();
337        for param in path_params.iter().copied() {
338            params.push(parameter_to_param_def(param));
339        }
340        for param in &op.parameters {
341            params.push(parameter_to_param_def(param));
342        }
343
344        // Use summary or description
345        let description = op.summary.clone().or_else(|| op.description.clone());
346
347        ToolDef {
348            name,
349            description,
350            method: method.to_string(),
351            path: path.to_string(),
352            body_template: None,
353            params,
354            response: None,
355        }
356    }
357
358    /// Get tools filtered by tag.
359    pub fn tools_by_tag(&self, tag: &str) -> Vec<(&str, &str, &Operation)> {
360        let mut results = Vec::new();
361
362        for (path, item) in &self.paths {
363            let ops = [
364                ("GET", &item.get),
365                ("POST", &item.post),
366                ("PUT", &item.put),
367                ("PATCH", &item.patch),
368                ("DELETE", &item.delete),
369            ];
370
371            for (method, op_opt) in ops {
372                if let Some(op) = op_opt {
373                    if op.tags.iter().any(|t| t.eq_ignore_ascii_case(tag)) {
374                        results.push((path.as_str(), method, op));
375                    }
376                }
377            }
378        }
379
380        results
381    }
382
383    /// Get all unique tags.
384    pub fn tags(&self) -> Vec<String> {
385        let mut tags = std::collections::HashSet::new();
386
387        for item in self.paths.values() {
388            let ops: [&Option<Operation>; 5] =
389                [&item.get, &item.post, &item.put, &item.patch, &item.delete];
390            for op_opt in ops.into_iter().flatten() {
391                for tag in &op_opt.tags {
392                    tags.insert(tag.clone());
393                }
394            }
395        }
396
397        let mut sorted: Vec<_> = tags.into_iter().collect();
398        sorted.sort();
399        sorted
400    }
401
402    /// Count total endpoints.
403    pub fn endpoint_count(&self) -> usize {
404        self.paths
405            .values()
406            .map(|item| {
407                [
408                    item.get.is_some(),
409                    item.post.is_some(),
410                    item.put.is_some(),
411                    item.patch.is_some(),
412                    item.delete.is_some(),
413                ]
414                .iter()
415                .filter(|&&b| b)
416                .count()
417            })
418            .sum()
419    }
420}
421
422/// Convert a Parameter to ParamDef.
423fn parameter_to_param_def(param: &Parameter) -> ParamDef {
424    let param_type = param
425        .schema
426        .as_ref()
427        .map(schema_to_param_type)
428        .unwrap_or(ParamType::String);
429
430    let required = param.location == "path" || param.required.unwrap_or(false);
431
432    ParamDef {
433        name: param.name.clone(),
434        param_type,
435        items: None,
436        required,
437        default: None,
438        description: param.description.clone(),
439    }
440}
441
442/// Convert schema type to ParamType.
443fn schema_to_param_type(schema: &SchemaRef) -> ParamType {
444    match schema {
445        SchemaRef::Ref { .. } => ParamType::Object,
446        SchemaRef::Inline(s) => match s.schema_type.as_deref() {
447            Some("integer") => ParamType::Integer,
448            Some("number") => ParamType::Number,
449            Some("boolean") => ParamType::Boolean,
450            Some("array") => ParamType::Array,
451            Some("object") => ParamType::Object,
452            _ => ParamType::String,
453        },
454    }
455}
456
457/// Generate a tool name from method and path.
458fn generate_tool_name(api_name: &str, method: &str, path: &str) -> String {
459    let path_part: String = path
460        .split('/')
461        .filter(|s| !s.is_empty() && !s.starts_with('{'))
462        .collect::<Vec<_>>()
463        .join("_");
464
465    let method_prefix = match method.to_uppercase().as_str() {
466        "GET" => "get",
467        "POST" => "create",
468        "PUT" | "PATCH" => "update",
469        "DELETE" => "delete",
470        _ => "call",
471    };
472
473    let name = format!("{}_{}", method_prefix, path_part);
474    let name = name.trim_matches('_');
475
476    // If empty, use api name
477    if name.is_empty() {
478        format!("{}_{}", api_name, method.to_lowercase())
479    } else {
480        name.to_string()
481    }
482}
483
484/// Convert a string to a slug (lowercase, underscores).
485fn slugify(s: &str) -> String {
486    s.to_lowercase()
487        .chars()
488        .map(|c| if c.is_alphanumeric() { c } else { '_' })
489        .collect::<String>()
490        .split('_')
491        .filter(|s| !s.is_empty())
492        .collect::<Vec<_>>()
493        .join("_")
494}
495
496#[cfg(test)]
497mod tests {
498    use super::*;
499
500    #[test]
501    fn test_slugify() {
502        assert_eq!(slugify("GitHub API"), "github_api");
503        assert_eq!(slugify("My-Cool_API v2"), "my_cool_api_v2");
504        assert_eq!(slugify("  spaces  "), "spaces");
505    }
506
507    #[test]
508    fn test_generate_tool_name() {
509        assert_eq!(
510            generate_tool_name("github", "GET", "/repos/{owner}/{repo}"),
511            "get_repos"
512        );
513        assert_eq!(
514            generate_tool_name("github", "POST", "/repos/{owner}/{repo}/issues"),
515            "create_repos_issues"
516        );
517        assert_eq!(
518            generate_tool_name("github", "DELETE", "/repos/{owner}/{repo}"),
519            "delete_repos"
520        );
521    }
522
523    #[test]
524    fn test_parse_yaml() {
525        let yaml = r#"
526openapi: "3.0.0"
527info:
528  title: Test API
529  version: "1.0"
530servers:
531  - url: https://api.example.com
532paths:
533  /users:
534    get:
535      operationId: listUsers
536      summary: List all users
537      parameters:
538        - name: limit
539          in: query
540          schema:
541            type: integer
542"#;
543        let spec: OpenApiSpec = serde_yaml::from_str(yaml).unwrap();
544        assert_eq!(spec.info.title, "Test API");
545        assert_eq!(spec.paths.len(), 1);
546        assert!(spec.paths.get("/users").unwrap().get.is_some());
547    }
548
549    #[test]
550    fn test_to_api_config() {
551        let yaml = r#"
552openapi: "3.0.0"
553info:
554  title: Test API
555  version: "1.0"
556servers:
557  - url: https://api.example.com
558paths:
559  /users:
560    get:
561      operationId: listUsers
562      summary: List users
563  /users/{id}:
564    get:
565      operationId: getUser
566      parameters:
567        - name: id
568          in: path
569          required: true
570          schema:
571            type: string
572"#;
573        let spec: OpenApiSpec = serde_yaml::from_str(yaml).unwrap();
574        let config = spec.to_api_config(None);
575
576        assert_eq!(config.name, "test_api");
577        assert_eq!(config.base_url, "https://api.example.com");
578        assert_eq!(config.tools.len(), 2);
579    }
580
581    #[test]
582    fn test_detect_bearer_auth() {
583        let yaml = r#"
584openapi: "3.0.0"
585info:
586  title: Test
587  version: "1.0"
588paths: {}
589components:
590  securitySchemes:
591    bearerAuth:
592      type: http
593      scheme: bearer
594"#;
595        let spec: OpenApiSpec = serde_yaml::from_str(yaml).unwrap();
596        let config = spec.to_api_config(None);
597
598        assert_eq!(config.auth.auth_type, AuthType::Bearer);
599        assert_eq!(config.auth.credential, "bearerAuth");
600    }
601
602    #[test]
603    fn test_detect_api_key_auth() {
604        let yaml = r#"
605openapi: "3.0.0"
606info:
607  title: Test
608  version: "1.0"
609paths: {}
610components:
611  securitySchemes:
612    apiKey:
613      type: apiKey
614      name: X-API-Key
615      in: header
616"#;
617        let spec: OpenApiSpec = serde_yaml::from_str(yaml).unwrap();
618        let config = spec.to_api_config(None);
619
620        assert_eq!(config.auth.auth_type, AuthType::ApiKey);
621        assert_eq!(config.auth.key_name, Some("X-API-Key".to_string()));
622    }
623
624    #[test]
625    fn test_endpoint_count() {
626        let yaml = r#"
627openapi: "3.0.0"
628info:
629  title: Test
630  version: "1.0"
631paths:
632  /a:
633    get: {}
634    post: {}
635  /b:
636    delete: {}
637"#;
638        let spec: OpenApiSpec = serde_yaml::from_str(yaml).unwrap();
639        assert_eq!(spec.endpoint_count(), 3);
640    }
641
642    #[test]
643    fn test_tags() {
644        let yaml = r#"
645openapi: "3.0.0"
646info:
647  title: Test
648  version: "1.0"
649paths:
650  /users:
651    get:
652      tags: [users, admin]
653  /posts:
654    get:
655      tags: [posts]
656"#;
657        let spec: OpenApiSpec = serde_yaml::from_str(yaml).unwrap();
658        let tags = spec.tags();
659        assert_eq!(tags, vec!["admin", "posts", "users"]);
660    }
661}