rust_x402/
axum.rs

1//! Axum integration for x402 payments
2
3use crate::middleware::{PaymentMiddleware, PaymentMiddlewareConfig};
4use crate::X402Error;
5use axum::{
6    extract::{Request, State},
7    http::{HeaderMap, HeaderValue, StatusCode},
8    middleware::Next,
9    response::{IntoResponse, Response},
10    routing::{delete, get, post, put},
11    Json, Router,
12};
13use rust_decimal::Decimal;
14use std::sync::Arc;
15use tower::ServiceBuilder;
16
17/// Re-export the payment middleware for convenience
18pub use crate::middleware::{create_payment_service, payment_middleware};
19
20/// Create a new Axum router with x402 payment middleware
21pub fn create_payment_router(
22    middleware: PaymentMiddleware,
23    routes: impl FnOnce(&mut Router) -> &mut Router,
24) -> Router {
25    let mut router = Router::new();
26    routes(&mut router);
27
28    // Apply payment middleware to all routes
29    router.layer(axum::middleware::from_fn_with_state(
30        middleware,
31        payment_middleware_handler,
32    ))
33}
34
35/// Helper function to create a payment-protected route
36pub fn payment_route<H>(
37    method: &str,
38    path: &str,
39    handler: H,
40    middleware: PaymentMiddleware,
41) -> Router
42where
43    H: axum::handler::Handler<(), ()> + Clone + Send + 'static,
44{
45    let router = match method.to_uppercase().as_str() {
46        "GET" => Router::new().route(path, get(handler)),
47        "POST" => Router::new().route(path, post(handler)),
48        "PUT" => Router::new().route(path, put(handler)),
49        "DELETE" => Router::new().route(path, delete(handler)),
50        _ => {
51            // For unsupported methods, return an error router
52            return Router::new().route(
53                path,
54                axum::routing::any(|| async {
55                    (StatusCode::METHOD_NOT_ALLOWED, "Unsupported HTTP method")
56                }),
57            );
58        }
59    };
60
61    // Apply payment middleware
62    router.layer(axum::middleware::from_fn_with_state(
63        middleware,
64        payment_middleware_handler,
65    ))
66}
67
68/// Create a payment middleware for Axum
69pub fn create_payment_middleware(amount: Decimal, pay_to: impl Into<String>) -> PaymentMiddleware {
70    PaymentMiddleware::new(amount, pay_to)
71}
72
73/// Check if the request is from a web browser
74fn is_web_browser(headers: &HeaderMap) -> bool {
75    let user_agent = headers
76        .get("User-Agent")
77        .and_then(|h| h.to_str().ok())
78        .unwrap_or("");
79
80    let accept = headers
81        .get("Accept")
82        .and_then(|h| h.to_str().ok())
83        .unwrap_or("");
84
85    accept.contains("text/html") && user_agent.contains("Mozilla")
86}
87
88/// Get default paywall HTML
89fn get_default_paywall_html() -> &'static str {
90    r#"<!DOCTYPE html>
91<html>
92<head>
93    <title>Payment Required</title>
94    <style>
95        body { font-family: Arial, sans-serif; text-align: center; padding: 50px; }
96        .container { max-width: 500px; margin: 0 auto; }
97        h1 { color: #333; }
98        p { color: #666; }
99    </style>
100</head>
101<body>
102    <div class="container">
103        <h1>Payment Required</h1>
104        <p>This resource requires payment to access. Please provide a valid X-PAYMENT header.</p>
105    </div>
106</body>
107</html>"#
108}
109
110/// Axum middleware handler for payment processing with settlement
111pub async fn payment_middleware_handler(
112    State(middleware): State<PaymentMiddleware>,
113    request: Request,
114    next: Next,
115) -> impl IntoResponse {
116    let config = middleware.config().clone();
117    let headers = request.headers().clone();
118
119    // Determine the resource URL
120    let resource = if let Some(ref resource_url) = config.resource {
121        resource_url.clone()
122    } else if let Some(ref root_url) = config.resource_root_url {
123        format!("{}{}", root_url, request.uri().path())
124    } else {
125        request.uri().path().to_string()
126    };
127
128    // Create payment requirements
129    let requirements = match config.create_payment_requirements(&resource) {
130        Ok(req) => req,
131        Err(_) => {
132            return (
133                StatusCode::INTERNAL_SERVER_ERROR,
134                Json(serde_json::json!({
135                    "error": "Failed to create payment requirements",
136                    "x402Version": 1
137                })),
138            )
139                .into_response();
140        }
141    };
142
143    // Check for payment header
144    if let Some(payment_header) = headers.get("X-PAYMENT") {
145        if let Ok(payment_str) = payment_header.to_str() {
146            // Parse the payment payload
147            match crate::types::PaymentPayload::from_base64(payment_str) {
148                Ok(payment_payload) => {
149                    // Verify the payment using the middleware's verify method
150                    match middleware
151                        .verify_with_requirements(&payment_payload, &requirements)
152                        .await
153                    {
154                        Ok(true) => {
155                            // Payment is valid, proceed to next handler
156                            let mut response = next.run(request).await;
157
158                            // After successful response, settle the payment
159                            match middleware
160                                .settle_with_requirements(&payment_payload, &requirements)
161                                .await
162                            {
163                                Ok(settlement_response) => {
164                                    if let Ok(settlement_header) = settlement_response.to_base64() {
165                                        if let Ok(header_value) =
166                                            HeaderValue::from_str(&settlement_header)
167                                        {
168                                            response
169                                                .headers_mut()
170                                                .insert("X-PAYMENT-RESPONSE", header_value);
171                                        }
172                                    }
173                                }
174                                Err(e) => {
175                                    // Log settlement error but don't fail the request
176                                    tracing::warn!("Payment settlement failed: {}", e);
177                                }
178                            }
179
180                            return response;
181                        }
182                        Ok(false) => {
183                            // Payment verification failed
184                            let response_body = serde_json::json!({
185                                "x402Version": 1,
186                                "error": "Payment verification failed",
187                                "accepts": vec![requirements],
188                            });
189                            return (StatusCode::PAYMENT_REQUIRED, Json(response_body))
190                                .into_response();
191                        }
192                        Err(e) => {
193                            // Error during verification
194                            let response_body = serde_json::json!({
195                                "x402Version": 1,
196                                "error": format!("Payment verification error: {}", e),
197                                "accepts": vec![requirements],
198                            });
199                            return (StatusCode::PAYMENT_REQUIRED, Json(response_body))
200                                .into_response();
201                        }
202                    }
203                }
204                Err(e) => {
205                    // Invalid payment payload
206                    let response_body = serde_json::json!({
207                        "x402Version": 1,
208                        "error": format!("Invalid payment payload: {}", e),
209                        "accepts": vec![requirements],
210                    });
211                    return (StatusCode::PAYMENT_REQUIRED, Json(response_body)).into_response();
212                }
213            }
214        }
215    }
216
217    // No valid payment found, check if this is a web browser request
218    if is_web_browser(&headers) {
219        let html = config
220            .custom_paywall_html
221            .clone()
222            .unwrap_or_else(|| get_default_paywall_html().to_string());
223
224        let mut response = Response::new(axum::body::Body::from(html));
225        *response.status_mut() = StatusCode::PAYMENT_REQUIRED;
226        response
227            .headers_mut()
228            .insert("Content-Type", HeaderValue::from_static("text/html"));
229
230        return response.into_response();
231    }
232
233    // Return JSON response for API clients
234    let response_body = serde_json::json!({
235        "x402Version": 1,
236        "error": "X-PAYMENT header is required",
237        "accepts": vec![requirements],
238    });
239
240    (StatusCode::PAYMENT_REQUIRED, Json(response_body)).into_response()
241}
242
243/// Axum-specific payment middleware configuration
244#[derive(Debug, Clone)]
245pub struct AxumPaymentConfig {
246    /// Base payment middleware config
247    pub base_config: PaymentMiddlewareConfig,
248    /// Additional Axum-specific options
249    pub axum_options: AxumOptions,
250}
251
252/// Axum-specific options
253#[derive(Clone, Default)]
254pub struct AxumOptions {
255    /// Whether to enable CORS
256    pub enable_cors: bool,
257    /// CORS origins
258    pub cors_origins: Vec<String>,
259    /// Whether to enable request tracing
260    pub enable_tracing: bool,
261    /// Custom error handler
262    pub error_handler: Option<Arc<dyn Fn(X402Error) -> StatusCode + Send + Sync>>,
263}
264
265impl std::fmt::Debug for AxumOptions {
266    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
267        f.debug_struct("AxumOptions")
268            .field("enable_cors", &self.enable_cors)
269            .field("cors_origins", &self.cors_origins)
270            .field("enable_tracing", &self.enable_tracing)
271            .field("error_handler", &"<function>")
272            .finish()
273    }
274}
275
276impl AxumPaymentConfig {
277    /// Create a new Axum payment config
278    pub fn new(amount: Decimal, pay_to: impl Into<String>) -> Self {
279        Self {
280            base_config: PaymentMiddlewareConfig::new(amount, pay_to),
281            axum_options: AxumOptions::default(),
282        }
283    }
284
285    /// Set the payment description
286    pub fn with_description(mut self, description: impl Into<String>) -> Self {
287        self.base_config.description = Some(description.into());
288        self
289    }
290
291    /// Set the MIME type
292    pub fn with_mime_type(mut self, mime_type: impl Into<String>) -> Self {
293        self.base_config.mime_type = Some(mime_type.into());
294        self
295    }
296
297    /// Set the maximum timeout
298    pub fn with_max_timeout_seconds(mut self, max_timeout_seconds: u32) -> Self {
299        self.base_config.max_timeout_seconds = max_timeout_seconds;
300        self
301    }
302
303    /// Set the output schema
304    pub fn with_output_schema(mut self, output_schema: serde_json::Value) -> Self {
305        self.base_config.output_schema = Some(output_schema);
306        self
307    }
308
309    /// Set the facilitator configuration
310    pub fn with_facilitator_config(
311        mut self,
312        facilitator_config: crate::types::FacilitatorConfig,
313    ) -> Self {
314        self.base_config.facilitator_config = facilitator_config;
315        self
316    }
317
318    /// Set whether this is a testnet
319    pub fn with_testnet(mut self, testnet: bool) -> Self {
320        self.base_config.testnet = testnet;
321        self
322    }
323
324    /// Set custom paywall HTML
325    pub fn with_custom_paywall_html(mut self, html: impl Into<String>) -> Self {
326        self.base_config.custom_paywall_html = Some(html.into());
327        self
328    }
329
330    /// Set the resource URL
331    pub fn with_resource(mut self, resource: impl Into<String>) -> Self {
332        self.base_config.resource = Some(resource.into());
333        self
334    }
335
336    /// Set the resource root URL
337    pub fn with_resource_root_url(mut self, url: impl Into<String>) -> Self {
338        self.base_config.resource_root_url = Some(url.into());
339        self
340    }
341
342    /// Enable CORS
343    pub fn with_cors(mut self, origins: Vec<String>) -> Self {
344        self.axum_options.enable_cors = true;
345        self.axum_options.cors_origins = origins;
346        self
347    }
348
349    /// Enable request tracing
350    pub fn with_tracing(mut self) -> Self {
351        self.axum_options.enable_tracing = true;
352        self
353    }
354
355    /// Set custom error handler
356    pub fn with_error_handler<F>(mut self, handler: F) -> Self
357    where
358        F: Fn(X402Error) -> StatusCode + Send + Sync + 'static,
359    {
360        self.axum_options.error_handler = Some(Arc::new(handler));
361        self
362    }
363
364    /// Convert to PaymentMiddleware
365    pub fn into_middleware(self) -> PaymentMiddleware {
366        PaymentMiddleware {
367            config: Arc::new(self.base_config),
368            facilitator: None,
369            template_config: None,
370        }
371    }
372
373    /// Create a service builder with this configuration
374    pub fn create_service(&self) -> ServiceBuilder<tower::layer::util::Identity> {
375        // Note: Service layer integration is simplified for now
376        // In a full implementation, you would conditionally add layers based on options
377        ServiceBuilder::new()
378    }
379}
380
381/// Create a complete Axum application with x402 payment support
382pub fn create_payment_app(
383    config: AxumPaymentConfig,
384    routes: impl FnOnce(Router) -> Router,
385) -> Router {
386    let router = Router::new();
387    let router = routes(router);
388
389    // Apply service layers
390    router.layer(config.create_service())
391}
392
393/// Helper for creating payment-protected handlers
394pub mod handlers {
395    use super::*;
396    use serde_json::json;
397
398    /// Create a simple JSON response handler
399    pub fn json_response<T: serde::Serialize>(data: T) -> impl IntoResponse {
400        Json(data)
401    }
402
403    /// Create a simple text response handler
404    pub fn text_response(text: impl Into<String>) -> impl IntoResponse {
405        text.into()
406    }
407
408    /// Create an error response handler
409    pub fn error_response(error: impl Into<String>) -> impl IntoResponse {
410        (
411            StatusCode::INTERNAL_SERVER_ERROR,
412            Json(json!({"error": error.into()})),
413        )
414    }
415
416    /// Create a success response handler
417    pub fn success_response<T: serde::Serialize>(data: T) -> impl IntoResponse {
418        (StatusCode::OK, Json(data))
419    }
420}
421
422/// Example handlers for common use cases
423pub mod examples {
424    use super::*;
425    use serde_json::json;
426
427    /// Example joke handler
428    pub async fn joke_handler() -> impl IntoResponse {
429        axum::Json(json!({
430            "joke": "Why do programmers prefer dark mode? Because light attracts bugs!"
431        }))
432    }
433
434    /// Example API data handler
435    pub async fn api_data_handler() -> impl IntoResponse {
436        axum::Json(json!({
437            "data": "This is premium API data that requires payment to access",
438            "timestamp": chrono::Utc::now().to_rfc3339(),
439            "source": "x402-protected-api"
440        }))
441    }
442
443    /// Example file download handler
444    pub async fn download_handler() -> impl IntoResponse {
445        let content = "This is premium content that requires payment to download.";
446        (StatusCode::OK, content)
447    }
448}
449
450#[cfg(test)]
451mod tests {
452    use super::*;
453    use std::str::FromStr;
454
455    #[test]
456    fn test_axum_payment_config() {
457        let config = AxumPaymentConfig::new(
458            Decimal::from_str("0.0001").unwrap(),
459            "0x209693Bc6afc0C5328bA36FaF03C514EF312287C",
460        )
461        .with_description("Test payment")
462        .with_testnet(true)
463        .with_cors(vec!["http://localhost:3000".to_string()])
464        .with_tracing();
465
466        assert_eq!(
467            config.base_config.amount,
468            Decimal::from_str("0.0001").unwrap()
469        );
470        assert!(config.axum_options.enable_cors);
471        assert!(config.axum_options.enable_tracing);
472    }
473
474    #[test]
475    fn test_payment_middleware_creation() {
476        let middleware = create_payment_middleware(
477            Decimal::from_str("0.0001").unwrap(),
478            "0x209693Bc6afc0C5328bA36FaF03C514EF312287C",
479        );
480
481        assert_eq!(
482            middleware.config().amount,
483            Decimal::from_str("0.0001").unwrap()
484        );
485    }
486}