1use crate::types::{Network, *};
4use crate::{Result, X402Error};
5use axum::{
6 extract::{Request, State},
7 http::{HeaderValue, StatusCode},
8 middleware::Next,
9 response::{IntoResponse, Response},
10 Json,
11};
12use rust_decimal::Decimal;
13use std::sync::Arc;
14use tower::ServiceBuilder;
15use tower_http::trace::TraceLayer;
16
17#[derive(Debug, Clone)]
19pub struct PaymentMiddlewareConfig {
20 pub amount: Decimal,
22 pub pay_to: String,
24 pub description: Option<String>,
26 pub mime_type: Option<String>,
28 pub max_timeout_seconds: u32,
30 pub output_schema: Option<serde_json::Value>,
32 pub facilitator_config: FacilitatorConfig,
34 pub testnet: bool,
36 pub custom_paywall_html: Option<String>,
38 pub resource: Option<String>,
40 pub resource_root_url: Option<String>,
42}
43
44impl PaymentMiddlewareConfig {
45 pub fn new(amount: Decimal, pay_to: impl Into<String>) -> Self {
47 let pay_to_normalized = pay_to.into().to_lowercase();
49 Self {
50 amount,
51 pay_to: pay_to_normalized,
52 description: None,
53 mime_type: None,
54 max_timeout_seconds: 60,
55 output_schema: None,
56 facilitator_config: FacilitatorConfig::default(),
57 testnet: true,
58 custom_paywall_html: None,
59 resource: None,
60 resource_root_url: None,
61 }
62 }
63
64 pub fn with_description(mut self, description: impl Into<String>) -> Self {
66 self.description = Some(description.into());
67 self
68 }
69
70 pub fn with_mime_type(mut self, mime_type: impl Into<String>) -> Self {
72 self.mime_type = Some(mime_type.into());
73 self
74 }
75
76 pub fn with_max_timeout_seconds(mut self, max_timeout_seconds: u32) -> Self {
78 self.max_timeout_seconds = max_timeout_seconds;
79 self
80 }
81
82 pub fn with_output_schema(mut self, output_schema: serde_json::Value) -> Self {
84 self.output_schema = Some(output_schema);
85 self
86 }
87
88 pub fn with_facilitator_config(mut self, facilitator_config: FacilitatorConfig) -> Self {
90 self.facilitator_config = facilitator_config;
91 self
92 }
93
94 pub fn with_testnet(mut self, testnet: bool) -> Self {
96 self.testnet = testnet;
97 self
98 }
99
100 pub fn with_custom_paywall_html(mut self, html: impl Into<String>) -> Self {
102 self.custom_paywall_html = Some(html.into());
103 self
104 }
105
106 pub fn with_resource(mut self, resource: impl Into<String>) -> Self {
108 self.resource = Some(resource.into());
109 self
110 }
111
112 pub fn with_resource_root_url(mut self, url: impl Into<String>) -> Self {
114 self.resource_root_url = Some(url.into());
115 self
116 }
117
118 pub fn create_payment_requirements(&self, request_uri: &str) -> Result<PaymentRequirements> {
120 let network = if self.testnet {
121 networks::BASE_SEPOLIA
122 } else {
123 networks::BASE_MAINNET
124 };
125
126 let usdc_address =
127 networks::get_usdc_address(network).ok_or_else(|| X402Error::NetworkNotSupported {
128 network: network.to_string(),
129 })?;
130
131 let resource = if let Some(ref resource_url) = self.resource {
132 resource_url.clone()
133 } else if let Some(ref root_url) = self.resource_root_url {
134 format!("{}{}", root_url, request_uri)
135 } else {
136 request_uri.to_string()
137 };
138
139 let max_amount_required = (self.amount * Decimal::from(1_000_000u64))
140 .normalize()
141 .to_string();
142
143 let pay_to_normalized = self.pay_to.to_lowercase();
145
146 let mut requirements = PaymentRequirements::new(
147 schemes::EXACT,
148 network,
149 max_amount_required,
150 usdc_address,
151 &pay_to_normalized,
152 resource,
153 self.description.as_deref().unwrap_or("Payment required"),
154 );
155
156 requirements.mime_type = self.mime_type.clone();
157 requirements.output_schema = self.output_schema.clone();
158 requirements.max_timeout_seconds = self.max_timeout_seconds;
159
160 let network = if self.testnet {
161 Network::Testnet
162 } else {
163 Network::Mainnet
164 };
165 requirements.set_usdc_info(network)?;
166
167 Ok(requirements)
168 }
169}
170
171#[derive(Debug, Clone)]
173pub struct PaymentMiddleware {
174 pub config: Arc<PaymentMiddlewareConfig>,
175 pub facilitator: Option<crate::facilitator::FacilitatorClient>,
176 pub template_config: Option<crate::template::PaywallConfig>,
177}
178
179#[derive(Debug)]
181pub enum PaymentResult {
182 Success {
184 response: axum::response::Response,
185 settlement: crate::types::SettleResponse,
186 },
187 PaymentRequired { response: axum::response::Response },
189 VerificationFailed { response: axum::response::Response },
191 SettlementFailed { response: axum::response::Response },
193}
194
195impl PaymentMiddleware {
196 pub fn new(amount: Decimal, pay_to: impl Into<String>) -> Self {
198 Self {
199 config: Arc::new(PaymentMiddlewareConfig::new(amount, pay_to)),
200 facilitator: None,
201 template_config: None,
202 }
203 }
204
205 pub fn with_description(mut self, description: impl Into<String>) -> Self {
207 Arc::make_mut(&mut self.config).description = Some(description.into());
208 self
209 }
210
211 pub fn with_mime_type(mut self, mime_type: impl Into<String>) -> Self {
213 Arc::make_mut(&mut self.config).mime_type = Some(mime_type.into());
214 self
215 }
216
217 pub fn with_max_timeout_seconds(mut self, max_timeout_seconds: u32) -> Self {
219 Arc::make_mut(&mut self.config).max_timeout_seconds = max_timeout_seconds;
220 self
221 }
222
223 pub fn with_output_schema(mut self, output_schema: serde_json::Value) -> Self {
225 Arc::make_mut(&mut self.config).output_schema = Some(output_schema);
226 self
227 }
228
229 pub fn with_facilitator_config(mut self, facilitator_config: FacilitatorConfig) -> Self {
231 Arc::make_mut(&mut self.config).facilitator_config = facilitator_config;
232 self
233 }
234
235 pub fn with_testnet(mut self, testnet: bool) -> Self {
237 Arc::make_mut(&mut self.config).testnet = testnet;
238 self
239 }
240
241 pub fn with_custom_paywall_html(mut self, html: impl Into<String>) -> Self {
243 Arc::make_mut(&mut self.config).custom_paywall_html = Some(html.into());
244 self
245 }
246
247 pub fn with_resource(mut self, resource: impl Into<String>) -> Self {
249 Arc::make_mut(&mut self.config).resource = Some(resource.into());
250 self
251 }
252
253 pub fn with_resource_root_url(mut self, url: impl Into<String>) -> Self {
255 Arc::make_mut(&mut self.config).resource_root_url = Some(url.into());
256 self
257 }
258
259 pub fn config(&self) -> &PaymentMiddlewareConfig {
261 &self.config
262 }
263
264 pub fn with_facilitator(mut self, facilitator: crate::facilitator::FacilitatorClient) -> Self {
266 self.facilitator = Some(facilitator);
267 self
268 }
269
270 pub fn with_template_config(mut self, template_config: crate::template::PaywallConfig) -> Self {
272 self.template_config = Some(template_config);
273 self
274 }
275
276 pub async fn verify(&self, payment_payload: &PaymentPayload) -> bool {
278 let facilitator = if let Some(facilitator) = &self.facilitator {
280 facilitator.clone()
281 } else {
282 match crate::facilitator::FacilitatorClient::new(self.config.facilitator_config.clone())
283 {
284 Ok(facilitator) => facilitator,
285 Err(_) => return false,
286 }
287 };
288
289 if let Ok(requirements) = self.config.create_payment_requirements("/") {
290 if let Ok(response) = facilitator.verify(payment_payload, &requirements).await {
291 return response.is_valid;
292 }
293 }
294 false
295 }
296
297 pub async fn settle(&self, payment_payload: &PaymentPayload) -> crate::Result<SettleResponse> {
299 let facilitator = if let Some(facilitator) = &self.facilitator {
301 facilitator.clone()
302 } else {
303 crate::facilitator::FacilitatorClient::new(self.config.facilitator_config.clone())?
304 };
305
306 let requirements = self.config.create_payment_requirements("/")?;
307 facilitator.settle(payment_payload, &requirements).await
308 }
309
310 pub async fn verify_with_requirements(
312 &self,
313 payment_payload: &PaymentPayload,
314 requirements: &PaymentRequirements,
315 ) -> crate::Result<bool> {
316 let facilitator = if let Some(facilitator) = &self.facilitator {
317 facilitator.clone()
318 } else {
319 crate::facilitator::FacilitatorClient::new(self.config.facilitator_config.clone())?
320 };
321
322 let response = facilitator.verify(payment_payload, requirements).await?;
323 Ok(response.is_valid)
324 }
325
326 pub async fn settle_with_requirements(
328 &self,
329 payment_payload: &PaymentPayload,
330 requirements: &PaymentRequirements,
331 ) -> crate::Result<SettleResponse> {
332 let facilitator = if let Some(facilitator) = &self.facilitator {
333 facilitator.clone()
334 } else {
335 crate::facilitator::FacilitatorClient::new(self.config.facilitator_config.clone())?
336 };
337
338 facilitator.settle(payment_payload, requirements).await
339 }
340
341 pub async fn process_payment(
343 &self,
344 request: Request,
345 next: Next,
346 ) -> crate::Result<PaymentResult> {
347 let headers = request.headers();
348 let uri = request.uri().to_string();
349
350 let user_agent = headers
352 .get("User-Agent")
353 .and_then(|v| v.to_str().ok())
354 .unwrap_or("");
355 let accept = headers
356 .get("Accept")
357 .and_then(|v| v.to_str().ok())
358 .unwrap_or("");
359
360 let is_web_browser = accept.contains("text/html") && user_agent.contains("Mozilla");
361
362 let payment_requirements = self.config.create_payment_requirements(&uri)?;
364
365 let payment_header = headers.get("X-PAYMENT").and_then(|v| v.to_str().ok());
367
368 match payment_header {
369 Some(payment_b64) => {
370 let payment_payload = PaymentPayload::from_base64(payment_b64).map_err(|e| {
372 X402Error::invalid_payment_payload(format!("Failed to decode payment: {}", e))
373 })?;
374
375 let facilitator = if let Some(facilitator) = &self.facilitator {
377 facilitator.clone()
378 } else {
379 crate::facilitator::FacilitatorClient::new(
380 self.config.facilitator_config.clone(),
381 )?
382 };
383
384 let verify_response = facilitator
386 .verify(&payment_payload, &payment_requirements)
387 .await
388 .map_err(|e| {
389 X402Error::facilitator_error(format!("Payment verification failed: {}", e))
390 })?;
391
392 if !verify_response.is_valid {
393 let error_response = self.create_payment_required_response(
394 "Payment verification failed",
395 &payment_requirements,
396 is_web_browser,
397 )?;
398 return Ok(PaymentResult::VerificationFailed {
399 response: error_response,
400 });
401 }
402
403 let mut response = next.run(request).await;
405
406 let settle_response = facilitator
408 .settle(&payment_payload, &payment_requirements)
409 .await
410 .map_err(|e| {
411 X402Error::facilitator_error(format!("Payment settlement failed: {}", e))
412 })?;
413
414 let settlement_header = settle_response.to_base64().map_err(|e| {
416 X402Error::config(format!("Failed to encode settlement response: {}", e))
417 })?;
418
419 if let Ok(header_value) = HeaderValue::from_str(&settlement_header) {
420 response
421 .headers_mut()
422 .insert("X-PAYMENT-RESPONSE", header_value);
423 }
424
425 Ok(PaymentResult::Success {
426 response,
427 settlement: settle_response,
428 })
429 }
430 None => {
431 let response = self.create_payment_required_response(
433 "X-PAYMENT header is required",
434 &payment_requirements,
435 is_web_browser,
436 )?;
437 Ok(PaymentResult::PaymentRequired { response })
438 }
439 }
440 }
441
442 fn create_payment_required_response(
444 &self,
445 error: &str,
446 payment_requirements: &PaymentRequirements,
447 is_web_browser: bool,
448 ) -> crate::Result<axum::response::Response> {
449 if is_web_browser {
450 let html = if let Some(custom_html) = &self.config.custom_paywall_html {
451 custom_html.clone()
452 } else {
453 let paywall_config = self.template_config.clone().unwrap_or_else(|| {
455 crate::template::PaywallConfig::new()
456 .with_app_name("x402 Service")
457 .with_app_logo("💰")
458 });
459
460 crate::template::generate_paywall_html(
461 error,
462 std::slice::from_ref(payment_requirements),
463 Some(&paywall_config),
464 )
465 };
466
467 let response = Response::builder()
468 .status(StatusCode::PAYMENT_REQUIRED)
469 .header("Content-Type", "text/html")
470 .body(html.into())
471 .map_err(|e| X402Error::config(format!("Failed to create HTML response: {}", e)))?;
472
473 Ok(response)
474 } else {
475 let payment_response =
476 PaymentRequirementsResponse::new(error, vec![payment_requirements.clone()]);
477
478 Ok(Json(payment_response).into_response())
479 }
480 }
481}
482
483pub async fn payment_middleware(
485 State(middleware): State<PaymentMiddleware>,
486 request: Request,
487 next: Next,
488) -> crate::Result<impl IntoResponse> {
489 match middleware.process_payment(request, next).await? {
490 PaymentResult::Success { response, .. } => Ok(response),
491 PaymentResult::PaymentRequired { response } => Ok(response),
492 PaymentResult::VerificationFailed { response } => Ok(response),
493 PaymentResult::SettlementFailed { response } => Ok(response),
494 }
495}
496
497pub fn create_payment_service(
499 middleware: PaymentMiddleware,
500) -> impl tower::Layer<tower::ServiceBuilder<tower::layer::util::Identity>> + Clone {
501 ServiceBuilder::new()
502 .layer(TraceLayer::new_for_http())
503 .layer(tower::layer::util::Stack::new(
504 tower::layer::util::Identity::new(),
505 PaymentServiceLayer::new(middleware),
506 ))
507}
508
509#[derive(Clone)]
511pub struct PaymentServiceLayer {
512 middleware: PaymentMiddleware,
513}
514
515impl PaymentServiceLayer {
516 pub fn new(middleware: PaymentMiddleware) -> Self {
517 Self { middleware }
518 }
519}
520
521impl<S> tower::Layer<S> for PaymentServiceLayer {
522 type Service = PaymentService<S>;
523
524 fn layer(&self, inner: S) -> Self::Service {
525 PaymentService {
526 inner,
527 middleware: self.middleware.clone(),
528 }
529 }
530}
531
532#[derive(Clone)]
534pub struct PaymentService<S> {
535 inner: S,
536 middleware: PaymentMiddleware,
537}
538
539impl<S, ReqBody, ResBody> tower::Service<http::Request<ReqBody>> for PaymentService<S>
540where
541 S: tower::Service<
542 http::Request<ReqBody>,
543 Response = http::Response<ResBody>,
544 Error = Box<dyn std::error::Error + Send + Sync>,
545 > + Send
546 + 'static,
547 S::Future: Send + 'static,
548 ReqBody: Send + 'static,
549 ResBody: Send + 'static,
550{
551 type Response = S::Response;
552 type Error = S::Error;
553 type Future = std::pin::Pin<
554 Box<
555 dyn std::future::Future<Output = std::result::Result<Self::Response, Self::Error>>
556 + Send,
557 >,
558 >;
559
560 fn poll_ready(
561 &mut self,
562 cx: &mut std::task::Context<'_>,
563 ) -> std::task::Poll<std::result::Result<(), Self::Error>> {
564 self.inner.poll_ready(cx)
565 }
566
567 fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
568 let middleware = self.middleware.clone();
569
570 let payment_header = req
572 .headers()
573 .get("X-PAYMENT")
574 .and_then(|h| h.to_str().ok())
575 .map(|s| s.to_string());
576 let uri_path = req.uri().path().to_string();
577
578 let future = self.inner.call(req);
579
580 Box::pin(async move {
581 match payment_header {
582 Some(payment_b64) => {
583 match crate::types::PaymentPayload::from_base64(&payment_b64) {
585 Ok(payment_payload) => {
586 let requirements =
588 match middleware.config.create_payment_requirements(&uri_path) {
589 Ok(req) => req,
590 Err(e) => {
591 return Err(
593 Box::new(e) as Box<dyn std::error::Error + Send + Sync>
594 );
595 }
596 };
597
598 match middleware
600 .verify_with_requirements(&payment_payload, &requirements)
601 .await
602 {
603 Ok(true) => {
604 let response = future.await?;
606
607 if let Ok(settlement) = middleware
609 .settle_with_requirements(&payment_payload, &requirements)
610 .await
611 {
612 let _ = settlement; }
617
618 Ok(response)
619 }
620 Ok(false) => {
621 Err(Box::new(crate::X402Error::payment_verification_failed(
623 "Payment verification failed",
624 ))
625 as Box<dyn std::error::Error + Send + Sync>)
626 }
627 Err(e) => {
628 Err(Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
630 }
631 }
632 }
633 Err(e) => {
634 Err(Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
636 }
637 }
638 }
639 None => {
640 Err(Box::new(crate::X402Error::payment_verification_failed(
642 "X-PAYMENT header is required",
643 ))
644 as Box<dyn std::error::Error + Send + Sync>)
645 }
646 }
647 })
648 }
649}
650
651#[cfg(test)]
652mod tests {
653 use super::*;
654 use std::str::FromStr;
655
656 #[test]
657 fn test_payment_middleware_config() {
658 let config = PaymentMiddlewareConfig::new(
659 Decimal::from_str("0.0001").unwrap(),
660 "0x209693Bc6afc0C5328bA36FaF03C514EF312287C",
661 )
662 .with_description("Test payment")
663 .with_testnet(true);
664
665 assert_eq!(config.amount, Decimal::from_str("0.0001").unwrap());
666 assert_eq!(config.pay_to, "0x209693bc6afc0c5328ba36faf03c514ef312287c");
667 assert_eq!(config.description, Some("Test payment".to_string()));
668 assert!(config.testnet);
669 }
670
671 #[test]
672 fn test_payment_middleware_creation() {
673 let middleware = PaymentMiddleware::new(
674 Decimal::from_str("0.0001").unwrap(),
675 "0x209693Bc6afc0C5328bA36FaF03C514EF312287C",
676 )
677 .with_description("Test payment");
678
679 assert_eq!(
680 middleware.config().amount,
681 Decimal::from_str("0.0001").unwrap()
682 );
683 assert_eq!(
684 middleware.config().pay_to,
685 "0x209693bc6afc0c5328ba36faf03c514ef312287c"
686 );
687 }
688
689 #[test]
690 fn test_payment_requirements_creation() {
691 let config = PaymentMiddlewareConfig::new(
692 Decimal::from_str("0.0001").unwrap(),
693 "0x209693Bc6afc0C5328bA36FaF03C514EF312287C",
694 )
695 .with_testnet(true);
696
697 let requirements = config.create_payment_requirements("/test").unwrap();
698
699 assert_eq!(requirements.scheme, "exact");
700 assert_eq!(requirements.network, "base-sepolia");
701 assert_eq!(requirements.max_amount_required, "100");
702 assert_eq!(
703 requirements.pay_to,
704 "0x209693bc6afc0c5328ba36faf03c514ef312287c"
705 );
706 }
707
708 #[test]
709 fn test_payment_middleware_config_builder() {
710 let config = PaymentMiddlewareConfig::new(
711 Decimal::from_str("0.01").unwrap(),
712 "0x209693Bc6afc0C5328bA36FaF03C514EF312287C",
713 )
714 .with_description("Test payment")
715 .with_mime_type("application/json")
716 .with_max_timeout_seconds(120)
717 .with_testnet(false)
718 .with_resource("https://example.com/test");
719
720 assert_eq!(config.amount, Decimal::from_str("0.01").unwrap());
721 assert_eq!(config.pay_to, "0x209693bc6afc0c5328ba36faf03c514ef312287c");
722 assert_eq!(config.description, Some("Test payment".to_string()));
723 assert_eq!(config.mime_type, Some("application/json".to_string()));
724 assert_eq!(config.max_timeout_seconds, 120);
725 assert!(!config.testnet);
726 assert_eq!(
727 config.resource,
728 Some("https://example.com/test".to_string())
729 );
730 }
731
732 #[test]
733 fn test_payment_middleware_creation_with_description() {
734 let middleware = PaymentMiddleware::new(
735 Decimal::from_str("0.001").unwrap(),
736 "0x209693Bc6afc0C5328bA36FaF03C514EF312287C",
737 )
738 .with_description("Test middleware");
739
740 assert_eq!(
741 middleware.config().amount,
742 Decimal::from_str("0.001").unwrap()
743 );
744 assert_eq!(
745 middleware.config().description,
746 Some("Test middleware".to_string())
747 );
748 }
749}