1use std::env::var;
2use std::net::SocketAddr;
3use std::net::TcpListener;
4use std::net::TcpStream;
5use std::net::ToSocketAddrs;
6#[cfg(unix)]
7use std::os::unix::net::UnixListener;
8#[cfg(unix)]
9use std::os::unix::net::UnixStream;
10use std::path::PathBuf;
11
12use super::Error;
13
14const SD_LISTEN_FDS_START: i32 = 3;
15
16#[derive(Debug, PartialEq, Eq, Clone)]
37pub enum Binding {
38 FileDescriptor(i32),
41
42 FilePath(PathBuf),
45
46 Sockets(Vec<SocketAddr>),
49
50 NamedPipe(std::ffi::OsString),
52}
53
54impl From<PathBuf> for Binding {
55 fn from(value: PathBuf) -> Self {
56 Binding::FilePath(value)
57 }
58}
59
60impl From<SocketAddr> for Binding {
61 fn from(value: SocketAddr) -> Self {
62 Binding::Sockets(vec![value])
63 }
64}
65
66#[derive(Debug)]
82pub enum Listener {
83 #[cfg(unix)]
85 Unix(UnixListener),
86
87 Tcp(TcpListener),
89
90 NamedPipe(std::ffi::OsString),
92}
93
94#[cfg(unix)]
95impl From<UnixListener> for Listener {
96 fn from(listener: UnixListener) -> Self {
97 while let Err(e) = listener.set_nonblocking(true) {
98 if e.kind() != std::io::ErrorKind::WouldBlock {
100 break;
101 }
102 }
103
104 Listener::Unix(listener)
105 }
106}
107
108impl From<TcpListener> for Listener {
109 fn from(listener: TcpListener) -> Self {
110 while let Err(e) = listener.set_nonblocking(true) {
111 if e.kind() != std::io::ErrorKind::WouldBlock {
113 break;
114 }
115 }
116
117 Listener::Tcp(listener)
118 }
119}
120
121#[derive(Debug)]
137pub enum Stream {
138 #[cfg(unix)]
140 Unix(UnixStream),
141
142 Tcp(TcpStream),
144
145 NamedPipe(std::ffi::OsString),
147}
148
149#[cfg(unix)]
150impl From<UnixStream> for Stream {
151 fn from(stream: UnixStream) -> Self {
152 while let Err(e) = stream.set_nonblocking(true) {
153 if e.kind() != std::io::ErrorKind::WouldBlock {
155 break;
156 }
157 }
158
159 Stream::Unix(stream)
160 }
161}
162
163impl From<TcpStream> for Stream {
164 fn from(stream: TcpStream) -> Self {
165 while let Err(e) = stream.set_nonblocking(true) {
166 if e.kind() != std::io::ErrorKind::WouldBlock {
168 break;
169 }
170 }
171
172 Stream::Tcp(stream)
173 }
174}
175
176impl<'a> std::convert::TryFrom<&'a str> for Binding {
177 type Error = Error;
178
179 fn try_from(s: &'a str) -> Result<Self, Self::Error> {
180 if let Some(name) = s.strip_prefix("fd://") {
181 if name.is_empty() {
182 if let Ok(fds) = var("LISTEN_FDS") {
183 let fds: i32 = fds.parse()?;
184
185 if fds != 1 {
187 return Err(Error::DescriptorOutOfRange(fds));
188 }
189
190 return Ok(Binding::FileDescriptor(SD_LISTEN_FDS_START));
191 } else {
192 return Err(Error::DescriptorsMissing);
193 }
194 }
195 if let Ok(fd) = name.parse() {
196 return Ok(Binding::FileDescriptor(fd));
197 }
198 #[cfg(target_os = "macos")]
199 {
200 let fds = raunch::activate_socket(name).map_err(|_| Error::DescriptorsMissing)?;
201 if fds.len() == 1 {
202 Ok(Binding::FileDescriptor(fds[0]))
203 } else {
204 Err(Error::DescriptorOutOfRange(fds.len() as i32))
205 }
206 }
207 #[cfg(not(target_os = "macos"))]
208 {
209 if let (Ok(names), Ok(fds)) = (var("LISTEN_FDNAMES"), var("LISTEN_FDS")) {
210 let fds: usize = fds.parse()?;
211 for (fd_index, fd_name) in names.split(':').enumerate() {
212 if fd_name == name && fd_index < fds {
213 return Ok(Binding::FileDescriptor(
214 SD_LISTEN_FDS_START + fd_index as i32,
215 ));
216 }
217 }
218 }
219 Err(Error::DescriptorsMissing)
220 }
221 } else if let Some(file) = s.strip_prefix("unix://") {
222 Ok(Binding::FilePath(file.into()))
223 } else if let Some(file) = s.strip_prefix("npipe://") {
224 if let Some('.' | '/' | '\\') = file.chars().next() {
225 Ok(Binding::NamedPipe(file.replace('/', "\\").into()))
226 } else {
227 Ok(Binding::NamedPipe(format!(r"\\.\pipe\{file}").into()))
228 }
229 } else if let Some(addr) = s.strip_prefix("tcp://") {
230 match addr.to_socket_addrs() {
231 Ok(addrs) => Ok(Binding::Sockets(addrs.collect())),
232 Err(err) => return Err(Error::BadAddress(err)),
233 }
234 } else if s.starts_with(r"\\") {
235 Ok(Binding::NamedPipe(s.into()))
236 } else {
237 Err(Error::UnsupportedScheme)
238 }
239 }
240}
241
242impl std::str::FromStr for Binding {
243 type Err = Error;
244
245 fn from_str(s: &str) -> Result<Self, Self::Err> {
246 s.try_into()
247 }
248}
249
250impl TryFrom<Binding> for Listener {
251 type Error = std::io::Error;
252
253 fn try_from(value: Binding) -> Result<Self, Self::Error> {
254 match value {
255 #[cfg(unix)]
256 Binding::FileDescriptor(descriptor) => {
257 use std::os::unix::io::FromRawFd;
258
259 Ok(unsafe { UnixListener::from_raw_fd(descriptor) }.into())
260 }
261 #[cfg(unix)]
262 Binding::FilePath(path) => {
263 let _ = std::fs::remove_file(&path);
265 Ok(UnixListener::bind(path)?.into())
266 }
267 Binding::Sockets(sockets) => Ok(std::net::TcpListener::bind(&*sockets)?.into()),
268 Binding::NamedPipe(pipe) => Ok(Listener::NamedPipe(pipe)),
269 #[cfg(not(unix))]
270 _ => Err(std::io::Error::new(
271 std::io::ErrorKind::Other,
272 Error::UnsupportedScheme,
273 )),
274 }
275 }
276}
277
278impl TryFrom<Binding> for Stream {
279 type Error = std::io::Error;
280
281 fn try_from(value: Binding) -> Result<Self, Self::Error> {
282 match value {
283 #[cfg(unix)]
284 Binding::FileDescriptor(descriptor) => {
285 use std::os::unix::io::FromRawFd;
286
287 Ok(unsafe { UnixStream::from_raw_fd(descriptor) }.into())
288 }
289 #[cfg(unix)]
290 Binding::FilePath(path) => Ok(UnixStream::connect(path)?.into()),
291 Binding::Sockets(sockets) => Ok(std::net::TcpStream::connect(&*sockets)?.into()),
292 Binding::NamedPipe(pipe) => Ok(Self::NamedPipe(pipe)),
293 #[cfg(not(unix))]
294 _ => Err(std::io::Error::new(
295 std::io::ErrorKind::Other,
296 Error::UnsupportedScheme,
297 )),
298 }
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 #[cfg(unix)]
305 use std::os::fd::IntoRawFd;
306 use std::str::FromStr;
307
308 use serial_test::serial;
309
310 use super::*;
311
312 type TestResult = Result<(), Box<dyn std::error::Error>>;
313
314 #[test]
315 #[serial]
316 fn parse_fd() -> TestResult {
317 std::env::set_var("LISTEN_FDS", "1");
318 let binding = "fd://".parse()?;
319 assert_eq!(Binding::FileDescriptor(3), binding);
320
321 Ok(())
322 }
323
324 #[test]
325 #[cfg(unix)]
326 #[serial]
327 fn fd_to_listener() -> TestResult {
328 let file = tempfile::tempfile()?;
329 let binding = Binding::FileDescriptor(file.into_raw_fd());
330 let result: Result<Listener, _> = binding.try_into();
331
332 assert_eq!(cfg!(unix), result.is_ok());
334
335 Ok(())
336 }
337
338 #[test]
339 #[cfg(not(target_os = "macos"))]
341 #[serial]
342 fn parse_fd_named() -> TestResult {
343 std::env::set_var("LISTEN_FDS", "2");
344 std::env::set_var("LISTEN_FDNAMES", "other:service-name");
345 let binding = "fd://service-name".parse()?;
346 assert_eq!(Binding::FileDescriptor(4), binding);
347 std::env::remove_var("LISTEN_FDNAMES");
348
349 Ok(())
350 }
351
352 #[test]
353 #[cfg(target_os = "macos")]
356 #[serial]
357 fn parse_fd_named() -> TestResult {
358 assert!(matches!(
359 Binding::from_str("fd://service-name"),
360 Err(Error::DescriptorsMissing)
361 ));
362
363 Ok(())
364 }
365
366 #[test]
367 #[serial]
368 fn parse_fd_bad() -> TestResult {
369 std::env::set_var("LISTEN_FDS", "1"); std::env::set_var("LISTEN_FDNAMES", "other:service-name");
371 assert!(matches!(
372 Binding::from_str("fd://service-name"),
373 Err(Error::DescriptorsMissing)
374 ));
375 std::env::remove_var("LISTEN_FDNAMES");
376
377 Ok(())
378 }
379
380 #[test]
381 #[cfg(unix)]
382 #[serial]
383 fn parse_fd_explicit() -> TestResult {
384 let file = tempfile::tempfile()?;
385
386 let raw_fd = file.into_raw_fd();
387 let binding = format!("fd://{raw_fd}").parse()?;
388 assert_eq!(Binding::FileDescriptor(raw_fd), binding);
389
390 let result: Result<Listener, _> = binding.try_into();
391
392 assert_eq!(cfg!(unix), result.is_ok());
394
395 Ok(())
396 }
397
398 #[test]
399 #[serial]
400 fn parse_fd_fail_unsupported_fds_count() -> TestResult {
401 std::env::set_var("LISTEN_FDS", "3");
402 assert!(matches!(
403 Binding::from_str("fd://"),
404 Err(Error::DescriptorOutOfRange(3))
405 ));
406 Ok(())
407 }
408
409 #[test]
410 #[serial]
411 fn parse_fd_fail_not_a_number() -> TestResult {
412 std::env::set_var("LISTEN_FDS", "3a");
413 assert!(matches!(
414 Binding::from_str("fd://"),
415 Err(Error::BadDescriptor(_))
416 ));
417 Ok(())
418 }
419
420 #[test]
421 #[serial]
422 fn parse_fd_fail() -> TestResult {
423 std::env::remove_var("LISTEN_FDS");
424 assert!(matches!(
425 Binding::from_str("fd://"),
426 Err(Error::DescriptorsMissing)
427 ));
428 Ok(())
429 }
430
431 #[test]
432 fn parse_unix() -> TestResult {
433 let binding = "unix:///tmp/test".try_into()?;
434 assert_eq!(Binding::FilePath("/tmp/test".into()), binding);
435
436 let result: Result<Listener, _> = binding.try_into();
437 if cfg!(unix) {
439 assert!(result.is_ok());
440 } else {
441 assert!(result.is_err());
442 }
443
444 Ok(())
445 }
446
447 #[test]
448 fn parse_tcp() -> TestResult {
449 let binding = "tcp://127.0.0.1:8081".try_into()?;
450 assert_eq!(
451 Binding::from(SocketAddr::from(([127, 0, 0, 1], 8081))),
452 binding
453 );
454 let _: Listener = binding.try_into()?;
455 Ok(())
456 }
457
458 #[test]
459 fn parse_tcp_localhost() -> TestResult {
460 let mut binding = "tcp://localhost:8081".try_into()?;
461
462 let Binding::Sockets(addrs) = &mut binding else {
463 panic!("Address should be parsed to Sockets");
464 };
465
466 let mut expected = vec![
467 SocketAddr::from(([127, 0, 0, 1], 8081)),
468 SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 1], 8081)),
469 ];
470
471 addrs.sort();
473 expected.sort();
474
475 assert_eq!(addrs, &expected);
476
477 let _: Listener = binding.try_into()?;
478 Ok(())
479 }
480
481 #[test]
482 fn parse_tcp_fail() -> TestResult {
483 assert!(matches!(
484 Binding::try_from("tcp://::8080"),
485 Err(Error::BadAddress(_))
486 ));
487
488 assert!(matches!(
489 Binding::try_from("tcp://an-unknown-hostname:8080"),
490 Err(Error::BadAddress(_))
491 ));
492
493 Ok(())
494 }
495
496 #[test]
497 fn parse_pipe() -> TestResult {
498 let binding = r"\\.\pipe\test".try_into()?;
499 assert_eq!(Binding::NamedPipe(r"\\.\pipe\test".into()), binding);
500 let _: Listener = binding.try_into()?;
501 Ok(())
502 }
503
504 #[test]
505 fn parse_pipe_short() -> TestResult {
506 let binding = r"npipe://test".try_into()?;
507 assert_eq!(Binding::NamedPipe(r"\\.\pipe\test".into()), binding);
508 let _: Listener = binding.try_into()?;
509 Ok(())
510 }
511
512 #[test]
513 fn parse_pipe_long() -> TestResult {
514 let binding = r"npipe:////./pipe/test".try_into()?;
515 assert_eq!(Binding::NamedPipe(r"\\.\pipe\test".into()), binding);
516 let _: Listener = binding.try_into()?;
517 Ok(())
518 }
519
520 #[test]
521 fn parse_pipe_fail() -> TestResult {
522 assert!(matches!(
523 Binding::try_from(r"\test"),
524 Err(Error::UnsupportedScheme)
525 ));
526 Ok(())
527 }
528
529 #[test]
530 fn parse_unknown_fail() -> TestResult {
531 assert!(matches!(
532 Binding::try_from("unknown://test"),
533 Err(Error::UnsupportedScheme)
534 ));
535 Ok(())
536 }
537
538 #[test]
539 #[cfg(unix)]
540 #[serial]
541 fn listen_on_socket_cleans_the_socket_file() -> TestResult {
542 let dir = std::env::temp_dir().join("temp-socket");
543 let binding = Binding::FilePath(dir);
544 let listener: Listener = binding.try_into().unwrap();
545 drop(listener);
546 let dir = std::env::temp_dir().join("temp-socket");
548 let binding = Binding::FilePath(dir);
549 let listener: Listener = binding.try_into().unwrap();
550 drop(listener);
551 Ok(())
552 }
553
554 #[test]
555 #[cfg(unix)]
556 fn convert_from_pathbuf() {
557 let path = std::path::PathBuf::from("/tmp");
558 let binding: Binding = path.into();
559 assert!(matches!(binding, Binding::FilePath(_)));
560 }
561
562 #[test]
563 fn convert_from_socket() {
564 let socket: SocketAddr = ([127, 0, 0, 1], 8080).into();
565 let binding: Binding = socket.into();
566 assert!(matches!(binding, Binding::Sockets(_)));
567 }
568}