rust_x402/
middleware.rs

1//! Middleware implementations for web frameworks
2
3use crate::types::{Network, *};
4use crate::{Result, X402Error};
5use axum::{
6    extract::{Request, State},
7    http::{HeaderValue, StatusCode},
8    middleware::Next,
9    response::{IntoResponse, Response},
10    Json,
11};
12use rust_decimal::Decimal;
13use std::sync::Arc;
14use tower::ServiceBuilder;
15use tower_http::trace::TraceLayer;
16
17/// Configuration for payment middleware
18#[derive(Debug, Clone)]
19pub struct PaymentMiddlewareConfig {
20    /// Payment amount in decimal units (e.g., 0.0001 for 1/10th of a cent)
21    pub amount: Decimal,
22    /// Recipient wallet address
23    pub pay_to: String,
24    /// Payment description
25    pub description: Option<String>,
26    /// MIME type of the expected response
27    pub mime_type: Option<String>,
28    /// Maximum timeout in seconds
29    pub max_timeout_seconds: u32,
30    /// JSON schema for response format
31    pub output_schema: Option<serde_json::Value>,
32    /// Facilitator configuration
33    pub facilitator_config: FacilitatorConfig,
34    /// Whether this is a testnet
35    pub testnet: bool,
36    /// Custom paywall HTML for web browsers
37    pub custom_paywall_html: Option<String>,
38    /// Resource URL (if different from request URL)
39    pub resource: Option<String>,
40    /// Resource root URL for constructing full resource URLs
41    pub resource_root_url: Option<String>,
42}
43
44impl PaymentMiddlewareConfig {
45    /// Create a new payment middleware config
46    pub fn new(amount: Decimal, pay_to: impl Into<String>) -> Self {
47        // Normalize pay_to to lowercase to avoid EIP-55 checksum mismatches
48        let pay_to_normalized = pay_to.into().to_lowercase();
49        Self {
50            amount,
51            pay_to: pay_to_normalized,
52            description: None,
53            mime_type: None,
54            max_timeout_seconds: 60,
55            output_schema: None,
56            facilitator_config: FacilitatorConfig::default(),
57            testnet: true,
58            custom_paywall_html: None,
59            resource: None,
60            resource_root_url: None,
61        }
62    }
63
64    /// Set the payment description
65    pub fn with_description(mut self, description: impl Into<String>) -> Self {
66        self.description = Some(description.into());
67        self
68    }
69
70    /// Set the MIME type
71    pub fn with_mime_type(mut self, mime_type: impl Into<String>) -> Self {
72        self.mime_type = Some(mime_type.into());
73        self
74    }
75
76    /// Set the maximum timeout
77    pub fn with_max_timeout_seconds(mut self, max_timeout_seconds: u32) -> Self {
78        self.max_timeout_seconds = max_timeout_seconds;
79        self
80    }
81
82    /// Set the output schema
83    pub fn with_output_schema(mut self, output_schema: serde_json::Value) -> Self {
84        self.output_schema = Some(output_schema);
85        self
86    }
87
88    /// Set the facilitator configuration
89    pub fn with_facilitator_config(mut self, facilitator_config: FacilitatorConfig) -> Self {
90        self.facilitator_config = facilitator_config;
91        self
92    }
93
94    /// Set whether this is a testnet
95    pub fn with_testnet(mut self, testnet: bool) -> Self {
96        self.testnet = testnet;
97        self
98    }
99
100    /// Set custom paywall HTML
101    pub fn with_custom_paywall_html(mut self, html: impl Into<String>) -> Self {
102        self.custom_paywall_html = Some(html.into());
103        self
104    }
105
106    /// Set the resource URL
107    pub fn with_resource(mut self, resource: impl Into<String>) -> Self {
108        self.resource = Some(resource.into());
109        self
110    }
111
112    /// Set the resource root URL
113    pub fn with_resource_root_url(mut self, url: impl Into<String>) -> Self {
114        self.resource_root_url = Some(url.into());
115        self
116    }
117
118    /// Create payment requirements from this config
119    pub fn create_payment_requirements(&self, request_uri: &str) -> Result<PaymentRequirements> {
120        let network = if self.testnet {
121            networks::BASE_SEPOLIA
122        } else {
123            networks::BASE_MAINNET
124        };
125
126        let usdc_address =
127            networks::get_usdc_address(network).ok_or_else(|| X402Error::NetworkNotSupported {
128                network: network.to_string(),
129            })?;
130
131        let resource = if let Some(ref resource_url) = self.resource {
132            resource_url.clone()
133        } else if let Some(ref root_url) = self.resource_root_url {
134            format!("{}{}", root_url, request_uri)
135        } else {
136            request_uri.to_string()
137        };
138
139        let max_amount_required = (self.amount * Decimal::from(1_000_000u64))
140            .normalize()
141            .to_string();
142
143        // Normalize pay_to to lowercase to avoid EIP-55 checksum mismatches
144        let pay_to_normalized = self.pay_to.to_lowercase();
145
146        let mut requirements = PaymentRequirements::new(
147            schemes::EXACT,
148            network,
149            max_amount_required,
150            usdc_address,
151            &pay_to_normalized,
152            resource,
153            self.description.as_deref().unwrap_or("Payment required"),
154        );
155
156        requirements.mime_type = self.mime_type.clone();
157        requirements.output_schema = self.output_schema.clone();
158        requirements.max_timeout_seconds = self.max_timeout_seconds;
159
160        let network = if self.testnet {
161            Network::Testnet
162        } else {
163            Network::Mainnet
164        };
165        requirements.set_usdc_info(network)?;
166
167        Ok(requirements)
168    }
169}
170
171/// Axum middleware for x402 payments
172#[derive(Debug, Clone)]
173pub struct PaymentMiddleware {
174    pub config: Arc<PaymentMiddlewareConfig>,
175    pub facilitator: Option<crate::facilitator::FacilitatorClient>,
176    pub template_config: Option<crate::template::PaywallConfig>,
177}
178
179/// Payment processing result
180#[derive(Debug)]
181pub enum PaymentResult {
182    /// Payment verified and settled successfully
183    Success {
184        response: axum::response::Response,
185        settlement: crate::types::SettleResponse,
186    },
187    /// Payment required (402 response)
188    PaymentRequired { response: axum::response::Response },
189    /// Payment verification failed
190    VerificationFailed { response: axum::response::Response },
191    /// Payment settlement failed
192    SettlementFailed { response: axum::response::Response },
193}
194
195impl PaymentMiddleware {
196    /// Create a new payment middleware
197    pub fn new(amount: Decimal, pay_to: impl Into<String>) -> Self {
198        Self {
199            config: Arc::new(PaymentMiddlewareConfig::new(amount, pay_to)),
200            facilitator: None,
201            template_config: None,
202        }
203    }
204
205    /// Set the payment description
206    pub fn with_description(mut self, description: impl Into<String>) -> Self {
207        Arc::make_mut(&mut self.config).description = Some(description.into());
208        self
209    }
210
211    /// Set the MIME type
212    pub fn with_mime_type(mut self, mime_type: impl Into<String>) -> Self {
213        Arc::make_mut(&mut self.config).mime_type = Some(mime_type.into());
214        self
215    }
216
217    /// Set the maximum timeout
218    pub fn with_max_timeout_seconds(mut self, max_timeout_seconds: u32) -> Self {
219        Arc::make_mut(&mut self.config).max_timeout_seconds = max_timeout_seconds;
220        self
221    }
222
223    /// Set the output schema
224    pub fn with_output_schema(mut self, output_schema: serde_json::Value) -> Self {
225        Arc::make_mut(&mut self.config).output_schema = Some(output_schema);
226        self
227    }
228
229    /// Set the facilitator configuration
230    pub fn with_facilitator_config(mut self, facilitator_config: FacilitatorConfig) -> Self {
231        Arc::make_mut(&mut self.config).facilitator_config = facilitator_config;
232        self
233    }
234
235    /// Set whether this is a testnet
236    pub fn with_testnet(mut self, testnet: bool) -> Self {
237        Arc::make_mut(&mut self.config).testnet = testnet;
238        self
239    }
240
241    /// Set custom paywall HTML
242    pub fn with_custom_paywall_html(mut self, html: impl Into<String>) -> Self {
243        Arc::make_mut(&mut self.config).custom_paywall_html = Some(html.into());
244        self
245    }
246
247    /// Set the resource URL
248    pub fn with_resource(mut self, resource: impl Into<String>) -> Self {
249        Arc::make_mut(&mut self.config).resource = Some(resource.into());
250        self
251    }
252
253    /// Set the resource root URL
254    pub fn with_resource_root_url(mut self, url: impl Into<String>) -> Self {
255        Arc::make_mut(&mut self.config).resource_root_url = Some(url.into());
256        self
257    }
258
259    /// Get the middleware configuration
260    pub fn config(&self) -> &PaymentMiddlewareConfig {
261        &self.config
262    }
263
264    /// Set the facilitator client
265    pub fn with_facilitator(mut self, facilitator: crate::facilitator::FacilitatorClient) -> Self {
266        self.facilitator = Some(facilitator);
267        self
268    }
269
270    /// Set the template configuration
271    pub fn with_template_config(mut self, template_config: crate::template::PaywallConfig) -> Self {
272        self.template_config = Some(template_config);
273        self
274    }
275
276    /// Verify a payment payload
277    pub async fn verify(&self, payment_payload: &PaymentPayload) -> bool {
278        // Create facilitator if not already configured
279        let facilitator = if let Some(facilitator) = &self.facilitator {
280            facilitator.clone()
281        } else {
282            match crate::facilitator::FacilitatorClient::new(self.config.facilitator_config.clone())
283            {
284                Ok(facilitator) => facilitator,
285                Err(_) => return false,
286            }
287        };
288
289        if let Ok(requirements) = self.config.create_payment_requirements("/") {
290            if let Ok(response) = facilitator.verify(payment_payload, &requirements).await {
291                return response.is_valid;
292            }
293        }
294        false
295    }
296
297    /// Settle a payment
298    pub async fn settle(&self, payment_payload: &PaymentPayload) -> crate::Result<SettleResponse> {
299        // Create facilitator if not already configured
300        let facilitator = if let Some(facilitator) = &self.facilitator {
301            facilitator.clone()
302        } else {
303            crate::facilitator::FacilitatorClient::new(self.config.facilitator_config.clone())?
304        };
305
306        let requirements = self.config.create_payment_requirements("/")?;
307        facilitator.settle(payment_payload, &requirements).await
308    }
309
310    /// Verify payment with specific requirements
311    pub async fn verify_with_requirements(
312        &self,
313        payment_payload: &PaymentPayload,
314        requirements: &PaymentRequirements,
315    ) -> crate::Result<bool> {
316        let facilitator = if let Some(facilitator) = &self.facilitator {
317            facilitator.clone()
318        } else {
319            crate::facilitator::FacilitatorClient::new(self.config.facilitator_config.clone())?
320        };
321
322        let response = facilitator.verify(payment_payload, requirements).await?;
323        Ok(response.is_valid)
324    }
325
326    /// Settle payment with specific requirements
327    pub async fn settle_with_requirements(
328        &self,
329        payment_payload: &PaymentPayload,
330        requirements: &PaymentRequirements,
331    ) -> crate::Result<SettleResponse> {
332        let facilitator = if let Some(facilitator) = &self.facilitator {
333            facilitator.clone()
334        } else {
335            crate::facilitator::FacilitatorClient::new(self.config.facilitator_config.clone())?
336        };
337
338        facilitator.settle(payment_payload, requirements).await
339    }
340
341    /// Process payment with unified flow
342    pub async fn process_payment(
343        &self,
344        request: Request,
345        next: Next,
346    ) -> crate::Result<PaymentResult> {
347        let headers = request.headers();
348        let uri = request.uri().to_string();
349
350        // Check if this is a web browser request
351        let user_agent = headers
352            .get("User-Agent")
353            .and_then(|v| v.to_str().ok())
354            .unwrap_or("");
355        let accept = headers
356            .get("Accept")
357            .and_then(|v| v.to_str().ok())
358            .unwrap_or("");
359
360        let is_web_browser = accept.contains("text/html") && user_agent.contains("Mozilla");
361
362        // Create payment requirements
363        let payment_requirements = self.config.create_payment_requirements(&uri)?;
364
365        // Check for payment header
366        let payment_header = headers.get("X-PAYMENT").and_then(|v| v.to_str().ok());
367
368        match payment_header {
369            Some(payment_b64) => {
370                // Decode payment payload
371                let payment_payload = PaymentPayload::from_base64(payment_b64).map_err(|e| {
372                    X402Error::invalid_payment_payload(format!("Failed to decode payment: {}", e))
373                })?;
374
375                // Get facilitator client
376                let facilitator = if let Some(facilitator) = &self.facilitator {
377                    facilitator.clone()
378                } else {
379                    crate::facilitator::FacilitatorClient::new(
380                        self.config.facilitator_config.clone(),
381                    )?
382                };
383
384                // Verify payment
385                let verify_response = facilitator
386                    .verify(&payment_payload, &payment_requirements)
387                    .await
388                    .map_err(|e| {
389                        X402Error::facilitator_error(format!("Payment verification failed: {}", e))
390                    })?;
391
392                if !verify_response.is_valid {
393                    let error_response = self.create_payment_required_response(
394                        "Payment verification failed",
395                        &payment_requirements,
396                        is_web_browser,
397                    )?;
398                    return Ok(PaymentResult::VerificationFailed {
399                        response: error_response,
400                    });
401                }
402
403                // Execute the handler
404                let mut response = next.run(request).await;
405
406                // Settle the payment
407                let settle_response = facilitator
408                    .settle(&payment_payload, &payment_requirements)
409                    .await
410                    .map_err(|e| {
411                        X402Error::facilitator_error(format!("Payment settlement failed: {}", e))
412                    })?;
413
414                // Add settlement header
415                let settlement_header = settle_response.to_base64().map_err(|e| {
416                    X402Error::config(format!("Failed to encode settlement response: {}", e))
417                })?;
418
419                if let Ok(header_value) = HeaderValue::from_str(&settlement_header) {
420                    response
421                        .headers_mut()
422                        .insert("X-PAYMENT-RESPONSE", header_value);
423                }
424
425                Ok(PaymentResult::Success {
426                    response,
427                    settlement: settle_response,
428                })
429            }
430            None => {
431                // No payment provided, return 402 with requirements
432                let response = self.create_payment_required_response(
433                    "X-PAYMENT header is required",
434                    &payment_requirements,
435                    is_web_browser,
436                )?;
437                Ok(PaymentResult::PaymentRequired { response })
438            }
439        }
440    }
441
442    /// Create payment required response
443    fn create_payment_required_response(
444        &self,
445        error: &str,
446        payment_requirements: &PaymentRequirements,
447        is_web_browser: bool,
448    ) -> crate::Result<axum::response::Response> {
449        if is_web_browser {
450            let html = if let Some(custom_html) = &self.config.custom_paywall_html {
451                custom_html.clone()
452            } else {
453                // Use the template system
454                let paywall_config = self.template_config.clone().unwrap_or_else(|| {
455                    crate::template::PaywallConfig::new()
456                        .with_app_name("x402 Service")
457                        .with_app_logo("💰")
458                });
459
460                crate::template::generate_paywall_html(
461                    error,
462                    std::slice::from_ref(payment_requirements),
463                    Some(&paywall_config),
464                )
465            };
466
467            let response = Response::builder()
468                .status(StatusCode::PAYMENT_REQUIRED)
469                .header("Content-Type", "text/html")
470                .body(html.into())
471                .map_err(|e| X402Error::config(format!("Failed to create HTML response: {}", e)))?;
472
473            Ok(response)
474        } else {
475            let payment_response =
476                PaymentRequirementsResponse::new(error, vec![payment_requirements.clone()]);
477
478            Ok(Json(payment_response).into_response())
479        }
480    }
481}
482
483/// Axum middleware function for handling x402 payments
484pub async fn payment_middleware(
485    State(middleware): State<PaymentMiddleware>,
486    request: Request,
487    next: Next,
488) -> crate::Result<impl IntoResponse> {
489    match middleware.process_payment(request, next).await? {
490        PaymentResult::Success { response, .. } => Ok(response),
491        PaymentResult::PaymentRequired { response } => Ok(response),
492        PaymentResult::VerificationFailed { response } => Ok(response),
493        PaymentResult::SettlementFailed { response } => Ok(response),
494    }
495}
496
497/// Create a service builder with x402 payment middleware
498pub fn create_payment_service(
499    middleware: PaymentMiddleware,
500) -> impl tower::Layer<tower::ServiceBuilder<tower::layer::util::Identity>> + Clone {
501    ServiceBuilder::new()
502        .layer(TraceLayer::new_for_http())
503        .layer(tower::layer::util::Stack::new(
504            tower::layer::util::Identity::new(),
505            PaymentServiceLayer::new(middleware),
506        ))
507}
508
509/// Tower service layer for x402 payment middleware
510#[derive(Clone)]
511pub struct PaymentServiceLayer {
512    middleware: PaymentMiddleware,
513}
514
515impl PaymentServiceLayer {
516    pub fn new(middleware: PaymentMiddleware) -> Self {
517        Self { middleware }
518    }
519}
520
521impl<S> tower::Layer<S> for PaymentServiceLayer {
522    type Service = PaymentService<S>;
523
524    fn layer(&self, inner: S) -> Self::Service {
525        PaymentService {
526            inner,
527            middleware: self.middleware.clone(),
528        }
529    }
530}
531
532/// Tower service for x402 payment middleware
533#[derive(Clone)]
534pub struct PaymentService<S> {
535    inner: S,
536    middleware: PaymentMiddleware,
537}
538
539impl<S, ReqBody, ResBody> tower::Service<http::Request<ReqBody>> for PaymentService<S>
540where
541    S: tower::Service<
542            http::Request<ReqBody>,
543            Response = http::Response<ResBody>,
544            Error = Box<dyn std::error::Error + Send + Sync>,
545        > + Send
546        + 'static,
547    S::Future: Send + 'static,
548    ReqBody: Send + 'static,
549    ResBody: Send + 'static,
550{
551    type Response = S::Response;
552    type Error = S::Error;
553    type Future = std::pin::Pin<
554        Box<
555            dyn std::future::Future<Output = std::result::Result<Self::Response, Self::Error>>
556                + Send,
557        >,
558    >;
559
560    fn poll_ready(
561        &mut self,
562        cx: &mut std::task::Context<'_>,
563    ) -> std::task::Poll<std::result::Result<(), Self::Error>> {
564        self.inner.poll_ready(cx)
565    }
566
567    fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
568        let middleware = self.middleware.clone();
569
570        // Extract payment header before moving the request
571        let payment_header = req
572            .headers()
573            .get("X-PAYMENT")
574            .and_then(|h| h.to_str().ok())
575            .map(|s| s.to_string());
576        let uri_path = req.uri().path().to_string();
577
578        let future = self.inner.call(req);
579
580        Box::pin(async move {
581            match payment_header {
582                Some(payment_b64) => {
583                    // Parse payment payload
584                    match crate::types::PaymentPayload::from_base64(&payment_b64) {
585                        Ok(payment_payload) => {
586                            // Create payment requirements
587                            let requirements =
588                                match middleware.config.create_payment_requirements(&uri_path) {
589                                    Ok(req) => req,
590                                    Err(e) => {
591                                        // Return 500 error if we can't create requirements
592                                        return Err(
593                                            Box::new(e) as Box<dyn std::error::Error + Send + Sync>
594                                        );
595                                    }
596                                };
597
598                            // Verify payment
599                            match middleware
600                                .verify_with_requirements(&payment_payload, &requirements)
601                                .await
602                            {
603                                Ok(true) => {
604                                    // Payment is valid, proceed with request
605                                    let response = future.await?;
606
607                                    // Settle payment after successful response
608                                    if let Ok(settlement) = middleware
609                                        .settle_with_requirements(&payment_payload, &requirements)
610                                        .await
611                                    {
612                                        // Note: In a real implementation, we would need to modify the response
613                                        // to add the X-PAYMENT-RESPONSE header, but this requires
614                                        // more complex response handling in Tower
615                                        let _ = settlement; // Acknowledge settlement
616                                    }
617
618                                    Ok(response)
619                                }
620                                Ok(false) => {
621                                    // Payment verification failed
622                                    Err(Box::new(crate::X402Error::payment_verification_failed(
623                                        "Payment verification failed",
624                                    ))
625                                        as Box<dyn std::error::Error + Send + Sync>)
626                                }
627                                Err(e) => {
628                                    // Error during verification
629                                    Err(Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
630                                }
631                            }
632                        }
633                        Err(e) => {
634                            // Invalid payment payload
635                            Err(Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
636                        }
637                    }
638                }
639                None => {
640                    // No payment header provided
641                    Err(Box::new(crate::X402Error::payment_verification_failed(
642                        "X-PAYMENT header is required",
643                    ))
644                        as Box<dyn std::error::Error + Send + Sync>)
645                }
646            }
647        })
648    }
649}
650
651#[cfg(test)]
652mod tests {
653    use super::*;
654    use std::str::FromStr;
655
656    #[test]
657    fn test_payment_middleware_config() {
658        let config = PaymentMiddlewareConfig::new(
659            Decimal::from_str("0.0001").unwrap(),
660            "0x209693Bc6afc0C5328bA36FaF03C514EF312287C",
661        )
662        .with_description("Test payment")
663        .with_testnet(true);
664
665        assert_eq!(config.amount, Decimal::from_str("0.0001").unwrap());
666        assert_eq!(config.pay_to, "0x209693bc6afc0c5328ba36faf03c514ef312287c");
667        assert_eq!(config.description, Some("Test payment".to_string()));
668        assert!(config.testnet);
669    }
670
671    #[test]
672    fn test_payment_middleware_creation() {
673        let middleware = PaymentMiddleware::new(
674            Decimal::from_str("0.0001").unwrap(),
675            "0x209693Bc6afc0C5328bA36FaF03C514EF312287C",
676        )
677        .with_description("Test payment");
678
679        assert_eq!(
680            middleware.config().amount,
681            Decimal::from_str("0.0001").unwrap()
682        );
683        assert_eq!(
684            middleware.config().pay_to,
685            "0x209693bc6afc0c5328ba36faf03c514ef312287c"
686        );
687    }
688
689    #[test]
690    fn test_payment_requirements_creation() {
691        let config = PaymentMiddlewareConfig::new(
692            Decimal::from_str("0.0001").unwrap(),
693            "0x209693Bc6afc0C5328bA36FaF03C514EF312287C",
694        )
695        .with_testnet(true);
696
697        let requirements = config.create_payment_requirements("/test").unwrap();
698
699        assert_eq!(requirements.scheme, "exact");
700        assert_eq!(requirements.network, "base-sepolia");
701        assert_eq!(requirements.max_amount_required, "100");
702        assert_eq!(
703            requirements.pay_to,
704            "0x209693bc6afc0c5328ba36faf03c514ef312287c"
705        );
706    }
707
708    #[test]
709    fn test_payment_middleware_config_builder() {
710        let config = PaymentMiddlewareConfig::new(
711            Decimal::from_str("0.01").unwrap(),
712            "0x209693Bc6afc0C5328bA36FaF03C514EF312287C",
713        )
714        .with_description("Test payment")
715        .with_mime_type("application/json")
716        .with_max_timeout_seconds(120)
717        .with_testnet(false)
718        .with_resource("https://example.com/test");
719
720        assert_eq!(config.amount, Decimal::from_str("0.01").unwrap());
721        assert_eq!(config.pay_to, "0x209693bc6afc0c5328ba36faf03c514ef312287c");
722        assert_eq!(config.description, Some("Test payment".to_string()));
723        assert_eq!(config.mime_type, Some("application/json".to_string()));
724        assert_eq!(config.max_timeout_seconds, 120);
725        assert!(!config.testnet);
726        assert_eq!(
727            config.resource,
728            Some("https://example.com/test".to_string())
729        );
730    }
731
732    #[test]
733    fn test_payment_middleware_creation_with_description() {
734        let middleware = PaymentMiddleware::new(
735            Decimal::from_str("0.001").unwrap(),
736            "0x209693Bc6afc0C5328bA36FaF03C514EF312287C",
737        )
738        .with_description("Test middleware");
739
740        assert_eq!(
741            middleware.config().amount,
742            Decimal::from_str("0.001").unwrap()
743        );
744        assert_eq!(
745            middleware.config().description,
746            Some("Test middleware".to_string())
747        );
748    }
749}