rust_x402/
axum.rs

1//! Axum integration for x402 payments
2
3use crate::middleware::{PaymentMiddleware, PaymentMiddlewareConfig};
4use crate::X402Error;
5use axum::{
6    extract::{Request, State},
7    http::{HeaderMap, HeaderValue, StatusCode},
8    middleware::Next,
9    response::{IntoResponse, Response},
10    routing::{delete, get, post, put},
11    Json, Router,
12};
13use rust_decimal::Decimal;
14use std::sync::Arc;
15use tower::ServiceBuilder;
16
17/// Re-export the payment middleware for convenience
18pub use crate::middleware::{create_payment_service, payment_middleware};
19
20/// Create a new Axum router with x402 payment middleware
21pub fn create_payment_router(
22    middleware: PaymentMiddleware,
23    routes: impl FnOnce(&mut Router) -> &mut Router,
24) -> Router {
25    let mut router = Router::new();
26    routes(&mut router);
27
28    // Apply payment middleware to all routes
29    router.layer(axum::middleware::from_fn_with_state(
30        middleware,
31        payment_middleware_handler,
32    ))
33}
34
35/// Helper function to create a payment-protected route
36pub fn payment_route<H>(
37    method: &str,
38    path: &str,
39    handler: H,
40    middleware: PaymentMiddleware,
41) -> Router
42where
43    H: axum::handler::Handler<(), ()> + Clone + Send + 'static,
44{
45    let router = match method.to_uppercase().as_str() {
46        "GET" => Router::new().route(path, get(handler)),
47        "POST" => Router::new().route(path, post(handler)),
48        "PUT" => Router::new().route(path, put(handler)),
49        "DELETE" => Router::new().route(path, delete(handler)),
50        _ => {
51            // For unsupported methods, return an error router
52            return Router::new().route(
53                path,
54                axum::routing::any(|| async {
55                    (StatusCode::METHOD_NOT_ALLOWED, "Unsupported HTTP method")
56                }),
57            );
58        }
59    };
60
61    // Apply payment middleware
62    router.layer(axum::middleware::from_fn_with_state(
63        middleware,
64        payment_middleware_handler,
65    ))
66}
67
68/// Create a payment middleware for Axum
69pub fn create_payment_middleware(amount: Decimal, pay_to: impl Into<String>) -> PaymentMiddleware {
70    PaymentMiddleware::new(amount, pay_to)
71}
72
73/// Check if the request is from a web browser
74fn is_web_browser(headers: &HeaderMap) -> bool {
75    let user_agent = headers
76        .get("User-Agent")
77        .and_then(|h| h.to_str().ok())
78        .unwrap_or("");
79
80    let accept = headers
81        .get("Accept")
82        .and_then(|h| h.to_str().ok())
83        .unwrap_or("");
84
85    accept.contains("text/html") && user_agent.contains("Mozilla")
86}
87
88/// Get default paywall HTML
89fn get_default_paywall_html() -> &'static str {
90    r#"<!DOCTYPE html>
91<html>
92<head>
93    <title>Payment Required</title>
94    <style>
95        body { font-family: Arial, sans-serif; text-align: center; padding: 50px; }
96        .container { max-width: 500px; margin: 0 auto; }
97        h1 { color: #333; }
98        p { color: #666; }
99    </style>
100</head>
101<body>
102    <div class="container">
103        <h1>Payment Required</h1>
104        <p>This resource requires payment to access. Please provide a valid X-PAYMENT header.</p>
105    </div>
106</body>
107</html>"#
108}
109
110/// Axum middleware handler for payment processing with settlement
111pub async fn payment_middleware_handler(
112    State(middleware): State<PaymentMiddleware>,
113    request: Request,
114    next: Next,
115) -> impl IntoResponse {
116    let config = middleware.config().clone();
117    let headers = request.headers().clone();
118    let path = request.uri().path();
119
120    // Skip payment middleware for health check endpoints
121    if path == "/health" || path.starts_with("/health/") {
122        return next.run(request).await;
123    }
124
125    // Determine the resource URL
126    let resource = if let Some(ref resource_url) = config.resource {
127        resource_url.clone()
128    } else if let Some(ref root_url) = config.resource_root_url {
129        format!("{}{}", root_url, path)
130    } else {
131        path.to_string()
132    };
133
134    // Create payment requirements
135    let requirements = match config.create_payment_requirements(&resource) {
136        Ok(req) => req,
137        Err(_) => {
138            return (
139                StatusCode::INTERNAL_SERVER_ERROR,
140                Json(serde_json::json!({
141                    "error": "Failed to create payment requirements",
142                    "x402Version": 1
143                })),
144            )
145                .into_response();
146        }
147    };
148
149    // Check for payment header
150    if let Some(payment_header) = headers.get("X-PAYMENT") {
151        if let Ok(payment_str) = payment_header.to_str() {
152            // Parse the payment payload
153            match crate::types::PaymentPayload::from_base64(payment_str) {
154                Ok(payment_payload) => {
155                    // Verify the payment using the middleware's verify method
156                    match middleware
157                        .verify_with_requirements(&payment_payload, &requirements)
158                        .await
159                    {
160                        Ok(true) => {
161                            // Payment is valid, proceed to next handler
162                            let mut response = next.run(request).await;
163
164                            // After successful response, settle the payment
165                            match middleware
166                                .settle_with_requirements(&payment_payload, &requirements)
167                                .await
168                            {
169                                Ok(settlement_response) => {
170                                    if let Ok(settlement_header) = settlement_response.to_base64() {
171                                        if let Ok(header_value) =
172                                            HeaderValue::from_str(&settlement_header)
173                                        {
174                                            response
175                                                .headers_mut()
176                                                .insert("X-PAYMENT-RESPONSE", header_value);
177                                        }
178                                    }
179                                }
180                                Err(e) => {
181                                    // Log settlement error but don't fail the request
182                                    tracing::warn!("Payment settlement failed: {}", e);
183                                }
184                            }
185
186                            return response;
187                        }
188                        Ok(false) => {
189                            // Payment verification failed
190                            let response_body = serde_json::json!({
191                                "x402Version": 1,
192                                "error": "Payment verification failed",
193                                "accepts": vec![requirements],
194                            });
195                            return (StatusCode::PAYMENT_REQUIRED, Json(response_body))
196                                .into_response();
197                        }
198                        Err(e) => {
199                            // Error during verification
200                            let response_body = serde_json::json!({
201                                "x402Version": 1,
202                                "error": format!("Payment verification error: {}", e),
203                                "accepts": vec![requirements],
204                            });
205                            return (StatusCode::PAYMENT_REQUIRED, Json(response_body))
206                                .into_response();
207                        }
208                    }
209                }
210                Err(e) => {
211                    // Invalid payment payload
212                    let response_body = serde_json::json!({
213                        "x402Version": 1,
214                        "error": format!("Invalid payment payload: {}", e),
215                        "accepts": vec![requirements],
216                    });
217                    return (StatusCode::PAYMENT_REQUIRED, Json(response_body)).into_response();
218                }
219            }
220        }
221    }
222
223    // No valid payment found, check if this is a web browser request
224    if is_web_browser(&headers) {
225        let html = config
226            .custom_paywall_html
227            .clone()
228            .unwrap_or_else(|| get_default_paywall_html().to_string());
229
230        let mut response = Response::new(axum::body::Body::from(html));
231        *response.status_mut() = StatusCode::PAYMENT_REQUIRED;
232        response
233            .headers_mut()
234            .insert("Content-Type", HeaderValue::from_static("text/html"));
235
236        return response.into_response();
237    }
238
239    // Return JSON response for API clients
240    let response_body = serde_json::json!({
241        "x402Version": 1,
242        "error": "X-PAYMENT header is required",
243        "accepts": vec![requirements],
244    });
245
246    (StatusCode::PAYMENT_REQUIRED, Json(response_body)).into_response()
247}
248
249/// Axum-specific payment middleware configuration
250#[derive(Debug, Clone)]
251pub struct AxumPaymentConfig {
252    /// Base payment middleware config
253    pub base_config: PaymentMiddlewareConfig,
254    /// Additional Axum-specific options
255    pub axum_options: AxumOptions,
256}
257
258/// Axum-specific options
259#[derive(Clone, Default)]
260pub struct AxumOptions {
261    /// Whether to enable CORS
262    pub enable_cors: bool,
263    /// CORS origins
264    pub cors_origins: Vec<String>,
265    /// Whether to enable request tracing
266    pub enable_tracing: bool,
267    /// Custom error handler
268    pub error_handler: Option<Arc<dyn Fn(X402Error) -> StatusCode + Send + Sync>>,
269}
270
271impl std::fmt::Debug for AxumOptions {
272    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
273        f.debug_struct("AxumOptions")
274            .field("enable_cors", &self.enable_cors)
275            .field("cors_origins", &self.cors_origins)
276            .field("enable_tracing", &self.enable_tracing)
277            .field("error_handler", &"<function>")
278            .finish()
279    }
280}
281
282impl AxumPaymentConfig {
283    /// Create a new Axum payment config
284    pub fn new(amount: Decimal, pay_to: impl Into<String>) -> Self {
285        Self {
286            base_config: PaymentMiddlewareConfig::new(amount, pay_to),
287            axum_options: AxumOptions::default(),
288        }
289    }
290
291    /// Set the payment description
292    pub fn with_description(mut self, description: impl Into<String>) -> Self {
293        self.base_config.description = Some(description.into());
294        self
295    }
296
297    /// Set the MIME type
298    pub fn with_mime_type(mut self, mime_type: impl Into<String>) -> Self {
299        self.base_config.mime_type = Some(mime_type.into());
300        self
301    }
302
303    /// Set the maximum timeout
304    pub fn with_max_timeout_seconds(mut self, max_timeout_seconds: u32) -> Self {
305        self.base_config.max_timeout_seconds = max_timeout_seconds;
306        self
307    }
308
309    /// Set the output schema
310    pub fn with_output_schema(mut self, output_schema: serde_json::Value) -> Self {
311        self.base_config.output_schema = Some(output_schema);
312        self
313    }
314
315    /// Set the facilitator configuration
316    pub fn with_facilitator_config(
317        mut self,
318        facilitator_config: crate::types::FacilitatorConfig,
319    ) -> Self {
320        self.base_config.facilitator_config = facilitator_config;
321        self
322    }
323
324    /// Set whether this is a testnet
325    pub fn with_testnet(mut self, testnet: bool) -> Self {
326        self.base_config.testnet = testnet;
327        self
328    }
329
330    /// Set custom paywall HTML
331    pub fn with_custom_paywall_html(mut self, html: impl Into<String>) -> Self {
332        self.base_config.custom_paywall_html = Some(html.into());
333        self
334    }
335
336    /// Set the resource URL
337    pub fn with_resource(mut self, resource: impl Into<String>) -> Self {
338        self.base_config.resource = Some(resource.into());
339        self
340    }
341
342    /// Set the resource root URL
343    pub fn with_resource_root_url(mut self, url: impl Into<String>) -> Self {
344        self.base_config.resource_root_url = Some(url.into());
345        self
346    }
347
348    /// Enable CORS
349    pub fn with_cors(mut self, origins: Vec<String>) -> Self {
350        self.axum_options.enable_cors = true;
351        self.axum_options.cors_origins = origins;
352        self
353    }
354
355    /// Enable request tracing
356    pub fn with_tracing(mut self) -> Self {
357        self.axum_options.enable_tracing = true;
358        self
359    }
360
361    /// Set custom error handler
362    pub fn with_error_handler<F>(mut self, handler: F) -> Self
363    where
364        F: Fn(X402Error) -> StatusCode + Send + Sync + 'static,
365    {
366        self.axum_options.error_handler = Some(Arc::new(handler));
367        self
368    }
369
370    /// Convert to PaymentMiddleware
371    pub fn into_middleware(self) -> PaymentMiddleware {
372        PaymentMiddleware {
373            config: Arc::new(self.base_config),
374            facilitator: None,
375            template_config: None,
376        }
377    }
378
379    /// Create a service builder with this configuration
380    pub fn create_service(&self) -> ServiceBuilder<tower::layer::util::Identity> {
381        // Note: Service layer integration is simplified for now
382        // In a full implementation, you would conditionally add layers based on options
383        ServiceBuilder::new()
384    }
385}
386
387/// Create a complete Axum application with x402 payment support
388pub fn create_payment_app(
389    config: AxumPaymentConfig,
390    routes: impl FnOnce(Router) -> Router,
391) -> Router {
392    let router = Router::new();
393    let router = routes(router);
394
395    // Convert config to middleware
396    let middleware = config.into_middleware();
397
398    // Apply payment middleware to all routes
399    router.layer(axum::middleware::from_fn_with_state(
400        middleware,
401        payment_middleware_handler,
402    ))
403}
404
405/// Helper for creating payment-protected handlers
406pub mod handlers {
407    use super::*;
408    use serde_json::json;
409
410    /// Create a simple JSON response handler
411    pub fn json_response<T: serde::Serialize>(data: T) -> impl IntoResponse {
412        Json(data)
413    }
414
415    /// Create a simple text response handler
416    pub fn text_response(text: impl Into<String>) -> impl IntoResponse {
417        text.into()
418    }
419
420    /// Create an error response handler
421    pub fn error_response(error: impl Into<String>) -> impl IntoResponse {
422        (
423            StatusCode::INTERNAL_SERVER_ERROR,
424            Json(json!({"error": error.into()})),
425        )
426    }
427
428    /// Create a success response handler
429    pub fn success_response<T: serde::Serialize>(data: T) -> impl IntoResponse {
430        (StatusCode::OK, Json(data))
431    }
432}
433
434/// Example handlers for common use cases
435pub mod examples {
436    use super::*;
437    use serde_json::json;
438
439    /// Example joke handler
440    pub async fn joke_handler() -> impl IntoResponse {
441        axum::Json(json!({
442            "joke": "Why do programmers prefer dark mode? Because light attracts bugs!"
443        }))
444    }
445
446    /// Example API data handler
447    pub async fn api_data_handler() -> impl IntoResponse {
448        axum::Json(json!({
449            "data": "This is premium API data that requires payment to access",
450            "timestamp": chrono::Utc::now().to_rfc3339(),
451            "source": "x402-protected-api"
452        }))
453    }
454
455    /// Example file download handler
456    pub async fn download_handler() -> impl IntoResponse {
457        let content = "This is premium content that requires payment to download.";
458        (StatusCode::OK, content)
459    }
460}
461
462#[cfg(test)]
463mod tests {
464    use super::*;
465    use std::str::FromStr;
466
467    #[test]
468    fn test_axum_payment_config() {
469        let config = AxumPaymentConfig::new(
470            Decimal::from_str("0.0001").unwrap(),
471            "0x209693Bc6afc0C5328bA36FaF03C514EF312287C",
472        )
473        .with_description("Test payment")
474        .with_testnet(true)
475        .with_cors(vec!["http://localhost:3000".to_string()])
476        .with_tracing();
477
478        assert_eq!(
479            config.base_config.amount,
480            Decimal::from_str("0.0001").unwrap()
481        );
482        assert!(config.axum_options.enable_cors);
483        assert!(config.axum_options.enable_tracing);
484    }
485
486    #[test]
487    fn test_payment_middleware_creation() {
488        let middleware = create_payment_middleware(
489            Decimal::from_str("0.0001").unwrap(),
490            "0x209693Bc6afc0C5328bA36FaF03C514EF312287C",
491        );
492
493        assert_eq!(
494            middleware.config().amount,
495            Decimal::from_str("0.0001").unwrap()
496        );
497    }
498}