Skip to main content

teaql_tool_extra/
proxy.rs

1use teaql_tool_core::{Result, TeaQLToolError};
2use axum::{extract::Request, response::IntoResponse, Router};
3use reqwest::Client;
4use tokio::runtime::Runtime;
5use std::sync::Arc;
6
7#[derive(Debug, Clone)]
8pub struct ProxyTool;
9
10impl ProxyTool {
11    pub fn new() -> Self { Self }
12
13    pub fn start(&self, listen_port: u16, target_url: &str) -> Result<()> {
14        let rt = Runtime::new().map_err(|e| TeaQLToolError::ExecutionError(e.to_string()))?;
15        let target = Arc::new(target_url.trim_end_matches('/').to_string());
16        
17        rt.block_on(async move {
18            let client = Client::new();
19            
20            let app = Router::new().fallback(move |req: Request| {
21                let client = client.clone();
22                let target = target.clone();
23                
24                async move {
25                    let path_query = req.uri().path_and_query().map(|pq| pq.as_str()).unwrap_or("");
26                    let url = format!("{}{}", target, path_query);
27                    
28                    let mut proxy_req = client.request(req.method().clone(), &url);
29                    for (k, v) in req.headers() {
30                        if k != reqwest::header::HOST {
31                            proxy_req = proxy_req.header(k.clone(), v.clone());
32                        }
33                    }
34                    
35                    let body_bytes = axum::body::to_bytes(req.into_body(), usize::MAX).await.unwrap_or_default();
36                    proxy_req = proxy_req.body(body_bytes);
37                    
38                    match proxy_req.send().await {
39                        Ok(resp) => {
40                            let mut builder = axum::http::Response::builder().status(resp.status());
41                            for (k, v) in resp.headers() {
42                                builder = builder.header(k.clone(), v.clone());
43                            }
44                            let stream = resp.bytes_stream();
45                            let body = axum::body::Body::from_stream(stream);
46                            builder.body(body).unwrap_or_else(|_| axum::http::Response::new(axum::body::Body::empty()))
47                        }
48                        Err(e) => {
49                            axum::http::Response::builder()
50                                .status(502)
51                                .body(axum::body::Body::from(format!("Bad Gateway: {}", e)))
52                                .unwrap()
53                        }
54                    }
55                }
56            });
57
58            let addr = format!("0.0.0.0:{}", listen_port);
59            let listener = tokio::net::TcpListener::bind(&addr).await.map_err(|e| TeaQLToolError::ExecutionError(e.to_string()))?;
60            println!("Proxy server listening on http://{} -> {}", addr, target_url);
61            axum::serve(listener, app).await.map_err(|e| TeaQLToolError::ExecutionError(e.to_string()))
62        })
63    }
64}
65
66impl Default for ProxyTool {
67    fn default() -> Self { Self::new() }
68}