salvo_proxy/
unix_sock_client.rs1use std::path::PathBuf;
2use std::time::Duration;
3
4use crate::{Client, HyperRequest, HyperResponse, Proxy, Upstreams};
5
6use hyper::client::conn::http1::handshake;
7use hyper::upgrade::OnUpgrade;
8use salvo_core::http::{ReqBody, ResBody};
9use salvo_core::rt::tokio::TokioIo;
10use salvo_core::{BoxedError, Error};
11use tokio::net::UnixStream;
12use tokio::time::timeout;
13
14const UNIX_SOCKET_CONNECT_TIMEOUT: u64 = 5; #[derive(Default, Clone, Debug)]
21pub struct UnixSockClient;
22
23impl<U> Proxy<U, UnixSockClient>
24where
25 U: Upstreams,
26 U::Error: Into<BoxedError>,
27{
28 pub fn use_unix_sock_tunnel(upstreams: U) -> Self {
30 Self::new(upstreams, UnixSockClient)
31 }
32}
33
34impl Client for UnixSockClient {
35 type Error = Error;
36
37 async fn execute(
38 &self,
39 proxied_request: HyperRequest,
40 _request_upgraded: Option<OnUpgrade>,
41 ) -> Result<HyperResponse, Self::Error> {
42 let (unix_sock_path, request_path) = extract_unix_paths(proxied_request.uri())?;
43 let stream = timeout(
44 Duration::from_secs(UNIX_SOCKET_CONNECT_TIMEOUT),
45 UnixStream::connect(unix_sock_path),
46 )
47 .await
48 .map_err(|_| Error::other("Connection to unix socket timed out"))?
49 .map_err(|e| Error::other(format!("Failed to connect to unix socket: {e}")))?;
50 let io = TokioIo::new(stream);
51 let (mut sender, conn) = handshake::<_, ReqBody>(io).await.map_err(Error::other)?;
52 tokio::spawn(async move {
53 if let Err(err) = conn.await {
54 tracing::error!(error = ?err, "Connection failed");
55 }
56 });
57 let (mut parts, body) = proxied_request.into_parts();
58 parts.uri = request_path.parse().map_err(Error::other)?;
59 let final_request = HyperRequest::from_parts(parts, body);
60 let response_future = sender.send_request(final_request);
61 let response = timeout(Duration::from_secs(30), response_future)
62 .await
63 .map_err(|_| Error::other("Request to unix socket timed out"))?
64 .map_err(Error::other)?
65 .map(ResBody::from);
66 Ok(response)
67 }
68}
69
70fn extract_unix_paths(uri: &hyper::Uri) -> Result<(String, String), Error> {
71 let full_path = uri.path();
72 if let Some(sock_end_index) = full_path.find(".sock") {
74 let sock_path_end = sock_end_index + ".sock".len();
75 let sock_path_str = &full_path[..sock_path_end];
76 let sock_path = PathBuf::from(sock_path_str);
77 if sock_path
78 .components()
79 .any(|c| c == std::path::Component::ParentDir)
80 {
81 return Err(Error::other(
82 "Invalid socket path: directory traversal ('..') is not allowed.",
83 ));
84 }
85 let mut request_path = full_path[sock_path_end..].to_string();
86 if request_path.is_empty() {
87 request_path = "/".to_owned();
88 }
89 if let Some(query) = uri.query() {
90 request_path.push('?');
91 request_path.push_str(query);
92 }
93 Ok((sock_path_str.to_owned(), request_path))
94 } else {
95 Err(Error::other(
96 "Could not find a .sock file in the URI path to determine the unix socket path.",
97 ))
98 }
99}
100
101#[cfg(test)]
102mod tests {
103 use super::*;
104
105 #[test]
106 fn test_proxy_creation() {
107 let upstreams = vec!["http://unix:/var/run/my.sock"];
108 let proxy = Proxy::new(upstreams.clone(), UnixSockClient);
109 assert_eq!(proxy.upstreams().len(), 1);
110 }
111}