1use std::fmt;
4use std::path::PathBuf;
5use std::str::FromStr;
6
7#[non_exhaustive]
9#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
10#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
11pub enum ForwardDirection {
12 Local,
14 Remote,
16}
17
18#[non_exhaustive]
20#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
21#[derive(Clone, Debug, Eq, Hash, PartialEq)]
22pub struct TcpEndpoint {
23 host: String,
24 port: u16,
25}
26
27impl TcpEndpoint {
28 pub fn new(host: impl Into<String>, port: u16) -> Self {
30 Self {
31 host: host.into(),
32 port,
33 }
34 }
35
36 pub fn host(&self) -> &str {
38 &self.host
39 }
40
41 pub fn port(&self) -> u16 {
43 self.port
44 }
45}
46
47impl fmt::Display for TcpEndpoint {
48 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49 if self.host.contains(':') {
50 write!(f, "[{}]:{}", self.host, self.port)
51 } else {
52 write!(f, "{}:{}", self.host, self.port)
53 }
54 }
55}
56
57impl FromStr for TcpEndpoint {
58 type Err = crate::Error;
59
60 fn from_str(s: &str) -> Result<Self, Self::Err> {
61 if let Some(rest) = s.strip_prefix('[') {
62 let (host, port_str) = rest.split_once("]:").ok_or_else(|| {
63 crate::Error::invalid_config(format!("invalid TCP endpoint: {s}"))
64 })?;
65 let port: u16 = port_str.parse().map_err(|_| {
66 crate::Error::invalid_config(format!("invalid TCP endpoint port: {s}"))
67 })?;
68 Ok(Self::new(host, port))
69 } else {
70 let (host, port_str) = s.rsplit_once(':').ok_or_else(|| {
71 crate::Error::invalid_config(format!("invalid TCP endpoint: {s}"))
72 })?;
73 let port: u16 = port_str.parse().map_err(|_| {
74 crate::Error::invalid_config(format!("invalid TCP endpoint port: {s}"))
75 })?;
76 Ok(Self::new(host, port))
77 }
78 }
79}
80
81impl From<(&str, u16)> for TcpEndpoint {
82 fn from((host, port): (&str, u16)) -> Self {
83 Self::new(host, port)
84 }
85}
86
87impl From<(String, u16)> for TcpEndpoint {
88 fn from((host, port): (String, u16)) -> Self {
89 Self::new(host, port)
90 }
91}
92
93#[non_exhaustive]
95#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
96#[derive(Clone, Debug, Eq, Hash, PartialEq)]
97pub struct StreamLocalSpec {
98 path: PathBuf,
99}
100
101impl StreamLocalSpec {
102 pub fn new(path: impl Into<PathBuf>) -> Self {
104 let path = path.into();
105 let path = expand_tilde_path(path);
106 Self { path }
107 }
108
109 pub fn path(&self) -> &std::path::Path {
111 &self.path
112 }
113}
114
115impl fmt::Display for StreamLocalSpec {
116 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
117 write!(f, "{}", self.path.display())
118 }
119}
120
121impl FromStr for StreamLocalSpec {
122 type Err = std::convert::Infallible;
123
124 fn from_str(s: &str) -> Result<Self, Self::Err> {
125 Ok(Self::new(s))
126 }
127}
128
129impl From<&str> for StreamLocalSpec {
130 fn from(path: &str) -> Self {
131 Self::new(path)
132 }
133}
134
135impl From<String> for StreamLocalSpec {
136 fn from(path: String) -> Self {
137 Self::new(path)
138 }
139}
140
141impl From<PathBuf> for StreamLocalSpec {
142 fn from(path: PathBuf) -> Self {
143 let path = expand_tilde_path(path);
144 Self { path }
145 }
146}
147
148#[non_exhaustive]
150#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
151#[derive(Clone, Debug, Eq, PartialEq)]
152pub enum ForwardSpec {
153 Tcp {
155 direction: ForwardDirection,
157 bind: TcpEndpoint,
159 target: TcpEndpoint,
161 },
162 StreamLocal {
164 direction: ForwardDirection,
166 bind: StreamLocalSpec,
168 target: StreamLocalSpec,
170 },
171}
172
173impl ForwardSpec {
174 pub fn local_tcp(bind: impl Into<TcpEndpoint>, target: impl Into<TcpEndpoint>) -> Self {
176 Self::Tcp {
177 direction: ForwardDirection::Local,
178 bind: bind.into(),
179 target: target.into(),
180 }
181 }
182
183 pub fn remote_tcp(bind: impl Into<TcpEndpoint>, target: impl Into<TcpEndpoint>) -> Self {
185 Self::Tcp {
186 direction: ForwardDirection::Remote,
187 bind: bind.into(),
188 target: target.into(),
189 }
190 }
191
192 pub fn local_streamlocal(
194 bind: impl Into<StreamLocalSpec>,
195 target: impl Into<StreamLocalSpec>,
196 ) -> Self {
197 Self::StreamLocal {
198 direction: ForwardDirection::Local,
199 bind: bind.into(),
200 target: target.into(),
201 }
202 }
203
204 pub fn remote_streamlocal(
206 bind: impl Into<StreamLocalSpec>,
207 target: impl Into<StreamLocalSpec>,
208 ) -> Self {
209 Self::StreamLocal {
210 direction: ForwardDirection::Remote,
211 bind: bind.into(),
212 target: target.into(),
213 }
214 }
215}
216
217fn expand_tilde_path(path: PathBuf) -> PathBuf {
218 if let Some(path_str) = path.to_str()
219 && (path_str == "~" || path_str.starts_with("~/"))
220 {
221 #[cfg(target_os = "windows")]
222 let home = std::env::var("HOME").or_else(|_| std::env::var("USERPROFILE"));
223 #[cfg(not(target_os = "windows"))]
224 let home = std::env::var("HOME");
225
226 if let Ok(home) = home {
227 if path_str == "~" {
228 return PathBuf::from(home);
229 }
230 return PathBuf::from(home).join(&path_str[2..]);
231 }
232 }
233 path
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239
240 #[test]
241 fn tcp_endpoint_display_ipv4() {
242 let ep = TcpEndpoint::new("192.168.1.1", 22);
243 assert_eq!(ep.to_string(), "192.168.1.1:22");
244 }
245
246 #[test]
247 fn tcp_endpoint_display_ipv6() {
248 let ep = TcpEndpoint::new("2001:db8::1", 22);
249 assert_eq!(ep.to_string(), "[2001:db8::1]:22");
250 }
251
252 #[test]
253 fn tcp_endpoint_display_hostname() {
254 let ep = TcpEndpoint::new("example.com", 8080);
255 assert_eq!(ep.to_string(), "example.com:8080");
256 }
257
258 #[test]
259 fn tcp_endpoint_from_str_ipv4() {
260 let ep: TcpEndpoint = "10.0.0.1:2222".parse().unwrap();
261 assert_eq!(ep.host(), "10.0.0.1");
262 assert_eq!(ep.port(), 2222);
263 }
264
265 #[test]
266 fn tcp_endpoint_from_str_ipv6() {
267 let ep: TcpEndpoint = "[::1]:2200".parse().unwrap();
268 assert_eq!(ep.host(), "::1");
269 assert_eq!(ep.port(), 2200);
270 }
271
272 #[test]
273 fn tcp_endpoint_from_str_hostname() {
274 let ep: TcpEndpoint = "db.internal:5432".parse().unwrap();
275 assert_eq!(ep.host(), "db.internal");
276 assert_eq!(ep.port(), 5432);
277 }
278
279 #[test]
280 fn tcp_endpoint_display_round_trip_ipv4() {
281 let original = "127.0.0.1:8022";
282 let ep: TcpEndpoint = original.parse().unwrap();
283 assert_eq!(ep.to_string(), original);
284 }
285
286 #[test]
287 fn tcp_endpoint_display_round_trip_ipv6() {
288 let original = "[2001:db8::1]:22";
289 let ep: TcpEndpoint = original.parse().unwrap();
290 assert_eq!(ep.to_string(), original);
291 }
292
293 #[test]
294 fn tcp_endpoint_from_str_invalid_missing_port() {
295 let result: Result<TcpEndpoint, _> = "host".parse();
296 assert!(result.is_err());
297 }
298
299 #[test]
300 fn tcp_endpoint_from_str_invalid_bad_port() {
301 let result: Result<TcpEndpoint, _> = "host:abc".parse();
302 assert!(result.is_err());
303 }
304
305 #[test]
306 fn streamlocal_spec_display() {
307 let spec = StreamLocalSpec::new("/tmp/app.sock");
308 assert_eq!(spec.to_string(), "/tmp/app.sock");
309 }
310
311 #[test]
312 fn streamlocal_spec_from_str() {
313 let spec: StreamLocalSpec = "/var/run/service.sock".parse().unwrap();
314 assert_eq!(spec.path(), std::path::Path::new("/var/run/service.sock"));
315 }
316
317 #[test]
318 fn streamlocal_spec_display_round_trip() {
319 let path = "/tmp/my-app.sock";
320 let spec: StreamLocalSpec = path.parse().unwrap();
321 assert_eq!(spec.to_string(), path);
322 }
323
324 #[test]
325 fn streamlocal_spec_tilde_expansion() {
326 let home = std::env::var("HOME").unwrap_or_default();
327 if home.is_empty() {
328 return; }
330 let spec = StreamLocalSpec::new("~/myapp/agent.sock");
331 let expected = format!("{}/myapp/agent.sock", home);
332 assert_eq!(spec.to_string(), expected);
333 }
334}