1use crate::middleware::{PaymentMiddleware, PaymentMiddlewareConfig};
4use crate::X402Error;
5use axum::{
6 extract::{Request, State},
7 http::{HeaderMap, HeaderValue, StatusCode},
8 middleware::Next,
9 response::{IntoResponse, Response},
10 routing::{delete, get, post, put},
11 Json, Router,
12};
13use rust_decimal::Decimal;
14use std::sync::Arc;
15use tower::ServiceBuilder;
16
17pub use crate::middleware::{create_payment_service, payment_middleware};
19
20pub fn create_payment_router(
22 middleware: PaymentMiddleware,
23 routes: impl FnOnce(&mut Router) -> &mut Router,
24) -> Router {
25 let mut router = Router::new();
26 routes(&mut router);
27
28 router.layer(axum::middleware::from_fn_with_state(
30 middleware,
31 payment_middleware_handler,
32 ))
33}
34
35pub fn payment_route<H>(
37 method: &str,
38 path: &str,
39 handler: H,
40 middleware: PaymentMiddleware,
41) -> Router
42where
43 H: axum::handler::Handler<(), ()> + Clone + Send + 'static,
44{
45 let router = match method.to_uppercase().as_str() {
46 "GET" => Router::new().route(path, get(handler)),
47 "POST" => Router::new().route(path, post(handler)),
48 "PUT" => Router::new().route(path, put(handler)),
49 "DELETE" => Router::new().route(path, delete(handler)),
50 _ => {
51 return Router::new().route(
53 path,
54 axum::routing::any(|| async {
55 (StatusCode::METHOD_NOT_ALLOWED, "Unsupported HTTP method")
56 }),
57 );
58 }
59 };
60
61 router.layer(axum::middleware::from_fn_with_state(
63 middleware,
64 payment_middleware_handler,
65 ))
66}
67
68pub fn create_payment_middleware(amount: Decimal, pay_to: impl Into<String>) -> PaymentMiddleware {
70 PaymentMiddleware::new(amount, pay_to)
71}
72
73fn is_web_browser(headers: &HeaderMap) -> bool {
75 let user_agent = headers
76 .get("User-Agent")
77 .and_then(|h| h.to_str().ok())
78 .unwrap_or("");
79
80 let accept = headers
81 .get("Accept")
82 .and_then(|h| h.to_str().ok())
83 .unwrap_or("");
84
85 accept.contains("text/html") && user_agent.contains("Mozilla")
86}
87
88fn get_default_paywall_html() -> &'static str {
90 r#"<!DOCTYPE html>
91<html>
92<head>
93 <title>Payment Required</title>
94 <style>
95 body { font-family: Arial, sans-serif; text-align: center; padding: 50px; }
96 .container { max-width: 500px; margin: 0 auto; }
97 h1 { color: #333; }
98 p { color: #666; }
99 </style>
100</head>
101<body>
102 <div class="container">
103 <h1>Payment Required</h1>
104 <p>This resource requires payment to access. Please provide a valid X-PAYMENT header.</p>
105 </div>
106</body>
107</html>"#
108}
109
110pub async fn payment_middleware_handler(
112 State(middleware): State<PaymentMiddleware>,
113 request: Request,
114 next: Next,
115) -> impl IntoResponse {
116 let config = middleware.config().clone();
117 let headers = request.headers().clone();
118 let path = request.uri().path();
119
120 if path == "/health" || path.starts_with("/health/") {
122 return next.run(request).await;
123 }
124
125 let resource = if let Some(ref resource_url) = config.resource {
127 resource_url.clone()
128 } else if let Some(ref root_url) = config.resource_root_url {
129 format!("{}{}", root_url, path)
130 } else {
131 path.to_string()
132 };
133
134 let requirements = match config.create_payment_requirements(&resource) {
136 Ok(req) => req,
137 Err(_) => {
138 return (
139 StatusCode::INTERNAL_SERVER_ERROR,
140 Json(serde_json::json!({
141 "error": "Failed to create payment requirements",
142 "x402Version": 1
143 })),
144 )
145 .into_response();
146 }
147 };
148
149 if let Some(payment_header) = headers.get("X-PAYMENT") {
151 if let Ok(payment_str) = payment_header.to_str() {
152 match crate::types::PaymentPayload::from_base64(payment_str) {
154 Ok(payment_payload) => {
155 match middleware
157 .verify_with_requirements(&payment_payload, &requirements)
158 .await
159 {
160 Ok(true) => {
161 let mut response = next.run(request).await;
163
164 match middleware
166 .settle_with_requirements(&payment_payload, &requirements)
167 .await
168 {
169 Ok(settlement_response) => {
170 if let Ok(settlement_header) = settlement_response.to_base64() {
171 if let Ok(header_value) =
172 HeaderValue::from_str(&settlement_header)
173 {
174 response
175 .headers_mut()
176 .insert("X-PAYMENT-RESPONSE", header_value);
177 }
178 }
179 }
180 Err(e) => {
181 tracing::warn!("Payment settlement failed: {}", e);
183 }
184 }
185
186 return response;
187 }
188 Ok(false) => {
189 let response_body = serde_json::json!({
191 "x402Version": 1,
192 "error": "Payment verification failed",
193 "accepts": vec![requirements],
194 });
195 return (StatusCode::PAYMENT_REQUIRED, Json(response_body))
196 .into_response();
197 }
198 Err(e) => {
199 let response_body = serde_json::json!({
201 "x402Version": 1,
202 "error": format!("Payment verification error: {}", e),
203 "accepts": vec![requirements],
204 });
205 return (StatusCode::PAYMENT_REQUIRED, Json(response_body))
206 .into_response();
207 }
208 }
209 }
210 Err(e) => {
211 let response_body = serde_json::json!({
213 "x402Version": 1,
214 "error": format!("Invalid payment payload: {}", e),
215 "accepts": vec![requirements],
216 });
217 return (StatusCode::PAYMENT_REQUIRED, Json(response_body)).into_response();
218 }
219 }
220 }
221 }
222
223 if is_web_browser(&headers) {
225 let html = config
226 .custom_paywall_html
227 .clone()
228 .unwrap_or_else(|| get_default_paywall_html().to_string());
229
230 let mut response = Response::new(axum::body::Body::from(html));
231 *response.status_mut() = StatusCode::PAYMENT_REQUIRED;
232 response
233 .headers_mut()
234 .insert("Content-Type", HeaderValue::from_static("text/html"));
235
236 return response.into_response();
237 }
238
239 let response_body = serde_json::json!({
241 "x402Version": 1,
242 "error": "X-PAYMENT header is required",
243 "accepts": vec![requirements],
244 });
245
246 (StatusCode::PAYMENT_REQUIRED, Json(response_body)).into_response()
247}
248
249#[derive(Debug, Clone)]
251pub struct AxumPaymentConfig {
252 pub base_config: PaymentMiddlewareConfig,
254 pub axum_options: AxumOptions,
256}
257
258#[derive(Clone, Default)]
260pub struct AxumOptions {
261 pub enable_cors: bool,
263 pub cors_origins: Vec<String>,
265 pub enable_tracing: bool,
267 pub error_handler: Option<Arc<dyn Fn(X402Error) -> StatusCode + Send + Sync>>,
269}
270
271impl std::fmt::Debug for AxumOptions {
272 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
273 f.debug_struct("AxumOptions")
274 .field("enable_cors", &self.enable_cors)
275 .field("cors_origins", &self.cors_origins)
276 .field("enable_tracing", &self.enable_tracing)
277 .field("error_handler", &"<function>")
278 .finish()
279 }
280}
281
282impl AxumPaymentConfig {
283 pub fn new(amount: Decimal, pay_to: impl Into<String>) -> Self {
285 Self {
286 base_config: PaymentMiddlewareConfig::new(amount, pay_to),
287 axum_options: AxumOptions::default(),
288 }
289 }
290
291 pub fn with_description(mut self, description: impl Into<String>) -> Self {
293 self.base_config.description = Some(description.into());
294 self
295 }
296
297 pub fn with_mime_type(mut self, mime_type: impl Into<String>) -> Self {
299 self.base_config.mime_type = Some(mime_type.into());
300 self
301 }
302
303 pub fn with_max_timeout_seconds(mut self, max_timeout_seconds: u32) -> Self {
305 self.base_config.max_timeout_seconds = max_timeout_seconds;
306 self
307 }
308
309 pub fn with_output_schema(mut self, output_schema: serde_json::Value) -> Self {
311 self.base_config.output_schema = Some(output_schema);
312 self
313 }
314
315 pub fn with_facilitator_config(
317 mut self,
318 facilitator_config: crate::types::FacilitatorConfig,
319 ) -> Self {
320 self.base_config.facilitator_config = facilitator_config;
321 self
322 }
323
324 pub fn with_testnet(mut self, testnet: bool) -> Self {
326 self.base_config.testnet = testnet;
327 self
328 }
329
330 pub fn with_custom_paywall_html(mut self, html: impl Into<String>) -> Self {
332 self.base_config.custom_paywall_html = Some(html.into());
333 self
334 }
335
336 pub fn with_resource(mut self, resource: impl Into<String>) -> Self {
338 self.base_config.resource = Some(resource.into());
339 self
340 }
341
342 pub fn with_resource_root_url(mut self, url: impl Into<String>) -> Self {
344 self.base_config.resource_root_url = Some(url.into());
345 self
346 }
347
348 pub fn with_cors(mut self, origins: Vec<String>) -> Self {
350 self.axum_options.enable_cors = true;
351 self.axum_options.cors_origins = origins;
352 self
353 }
354
355 pub fn with_tracing(mut self) -> Self {
357 self.axum_options.enable_tracing = true;
358 self
359 }
360
361 pub fn with_error_handler<F>(mut self, handler: F) -> Self
363 where
364 F: Fn(X402Error) -> StatusCode + Send + Sync + 'static,
365 {
366 self.axum_options.error_handler = Some(Arc::new(handler));
367 self
368 }
369
370 pub fn into_middleware(self) -> PaymentMiddleware {
372 PaymentMiddleware {
373 config: Arc::new(self.base_config),
374 facilitator: None,
375 template_config: None,
376 }
377 }
378
379 pub fn create_service(&self) -> ServiceBuilder<tower::layer::util::Identity> {
381 ServiceBuilder::new()
384 }
385}
386
387pub fn create_payment_app(
389 config: AxumPaymentConfig,
390 routes: impl FnOnce(Router) -> Router,
391) -> Router {
392 let router = Router::new();
393 let router = routes(router);
394
395 let middleware = config.into_middleware();
397
398 router.layer(axum::middleware::from_fn_with_state(
400 middleware,
401 payment_middleware_handler,
402 ))
403}
404
405pub mod handlers {
407 use super::*;
408 use serde_json::json;
409
410 pub fn json_response<T: serde::Serialize>(data: T) -> impl IntoResponse {
412 Json(data)
413 }
414
415 pub fn text_response(text: impl Into<String>) -> impl IntoResponse {
417 text.into()
418 }
419
420 pub fn error_response(error: impl Into<String>) -> impl IntoResponse {
422 (
423 StatusCode::INTERNAL_SERVER_ERROR,
424 Json(json!({"error": error.into()})),
425 )
426 }
427
428 pub fn success_response<T: serde::Serialize>(data: T) -> impl IntoResponse {
430 (StatusCode::OK, Json(data))
431 }
432}
433
434pub mod examples {
436 use super::*;
437 use serde_json::json;
438
439 pub async fn joke_handler() -> impl IntoResponse {
441 axum::Json(json!({
442 "joke": "Why do programmers prefer dark mode? Because light attracts bugs!"
443 }))
444 }
445
446 pub async fn api_data_handler() -> impl IntoResponse {
448 axum::Json(json!({
449 "data": "This is premium API data that requires payment to access",
450 "timestamp": chrono::Utc::now().to_rfc3339(),
451 "source": "x402-protected-api"
452 }))
453 }
454
455 pub async fn download_handler() -> impl IntoResponse {
457 let content = "This is premium content that requires payment to download.";
458 (StatusCode::OK, content)
459 }
460}
461
462#[cfg(test)]
463mod tests {
464 use super::*;
465 use std::str::FromStr;
466
467 #[test]
468 fn test_axum_payment_config() {
469 let config = AxumPaymentConfig::new(
470 Decimal::from_str("0.0001").unwrap(),
471 "0x209693Bc6afc0C5328bA36FaF03C514EF312287C",
472 )
473 .with_description("Test payment")
474 .with_testnet(true)
475 .with_cors(vec!["http://localhost:3000".to_string()])
476 .with_tracing();
477
478 assert_eq!(
479 config.base_config.amount,
480 Decimal::from_str("0.0001").unwrap()
481 );
482 assert!(config.axum_options.enable_cors);
483 assert!(config.axum_options.enable_tracing);
484 }
485
486 #[test]
487 fn test_payment_middleware_creation() {
488 let middleware = create_payment_middleware(
489 Decimal::from_str("0.0001").unwrap(),
490 "0x209693Bc6afc0C5328bA36FaF03C514EF312287C",
491 );
492
493 assert_eq!(
494 middleware.config().amount,
495 Decimal::from_str("0.0001").unwrap()
496 );
497 }
498}