1use crate::types::{Network, *};
4use crate::{Result, X402Error};
5use axum::{
6 extract::{Request, State},
7 http::{HeaderValue, StatusCode},
8 middleware::Next,
9 response::{IntoResponse, Response},
10 Json,
11};
12use rust_decimal::Decimal;
13use std::sync::Arc;
14use tower::ServiceBuilder;
15use tower_http::trace::TraceLayer;
16
17#[derive(Debug, Clone)]
19pub struct PaymentMiddlewareConfig {
20 pub amount: Decimal,
22 pub pay_to: String,
24 pub description: Option<String>,
26 pub mime_type: Option<String>,
28 pub max_timeout_seconds: u32,
30 pub output_schema: Option<serde_json::Value>,
32 pub facilitator_config: FacilitatorConfig,
34 pub testnet: bool,
36 pub custom_paywall_html: Option<String>,
38 pub resource: Option<String>,
40 pub resource_root_url: Option<String>,
42}
43
44impl PaymentMiddlewareConfig {
45 pub fn new(amount: Decimal, pay_to: impl Into<String>) -> Self {
47 Self {
48 amount,
49 pay_to: pay_to.into(),
50 description: None,
51 mime_type: None,
52 max_timeout_seconds: 60,
53 output_schema: None,
54 facilitator_config: FacilitatorConfig::default(),
55 testnet: true,
56 custom_paywall_html: None,
57 resource: None,
58 resource_root_url: None,
59 }
60 }
61
62 pub fn with_description(mut self, description: impl Into<String>) -> Self {
64 self.description = Some(description.into());
65 self
66 }
67
68 pub fn with_mime_type(mut self, mime_type: impl Into<String>) -> Self {
70 self.mime_type = Some(mime_type.into());
71 self
72 }
73
74 pub fn with_max_timeout_seconds(mut self, max_timeout_seconds: u32) -> Self {
76 self.max_timeout_seconds = max_timeout_seconds;
77 self
78 }
79
80 pub fn with_output_schema(mut self, output_schema: serde_json::Value) -> Self {
82 self.output_schema = Some(output_schema);
83 self
84 }
85
86 pub fn with_facilitator_config(mut self, facilitator_config: FacilitatorConfig) -> Self {
88 self.facilitator_config = facilitator_config;
89 self
90 }
91
92 pub fn with_testnet(mut self, testnet: bool) -> Self {
94 self.testnet = testnet;
95 self
96 }
97
98 pub fn with_custom_paywall_html(mut self, html: impl Into<String>) -> Self {
100 self.custom_paywall_html = Some(html.into());
101 self
102 }
103
104 pub fn with_resource(mut self, resource: impl Into<String>) -> Self {
106 self.resource = Some(resource.into());
107 self
108 }
109
110 pub fn with_resource_root_url(mut self, url: impl Into<String>) -> Self {
112 self.resource_root_url = Some(url.into());
113 self
114 }
115
116 pub fn create_payment_requirements(&self, request_uri: &str) -> Result<PaymentRequirements> {
118 let network = if self.testnet {
119 networks::BASE_SEPOLIA
120 } else {
121 networks::BASE_MAINNET
122 };
123
124 let usdc_address =
125 networks::get_usdc_address(network).ok_or_else(|| X402Error::NetworkNotSupported {
126 network: network.to_string(),
127 })?;
128
129 let resource = if let Some(ref resource_url) = self.resource {
130 resource_url.clone()
131 } else if let Some(ref root_url) = self.resource_root_url {
132 format!("{}{}", root_url, request_uri)
133 } else {
134 request_uri.to_string()
135 };
136
137 let max_amount_required = (self.amount * Decimal::from(1_000_000u64))
138 .normalize()
139 .to_string();
140
141 let mut requirements = PaymentRequirements::new(
142 schemes::EXACT,
143 network,
144 max_amount_required,
145 usdc_address,
146 &self.pay_to,
147 resource,
148 self.description.as_deref().unwrap_or("Payment required"),
149 );
150
151 requirements.mime_type = self.mime_type.clone();
152 requirements.output_schema = self.output_schema.clone();
153 requirements.max_timeout_seconds = self.max_timeout_seconds;
154
155 let network = if self.testnet {
156 Network::Testnet
157 } else {
158 Network::Mainnet
159 };
160 requirements.set_usdc_info(network)?;
161
162 Ok(requirements)
163 }
164}
165
166#[derive(Debug, Clone)]
168pub struct PaymentMiddleware {
169 pub config: Arc<PaymentMiddlewareConfig>,
170 pub facilitator: Option<crate::facilitator::FacilitatorClient>,
171 pub template_config: Option<crate::template::PaywallConfig>,
172}
173
174#[derive(Debug)]
176pub enum PaymentResult {
177 Success {
179 response: axum::response::Response,
180 settlement: crate::types::SettleResponse,
181 },
182 PaymentRequired { response: axum::response::Response },
184 VerificationFailed { response: axum::response::Response },
186 SettlementFailed { response: axum::response::Response },
188}
189
190impl PaymentMiddleware {
191 pub fn new(amount: Decimal, pay_to: impl Into<String>) -> Self {
193 Self {
194 config: Arc::new(PaymentMiddlewareConfig::new(amount, pay_to)),
195 facilitator: None,
196 template_config: None,
197 }
198 }
199
200 pub fn with_description(mut self, description: impl Into<String>) -> Self {
202 Arc::make_mut(&mut self.config).description = Some(description.into());
203 self
204 }
205
206 pub fn with_mime_type(mut self, mime_type: impl Into<String>) -> Self {
208 Arc::make_mut(&mut self.config).mime_type = Some(mime_type.into());
209 self
210 }
211
212 pub fn with_max_timeout_seconds(mut self, max_timeout_seconds: u32) -> Self {
214 Arc::make_mut(&mut self.config).max_timeout_seconds = max_timeout_seconds;
215 self
216 }
217
218 pub fn with_output_schema(mut self, output_schema: serde_json::Value) -> Self {
220 Arc::make_mut(&mut self.config).output_schema = Some(output_schema);
221 self
222 }
223
224 pub fn with_facilitator_config(mut self, facilitator_config: FacilitatorConfig) -> Self {
226 Arc::make_mut(&mut self.config).facilitator_config = facilitator_config;
227 self
228 }
229
230 pub fn with_testnet(mut self, testnet: bool) -> Self {
232 Arc::make_mut(&mut self.config).testnet = testnet;
233 self
234 }
235
236 pub fn with_custom_paywall_html(mut self, html: impl Into<String>) -> Self {
238 Arc::make_mut(&mut self.config).custom_paywall_html = Some(html.into());
239 self
240 }
241
242 pub fn with_resource(mut self, resource: impl Into<String>) -> Self {
244 Arc::make_mut(&mut self.config).resource = Some(resource.into());
245 self
246 }
247
248 pub fn with_resource_root_url(mut self, url: impl Into<String>) -> Self {
250 Arc::make_mut(&mut self.config).resource_root_url = Some(url.into());
251 self
252 }
253
254 pub fn config(&self) -> &PaymentMiddlewareConfig {
256 &self.config
257 }
258
259 pub fn with_facilitator(mut self, facilitator: crate::facilitator::FacilitatorClient) -> Self {
261 self.facilitator = Some(facilitator);
262 self
263 }
264
265 pub fn with_template_config(mut self, template_config: crate::template::PaywallConfig) -> Self {
267 self.template_config = Some(template_config);
268 self
269 }
270
271 pub async fn verify(&self, payment_payload: &PaymentPayload) -> bool {
273 let facilitator = if let Some(facilitator) = &self.facilitator {
275 facilitator.clone()
276 } else {
277 match crate::facilitator::FacilitatorClient::new(self.config.facilitator_config.clone())
278 {
279 Ok(facilitator) => facilitator,
280 Err(_) => return false,
281 }
282 };
283
284 if let Ok(requirements) = self.config.create_payment_requirements("/") {
285 if let Ok(response) = facilitator.verify(payment_payload, &requirements).await {
286 return response.is_valid;
287 }
288 }
289 false
290 }
291
292 pub async fn settle(&self, payment_payload: &PaymentPayload) -> crate::Result<SettleResponse> {
294 let facilitator = if let Some(facilitator) = &self.facilitator {
296 facilitator.clone()
297 } else {
298 crate::facilitator::FacilitatorClient::new(self.config.facilitator_config.clone())?
299 };
300
301 let requirements = self.config.create_payment_requirements("/")?;
302 facilitator.settle(payment_payload, &requirements).await
303 }
304
305 pub async fn verify_with_requirements(
307 &self,
308 payment_payload: &PaymentPayload,
309 requirements: &PaymentRequirements,
310 ) -> crate::Result<bool> {
311 let facilitator = if let Some(facilitator) = &self.facilitator {
312 facilitator.clone()
313 } else {
314 crate::facilitator::FacilitatorClient::new(self.config.facilitator_config.clone())?
315 };
316
317 let response = facilitator.verify(payment_payload, requirements).await?;
318 Ok(response.is_valid)
319 }
320
321 pub async fn settle_with_requirements(
323 &self,
324 payment_payload: &PaymentPayload,
325 requirements: &PaymentRequirements,
326 ) -> crate::Result<SettleResponse> {
327 let facilitator = if let Some(facilitator) = &self.facilitator {
328 facilitator.clone()
329 } else {
330 crate::facilitator::FacilitatorClient::new(self.config.facilitator_config.clone())?
331 };
332
333 facilitator.settle(payment_payload, requirements).await
334 }
335
336 pub async fn process_payment(
338 &self,
339 request: Request,
340 next: Next,
341 ) -> crate::Result<PaymentResult> {
342 let headers = request.headers();
343 let uri = request.uri().to_string();
344
345 let user_agent = headers
347 .get("User-Agent")
348 .and_then(|v| v.to_str().ok())
349 .unwrap_or("");
350 let accept = headers
351 .get("Accept")
352 .and_then(|v| v.to_str().ok())
353 .unwrap_or("");
354
355 let is_web_browser = accept.contains("text/html") && user_agent.contains("Mozilla");
356
357 let payment_requirements = self.config.create_payment_requirements(&uri)?;
359
360 let payment_header = headers.get("X-PAYMENT").and_then(|v| v.to_str().ok());
362
363 match payment_header {
364 Some(payment_b64) => {
365 let payment_payload = PaymentPayload::from_base64(payment_b64).map_err(|e| {
367 X402Error::invalid_payment_payload(format!("Failed to decode payment: {}", e))
368 })?;
369
370 let facilitator = if let Some(facilitator) = &self.facilitator {
372 facilitator.clone()
373 } else {
374 crate::facilitator::FacilitatorClient::new(
375 self.config.facilitator_config.clone(),
376 )?
377 };
378
379 let verify_response = facilitator
381 .verify(&payment_payload, &payment_requirements)
382 .await
383 .map_err(|e| {
384 X402Error::facilitator_error(format!("Payment verification failed: {}", e))
385 })?;
386
387 if !verify_response.is_valid {
388 let error_response = self.create_payment_required_response(
389 "Payment verification failed",
390 &payment_requirements,
391 is_web_browser,
392 )?;
393 return Ok(PaymentResult::VerificationFailed {
394 response: error_response,
395 });
396 }
397
398 let mut response = next.run(request).await;
400
401 let settle_response = facilitator
403 .settle(&payment_payload, &payment_requirements)
404 .await
405 .map_err(|e| {
406 X402Error::facilitator_error(format!("Payment settlement failed: {}", e))
407 })?;
408
409 let settlement_header = settle_response.to_base64().map_err(|e| {
411 X402Error::config(format!("Failed to encode settlement response: {}", e))
412 })?;
413
414 if let Ok(header_value) = HeaderValue::from_str(&settlement_header) {
415 response
416 .headers_mut()
417 .insert("X-PAYMENT-RESPONSE", header_value);
418 }
419
420 Ok(PaymentResult::Success {
421 response,
422 settlement: settle_response,
423 })
424 }
425 None => {
426 let response = self.create_payment_required_response(
428 "X-PAYMENT header is required",
429 &payment_requirements,
430 is_web_browser,
431 )?;
432 Ok(PaymentResult::PaymentRequired { response })
433 }
434 }
435 }
436
437 fn create_payment_required_response(
439 &self,
440 error: &str,
441 payment_requirements: &PaymentRequirements,
442 is_web_browser: bool,
443 ) -> crate::Result<axum::response::Response> {
444 if is_web_browser {
445 let html = if let Some(custom_html) = &self.config.custom_paywall_html {
446 custom_html.clone()
447 } else {
448 let paywall_config = self.template_config.clone().unwrap_or_else(|| {
450 crate::template::PaywallConfig::new()
451 .with_app_name("x402 Service")
452 .with_app_logo("💰")
453 });
454
455 crate::template::generate_paywall_html(
456 error,
457 std::slice::from_ref(payment_requirements),
458 Some(&paywall_config),
459 )
460 };
461
462 let response = Response::builder()
463 .status(StatusCode::PAYMENT_REQUIRED)
464 .header("Content-Type", "text/html")
465 .body(html.into())
466 .map_err(|e| X402Error::config(format!("Failed to create HTML response: {}", e)))?;
467
468 Ok(response)
469 } else {
470 let payment_response =
471 PaymentRequirementsResponse::new(error, vec![payment_requirements.clone()]);
472
473 Ok(Json(payment_response).into_response())
474 }
475 }
476}
477
478pub async fn payment_middleware(
480 State(middleware): State<PaymentMiddleware>,
481 request: Request,
482 next: Next,
483) -> crate::Result<impl IntoResponse> {
484 match middleware.process_payment(request, next).await? {
485 PaymentResult::Success { response, .. } => Ok(response),
486 PaymentResult::PaymentRequired { response } => Ok(response),
487 PaymentResult::VerificationFailed { response } => Ok(response),
488 PaymentResult::SettlementFailed { response } => Ok(response),
489 }
490}
491
492pub fn create_payment_service(
494 middleware: PaymentMiddleware,
495) -> impl tower::Layer<tower::ServiceBuilder<tower::layer::util::Identity>> + Clone {
496 ServiceBuilder::new()
497 .layer(TraceLayer::new_for_http())
498 .layer(tower::layer::util::Stack::new(
499 tower::layer::util::Identity::new(),
500 PaymentServiceLayer::new(middleware),
501 ))
502}
503
504#[derive(Clone)]
506pub struct PaymentServiceLayer {
507 middleware: PaymentMiddleware,
508}
509
510impl PaymentServiceLayer {
511 pub fn new(middleware: PaymentMiddleware) -> Self {
512 Self { middleware }
513 }
514}
515
516impl<S> tower::Layer<S> for PaymentServiceLayer {
517 type Service = PaymentService<S>;
518
519 fn layer(&self, inner: S) -> Self::Service {
520 PaymentService {
521 inner,
522 middleware: self.middleware.clone(),
523 }
524 }
525}
526
527#[derive(Clone)]
529pub struct PaymentService<S> {
530 inner: S,
531 middleware: PaymentMiddleware,
532}
533
534impl<S, ReqBody, ResBody> tower::Service<http::Request<ReqBody>> for PaymentService<S>
535where
536 S: tower::Service<
537 http::Request<ReqBody>,
538 Response = http::Response<ResBody>,
539 Error = Box<dyn std::error::Error + Send + Sync>,
540 > + Send
541 + 'static,
542 S::Future: Send + 'static,
543 ReqBody: Send + 'static,
544 ResBody: Send + 'static,
545{
546 type Response = S::Response;
547 type Error = S::Error;
548 type Future = std::pin::Pin<
549 Box<
550 dyn std::future::Future<Output = std::result::Result<Self::Response, Self::Error>>
551 + Send,
552 >,
553 >;
554
555 fn poll_ready(
556 &mut self,
557 cx: &mut std::task::Context<'_>,
558 ) -> std::task::Poll<std::result::Result<(), Self::Error>> {
559 self.inner.poll_ready(cx)
560 }
561
562 fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
563 let middleware = self.middleware.clone();
564
565 let payment_header = req
567 .headers()
568 .get("X-PAYMENT")
569 .and_then(|h| h.to_str().ok())
570 .map(|s| s.to_string());
571 let uri_path = req.uri().path().to_string();
572
573 let future = self.inner.call(req);
574
575 Box::pin(async move {
576 match payment_header {
577 Some(payment_b64) => {
578 match crate::types::PaymentPayload::from_base64(&payment_b64) {
580 Ok(payment_payload) => {
581 let requirements =
583 match middleware.config.create_payment_requirements(&uri_path) {
584 Ok(req) => req,
585 Err(e) => {
586 return Err(
588 Box::new(e) as Box<dyn std::error::Error + Send + Sync>
589 );
590 }
591 };
592
593 match middleware
595 .verify_with_requirements(&payment_payload, &requirements)
596 .await
597 {
598 Ok(true) => {
599 let response = future.await?;
601
602 if let Ok(settlement) = middleware
604 .settle_with_requirements(&payment_payload, &requirements)
605 .await
606 {
607 let _ = settlement; }
612
613 Ok(response)
614 }
615 Ok(false) => {
616 Err(Box::new(crate::X402Error::payment_verification_failed(
618 "Payment verification failed",
619 ))
620 as Box<dyn std::error::Error + Send + Sync>)
621 }
622 Err(e) => {
623 Err(Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
625 }
626 }
627 }
628 Err(e) => {
629 Err(Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
631 }
632 }
633 }
634 None => {
635 Err(Box::new(crate::X402Error::payment_verification_failed(
637 "X-PAYMENT header is required",
638 ))
639 as Box<dyn std::error::Error + Send + Sync>)
640 }
641 }
642 })
643 }
644}
645
646#[cfg(test)]
647mod tests {
648 use super::*;
649 use std::str::FromStr;
650
651 #[test]
652 fn test_payment_middleware_config() {
653 let config = PaymentMiddlewareConfig::new(
654 Decimal::from_str("0.0001").unwrap(),
655 "0x209693Bc6afc0C5328bA36FaF03C514EF312287C",
656 )
657 .with_description("Test payment")
658 .with_testnet(true);
659
660 assert_eq!(config.amount, Decimal::from_str("0.0001").unwrap());
661 assert_eq!(config.pay_to, "0x209693Bc6afc0C5328bA36FaF03C514EF312287C");
662 assert_eq!(config.description, Some("Test payment".to_string()));
663 assert!(config.testnet);
664 }
665
666 #[test]
667 fn test_payment_middleware_creation() {
668 let middleware = PaymentMiddleware::new(
669 Decimal::from_str("0.0001").unwrap(),
670 "0x209693Bc6afc0C5328bA36FaF03C514EF312287C",
671 )
672 .with_description("Test payment");
673
674 assert_eq!(
675 middleware.config().amount,
676 Decimal::from_str("0.0001").unwrap()
677 );
678 assert_eq!(
679 middleware.config().pay_to,
680 "0x209693Bc6afc0C5328bA36FaF03C514EF312287C"
681 );
682 }
683
684 #[test]
685 fn test_payment_requirements_creation() {
686 let config = PaymentMiddlewareConfig::new(
687 Decimal::from_str("0.0001").unwrap(),
688 "0x209693Bc6afc0C5328bA36FaF03C514EF312287C",
689 )
690 .with_testnet(true);
691
692 let requirements = config.create_payment_requirements("/test").unwrap();
693
694 assert_eq!(requirements.scheme, "exact");
695 assert_eq!(requirements.network, "base-sepolia");
696 assert_eq!(requirements.max_amount_required, "100");
697 assert_eq!(
698 requirements.pay_to,
699 "0x209693Bc6afc0C5328bA36FaF03C514EF312287C"
700 );
701 }
702
703 #[test]
704 fn test_payment_middleware_config_builder() {
705 let config = PaymentMiddlewareConfig::new(
706 Decimal::from_str("0.01").unwrap(),
707 "0x209693Bc6afc0C5328bA36FaF03C514EF312287C",
708 )
709 .with_description("Test payment")
710 .with_mime_type("application/json")
711 .with_max_timeout_seconds(120)
712 .with_testnet(false)
713 .with_resource("https://example.com/test");
714
715 assert_eq!(config.amount, Decimal::from_str("0.01").unwrap());
716 assert_eq!(config.pay_to, "0x209693Bc6afc0C5328bA36FaF03C514EF312287C");
717 assert_eq!(config.description, Some("Test payment".to_string()));
718 assert_eq!(config.mime_type, Some("application/json".to_string()));
719 assert_eq!(config.max_timeout_seconds, 120);
720 assert!(!config.testnet);
721 assert_eq!(
722 config.resource,
723 Some("https://example.com/test".to_string())
724 );
725 }
726
727 #[test]
728 fn test_payment_middleware_creation_with_description() {
729 let middleware = PaymentMiddleware::new(
730 Decimal::from_str("0.001").unwrap(),
731 "0x209693Bc6afc0C5328bA36FaF03C514EF312287C",
732 )
733 .with_description("Test middleware");
734
735 assert_eq!(
736 middleware.config().amount,
737 Decimal::from_str("0.001").unwrap()
738 );
739 assert_eq!(
740 middleware.config().description,
741 Some("Test middleware".to_string())
742 );
743 }
744}