salvo_proxy/
unix_sock_client.rs

1use std::path::PathBuf;
2use std::time::Duration;
3
4use crate::{Client, HyperRequest, HyperResponse, Proxy, Upstreams};
5
6use hyper::client::conn::http1::handshake;
7use hyper::upgrade::OnUpgrade;
8use salvo_core::http::{ReqBody, ResBody};
9use salvo_core::rt::tokio::TokioIo;
10use salvo_core::{BoxedError, Error};
11use tokio::net::UnixStream;
12use tokio::time::timeout;
13
14const UNIX_SOCKET_CONNECT_TIMEOUT: u64 = 5; // seconds
15/// A client that creates a direct bidirectional channel (TCP tunnel) to a Unix socket.
16///
17/// This client is designed for scenarios where a raw data stream is established
18/// between the client and the upstream service via a Unix socket. It works by
19/// "hijacking" the client connection and forwarding all data at the transport layer.
20#[derive(Default, Clone, Debug)]
21pub struct UnixSockClient;
22
23impl<U> Proxy<U, UnixSockClient>
24where
25    U: Upstreams,
26    U::Error: Into<BoxedError>,
27{
28    /// Create a new `Proxy` that tunnels connections to a Unix socket.
29    pub fn use_unix_sock_tunnel(upstreams: U) -> Self {
30        Self::new(upstreams, UnixSockClient)
31    }
32}
33
34impl Client for UnixSockClient {
35    type Error = Error;
36
37    async fn execute(
38        &self,
39        proxied_request: HyperRequest,
40        _request_upgraded: Option<OnUpgrade>,
41    ) -> Result<HyperResponse, Self::Error> {
42        let (unix_sock_path, request_path) = extract_unix_paths(proxied_request.uri())?;
43        let stream = timeout(
44            Duration::from_secs(UNIX_SOCKET_CONNECT_TIMEOUT),
45            UnixStream::connect(unix_sock_path),
46        )
47        .await
48        .map_err(|_| Error::other("Connection to unix socket timed out"))?
49        .map_err(|e| Error::other(format!("Failed to connect to unix socket: {e}")))?;
50        let io = TokioIo::new(stream);
51        let (mut sender, conn) = handshake::<_, ReqBody>(io).await.map_err(Error::other)?;
52        tokio::spawn(async move {
53            if let Err(err) = conn.await {
54                tracing::error!(error = ?err, "Connection failed");
55            }
56        });
57        let (mut parts, body) = proxied_request.into_parts();
58        parts.uri = request_path.parse().map_err(Error::other)?;
59        let final_request = HyperRequest::from_parts(parts, body);
60        let response_future = sender.send_request(final_request);
61        let response = timeout(Duration::from_secs(30), response_future)
62            .await
63            .map_err(|_| Error::other("Request to unix socket timed out"))?
64            .map_err(Error::other)?
65            .map(ResBody::from);
66        Ok(response)
67    }
68}
69
70fn extract_unix_paths(uri: &hyper::Uri) -> Result<(String, String), Error> {
71    let full_path = uri.path();
72    // Assume the path contains a unix socket path ending with ".sock"
73    if let Some(sock_end_index) = full_path.find(".sock") {
74        let sock_path_end = sock_end_index + ".sock".len();
75        let sock_path_str = &full_path[..sock_path_end];
76        let sock_path = PathBuf::from(sock_path_str);
77        if sock_path
78            .components()
79            .any(|c| c == std::path::Component::ParentDir)
80        {
81            return Err(Error::other(
82                "Invalid socket path: directory traversal ('..') is not allowed.",
83            ));
84        }
85        let mut request_path = full_path[sock_path_end..].to_string();
86        if request_path.is_empty() {
87            request_path = "/".to_owned();
88        }
89        if let Some(query) = uri.query() {
90            request_path.push('?');
91            request_path.push_str(query);
92        }
93        Ok((sock_path_str.to_owned(), request_path))
94    } else {
95        Err(Error::other(
96            "Could not find a .sock file in the URI path to determine the unix socket path.",
97        ))
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104
105    #[test]
106    fn test_proxy_creation() {
107        let upstreams = vec!["http://unix:/var/run/my.sock"];
108        let proxy = Proxy::new(upstreams.clone(), UnixSockClient);
109        assert_eq!(proxy.upstreams().len(), 1);
110    }
111}