rust_x402/
proxy.rs

1//! Proxy server implementation for x402 payments
2//!
3//! This module provides a reverse proxy server that adds x402 payment protection
4//! to any existing HTTP service.
5
6use crate::middleware::PaymentMiddlewareConfig;
7use crate::types::{FacilitatorConfig, Network};
8use crate::{Result, X402Error};
9use axum::{
10    extract::State,
11    http::{HeaderMap, HeaderName, HeaderValue, Method, StatusCode},
12    response::{IntoResponse, Response},
13    routing::any,
14    Router,
15};
16use rust_decimal::Decimal;
17use serde::{Deserialize, Serialize};
18use std::collections::HashMap;
19use std::str::FromStr;
20use tower::ServiceBuilder;
21use tower_http::trace::TraceLayer;
22use tracing::{info, warn};
23
24/// Configuration for the proxy server
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct ProxyConfig {
27    /// Target URL to proxy requests to
28    pub target_url: String,
29    /// Payment amount in decimal units (e.g., 0.01 for 1 cent)
30    pub amount: f64,
31    /// Recipient wallet address
32    pub pay_to: String,
33    /// Payment description
34    pub description: Option<String>,
35    /// MIME type of the expected response
36    pub mime_type: Option<String>,
37    /// Maximum timeout in seconds
38    pub max_timeout_seconds: u32,
39    /// Facilitator URL
40    pub facilitator_url: String,
41    /// Whether to use testnet
42    pub testnet: bool,
43    /// Additional headers to forward to target
44    pub headers: HashMap<String, String>,
45    /// CDP API credentials (optional)
46    pub cdp_api_key_id: Option<String>,
47    pub cdp_api_key_secret: Option<String>,
48}
49
50impl Default for ProxyConfig {
51    fn default() -> Self {
52        Self {
53            target_url: String::new(),
54            amount: 0.0001,
55            pay_to: String::new(),
56            description: None,
57            mime_type: None,
58            max_timeout_seconds: 60,
59            facilitator_url: "https://x402.org/facilitator".to_string(),
60            testnet: true,
61            headers: HashMap::new(),
62            cdp_api_key_id: None,
63            cdp_api_key_secret: None,
64        }
65    }
66}
67
68impl ProxyConfig {
69    /// Load configuration from a JSON file
70    pub fn from_file(path: &str) -> Result<Self> {
71        let content = std::fs::read_to_string(path)
72            .map_err(|e| X402Error::config(format!("Failed to read config file: {}", e)))?;
73
74        let config: ProxyConfig = serde_json::from_str(&content)
75            .map_err(|e| X402Error::config(format!("Failed to parse config file: {}", e)))?;
76
77        config.validate()?;
78        Ok(config)
79    }
80
81    /// Load configuration from environment variables
82    pub fn from_env() -> Result<Self> {
83        let mut config = Self::default();
84
85        if let Ok(target_url) = std::env::var("TARGET_URL") {
86            config.target_url = target_url;
87        }
88
89        if let Ok(amount) = std::env::var("AMOUNT") {
90            config.amount = amount
91                .parse()
92                .map_err(|e| X402Error::config(format!("Invalid AMOUNT: {}", e)))?;
93        }
94
95        if let Ok(pay_to) = std::env::var("PAY_TO") {
96            config.pay_to = pay_to;
97        }
98
99        if let Ok(description) = std::env::var("DESCRIPTION") {
100            config.description = Some(description);
101        }
102
103        if let Ok(facilitator_url) = std::env::var("FACILITATOR_URL") {
104            config.facilitator_url = facilitator_url;
105        }
106
107        if let Ok(testnet) = std::env::var("TESTNET") {
108            config.testnet = testnet
109                .parse()
110                .map_err(|e| X402Error::config(format!("Invalid TESTNET: {}", e)))?;
111        }
112
113        if let Ok(cdp_api_key_id) = std::env::var("CDP_API_KEY_ID") {
114            config.cdp_api_key_id = Some(cdp_api_key_id);
115        }
116
117        if let Ok(cdp_api_key_secret) = std::env::var("CDP_API_KEY_SECRET") {
118            config.cdp_api_key_secret = Some(cdp_api_key_secret);
119        }
120
121        config.validate()?;
122        Ok(config)
123    }
124
125    /// Validate the configuration
126    pub fn validate(&self) -> Result<()> {
127        if self.target_url.is_empty() {
128            return Err(X402Error::config("TARGET_URL is required"));
129        }
130
131        if self.pay_to.is_empty() {
132            return Err(X402Error::config("PAY_TO is required"));
133        }
134
135        if self.amount <= 0.0 {
136            return Err(X402Error::config("AMOUNT must be positive"));
137        }
138
139        // Validate target URL
140        url::Url::parse(&self.target_url)
141            .map_err(|e| X402Error::config(format!("Invalid TARGET_URL: {}", e)))?;
142
143        // Validate facilitator URL
144        url::Url::parse(&self.facilitator_url)
145            .map_err(|e| X402Error::config(format!("Invalid FACILITATOR_URL: {}", e)))?;
146
147        Ok(())
148    }
149
150    /// Convert to payment middleware config
151    pub fn to_payment_config(&self) -> Result<PaymentMiddlewareConfig> {
152        let amount = Decimal::from_str(&self.amount.to_string())
153            .map_err(|e| X402Error::config(format!("Invalid amount: {}", e)))?;
154
155        let mut facilitator_config = FacilitatorConfig::new(&self.facilitator_url);
156
157        // Set up CDP authentication if credentials are provided
158        if let (Some(api_key_id), Some(api_key_secret)) =
159            (&self.cdp_api_key_id, &self.cdp_api_key_secret)
160        {
161            if !api_key_id.is_empty() && !api_key_secret.is_empty() {
162                let auth_headers =
163                    crate::facilitator::coinbase::create_auth_headers(api_key_id, api_key_secret);
164                facilitator_config = facilitator_config.with_auth_headers(Box::new(auth_headers));
165            }
166        }
167
168        let _network = if self.testnet {
169            Network::Testnet
170        } else {
171            Network::Mainnet
172        };
173
174        // Normalize pay_to to lowercase to avoid EIP-55 checksum mismatches
175        let pay_to_normalized = self.pay_to.to_lowercase();
176        let mut config = PaymentMiddlewareConfig::new(amount, &pay_to_normalized)
177            .with_facilitator_config(facilitator_config)
178            .with_testnet(self.testnet)
179            .with_max_timeout_seconds(self.max_timeout_seconds);
180
181        if let Some(description) = &self.description {
182            config = config.with_description(description);
183        }
184
185        if let Some(mime_type) = &self.mime_type {
186            config = config.with_mime_type(mime_type);
187        }
188
189        Ok(config)
190    }
191}
192
193/// Proxy server state
194#[derive(Clone)]
195pub struct ProxyState {
196    config: ProxyConfig,
197    client: reqwest::Client,
198}
199
200impl ProxyState {
201    pub fn new(config: ProxyConfig) -> Result<Self> {
202        let client = reqwest::Client::builder()
203            .timeout(std::time::Duration::from_secs(30))
204            .build()
205            .map_err(|e| X402Error::config(format!("Failed to create HTTP client: {}", e)))?;
206
207        Ok(Self { config, client })
208    }
209}
210
211/// Create a proxy server with x402 payment protection
212pub fn create_proxy_server(config: ProxyConfig) -> Result<Router> {
213    let state = ProxyState::new(config.clone())?;
214
215    let app = Router::new()
216        .route("/*path", any(proxy_handler))
217        .with_state(state);
218
219    Ok(app)
220}
221
222/// Create a proxy server with tracing middleware
223pub fn create_proxy_server_with_tracing(config: ProxyConfig) -> Result<Router> {
224    let state = ProxyState::new(config.clone())?;
225
226    let app = Router::new()
227        .route("/*path", any(proxy_handler))
228        .with_state(state)
229        .layer(ServiceBuilder::new().layer(TraceLayer::new_for_http()));
230
231    Ok(app)
232}
233
234/// Create a proxy server with x402 payment middleware
235pub fn create_proxy_server_with_payment(config: ProxyConfig) -> Result<Router> {
236    let state = ProxyState::new(config.clone())?;
237
238    // Create payment middleware from config
239    let payment_config = config.to_payment_config()?;
240    let payment_middleware = crate::middleware::PaymentMiddleware::new(
241        payment_config.amount,
242        payment_config.pay_to.clone(),
243    )
244    .with_facilitator_config(payment_config.facilitator_config.clone())
245    .with_testnet(payment_config.testnet)
246    .with_description(
247        payment_config
248            .description
249            .as_deref()
250            .unwrap_or("Proxy payment"),
251    );
252
253    let app = Router::new()
254        .route("/*path", any(proxy_handler_with_payment))
255        .with_state(state)
256        .layer(axum::middleware::from_fn_with_state(
257            payment_middleware,
258            payment_middleware_handler,
259        ));
260
261    Ok(app)
262}
263
264/// Payment middleware handler for proxy
265async fn payment_middleware_handler(
266    State(middleware): State<crate::middleware::PaymentMiddleware>,
267    request: axum::extract::Request,
268    next: axum::middleware::Next,
269) -> impl axum::response::IntoResponse {
270    match middleware.process_payment(request, next).await {
271        Ok(result) => match result {
272            crate::middleware::PaymentResult::Success { response, .. } => response,
273            crate::middleware::PaymentResult::PaymentRequired { response } => response,
274            crate::middleware::PaymentResult::VerificationFailed { response } => response,
275            crate::middleware::PaymentResult::SettlementFailed { response } => response,
276        },
277        Err(e) => (
278            axum::http::StatusCode::INTERNAL_SERVER_ERROR,
279            axum::Json(serde_json::json!({
280                "error": format!("Payment processing error: {}", e),
281                "x402Version": 1
282            })),
283        )
284            .into_response(),
285    }
286}
287
288/// Proxy handler with payment protection that forwards requests to the target server
289async fn proxy_handler_with_payment(
290    State(state): State<ProxyState>,
291    request: axum::extract::Request,
292) -> std::result::Result<Response, StatusCode> {
293    // This handler is called after payment middleware has verified the payment
294    proxy_handler(State(state), request).await
295}
296
297/// Proxy handler that forwards requests to the target server
298async fn proxy_handler(
299    State(state): State<ProxyState>,
300    request: axum::extract::Request,
301) -> std::result::Result<Response, StatusCode> {
302    #[cfg(feature = "streaming")]
303    {
304        proxy_handler_with_streaming(State(state), request).await
305    }
306    #[cfg(not(feature = "streaming"))]
307    {
308        proxy_handler_without_streaming(State(state), request).await
309    }
310}
311
312#[cfg(feature = "streaming")]
313async fn proxy_handler_with_streaming(
314    State(state): State<ProxyState>,
315    request: axum::extract::Request,
316) -> std::result::Result<Response, StatusCode> {
317    use axum::body::Body;
318    use futures_util::{StreamExt, TryStreamExt};
319    use reqwest::Body as ReqwestBody;
320
321    let target_url = &state.config.target_url;
322    let client = &state.client;
323
324    // Extract the path from the request
325    let path = request.uri().path();
326    let query = request.uri().query().unwrap_or("");
327
328    // Build the target URL
329    let full_url = if query.is_empty() {
330        format!("{}{}", target_url, path)
331    } else {
332        format!("{}{}?{}", target_url, path, query)
333    };
334
335    info!("Proxying request to: {}", full_url);
336
337    // Create a new request to the target server
338    let method =
339        Method::from_str(request.method().as_str()).map_err(|_| StatusCode::BAD_REQUEST)?;
340
341    let mut target_request = client.request(method, &full_url);
342
343    // Copy essential headers
344    target_request = copy_essential_headers(request.headers(), target_request);
345
346    // Add custom headers from config
347    for (key, value) in &state.config.headers {
348        if let (Ok(name), Ok(val)) = (HeaderName::try_from(key), HeaderValue::try_from(value)) {
349            target_request = target_request.header(name, val);
350        }
351    }
352
353    // Handle request body with streaming support
354    let (parts, body) = request.into_parts();
355
356    // Check if this is a multipart or streaming request
357    let content_type = parts
358        .headers
359        .get("content-type")
360        .and_then(|v| v.to_str().ok());
361
362    let is_multipart = content_type
363        .map(|ct| ct.starts_with("multipart/"))
364        .unwrap_or(false);
365    let is_streaming = parts
366        .headers
367        .get("transfer-encoding")
368        .and_then(|v| v.to_str().ok())
369        .map(|v| v.contains("chunked"))
370        .unwrap_or(false);
371
372    if is_multipart || is_streaming {
373        // For multipart or streaming requests, stream the body
374        let body_stream = body.into_data_stream();
375        let reqwest_body = ReqwestBody::wrap_stream(body_stream);
376        target_request = target_request.body(reqwest_body);
377    } else {
378        // For regular requests, buffer the body
379        let body_bytes = body
380            .into_data_stream()
381            .try_fold(Vec::new(), |mut acc, chunk| async move {
382                acc.extend_from_slice(&chunk);
383                Ok::<_, axum::Error>(acc)
384            })
385            .await
386            .map_err(|_| StatusCode::BAD_REQUEST)?;
387
388        if !body_bytes.is_empty() {
389            target_request = target_request.body(body_bytes);
390        }
391    }
392
393    // Execute the request
394    let response = target_request.send().await.map_err(|e| {
395        warn!("Failed to execute proxy request: {}", e);
396        StatusCode::BAD_GATEWAY
397    })?;
398
399    // Convert response
400    let status = response.status();
401    let headers = response.headers().clone();
402
403    // Check if response is streaming
404    let response_is_streaming = headers
405        .get("transfer-encoding")
406        .and_then(|v| v.to_str().ok())
407        .map(|v| v.contains("chunked"))
408        .unwrap_or(false);
409
410    let mut response_builder = Response::builder().status(status);
411
412    // Copy response headers
413    for (key, value) in headers.iter() {
414        if let Ok(header_name) = HeaderName::try_from(key.as_str()) {
415            response_builder = response_builder.header(header_name, value);
416        }
417    }
418
419    if response_is_streaming {
420        // Stream the response body
421        let response_stream = response
422            .bytes_stream()
423            .map(|result| result.map_err(axum::Error::new));
424        let body = Body::from_stream(response_stream);
425        response_builder
426            .body(body)
427            .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
428    } else {
429        // Buffer the response body
430        let body = response
431            .bytes()
432            .await
433            .map_err(|_| StatusCode::BAD_GATEWAY)?;
434        response_builder
435            .body(body.into())
436            .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
437    }
438}
439
440#[cfg(not(feature = "streaming"))]
441async fn proxy_handler_without_streaming(
442    State(state): State<ProxyState>,
443    request: axum::extract::Request,
444) -> std::result::Result<Response, StatusCode> {
445    let target_url = &state.config.target_url;
446    let client = &state.client;
447
448    // Extract the path from the request
449    let path = request.uri().path();
450    let query = request.uri().query().unwrap_or("");
451
452    // Build the target URL
453    let full_url = if query.is_empty() {
454        format!("{}{}", target_url, path)
455    } else {
456        format!("{}{}?{}", target_url, path, query)
457    };
458
459    info!("Proxying request to: {}", full_url);
460
461    // Create a new request to the target server
462    let method =
463        Method::from_str(request.method().as_str()).map_err(|_| StatusCode::BAD_REQUEST)?;
464
465    let mut target_request = client.request(method, &full_url);
466
467    // Copy essential headers
468    target_request = copy_essential_headers(request.headers(), target_request);
469
470    // Add custom headers from config
471    for (key, value) in &state.config.headers {
472        if let (Ok(name), Ok(val)) = (HeaderName::try_from(key), HeaderValue::try_from(value)) {
473            target_request = target_request.header(name, val);
474        }
475    }
476
477    // Copy request body (must buffer since streaming not available)
478    let body = axum::body::to_bytes(request.into_body(), usize::MAX)
479        .await
480        .map_err(|_| StatusCode::BAD_REQUEST)?;
481
482    if !body.is_empty() {
483        target_request = target_request.body(body);
484    }
485
486    // Execute the request
487    let response = target_request.send().await.map_err(|e| {
488        warn!("Failed to execute proxy request: {}", e);
489        StatusCode::BAD_GATEWAY
490    })?;
491
492    // Convert response
493    let status = response.status();
494    let headers = response.headers().clone();
495    let body = response
496        .bytes()
497        .await
498        .map_err(|_| StatusCode::BAD_GATEWAY)?;
499
500    let mut response_builder = Response::builder().status(status);
501
502    // Copy response headers
503    for (key, value) in headers.iter() {
504        if let Ok(header_name) = HeaderName::try_from(key.as_str()) {
505            response_builder = response_builder.header(header_name, value);
506        }
507    }
508
509    response_builder
510        .body(body.into())
511        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
512}
513
514/// Copy essential headers from the original request to the target request
515fn copy_essential_headers(
516    source_headers: &HeaderMap,
517    target_request: reqwest::RequestBuilder,
518) -> reqwest::RequestBuilder {
519    let essential_headers = [
520        "user-agent",
521        "accept",
522        "accept-language",
523        "accept-encoding",
524        "content-type",
525        "content-length",
526        "authorization",
527        "x-requested-with",
528    ];
529
530    let mut request = target_request;
531
532    for header_name in &essential_headers {
533        if let Some(value) = source_headers.get(*header_name) {
534            if let Ok(name) = HeaderName::try_from(*header_name) {
535                request = request.header(name, value);
536            }
537        }
538    }
539
540    request
541}
542
543/// Run a proxy server with the given configuration
544pub async fn run_proxy_server(config: ProxyConfig, port: u16) -> Result<()> {
545    let app = create_proxy_server_with_tracing(config)?;
546
547    let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{}", port))
548        .await
549        .map_err(|e| X402Error::config(format!("Failed to bind to port {}: {}", port, e)))?;
550
551    info!("🚀 Proxy server running on port {}", port);
552    info!("💰 All requests will require payment");
553
554    axum::serve(listener, app)
555        .await
556        .map_err(|e| X402Error::config(format!("Server error: {}", e)))?;
557
558    Ok(())
559}
560
561#[cfg(test)]
562mod tests {
563    use super::*;
564
565    #[test]
566    fn test_proxy_config_default() {
567        let config = ProxyConfig::default();
568        assert_eq!(config.amount, 0.0001);
569        assert!(config.testnet);
570        assert_eq!(config.facilitator_url, "https://x402.org/facilitator");
571    }
572
573    #[test]
574    fn test_proxy_config_validation() {
575        let config = ProxyConfig {
576            target_url: "https://example.com".to_string(),
577            pay_to: "0x1234567890123456789012345678901234567890".to_string(),
578            ..Default::default()
579        };
580
581        let result = config.validate();
582        assert!(result.is_ok(), "Valid config should pass validation");
583
584        // Verify the config values are preserved
585        assert_eq!(config.target_url, "https://example.com");
586        assert_eq!(config.pay_to, "0x1234567890123456789012345678901234567890");
587        assert!(config.testnet, "Default should be testnet");
588    }
589
590    #[test]
591    fn test_proxy_config_validation_missing_target() {
592        let config = ProxyConfig::default();
593        let result = config.validate();
594        assert!(
595            result.is_err(),
596            "Config without target URL should fail validation"
597        );
598
599        // Verify the specific error type and message
600        let error_msg = result.unwrap_err().to_string();
601        assert!(
602            error_msg.contains("TARGET_URL is required"),
603            "Error should mention TARGET_URL is required - actual: {}",
604            error_msg
605        );
606    }
607
608    #[test]
609    fn test_proxy_config_validation_invalid_url() {
610        let config = ProxyConfig {
611            target_url: "not-a-url".to_string(),
612            pay_to: "0x1234567890123456789012345678901234567890".to_string(),
613            ..Default::default()
614        };
615
616        let result = config.validate();
617        assert!(
618            result.is_err(),
619            "Config with invalid URL should fail validation"
620        );
621
622        // Verify the specific error type and message
623        let error_msg = result.unwrap_err().to_string();
624        assert!(
625            error_msg.contains("invalid URL") || error_msg.contains("URL"),
626            "Error should mention invalid URL - actual: {}",
627            error_msg
628        );
629    }
630
631    #[test]
632    fn test_proxy_config_to_payment_config() {
633        let config = ProxyConfig {
634            target_url: "https://example.com".to_string(),
635            pay_to: "0x1234567890123456789012345678901234567890".to_string(),
636            amount: 0.01,
637            description: Some("Test payment".to_string()),
638            ..Default::default()
639        };
640
641        let payment_config = config.to_payment_config().unwrap();
642        assert_eq!(
643            payment_config.pay_to,
644            "0x1234567890123456789012345678901234567890"
645        );
646        assert!(payment_config.testnet);
647    }
648
649    #[test]
650    fn test_copy_essential_headers() {
651        use axum::http::HeaderMap;
652
653        let mut headers = HeaderMap::new();
654        headers.insert("user-agent", "test-agent".parse().unwrap());
655        headers.insert("accept", "application/json".parse().unwrap());
656        headers.insert("content-type", "multipart/form-data".parse().unwrap());
657        headers.insert("authorization", "Bearer token123".parse().unwrap());
658
659        let client = reqwest::Client::new();
660        let request = client.get("https://example.com");
661
662        // Just verify the function doesn't panic
663        let _result = copy_essential_headers(&headers, request);
664
665        // Test with empty headers
666        let empty_headers = HeaderMap::new();
667        let client2 = reqwest::Client::new();
668        let request2 = client2.get("https://example.com");
669        let _result2 = copy_essential_headers(&empty_headers, request2);
670    }
671
672    #[test]
673    fn test_proxy_config_validation_missing_pay_to() {
674        let config = ProxyConfig {
675            target_url: "https://example.com".to_string(),
676            pay_to: String::new(), // Empty pay_to
677            ..Default::default()
678        };
679
680        let result = config.validate();
681        assert!(
682            result.is_err(),
683            "Config without pay_to address should fail validation"
684        );
685
686        let error_msg = result.unwrap_err().to_string();
687        assert!(
688            error_msg.contains("PAY_TO") || error_msg.contains("pay_to"),
689            "Error should mention PAY_TO - actual: {}",
690            error_msg
691        );
692    }
693
694    #[test]
695    fn test_proxy_config_validation_invalid_amount() {
696        let config = ProxyConfig {
697            target_url: "https://example.com".to_string(),
698            pay_to: "0x1234567890123456789012345678901234567890".to_string(),
699            amount: -0.001, // Negative amount
700            ..Default::default()
701        };
702
703        let result = config.validate();
704        assert!(
705            result.is_err(),
706            "Config with negative amount should fail validation"
707        );
708
709        let error_msg = result.unwrap_err().to_string();
710        assert!(
711            error_msg.contains("AMOUNT") || error_msg.contains("positive"),
712            "Error should mention AMOUNT or positive - actual: {}",
713            error_msg
714        );
715    }
716
717    #[test]
718    fn test_proxy_config_validation_zero_amount() {
719        let config = ProxyConfig {
720            target_url: "https://example.com".to_string(),
721            pay_to: "0x1234567890123456789012345678901234567890".to_string(),
722            amount: 0.0, // Zero amount
723            ..Default::default()
724        };
725
726        let result = config.validate();
727        assert!(
728            result.is_err(),
729            "Config with zero amount should fail validation"
730        );
731    }
732}