torch_web/
api.rs

1//! API versioning, documentation, and OpenAPI support
2
3use std::collections::HashMap;
4use crate::{Request, Response, App, Handler};
5
6#[cfg(feature = "json")]
7use serde_json::{json, Value};
8
9/// API version information
10#[derive(Debug, Clone)]
11pub struct ApiVersion {
12    pub version: String,
13    pub description: String,
14    pub deprecated: bool,
15    pub sunset_date: Option<String>,
16}
17
18impl ApiVersion {
19    pub fn new(version: &str, description: &str) -> Self {
20        Self {
21            version: version.to_string(),
22            description: description.to_string(),
23            deprecated: false,
24            sunset_date: None,
25        }
26    }
27
28    pub fn deprecated(mut self, sunset_date: Option<&str>) -> Self {
29        self.deprecated = true;
30        self.sunset_date = sunset_date.map(|s| s.to_string());
31        self
32    }
33}
34
35/// API endpoint documentation
36#[derive(Debug, Clone)]
37pub struct EndpointDoc {
38    pub method: String,
39    pub path: String,
40    pub summary: String,
41    pub description: String,
42    pub parameters: Vec<ParameterDoc>,
43    pub responses: HashMap<u16, ResponseDoc>,
44    pub tags: Vec<String>,
45}
46
47/// Internal API endpoint representation
48#[derive(Debug, Clone)]
49pub struct ApiEndpoint {
50    pub method: String,
51    pub path: String,
52    pub summary: String,
53    pub description: String,
54    pub parameters: Vec<ParameterDoc>,
55    pub responses: HashMap<u16, ResponseDoc>,
56    pub tags: Vec<String>,
57}
58
59/// Complete API documentation
60#[derive(Debug, Clone)]
61pub struct ApiDocumentation {
62    pub title: String,
63    pub version: String,
64    pub description: String,
65    pub endpoints: Vec<ApiEndpoint>,
66}
67
68#[derive(Debug, Clone)]
69pub struct ParameterDoc {
70    pub name: String,
71    pub location: ParameterLocation,
72    pub description: String,
73    pub required: bool,
74    pub schema_type: String,
75    pub example: Option<String>,
76}
77
78#[derive(Debug, Clone)]
79pub enum ParameterLocation {
80    Path,
81    Query,
82    Header,
83    Body,
84}
85
86#[derive(Debug, Clone)]
87pub struct ResponseDoc {
88    pub description: String,
89    pub content_type: String,
90    pub example: Option<String>,
91}
92
93/// API documentation builder
94#[derive(Clone)]
95pub struct ApiDocBuilder {
96    title: String,
97    description: String,
98    version: String,
99    base_url: String,
100    endpoints: Vec<EndpointDoc>,
101    versions: HashMap<String, ApiVersion>,
102}
103
104impl ApiDocBuilder {
105    pub fn new(title: &str, version: &str) -> Self {
106        Self {
107            title: title.to_string(),
108            description: String::new(),
109            version: version.to_string(),
110            base_url: "/".to_string(),
111            endpoints: Vec::new(),
112            versions: HashMap::new(),
113        }
114    }
115
116    pub fn description(mut self, description: &str) -> Self {
117        self.description = description.to_string();
118        self
119    }
120
121    pub fn base_url(mut self, base_url: &str) -> Self {
122        self.base_url = base_url.to_string();
123        self
124    }
125
126    pub fn add_version(mut self, version: ApiVersion) -> Self {
127        self.versions.insert(version.version.clone(), version);
128        self
129    }
130
131    pub fn add_endpoint(mut self, endpoint: EndpointDoc) -> Self {
132        self.endpoints.push(endpoint);
133        self
134    }
135
136    /// Generate OpenAPI 3.0 specification
137    #[cfg(feature = "json")]
138    pub fn generate_openapi(&self) -> Value {
139        let mut paths = serde_json::Map::new();
140        
141        for endpoint in &self.endpoints {
142            let path_item = paths.entry(&endpoint.path).or_insert_with(|| json!({}));
143            
144            let mut operation = serde_json::Map::new();
145            operation.insert("summary".to_string(), json!(endpoint.summary));
146            operation.insert("description".to_string(), json!(endpoint.description));
147            operation.insert("tags".to_string(), json!(endpoint.tags));
148            
149            // Note: deprecated field removed from ApiEndpoint for simplicity
150            
151            // Parameters
152            if !endpoint.parameters.is_empty() {
153                let params: Vec<Value> = endpoint.parameters.iter().map(|p| {
154                    json!({
155                        "name": p.name,
156                        "in": match p.location {
157                            ParameterLocation::Path => "path",
158                            ParameterLocation::Query => "query",
159                            ParameterLocation::Header => "header",
160                            ParameterLocation::Body => "body",
161                        },
162                        "description": p.description,
163                        "required": p.required,
164                        "schema": {
165                            "type": p.schema_type
166                        }
167                    })
168                }).collect();
169                operation.insert("parameters".to_string(), json!(params));
170            }
171            
172            // Responses
173            let mut responses = serde_json::Map::new();
174            for (status, response) in &endpoint.responses {
175                responses.insert(status.to_string(), json!({
176                    "description": response.description,
177                    "content": {
178                        response.content_type.clone(): {
179                            "example": response.example
180                        }
181                    }
182                }));
183            }
184            operation.insert("responses".to_string(), json!(responses));
185            
186            path_item[endpoint.method.to_lowercase()] = json!(operation);
187        }
188        
189        json!({
190            "openapi": "3.0.0",
191            "info": {
192                "title": self.title,
193                "description": self.description,
194                "version": self.version
195            },
196            "servers": [{
197                "url": self.base_url
198            }],
199            "paths": paths
200        })
201    }
202
203    #[cfg(not(feature = "json"))]
204    pub fn generate_openapi(&self) -> String {
205        "OpenAPI generation requires 'json' feature".to_string()
206    }
207
208    /// Generate simple HTML documentation
209    pub fn generate_html_docs(&self) -> String {
210        let mut html = format!(
211            r#"<!DOCTYPE html>
212<html>
213<head>
214    <title>{} API Documentation</title>
215    <style>
216        body {{ font-family: Arial, sans-serif; margin: 40px; }}
217        .endpoint {{ margin: 20px 0; padding: 20px; border: 1px solid #ddd; border-radius: 5px; }}
218        .method {{ display: inline-block; padding: 4px 8px; border-radius: 3px; color: white; font-weight: bold; }}
219        .get {{ background-color: #61affe; }}
220        .post {{ background-color: #49cc90; }}
221        .put {{ background-color: #fca130; }}
222        .delete {{ background-color: #f93e3e; }}
223        .deprecated {{ opacity: 0.6; }}
224        .parameter {{ margin: 10px 0; padding: 10px; background-color: #f8f9fa; border-radius: 3px; }}
225    </style>
226</head>
227<body>
228    <h1>{} API Documentation</h1>
229    <p>{}</p>
230    <p><strong>Version:</strong> {}</p>
231"#,
232            self.title, self.title, self.description, self.version
233        );
234
235        if !self.versions.is_empty() {
236            html.push_str("<h2>Available Versions</h2>");
237            for version in self.versions.values() {
238                let deprecated_class = if version.deprecated { " class=\"deprecated\"" } else { "" };
239                html.push_str(&format!(
240                    "<div{}><strong>v{}</strong> - {}</div>",
241                    deprecated_class, version.version, version.description
242                ));
243            }
244        }
245
246        html.push_str("<h2>Endpoints</h2>");
247        
248        for endpoint in &self.endpoints {
249            let deprecated_class = ""; // Deprecated field removed for simplicity
250            let method_class = endpoint.method.to_lowercase();
251            
252            html.push_str(&format!(
253                r#"<div class="endpoint{}">
254                    <h3><span class="method {}">{}</span> {}</h3>
255                    <p><strong>Summary:</strong> {}</p>
256                    <p>{}</p>
257"#,
258                deprecated_class, method_class, endpoint.method, endpoint.path,
259                endpoint.summary, endpoint.description
260            ));
261
262            if !endpoint.parameters.is_empty() {
263                html.push_str("<h4>Parameters</h4>");
264                for param in &endpoint.parameters {
265                    html.push_str(&format!(
266                        r#"<div class="parameter">
267                            <strong>{}</strong> ({:?}) - {}
268                            {}</div>"#,
269                        param.name,
270                        param.location,
271                        param.description,
272                        if param.required { " <em>(required)</em>" } else { "" }
273                    ));
274                }
275            }
276
277            if !endpoint.responses.is_empty() {
278                html.push_str("<h4>Responses</h4>");
279                for (status, response) in &endpoint.responses {
280                    html.push_str(&format!(
281                        "<div><strong>{}</strong> - {}</div>",
282                        status, response.description
283                    ));
284                }
285            }
286
287            html.push_str("</div>");
288        }
289
290        html.push_str("</body></html>");
291        html
292    }
293}
294
295/// API versioning middleware
296pub struct ApiVersioning {
297    default_version: String,
298    supported_versions: Vec<String>,
299    version_header: String,
300}
301
302impl ApiVersioning {
303    pub fn new(default_version: &str) -> Self {
304        Self {
305            default_version: default_version.to_string(),
306            supported_versions: vec![default_version.to_string()],
307            version_header: "API-Version".to_string(),
308        }
309    }
310
311    pub fn add_version(mut self, version: &str) -> Self {
312        self.supported_versions.push(version.to_string());
313        self
314    }
315
316    pub fn version_header(mut self, header: &str) -> Self {
317        self.version_header = header.to_string();
318        self
319    }
320
321    fn extract_version(&self, req: &Request) -> String {
322        // Try header first
323        if let Some(version) = req.header(&self.version_header) {
324            return version.to_string();
325        }
326
327        // Try query parameter
328        if let Some(version) = req.query("version") {
329            return version.to_string();
330        }
331
332        // Try path prefix (e.g., /v1/users)
333        let path = req.path();
334        if path.starts_with("/v") {
335            if let Some(version_part) = path.split('/').nth(1) {
336                if version_part.starts_with('v') {
337                    return version_part[1..].to_string();
338                }
339            }
340        }
341
342        self.default_version.clone()
343    }
344}
345
346impl crate::middleware::Middleware for ApiVersioning {
347    fn call(
348        &self,
349        req: Request,
350        next: Box<dyn Fn(Request) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> + Send + Sync>,
351    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> {
352        let version = self.extract_version(&req);
353        let supported_versions = self.supported_versions.clone();
354        let version_header = self.version_header.clone();
355
356        Box::pin(async move {
357            // Check if version is supported
358            if !supported_versions.contains(&version) {
359                return Response::bad_request()
360                    .json(&json!({
361                        "error": "Unsupported API version",
362                        "requested_version": version,
363                        "supported_versions": supported_versions
364                    }))
365                    .unwrap_or_else(|_| Response::bad_request().body("Unsupported API version"));
366            }
367
368            // Add version info to request context (would need to extend Request struct)
369            let mut response = next(req).await;
370            response = response.header(&version_header, &version);
371            response
372        })
373    }
374}
375
376/// Convenience methods for App to add documented endpoints
377impl App {
378    /// Add a documented GET endpoint
379    pub fn documented_get<H, T>(
380        self,
381        path: &str,
382        handler: H,
383        doc: EndpointDoc,
384    ) -> Self
385    where
386        H: Handler<T>,
387    {
388        // Store the documentation for later use in API doc generation
389        #[cfg(feature = "api")]
390        {
391            let mut app = self;
392            if let Some(ref mut api_docs) = app.api_docs {
393                let mut endpoint_doc = doc;
394                endpoint_doc.method = "GET".to_string();
395                endpoint_doc.path = path.to_string();
396                *api_docs = api_docs.clone().add_endpoint(endpoint_doc);
397            }
398            app.get(path, handler)
399        }
400
401        #[cfg(not(feature = "api"))]
402        {
403            let _ = doc; // Suppress unused warning
404            self.get(path, handler)
405        }
406    }
407
408    /// Add a documented POST endpoint
409    pub fn documented_post<H, T>(
410        self,
411        path: &str,
412        handler: H,
413        doc: EndpointDoc,
414    ) -> Self
415    where
416        H: Handler<T>,
417    {
418        // Store the documentation for later use in API doc generation
419        #[cfg(feature = "api")]
420        {
421            let mut app = self;
422            if let Some(ref mut api_docs) = app.api_docs {
423                let mut endpoint_doc = doc;
424                endpoint_doc.method = "POST".to_string();
425                endpoint_doc.path = path.to_string();
426                *api_docs = api_docs.clone().add_endpoint(endpoint_doc);
427            }
428            app.post(path, handler)
429        }
430
431        #[cfg(not(feature = "api"))]
432        {
433            let _ = doc; // Suppress unused warning
434            self.post(path, handler)
435        }
436    }
437}
438
439#[cfg(test)]
440mod tests {
441    use super::*;
442
443    #[test]
444    fn test_api_version() {
445        let version = ApiVersion::new("1.0", "Initial version");
446        assert_eq!(version.version, "1.0");
447        assert!(!version.deprecated);
448    }
449
450    #[test]
451    fn test_api_doc_builder() {
452        let builder = ApiDocBuilder::new("Test API", "1.0")
453            .description("A test API")
454            .base_url("https://api.example.com");
455        
456        assert_eq!(builder.title, "Test API");
457        assert_eq!(builder.version, "1.0");
458    }
459
460    #[cfg(feature = "json")]
461    #[test]
462    fn test_openapi_generation() {
463        let mut builder = ApiDocBuilder::new("Test API", "1.0");
464        
465        let endpoint = EndpointDoc {
466            method: "GET".to_string(),
467            path: "/users".to_string(),
468            summary: "Get users".to_string(),
469            description: "Retrieve all users".to_string(),
470            parameters: vec![],
471            responses: HashMap::new(),
472            tags: vec!["users".to_string()],
473            // deprecated field removed
474        };
475        
476        builder = builder.add_endpoint(endpoint);
477        let openapi = builder.generate_openapi();
478        
479        assert!(openapi["openapi"].as_str().unwrap().starts_with("3.0"));
480        assert_eq!(openapi["info"]["title"], "Test API");
481    }
482}