Skip to main content

turbomcp_openapi/
provider.rs

1//! OpenAPI provider for generating MCP components from OpenAPI specs.
2
3use std::collections::HashMap;
4use std::path::Path;
5use std::sync::Arc;
6
7use openapiv3::{OpenAPI, Operation, Parameter, ParameterSchemaOrContent, ReferenceOr, Schema};
8use serde_json::{Value, json};
9use url::Url;
10
11use crate::error::{OpenApiError, Result};
12use crate::handler::OpenApiHandler;
13use crate::mapping::{McpType, RouteMapping};
14use crate::parser::{fetch_from_url, load_from_file, parse_spec};
15
16/// An operation extracted from an OpenAPI spec.
17#[derive(Debug, Clone)]
18pub struct ExtractedOperation {
19    /// HTTP method (GET, POST, etc.)
20    pub method: String,
21    /// Path template (e.g., "/users/{id}")
22    pub path: String,
23    /// Operation ID (if specified)
24    pub operation_id: Option<String>,
25    /// Summary/description
26    pub summary: Option<String>,
27    /// Operation description
28    pub description: Option<String>,
29    /// Parameters
30    pub parameters: Vec<ExtractedParameter>,
31    /// Request body schema (if any)
32    pub request_body_schema: Option<Value>,
33    /// What MCP type this maps to
34    pub mcp_type: McpType,
35}
36
37/// A parameter extracted from an OpenAPI operation.
38#[derive(Debug, Clone)]
39pub struct ExtractedParameter {
40    /// Parameter name
41    pub name: String,
42    /// Where the parameter goes (path, query, header, cookie)
43    pub location: String,
44    /// Whether the parameter is required
45    pub required: bool,
46    /// Description
47    pub description: Option<String>,
48    /// JSON Schema for the parameter
49    pub schema: Option<Value>,
50}
51
52/// Default request timeout in seconds.
53const DEFAULT_TIMEOUT_SECS: u64 = 30;
54
55/// OpenAPI to MCP provider.
56///
57/// This provider parses OpenAPI specifications and converts them to MCP
58/// tools and resources that can be used with a TurboMCP server.
59///
60/// # Security
61///
62/// The provider includes built-in SSRF protection that blocks requests to:
63/// - Localhost and loopback addresses (127.0.0.0/8, ::1)
64/// - Private IP ranges (10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16)
65/// - Link-local addresses (169.254.0.0/16) including cloud metadata endpoints
66/// - Other reserved ranges
67///
68/// Requests have a default timeout of 30 seconds to prevent slowloris attacks.
69#[derive(Debug)]
70pub struct OpenApiProvider {
71    /// The parsed OpenAPI specification
72    spec: OpenAPI,
73    /// Base URL for API calls
74    base_url: Option<Url>,
75    /// Route mapping configuration
76    mapping: RouteMapping,
77    /// HTTP client for making API calls
78    client: reqwest::Client,
79    /// Extracted operations
80    operations: Vec<ExtractedOperation>,
81    /// Request timeout
82    timeout: std::time::Duration,
83}
84
85impl OpenApiProvider {
86    /// Create a provider from a parsed OpenAPI specification.
87    pub fn from_spec(spec: OpenAPI) -> Self {
88        let mapping = RouteMapping::default_rules();
89        let timeout = std::time::Duration::from_secs(DEFAULT_TIMEOUT_SECS);
90        let client = reqwest::Client::builder()
91            .timeout(timeout)
92            .build()
93            .unwrap_or_else(|_| reqwest::Client::new());
94
95        let mut provider = Self {
96            spec,
97            base_url: None,
98            mapping,
99            client,
100            operations: Vec::new(),
101            timeout,
102        };
103        provider.extract_operations();
104        provider
105    }
106
107    /// Create a provider from an OpenAPI specification string.
108    pub fn from_string(content: &str) -> Result<Self> {
109        let spec = parse_spec(content)?;
110        Ok(Self::from_spec(spec))
111    }
112
113    /// Create a provider by loading from a file.
114    pub fn from_file(path: &Path) -> Result<Self> {
115        let spec = load_from_file(path)?;
116        Ok(Self::from_spec(spec))
117    }
118
119    /// Create a provider by fetching from a URL.
120    pub async fn from_url(url: &str) -> Result<Self> {
121        let spec = fetch_from_url(url).await?;
122        Ok(Self::from_spec(spec))
123    }
124
125    /// Set the base URL for API calls.
126    pub fn with_base_url(mut self, base_url: &str) -> Result<Self> {
127        self.base_url = Some(Url::parse(base_url)?);
128        Ok(self)
129    }
130
131    /// Set a custom route mapping configuration.
132    #[must_use]
133    pub fn with_route_mapping(mut self, mapping: RouteMapping) -> Self {
134        self.mapping = mapping;
135        self.extract_operations(); // Re-extract with new mapping
136        self
137    }
138
139    /// Set a custom HTTP client.
140    ///
141    /// # Warning
142    ///
143    /// When using a custom client, ensure it has appropriate timeout settings.
144    /// The default client uses a 30-second timeout.
145    #[must_use]
146    pub fn with_client(mut self, client: reqwest::Client) -> Self {
147        self.client = client;
148        self
149    }
150
151    /// Set a custom request timeout.
152    ///
153    /// This rebuilds the HTTP client with the new timeout. The default timeout
154    /// is 30 seconds.
155    #[must_use]
156    pub fn with_timeout(mut self, timeout: std::time::Duration) -> Self {
157        self.timeout = timeout;
158        self.client = reqwest::Client::builder()
159            .timeout(timeout)
160            .build()
161            .unwrap_or_else(|_| reqwest::Client::new());
162        self
163    }
164
165    /// Get the current request timeout.
166    pub fn timeout(&self) -> std::time::Duration {
167        self.timeout
168    }
169
170    /// Get the API title from the spec.
171    pub fn title(&self) -> &str {
172        &self.spec.info.title
173    }
174
175    /// Get the API version from the spec.
176    pub fn version(&self) -> &str {
177        &self.spec.info.version
178    }
179
180    /// Get all extracted operations.
181    pub fn operations(&self) -> &[ExtractedOperation] {
182        &self.operations
183    }
184
185    /// Get operations that map to MCP tools.
186    pub fn tools(&self) -> impl Iterator<Item = &ExtractedOperation> {
187        self.operations
188            .iter()
189            .filter(|op| op.mcp_type == McpType::Tool)
190    }
191
192    /// Get operations that map to MCP resources.
193    pub fn resources(&self) -> impl Iterator<Item = &ExtractedOperation> {
194        self.operations
195            .iter()
196            .filter(|op| op.mcp_type == McpType::Resource)
197    }
198
199    /// Convert this provider into an McpHandler.
200    pub fn into_handler(self) -> OpenApiHandler {
201        OpenApiHandler::new(Arc::new(self))
202    }
203
204    /// Extract operations from the OpenAPI spec.
205    fn extract_operations(&mut self) {
206        self.operations.clear();
207
208        for (path, path_item) in &self.spec.paths.paths {
209            let path_item = match path_item {
210                ReferenceOr::Item(item) => item,
211                ReferenceOr::Reference { .. } => continue, // Skip references for now
212            };
213
214            // Extract operations for each HTTP method
215            let methods = [
216                ("GET", &path_item.get),
217                ("POST", &path_item.post),
218                ("PUT", &path_item.put),
219                ("DELETE", &path_item.delete),
220                ("PATCH", &path_item.patch),
221            ];
222
223            for (method, operation) in methods {
224                if let Some(op) = operation {
225                    let mcp_type = self.mapping.get_mcp_type(method, path);
226                    if mcp_type == McpType::Skip {
227                        continue;
228                    }
229
230                    self.operations
231                        .push(self.extract_operation(method, path, op, mcp_type));
232                }
233            }
234        }
235    }
236
237    /// Extract a single operation.
238    fn extract_operation(
239        &self,
240        method: &str,
241        path: &str,
242        operation: &Operation,
243        mcp_type: McpType,
244    ) -> ExtractedOperation {
245        let parameters = operation
246            .parameters
247            .iter()
248            .filter_map(|p| match p {
249                ReferenceOr::Item(param) => Some(self.extract_parameter(param)),
250                ReferenceOr::Reference { .. } => None,
251            })
252            .collect();
253
254        let request_body_schema = operation.request_body.as_ref().and_then(|rb| match rb {
255            ReferenceOr::Item(body) => body
256                .content
257                .get("application/json")
258                .and_then(|mt| mt.schema.as_ref())
259                .and_then(|s| self.schema_to_json(s)),
260            ReferenceOr::Reference { .. } => None,
261        });
262
263        ExtractedOperation {
264            method: method.to_string(),
265            path: path.to_string(),
266            operation_id: operation.operation_id.clone(),
267            summary: operation.summary.clone(),
268            description: operation.description.clone(),
269            parameters,
270            request_body_schema,
271            mcp_type,
272        }
273    }
274
275    /// Extract a parameter definition.
276    fn extract_parameter(&self, param: &Parameter) -> ExtractedParameter {
277        let (name, location, required, description, schema) = match param {
278            Parameter::Query { parameter_data, .. } => (
279                parameter_data.name.clone(),
280                "query".to_string(),
281                parameter_data.required,
282                parameter_data.description.clone(),
283                self.extract_param_schema(&parameter_data.format),
284            ),
285            Parameter::Header { parameter_data, .. } => (
286                parameter_data.name.clone(),
287                "header".to_string(),
288                parameter_data.required,
289                parameter_data.description.clone(),
290                self.extract_param_schema(&parameter_data.format),
291            ),
292            Parameter::Path { parameter_data, .. } => (
293                parameter_data.name.clone(),
294                "path".to_string(),
295                true, // Path params are always required
296                parameter_data.description.clone(),
297                self.extract_param_schema(&parameter_data.format),
298            ),
299            Parameter::Cookie { parameter_data, .. } => (
300                parameter_data.name.clone(),
301                "cookie".to_string(),
302                parameter_data.required,
303                parameter_data.description.clone(),
304                self.extract_param_schema(&parameter_data.format),
305            ),
306        };
307
308        ExtractedParameter {
309            name,
310            location,
311            required,
312            description,
313            schema,
314        }
315    }
316
317    /// Extract schema from parameter format.
318    fn extract_param_schema(&self, format: &ParameterSchemaOrContent) -> Option<Value> {
319        match format {
320            ParameterSchemaOrContent::Schema(schema) => self.schema_to_json(schema),
321            ParameterSchemaOrContent::Content(_) => None,
322        }
323    }
324
325    /// Convert an OpenAPI schema to a JSON Schema value.
326    fn schema_to_json(&self, schema: &ReferenceOr<Schema>) -> Option<Value> {
327        match schema {
328            ReferenceOr::Item(s) => Some(serde_json::to_value(s).ok()?),
329            ReferenceOr::Reference { reference } => Some(json!({ "$ref": reference })),
330        }
331    }
332
333    /// Build the full URL for an operation.
334    pub(crate) fn build_url(
335        &self,
336        operation: &ExtractedOperation,
337        args: &HashMap<String, Value>,
338    ) -> Result<Url> {
339        let base = self.base_url.as_ref().ok_or(OpenApiError::NoBaseUrl)?;
340
341        // Replace path parameters
342        let mut path = operation.path.clone();
343        for param in &operation.parameters {
344            if param.location == "path" {
345                if let Some(value) = args.get(&param.name) {
346                    let value_str = match value {
347                        Value::String(s) => s.clone(),
348                        _ => value.to_string(),
349                    };
350                    path = path.replace(&format!("{{{}}}", param.name), &value_str);
351                } else if param.required {
352                    return Err(OpenApiError::MissingParameter(param.name.clone()));
353                }
354            }
355        }
356
357        let mut url = base.join(&path)?;
358
359        // Collect query parameters first
360        let mut query_params: Vec<(String, String)> = Vec::new();
361        for param in &operation.parameters {
362            if param.location == "query" {
363                if let Some(value) = args.get(&param.name) {
364                    let value_str = match value {
365                        Value::String(s) => s.clone(),
366                        Value::Bool(b) => b.to_string(),
367                        Value::Number(n) => n.to_string(),
368                        _ => value.to_string(),
369                    };
370                    query_params.push((param.name.clone(), value_str));
371                } else if param.required {
372                    return Err(OpenApiError::MissingParameter(param.name.clone()));
373                }
374            }
375        }
376
377        // Only add query string if there are parameters
378        if !query_params.is_empty() {
379            let mut query_pairs = url.query_pairs_mut();
380            for (key, value) in query_params {
381                query_pairs.append_pair(&key, &value);
382            }
383        }
384
385        Ok(url)
386    }
387
388    /// Get the HTTP client.
389    pub(crate) fn client(&self) -> &reqwest::Client {
390        &self.client
391    }
392}
393
394#[cfg(test)]
395mod tests {
396    use super::*;
397
398    const TEST_SPEC: &str = r#"{
399        "openapi": "3.0.0",
400        "info": {
401            "title": "Test API",
402            "version": "1.0.0"
403        },
404        "paths": {
405            "/users": {
406                "get": {
407                    "operationId": "listUsers",
408                    "summary": "List all users",
409                    "responses": { "200": { "description": "Success" } }
410                },
411                "post": {
412                    "operationId": "createUser",
413                    "summary": "Create a user",
414                    "responses": { "201": { "description": "Created" } }
415                }
416            },
417            "/users/{id}": {
418                "get": {
419                    "operationId": "getUser",
420                    "summary": "Get a user by ID",
421                    "parameters": [
422                        {
423                            "name": "id",
424                            "in": "path",
425                            "required": true,
426                            "schema": { "type": "string" }
427                        }
428                    ],
429                    "responses": { "200": { "description": "Success" } }
430                },
431                "delete": {
432                    "operationId": "deleteUser",
433                    "summary": "Delete a user",
434                    "parameters": [
435                        {
436                            "name": "id",
437                            "in": "path",
438                            "required": true,
439                            "schema": { "type": "string" }
440                        }
441                    ],
442                    "responses": { "204": { "description": "Deleted" } }
443                }
444            }
445        }
446    }"#;
447
448    #[test]
449    fn test_provider_from_string() {
450        let provider = OpenApiProvider::from_string(TEST_SPEC).unwrap();
451
452        assert_eq!(provider.title(), "Test API");
453        assert_eq!(provider.version(), "1.0.0");
454    }
455
456    #[test]
457    fn test_operation_extraction() {
458        let provider = OpenApiProvider::from_string(TEST_SPEC).unwrap();
459
460        assert_eq!(provider.operations().len(), 4);
461
462        // Check GET /users is a resource
463        let list_users = provider
464            .operations()
465            .iter()
466            .find(|op| op.operation_id.as_deref() == Some("listUsers"))
467            .unwrap();
468        assert_eq!(list_users.mcp_type, McpType::Resource);
469        assert_eq!(list_users.method, "GET");
470
471        // Check POST /users is a tool
472        let create_user = provider
473            .operations()
474            .iter()
475            .find(|op| op.operation_id.as_deref() == Some("createUser"))
476            .unwrap();
477        assert_eq!(create_user.mcp_type, McpType::Tool);
478        assert_eq!(create_user.method, "POST");
479    }
480
481    #[test]
482    fn test_tools_and_resources() {
483        let provider = OpenApiProvider::from_string(TEST_SPEC).unwrap();
484
485        let tools: Vec<_> = provider.tools().collect();
486        let resources: Vec<_> = provider.resources().collect();
487
488        // GET operations -> resources
489        assert_eq!(resources.len(), 2);
490        // POST, DELETE operations -> tools
491        assert_eq!(tools.len(), 2);
492    }
493
494    #[test]
495    fn test_build_url_with_path_params() {
496        let provider = OpenApiProvider::from_string(TEST_SPEC)
497            .unwrap()
498            .with_base_url("https://api.example.com")
499            .unwrap();
500
501        let get_user = provider
502            .operations()
503            .iter()
504            .find(|op| op.operation_id.as_deref() == Some("getUser"))
505            .unwrap();
506
507        let mut args = HashMap::new();
508        args.insert("id".to_string(), json!("123"));
509
510        let url = provider.build_url(get_user, &args).unwrap();
511        assert_eq!(url.as_str(), "https://api.example.com/users/123");
512    }
513
514    #[test]
515    fn test_missing_required_param() {
516        let provider = OpenApiProvider::from_string(TEST_SPEC)
517            .unwrap()
518            .with_base_url("https://api.example.com")
519            .unwrap();
520
521        let get_user = provider
522            .operations()
523            .iter()
524            .find(|op| op.operation_id.as_deref() == Some("getUser"))
525            .unwrap();
526
527        let args = HashMap::new(); // Missing 'id'
528
529        let result = provider.build_url(get_user, &args);
530        assert!(matches!(result, Err(OpenApiError::MissingParameter(_))));
531    }
532}