Skip to main content

russh_extra_core/
forward.rs

1//! Port forwarding domain types.
2
3use std::fmt;
4use std::path::PathBuf;
5use std::str::FromStr;
6
7/// Direction of an SSH forwarding request.
8#[non_exhaustive]
9#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
10#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
11pub enum ForwardDirection {
12    /// Local forwarding: local listener to remote target.
13    Local,
14    /// Remote forwarding: remote listener to local target.
15    Remote,
16}
17
18/// TCP endpoint used by forwarding.
19#[non_exhaustive]
20#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
21#[derive(Clone, Debug, Eq, Hash, PartialEq)]
22pub struct TcpEndpoint {
23    host: String,
24    port: u16,
25}
26
27impl TcpEndpoint {
28    /// Creates a TCP endpoint.
29    pub fn new(host: impl Into<String>, port: u16) -> Self {
30        Self {
31            host: host.into(),
32            port,
33        }
34    }
35
36    /// Returns the host.
37    pub fn host(&self) -> &str {
38        &self.host
39    }
40
41    /// Returns the port.
42    pub fn port(&self) -> u16 {
43        self.port
44    }
45}
46
47impl fmt::Display for TcpEndpoint {
48    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49        if self.host.contains(':') {
50            write!(f, "[{}]:{}", self.host, self.port)
51        } else {
52            write!(f, "{}:{}", self.host, self.port)
53        }
54    }
55}
56
57impl FromStr for TcpEndpoint {
58    type Err = crate::Error;
59
60    fn from_str(s: &str) -> Result<Self, Self::Err> {
61        if let Some(rest) = s.strip_prefix('[') {
62            let (host, port_str) = rest.split_once("]:").ok_or_else(|| {
63                crate::Error::invalid_config(format!("invalid TCP endpoint: {s}"))
64            })?;
65            let port: u16 = port_str.parse().map_err(|_| {
66                crate::Error::invalid_config(format!("invalid TCP endpoint port: {s}"))
67            })?;
68            Ok(Self::new(host, port))
69        } else {
70            let (host, port_str) = s.rsplit_once(':').ok_or_else(|| {
71                crate::Error::invalid_config(format!("invalid TCP endpoint: {s}"))
72            })?;
73            let port: u16 = port_str.parse().map_err(|_| {
74                crate::Error::invalid_config(format!("invalid TCP endpoint port: {s}"))
75            })?;
76            Ok(Self::new(host, port))
77        }
78    }
79}
80
81impl From<(&str, u16)> for TcpEndpoint {
82    fn from((host, port): (&str, u16)) -> Self {
83        Self::new(host, port)
84    }
85}
86
87impl From<(String, u16)> for TcpEndpoint {
88    fn from((host, port): (String, u16)) -> Self {
89        Self::new(host, port)
90    }
91}
92
93/// Unix-domain streamlocal forwarding endpoint.
94#[non_exhaustive]
95#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
96#[derive(Clone, Debug, Eq, Hash, PartialEq)]
97pub struct StreamLocalSpec {
98    path: PathBuf,
99}
100
101impl StreamLocalSpec {
102    /// Creates a streamlocal endpoint.
103    pub fn new(path: impl Into<PathBuf>) -> Self {
104        let path = path.into();
105        let path = expand_tilde_path(path);
106        Self { path }
107    }
108
109    /// Returns the streamlocal path.
110    pub fn path(&self) -> &std::path::Path {
111        &self.path
112    }
113}
114
115impl fmt::Display for StreamLocalSpec {
116    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
117        write!(f, "{}", self.path.display())
118    }
119}
120
121impl FromStr for StreamLocalSpec {
122    type Err = std::convert::Infallible;
123
124    fn from_str(s: &str) -> Result<Self, Self::Err> {
125        Ok(Self::new(s))
126    }
127}
128
129impl From<&str> for StreamLocalSpec {
130    fn from(path: &str) -> Self {
131        Self::new(path)
132    }
133}
134
135impl From<String> for StreamLocalSpec {
136    fn from(path: String) -> Self {
137        Self::new(path)
138    }
139}
140
141impl From<PathBuf> for StreamLocalSpec {
142    fn from(path: PathBuf) -> Self {
143        let path = expand_tilde_path(path);
144        Self { path }
145    }
146}
147
148/// High-level forwarding specification.
149#[non_exhaustive]
150#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
151#[derive(Clone, Debug, Eq, PartialEq)]
152pub enum ForwardSpec {
153    /// TCP forwarding between two endpoints.
154    Tcp {
155        /// Forward direction.
156        direction: ForwardDirection,
157        /// Bind endpoint.
158        bind: TcpEndpoint,
159        /// Target endpoint.
160        target: TcpEndpoint,
161    },
162    /// Streamlocal forwarding between two paths.
163    StreamLocal {
164        /// Forward direction.
165        direction: ForwardDirection,
166        /// Bind endpoint.
167        bind: StreamLocalSpec,
168        /// Target endpoint.
169        target: StreamLocalSpec,
170    },
171}
172
173impl ForwardSpec {
174    /// Creates a local TCP forwarding specification.
175    pub fn local_tcp(bind: impl Into<TcpEndpoint>, target: impl Into<TcpEndpoint>) -> Self {
176        Self::Tcp {
177            direction: ForwardDirection::Local,
178            bind: bind.into(),
179            target: target.into(),
180        }
181    }
182
183    /// Creates a remote TCP forwarding specification.
184    pub fn remote_tcp(bind: impl Into<TcpEndpoint>, target: impl Into<TcpEndpoint>) -> Self {
185        Self::Tcp {
186            direction: ForwardDirection::Remote,
187            bind: bind.into(),
188            target: target.into(),
189        }
190    }
191
192    /// Creates a local streamlocal forwarding specification.
193    pub fn local_streamlocal(
194        bind: impl Into<StreamLocalSpec>,
195        target: impl Into<StreamLocalSpec>,
196    ) -> Self {
197        Self::StreamLocal {
198            direction: ForwardDirection::Local,
199            bind: bind.into(),
200            target: target.into(),
201        }
202    }
203
204    /// Creates a remote streamlocal forwarding specification.
205    pub fn remote_streamlocal(
206        bind: impl Into<StreamLocalSpec>,
207        target: impl Into<StreamLocalSpec>,
208    ) -> Self {
209        Self::StreamLocal {
210            direction: ForwardDirection::Remote,
211            bind: bind.into(),
212            target: target.into(),
213        }
214    }
215}
216
217fn expand_tilde_path(path: PathBuf) -> PathBuf {
218    if let Some(path_str) = path.to_str()
219        && (path_str == "~" || path_str.starts_with("~/"))
220    {
221        #[cfg(target_os = "windows")]
222        let home = std::env::var("HOME").or_else(|_| std::env::var("USERPROFILE"));
223        #[cfg(not(target_os = "windows"))]
224        let home = std::env::var("HOME");
225
226        if let Ok(home) = home {
227            if path_str == "~" {
228                return PathBuf::from(home);
229            }
230            return PathBuf::from(home).join(&path_str[2..]);
231        }
232    }
233    path
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239
240    #[test]
241    fn tcp_endpoint_display_ipv4() {
242        let ep = TcpEndpoint::new("192.168.1.1", 22);
243        assert_eq!(ep.to_string(), "192.168.1.1:22");
244    }
245
246    #[test]
247    fn tcp_endpoint_display_ipv6() {
248        let ep = TcpEndpoint::new("2001:db8::1", 22);
249        assert_eq!(ep.to_string(), "[2001:db8::1]:22");
250    }
251
252    #[test]
253    fn tcp_endpoint_display_hostname() {
254        let ep = TcpEndpoint::new("example.com", 8080);
255        assert_eq!(ep.to_string(), "example.com:8080");
256    }
257
258    #[test]
259    fn tcp_endpoint_from_str_ipv4() {
260        let ep: TcpEndpoint = "10.0.0.1:2222".parse().unwrap();
261        assert_eq!(ep.host(), "10.0.0.1");
262        assert_eq!(ep.port(), 2222);
263    }
264
265    #[test]
266    fn tcp_endpoint_from_str_ipv6() {
267        let ep: TcpEndpoint = "[::1]:2200".parse().unwrap();
268        assert_eq!(ep.host(), "::1");
269        assert_eq!(ep.port(), 2200);
270    }
271
272    #[test]
273    fn tcp_endpoint_from_str_hostname() {
274        let ep: TcpEndpoint = "db.internal:5432".parse().unwrap();
275        assert_eq!(ep.host(), "db.internal");
276        assert_eq!(ep.port(), 5432);
277    }
278
279    #[test]
280    fn tcp_endpoint_display_round_trip_ipv4() {
281        let original = "127.0.0.1:8022";
282        let ep: TcpEndpoint = original.parse().unwrap();
283        assert_eq!(ep.to_string(), original);
284    }
285
286    #[test]
287    fn tcp_endpoint_display_round_trip_ipv6() {
288        let original = "[2001:db8::1]:22";
289        let ep: TcpEndpoint = original.parse().unwrap();
290        assert_eq!(ep.to_string(), original);
291    }
292
293    #[test]
294    fn tcp_endpoint_from_str_invalid_missing_port() {
295        let result: Result<TcpEndpoint, _> = "host".parse();
296        assert!(result.is_err());
297    }
298
299    #[test]
300    fn tcp_endpoint_from_str_invalid_bad_port() {
301        let result: Result<TcpEndpoint, _> = "host:abc".parse();
302        assert!(result.is_err());
303    }
304
305    #[test]
306    fn streamlocal_spec_display() {
307        let spec = StreamLocalSpec::new("/tmp/app.sock");
308        assert_eq!(spec.to_string(), "/tmp/app.sock");
309    }
310
311    #[test]
312    fn streamlocal_spec_from_str() {
313        let spec: StreamLocalSpec = "/var/run/service.sock".parse().unwrap();
314        assert_eq!(spec.path(), std::path::Path::new("/var/run/service.sock"));
315    }
316
317    #[test]
318    fn streamlocal_spec_display_round_trip() {
319        let path = "/tmp/my-app.sock";
320        let spec: StreamLocalSpec = path.parse().unwrap();
321        assert_eq!(spec.to_string(), path);
322    }
323
324    #[test]
325    fn streamlocal_spec_tilde_expansion() {
326        let home = std::env::var("HOME").unwrap_or_default();
327        if home.is_empty() {
328            return; // skip if HOME not set
329        }
330        let spec = StreamLocalSpec::new("~/myapp/agent.sock");
331        let expected = format!("{}/myapp/agent.sock", home);
332        assert_eq!(spec.to_string(), expected);
333    }
334}