1use crate::middleware::{PaymentMiddleware, PaymentMiddlewareConfig};
4use crate::X402Error;
5use axum::{
6 extract::{Request, State},
7 http::{HeaderMap, HeaderValue, StatusCode},
8 middleware::Next,
9 response::{IntoResponse, Response},
10 routing::{delete, get, post, put},
11 Json, Router,
12};
13use rust_decimal::Decimal;
14use std::sync::Arc;
15use tower::ServiceBuilder;
16
17pub use crate::middleware::{create_payment_service, payment_middleware};
19
20pub fn create_payment_router(
22 middleware: PaymentMiddleware,
23 routes: impl FnOnce(&mut Router) -> &mut Router,
24) -> Router {
25 let mut router = Router::new();
26 routes(&mut router);
27
28 router.layer(axum::middleware::from_fn_with_state(
30 middleware,
31 payment_middleware_handler,
32 ))
33}
34
35pub fn payment_route<H>(
37 method: &str,
38 path: &str,
39 handler: H,
40 middleware: PaymentMiddleware,
41) -> Router
42where
43 H: axum::handler::Handler<(), ()> + Clone + Send + 'static,
44{
45 let router = match method.to_uppercase().as_str() {
46 "GET" => Router::new().route(path, get(handler)),
47 "POST" => Router::new().route(path, post(handler)),
48 "PUT" => Router::new().route(path, put(handler)),
49 "DELETE" => Router::new().route(path, delete(handler)),
50 _ => {
51 return Router::new().route(
53 path,
54 axum::routing::any(|| async {
55 (StatusCode::METHOD_NOT_ALLOWED, "Unsupported HTTP method")
56 }),
57 );
58 }
59 };
60
61 router.layer(axum::middleware::from_fn_with_state(
63 middleware,
64 payment_middleware_handler,
65 ))
66}
67
68pub fn create_payment_middleware(amount: Decimal, pay_to: impl Into<String>) -> PaymentMiddleware {
70 PaymentMiddleware::new(amount, pay_to)
71}
72
73fn is_web_browser(headers: &HeaderMap) -> bool {
75 let user_agent = headers
76 .get("User-Agent")
77 .and_then(|h| h.to_str().ok())
78 .unwrap_or("");
79
80 let accept = headers
81 .get("Accept")
82 .and_then(|h| h.to_str().ok())
83 .unwrap_or("");
84
85 accept.contains("text/html") && user_agent.contains("Mozilla")
86}
87
88fn get_default_paywall_html() -> &'static str {
90 r#"<!DOCTYPE html>
91<html>
92<head>
93 <title>Payment Required</title>
94 <style>
95 body { font-family: Arial, sans-serif; text-align: center; padding: 50px; }
96 .container { max-width: 500px; margin: 0 auto; }
97 h1 { color: #333; }
98 p { color: #666; }
99 </style>
100</head>
101<body>
102 <div class="container">
103 <h1>Payment Required</h1>
104 <p>This resource requires payment to access. Please provide a valid X-PAYMENT header.</p>
105 </div>
106</body>
107</html>"#
108}
109
110pub async fn payment_middleware_handler(
112 State(middleware): State<PaymentMiddleware>,
113 request: Request,
114 next: Next,
115) -> impl IntoResponse {
116 let config = middleware.config().clone();
117 let headers = request.headers().clone();
118
119 let resource = if let Some(ref resource_url) = config.resource {
121 resource_url.clone()
122 } else if let Some(ref root_url) = config.resource_root_url {
123 format!("{}{}", root_url, request.uri().path())
124 } else {
125 request.uri().path().to_string()
126 };
127
128 let requirements = match config.create_payment_requirements(&resource) {
130 Ok(req) => req,
131 Err(_) => {
132 return (
133 StatusCode::INTERNAL_SERVER_ERROR,
134 Json(serde_json::json!({
135 "error": "Failed to create payment requirements",
136 "x402Version": 1
137 })),
138 )
139 .into_response();
140 }
141 };
142
143 if let Some(payment_header) = headers.get("X-PAYMENT") {
145 if let Ok(payment_str) = payment_header.to_str() {
146 match crate::types::PaymentPayload::from_base64(payment_str) {
148 Ok(payment_payload) => {
149 match middleware
151 .verify_with_requirements(&payment_payload, &requirements)
152 .await
153 {
154 Ok(true) => {
155 let mut response = next.run(request).await;
157
158 match middleware
160 .settle_with_requirements(&payment_payload, &requirements)
161 .await
162 {
163 Ok(settlement_response) => {
164 if let Ok(settlement_header) = settlement_response.to_base64() {
165 if let Ok(header_value) =
166 HeaderValue::from_str(&settlement_header)
167 {
168 response
169 .headers_mut()
170 .insert("X-PAYMENT-RESPONSE", header_value);
171 }
172 }
173 }
174 Err(e) => {
175 tracing::warn!("Payment settlement failed: {}", e);
177 }
178 }
179
180 return response;
181 }
182 Ok(false) => {
183 let response_body = serde_json::json!({
185 "x402Version": 1,
186 "error": "Payment verification failed",
187 "accepts": vec![requirements],
188 });
189 return (StatusCode::PAYMENT_REQUIRED, Json(response_body))
190 .into_response();
191 }
192 Err(e) => {
193 let response_body = serde_json::json!({
195 "x402Version": 1,
196 "error": format!("Payment verification error: {}", e),
197 "accepts": vec![requirements],
198 });
199 return (StatusCode::PAYMENT_REQUIRED, Json(response_body))
200 .into_response();
201 }
202 }
203 }
204 Err(e) => {
205 let response_body = serde_json::json!({
207 "x402Version": 1,
208 "error": format!("Invalid payment payload: {}", e),
209 "accepts": vec![requirements],
210 });
211 return (StatusCode::PAYMENT_REQUIRED, Json(response_body)).into_response();
212 }
213 }
214 }
215 }
216
217 if is_web_browser(&headers) {
219 let html = config
220 .custom_paywall_html
221 .clone()
222 .unwrap_or_else(|| get_default_paywall_html().to_string());
223
224 let mut response = Response::new(axum::body::Body::from(html));
225 *response.status_mut() = StatusCode::PAYMENT_REQUIRED;
226 response
227 .headers_mut()
228 .insert("Content-Type", HeaderValue::from_static("text/html"));
229
230 return response.into_response();
231 }
232
233 let response_body = serde_json::json!({
235 "x402Version": 1,
236 "error": "X-PAYMENT header is required",
237 "accepts": vec![requirements],
238 });
239
240 (StatusCode::PAYMENT_REQUIRED, Json(response_body)).into_response()
241}
242
243#[derive(Debug, Clone)]
245pub struct AxumPaymentConfig {
246 pub base_config: PaymentMiddlewareConfig,
248 pub axum_options: AxumOptions,
250}
251
252#[derive(Clone, Default)]
254pub struct AxumOptions {
255 pub enable_cors: bool,
257 pub cors_origins: Vec<String>,
259 pub enable_tracing: bool,
261 pub error_handler: Option<Arc<dyn Fn(X402Error) -> StatusCode + Send + Sync>>,
263}
264
265impl std::fmt::Debug for AxumOptions {
266 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
267 f.debug_struct("AxumOptions")
268 .field("enable_cors", &self.enable_cors)
269 .field("cors_origins", &self.cors_origins)
270 .field("enable_tracing", &self.enable_tracing)
271 .field("error_handler", &"<function>")
272 .finish()
273 }
274}
275
276impl AxumPaymentConfig {
277 pub fn new(amount: Decimal, pay_to: impl Into<String>) -> Self {
279 Self {
280 base_config: PaymentMiddlewareConfig::new(amount, pay_to),
281 axum_options: AxumOptions::default(),
282 }
283 }
284
285 pub fn with_description(mut self, description: impl Into<String>) -> Self {
287 self.base_config.description = Some(description.into());
288 self
289 }
290
291 pub fn with_mime_type(mut self, mime_type: impl Into<String>) -> Self {
293 self.base_config.mime_type = Some(mime_type.into());
294 self
295 }
296
297 pub fn with_max_timeout_seconds(mut self, max_timeout_seconds: u32) -> Self {
299 self.base_config.max_timeout_seconds = max_timeout_seconds;
300 self
301 }
302
303 pub fn with_output_schema(mut self, output_schema: serde_json::Value) -> Self {
305 self.base_config.output_schema = Some(output_schema);
306 self
307 }
308
309 pub fn with_facilitator_config(
311 mut self,
312 facilitator_config: crate::types::FacilitatorConfig,
313 ) -> Self {
314 self.base_config.facilitator_config = facilitator_config;
315 self
316 }
317
318 pub fn with_testnet(mut self, testnet: bool) -> Self {
320 self.base_config.testnet = testnet;
321 self
322 }
323
324 pub fn with_custom_paywall_html(mut self, html: impl Into<String>) -> Self {
326 self.base_config.custom_paywall_html = Some(html.into());
327 self
328 }
329
330 pub fn with_resource(mut self, resource: impl Into<String>) -> Self {
332 self.base_config.resource = Some(resource.into());
333 self
334 }
335
336 pub fn with_resource_root_url(mut self, url: impl Into<String>) -> Self {
338 self.base_config.resource_root_url = Some(url.into());
339 self
340 }
341
342 pub fn with_cors(mut self, origins: Vec<String>) -> Self {
344 self.axum_options.enable_cors = true;
345 self.axum_options.cors_origins = origins;
346 self
347 }
348
349 pub fn with_tracing(mut self) -> Self {
351 self.axum_options.enable_tracing = true;
352 self
353 }
354
355 pub fn with_error_handler<F>(mut self, handler: F) -> Self
357 where
358 F: Fn(X402Error) -> StatusCode + Send + Sync + 'static,
359 {
360 self.axum_options.error_handler = Some(Arc::new(handler));
361 self
362 }
363
364 pub fn into_middleware(self) -> PaymentMiddleware {
366 PaymentMiddleware {
367 config: Arc::new(self.base_config),
368 facilitator: None,
369 template_config: None,
370 }
371 }
372
373 pub fn create_service(&self) -> ServiceBuilder<tower::layer::util::Identity> {
375 ServiceBuilder::new()
378 }
379}
380
381pub fn create_payment_app(
383 config: AxumPaymentConfig,
384 routes: impl FnOnce(Router) -> Router,
385) -> Router {
386 let router = Router::new();
387 let router = routes(router);
388
389 router.layer(config.create_service())
391}
392
393pub mod handlers {
395 use super::*;
396 use serde_json::json;
397
398 pub fn json_response<T: serde::Serialize>(data: T) -> impl IntoResponse {
400 Json(data)
401 }
402
403 pub fn text_response(text: impl Into<String>) -> impl IntoResponse {
405 text.into()
406 }
407
408 pub fn error_response(error: impl Into<String>) -> impl IntoResponse {
410 (
411 StatusCode::INTERNAL_SERVER_ERROR,
412 Json(json!({"error": error.into()})),
413 )
414 }
415
416 pub fn success_response<T: serde::Serialize>(data: T) -> impl IntoResponse {
418 (StatusCode::OK, Json(data))
419 }
420}
421
422pub mod examples {
424 use super::*;
425 use serde_json::json;
426
427 pub async fn joke_handler() -> impl IntoResponse {
429 axum::Json(json!({
430 "joke": "Why do programmers prefer dark mode? Because light attracts bugs!"
431 }))
432 }
433
434 pub async fn api_data_handler() -> impl IntoResponse {
436 axum::Json(json!({
437 "data": "This is premium API data that requires payment to access",
438 "timestamp": chrono::Utc::now().to_rfc3339(),
439 "source": "x402-protected-api"
440 }))
441 }
442
443 pub async fn download_handler() -> impl IntoResponse {
445 let content = "This is premium content that requires payment to download.";
446 (StatusCode::OK, content)
447 }
448}
449
450#[cfg(test)]
451mod tests {
452 use super::*;
453 use std::str::FromStr;
454
455 #[test]
456 fn test_axum_payment_config() {
457 let config = AxumPaymentConfig::new(
458 Decimal::from_str("0.0001").unwrap(),
459 "0x209693Bc6afc0C5328bA36FaF03C514EF312287C",
460 )
461 .with_description("Test payment")
462 .with_testnet(true)
463 .with_cors(vec!["http://localhost:3000".to_string()])
464 .with_tracing();
465
466 assert_eq!(
467 config.base_config.amount,
468 Decimal::from_str("0.0001").unwrap()
469 );
470 assert!(config.axum_options.enable_cors);
471 assert!(config.axum_options.enable_tracing);
472 }
473
474 #[test]
475 fn test_payment_middleware_creation() {
476 let middleware = create_payment_middleware(
477 Decimal::from_str("0.0001").unwrap(),
478 "0x209693Bc6afc0C5328bA36FaF03C514EF312287C",
479 );
480
481 assert_eq!(
482 middleware.config().amount,
483 Decimal::from_str("0.0001").unwrap()
484 );
485 }
486}