Skip to main content

spn_mcp/config/
schema.rs

1//! Configuration schema types for API wrappers.
2//!
3//! These types define the YAML structure for API wrapper configurations.
4
5use schemars::JsonSchema;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9/// Top-level API wrapper configuration.
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct ApiConfig {
12    /// API identifier (used in tool names: {name}_{tool_name})
13    pub name: String,
14
15    /// Configuration version
16    #[serde(default = "default_version")]
17    pub version: String,
18
19    /// Base URL for all API requests
20    pub base_url: String,
21
22    /// Human-readable description
23    #[serde(default)]
24    pub description: Option<String>,
25
26    /// Authentication configuration
27    pub auth: AuthConfig,
28
29    /// Rate limiting configuration
30    #[serde(default)]
31    pub rate_limit: Option<RateLimitConfig>,
32
33    /// Default headers for all requests
34    #[serde(default)]
35    pub headers: Option<HashMap<String, String>>,
36
37    /// Tool definitions
38    pub tools: Vec<ToolDef>,
39}
40
41fn default_version() -> String {
42    "1.0".into()
43}
44
45/// Authentication configuration.
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct AuthConfig {
48    /// Authentication type
49    #[serde(rename = "type")]
50    pub auth_type: AuthType,
51
52    /// Credential name (resolved via spn daemon)
53    pub credential: String,
54
55    /// For api_key: header or query
56    #[serde(default)]
57    pub location: Option<ApiKeyLocation>,
58
59    /// For api_key: header/param name (e.g., "X-API-Key")
60    #[serde(default)]
61    pub key_name: Option<String>,
62}
63
64/// Authentication type.
65#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
66#[serde(rename_all = "snake_case")]
67pub enum AuthType {
68    /// HTTP Basic Authentication (username:password base64)
69    Basic,
70    /// Bearer token (Authorization: Bearer <token>)
71    Bearer,
72    /// API key (in header or query param)
73    ApiKey,
74}
75
76/// API key location.
77#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
78#[serde(rename_all = "snake_case")]
79pub enum ApiKeyLocation {
80    Header,
81    Query,
82}
83
84/// Rate limiting configuration.
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct RateLimitConfig {
87    /// Maximum requests per minute
88    pub requests_per_minute: u32,
89
90    /// Burst allowance (default: 1)
91    #[serde(default = "default_burst")]
92    pub burst: u32,
93}
94
95fn default_burst() -> u32 {
96    1
97}
98
99/// Tool definition (maps to a single MCP tool).
100#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct ToolDef {
102    /// Tool name (combined with API name: {api}_{name})
103    pub name: String,
104
105    /// Tool description for MCP
106    #[serde(default)]
107    pub description: Option<String>,
108
109    /// HTTP method (GET, POST, PUT, DELETE, etc.)
110    #[serde(default = "default_method")]
111    pub method: String,
112
113    /// API path (appended to base_url)
114    pub path: String,
115
116    /// Request body template (Tera syntax)
117    #[serde(default)]
118    pub body_template: Option<String>,
119
120    /// Parameter definitions
121    #[serde(default)]
122    pub params: Vec<ParamDef>,
123
124    /// Response handling
125    #[serde(default)]
126    pub response: Option<ResponseConfig>,
127}
128
129fn default_method() -> String {
130    "GET".into()
131}
132
133/// Parameter definition for a tool.
134#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
135pub struct ParamDef {
136    /// Parameter name
137    pub name: String,
138
139    /// Parameter type
140    #[serde(rename = "type")]
141    pub param_type: ParamType,
142
143    /// For array type: item type
144    #[serde(default)]
145    pub items: Option<ParamType>,
146
147    /// Whether parameter is required
148    #[serde(default)]
149    pub required: bool,
150
151    /// Default value (JSON)
152    #[serde(default)]
153    pub default: Option<serde_json::Value>,
154
155    /// Parameter description
156    #[serde(default)]
157    pub description: Option<String>,
158}
159
160/// Parameter type.
161#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
162#[serde(rename_all = "snake_case")]
163pub enum ParamType {
164    String,
165    Integer,
166    Number,
167    Boolean,
168    Array,
169    Object,
170}
171
172/// Response handling configuration.
173#[derive(Debug, Clone, Serialize, Deserialize)]
174pub struct ResponseConfig {
175    /// JSON path to extract from response (e.g., "tasks[0].result")
176    #[serde(default)]
177    pub extract: Option<String>,
178}
179
180impl ToolDef {
181    /// Generate JSON Schema for this tool's parameters.
182    pub fn to_json_schema(&self) -> serde_json::Value {
183        let mut properties = serde_json::Map::new();
184        let mut required = Vec::new();
185
186        for param in &self.params {
187            let mut prop = serde_json::Map::new();
188
189            // Type
190            let type_str = match param.param_type {
191                ParamType::String => "string",
192                ParamType::Integer => "integer",
193                ParamType::Number => "number",
194                ParamType::Boolean => "boolean",
195                ParamType::Array => "array",
196                ParamType::Object => "object",
197            };
198            prop.insert("type".into(), serde_json::Value::String(type_str.into()));
199
200            // Description
201            if let Some(desc) = &param.description {
202                prop.insert(
203                    "description".into(),
204                    serde_json::Value::String(desc.clone()),
205                );
206            }
207
208            // Default
209            if let Some(default) = &param.default {
210                prop.insert("default".into(), default.clone());
211            }
212
213            // Array items
214            if param.param_type == ParamType::Array {
215                if let Some(items_type) = &param.items {
216                    let items_type_str = match items_type {
217                        ParamType::String => "string",
218                        ParamType::Integer => "integer",
219                        ParamType::Number => "number",
220                        ParamType::Boolean => "boolean",
221                        ParamType::Array => "array",
222                        ParamType::Object => "object",
223                    };
224                    prop.insert("items".into(), serde_json::json!({"type": items_type_str}));
225                }
226            }
227
228            properties.insert(param.name.clone(), serde_json::Value::Object(prop));
229
230            if param.required {
231                required.push(serde_json::Value::String(param.name.clone()));
232            }
233        }
234
235        serde_json::json!({
236            "type": "object",
237            "properties": properties,
238            "required": required
239        })
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246
247    #[test]
248    fn test_parse_minimal_config() {
249        let yaml = r#"
250name: test
251base_url: https://api.example.com
252auth:
253  type: bearer
254  credential: test
255tools:
256  - name: get_data
257    path: /data
258"#;
259
260        let config: ApiConfig = serde_yaml::from_str(yaml).unwrap();
261        assert_eq!(config.name, "test");
262        assert_eq!(config.auth.auth_type, AuthType::Bearer);
263        assert_eq!(config.tools.len(), 1);
264        assert_eq!(config.tools[0].method, "GET"); // default
265    }
266
267    #[test]
268    fn test_parse_full_config() {
269        let yaml = r#"
270name: dataforseo
271version: "1.0"
272base_url: https://api.dataforseo.com/v3
273description: "DataForSEO API"
274auth:
275  type: basic
276  credential: dataforseo
277rate_limit:
278  requests_per_minute: 12
279  burst: 3
280headers:
281  Content-Type: application/json
282tools:
283  - name: keyword_ideas
284    description: "Get keyword ideas"
285    method: POST
286    path: /dataforseo_labs/google/keyword_ideas/live
287    body_template: |
288      [{"keywords": {{ keywords | json }}}]
289    params:
290      - name: keywords
291        type: array
292        items: string
293        required: true
294        description: "Seed keywords"
295"#;
296
297        let config: ApiConfig = serde_yaml::from_str(yaml).unwrap();
298        assert_eq!(config.name, "dataforseo");
299        assert_eq!(config.auth.auth_type, AuthType::Basic);
300        assert!(config.rate_limit.is_some());
301        assert_eq!(config.rate_limit.unwrap().requests_per_minute, 12);
302        assert_eq!(config.tools[0].params[0].param_type, ParamType::Array);
303    }
304
305    #[test]
306    fn test_to_json_schema() {
307        let tool = ToolDef {
308            name: "test".into(),
309            description: Some("Test tool".into()),
310            method: "POST".into(),
311            path: "/test".into(),
312            body_template: None,
313            params: vec![
314                ParamDef {
315                    name: "query".into(),
316                    param_type: ParamType::String,
317                    items: None,
318                    required: true,
319                    default: None,
320                    description: Some("Search query".into()),
321                },
322                ParamDef {
323                    name: "limit".into(),
324                    param_type: ParamType::Integer,
325                    items: None,
326                    required: false,
327                    default: Some(serde_json::json!(10)),
328                    description: None,
329                },
330            ],
331            response: None,
332        };
333
334        let schema = tool.to_json_schema();
335        assert_eq!(schema["type"], "object");
336        assert!(schema["properties"]["query"]["type"] == "string");
337        assert!(schema["properties"]["limit"]["default"] == 10);
338        assert!(schema["required"]
339            .as_array()
340            .unwrap()
341            .contains(&serde_json::json!("query")));
342    }
343}