service_binding/
service.rs

1use std::env::var;
2use std::net::SocketAddr;
3use std::net::TcpListener;
4use std::net::TcpStream;
5use std::net::ToSocketAddrs;
6#[cfg(unix)]
7use std::os::unix::net::UnixListener;
8#[cfg(unix)]
9use std::os::unix::net::UnixStream;
10use std::path::PathBuf;
11
12use super::Error;
13
14const SD_LISTEN_FDS_START: i32 = 3;
15
16/// Service binding.
17///
18/// Indicates which mechanism should the service take to bind its
19/// listener to.
20///
21/// # Examples
22///
23/// Note that since the `tcp` protocol can use an address the `Sockets`
24/// binding will contain all IP addresses that the address resolves to.
25///
26/// ```
27/// # use service_binding::Binding;
28/// # fn main() -> testresult::TestResult {
29/// let binding = "tcp://127.0.0.1:8080".try_into()?;
30/// assert_eq!(
31///     Binding::Sockets(vec![([127, 0, 0, 1], 8080).into()]),
32///     binding
33/// );
34/// # Ok(()) }
35/// ```
36#[derive(Debug, PartialEq, Eq, Clone)]
37pub enum Binding {
38    /// The service should be bound to this explicit, opened file
39    /// descriptor. This mechanism is used by the socket activation.
40    FileDescriptor(i32),
41
42    /// The service should be bound to a Unix domain socket file under
43    /// specified path.
44    FilePath(PathBuf),
45
46    /// The service should be bound to the first TCP socket that succeed
47    /// the binding.
48    Sockets(Vec<SocketAddr>),
49
50    /// Windows Named Pipe.
51    NamedPipe(std::ffi::OsString),
52}
53
54impl From<PathBuf> for Binding {
55    fn from(value: PathBuf) -> Self {
56        Binding::FilePath(value)
57    }
58}
59
60impl From<SocketAddr> for Binding {
61    fn from(value: SocketAddr) -> Self {
62        Binding::Sockets(vec![value])
63    }
64}
65
66/// Opened service listener.
67///
68/// This structure contains an already open listener. Note that the
69/// listeners are set to non-blocking mode.
70///
71/// # Examples
72///
73/// ```
74/// # use service_binding::{Binding, Listener};
75/// # fn main() -> testresult::TestResult {
76/// let binding: Binding = "tcp://127.0.0.1:8080".parse()?;
77/// let listener = binding.try_into()?;
78/// assert!(matches!(listener, Listener::Tcp(_)));
79/// # Ok(()) }
80/// ```
81#[derive(Debug)]
82pub enum Listener {
83    /// Listener for a Unix domain socket.
84    #[cfg(unix)]
85    Unix(UnixListener),
86
87    /// Listener for a TCP socket.
88    Tcp(TcpListener),
89
90    /// Named Pipe.
91    NamedPipe(std::ffi::OsString),
92}
93
94#[cfg(unix)]
95impl From<UnixListener> for Listener {
96    fn from(listener: UnixListener) -> Self {
97        while let Err(e) = listener.set_nonblocking(true) {
98            // retry WouldBlock errors
99            if e.kind() != std::io::ErrorKind::WouldBlock {
100                break;
101            }
102        }
103
104        Listener::Unix(listener)
105    }
106}
107
108impl From<TcpListener> for Listener {
109    fn from(listener: TcpListener) -> Self {
110        while let Err(e) = listener.set_nonblocking(true) {
111            // retry WouldBlock errors
112            if e.kind() != std::io::ErrorKind::WouldBlock {
113                break;
114            }
115        }
116
117        Listener::Tcp(listener)
118    }
119}
120
121/// Client service connection.
122///
123/// This structure contains an already open stream. Note that the
124/// streams are set to non-blocking mode.
125///
126/// # Examples
127///
128/// ```no_run
129/// # use service_binding::{Binding, Stream};
130/// # fn main() -> testresult::TestResult {
131/// let binding: Binding = "tcp://127.0.0.1:8080".parse()?;
132/// let stream = binding.try_into()?;
133/// assert!(matches!(stream, Stream::Tcp(_)));
134/// # Ok(()) }
135/// ```
136#[derive(Debug)]
137pub enum Stream {
138    /// Stream for a Unix domain socket.
139    #[cfg(unix)]
140    Unix(UnixStream),
141
142    /// Stream for a TCP socket.
143    Tcp(TcpStream),
144
145    /// Named Pipe.
146    NamedPipe(std::ffi::OsString),
147}
148
149#[cfg(unix)]
150impl From<UnixStream> for Stream {
151    fn from(stream: UnixStream) -> Self {
152        while let Err(e) = stream.set_nonblocking(true) {
153            // retry WouldBlock errors
154            if e.kind() != std::io::ErrorKind::WouldBlock {
155                break;
156            }
157        }
158
159        Stream::Unix(stream)
160    }
161}
162
163impl From<TcpStream> for Stream {
164    fn from(stream: TcpStream) -> Self {
165        while let Err(e) = stream.set_nonblocking(true) {
166            // retry WouldBlock errors
167            if e.kind() != std::io::ErrorKind::WouldBlock {
168                break;
169            }
170        }
171
172        Stream::Tcp(stream)
173    }
174}
175
176impl<'a> std::convert::TryFrom<&'a str> for Binding {
177    type Error = Error;
178
179    fn try_from(s: &'a str) -> Result<Self, Self::Error> {
180        if let Some(name) = s.strip_prefix("fd://") {
181            if name.is_empty() {
182                if let Ok(fds) = var("LISTEN_FDS") {
183                    let fds: i32 = fds.parse()?;
184
185                    // we support only one socket for now
186                    if fds != 1 {
187                        return Err(Error::DescriptorOutOfRange(fds));
188                    }
189
190                    return Ok(Binding::FileDescriptor(SD_LISTEN_FDS_START));
191                } else {
192                    return Err(Error::DescriptorsMissing);
193                }
194            }
195            if let Ok(fd) = name.parse() {
196                return Ok(Binding::FileDescriptor(fd));
197            }
198            #[cfg(target_os = "macos")]
199            {
200                let fds = raunch::activate_socket(name).map_err(|_| Error::DescriptorsMissing)?;
201                if fds.len() == 1 {
202                    Ok(Binding::FileDescriptor(fds[0]))
203                } else {
204                    Err(Error::DescriptorOutOfRange(fds.len() as i32))
205                }
206            }
207            #[cfg(not(target_os = "macos"))]
208            {
209                if let (Ok(names), Ok(fds)) = (var("LISTEN_FDNAMES"), var("LISTEN_FDS")) {
210                    let fds: usize = fds.parse()?;
211                    for (fd_index, fd_name) in names.split(':').enumerate() {
212                        if fd_name == name && fd_index < fds {
213                            return Ok(Binding::FileDescriptor(
214                                SD_LISTEN_FDS_START + fd_index as i32,
215                            ));
216                        }
217                    }
218                }
219                Err(Error::DescriptorsMissing)
220            }
221        } else if let Some(file) = s.strip_prefix("unix://") {
222            Ok(Binding::FilePath(file.into()))
223        } else if let Some(file) = s.strip_prefix("npipe://") {
224            if let Some('.' | '/' | '\\') = file.chars().next() {
225                Ok(Binding::NamedPipe(file.replace('/', "\\").into()))
226            } else {
227                Ok(Binding::NamedPipe(format!(r"\\.\pipe\{file}").into()))
228            }
229        } else if let Some(addr) = s.strip_prefix("tcp://") {
230            match addr.to_socket_addrs() {
231                Ok(addrs) => Ok(Binding::Sockets(addrs.collect())),
232                Err(err) => return Err(Error::BadAddress(err)),
233            }
234        } else if s.starts_with(r"\\") {
235            Ok(Binding::NamedPipe(s.into()))
236        } else {
237            Err(Error::UnsupportedScheme)
238        }
239    }
240}
241
242impl std::str::FromStr for Binding {
243    type Err = Error;
244
245    fn from_str(s: &str) -> Result<Self, Self::Err> {
246        s.try_into()
247    }
248}
249
250impl TryFrom<Binding> for Listener {
251    type Error = std::io::Error;
252
253    fn try_from(value: Binding) -> Result<Self, Self::Error> {
254        match value {
255            #[cfg(unix)]
256            Binding::FileDescriptor(descriptor) => {
257                use std::os::unix::io::FromRawFd;
258
259                Ok(unsafe { UnixListener::from_raw_fd(descriptor) }.into())
260            }
261            #[cfg(unix)]
262            Binding::FilePath(path) => {
263                // ignore errors if the file does not exist
264                let _ = std::fs::remove_file(&path);
265                Ok(UnixListener::bind(path)?.into())
266            }
267            Binding::Sockets(sockets) => Ok(std::net::TcpListener::bind(&*sockets)?.into()),
268            Binding::NamedPipe(pipe) => Ok(Listener::NamedPipe(pipe)),
269            #[cfg(not(unix))]
270            _ => Err(std::io::Error::new(
271                std::io::ErrorKind::Other,
272                Error::UnsupportedScheme,
273            )),
274        }
275    }
276}
277
278impl TryFrom<Binding> for Stream {
279    type Error = std::io::Error;
280
281    fn try_from(value: Binding) -> Result<Self, Self::Error> {
282        match value {
283            #[cfg(unix)]
284            Binding::FileDescriptor(descriptor) => {
285                use std::os::unix::io::FromRawFd;
286
287                Ok(unsafe { UnixStream::from_raw_fd(descriptor) }.into())
288            }
289            #[cfg(unix)]
290            Binding::FilePath(path) => Ok(UnixStream::connect(path)?.into()),
291            Binding::Sockets(sockets) => Ok(std::net::TcpStream::connect(&*sockets)?.into()),
292            Binding::NamedPipe(pipe) => Ok(Self::NamedPipe(pipe)),
293            #[cfg(not(unix))]
294            _ => Err(std::io::Error::new(
295                std::io::ErrorKind::Other,
296                Error::UnsupportedScheme,
297            )),
298        }
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    #[cfg(unix)]
305    use std::os::fd::IntoRawFd;
306    use std::str::FromStr;
307
308    use serial_test::serial;
309
310    use super::*;
311
312    type TestResult = Result<(), Box<dyn std::error::Error>>;
313
314    #[test]
315    #[serial]
316    fn parse_fd() -> TestResult {
317        std::env::set_var("LISTEN_FDS", "1");
318        let binding = "fd://".parse()?;
319        assert_eq!(Binding::FileDescriptor(3), binding);
320
321        Ok(())
322    }
323
324    #[test]
325    #[cfg(unix)]
326    #[serial]
327    fn fd_to_listener() -> TestResult {
328        let file = tempfile::tempfile()?;
329        let binding = Binding::FileDescriptor(file.into_raw_fd());
330        let result: Result<Listener, _> = binding.try_into();
331
332        // UnixListener is supported only on Unix platforms
333        assert_eq!(cfg!(unix), result.is_ok());
334
335        Ok(())
336    }
337
338    #[test]
339    // on non-macOS systems this reads environment variables
340    #[cfg(not(target_os = "macos"))]
341    #[serial]
342    fn parse_fd_named() -> TestResult {
343        std::env::set_var("LISTEN_FDS", "2");
344        std::env::set_var("LISTEN_FDNAMES", "other:service-name");
345        let binding = "fd://service-name".parse()?;
346        assert_eq!(Binding::FileDescriptor(4), binding);
347        std::env::remove_var("LISTEN_FDNAMES");
348
349        Ok(())
350    }
351
352    #[test]
353    // on macOS the test will attempt launchd system activation but since
354    // the plist file is not present it will fail
355    #[cfg(target_os = "macos")]
356    #[serial]
357    fn parse_fd_named() -> TestResult {
358        assert!(matches!(
359            Binding::from_str("fd://service-name"),
360            Err(Error::DescriptorsMissing)
361        ));
362
363        Ok(())
364    }
365
366    #[test]
367    #[serial]
368    fn parse_fd_bad() -> TestResult {
369        std::env::set_var("LISTEN_FDS", "1"); // should be "2"
370        std::env::set_var("LISTEN_FDNAMES", "other:service-name");
371        assert!(matches!(
372            Binding::from_str("fd://service-name"),
373            Err(Error::DescriptorsMissing)
374        ));
375        std::env::remove_var("LISTEN_FDNAMES");
376
377        Ok(())
378    }
379
380    #[test]
381    #[cfg(unix)]
382    #[serial]
383    fn parse_fd_explicit() -> TestResult {
384        let file = tempfile::tempfile()?;
385
386        let raw_fd = file.into_raw_fd();
387        let binding = format!("fd://{raw_fd}").parse()?;
388        assert_eq!(Binding::FileDescriptor(raw_fd), binding);
389
390        let result: Result<Listener, _> = binding.try_into();
391
392        // UnixListener is supported only on Unix platforms
393        assert_eq!(cfg!(unix), result.is_ok());
394
395        Ok(())
396    }
397
398    #[test]
399    #[serial]
400    fn parse_fd_fail_unsupported_fds_count() -> TestResult {
401        std::env::set_var("LISTEN_FDS", "3");
402        assert!(matches!(
403            Binding::from_str("fd://"),
404            Err(Error::DescriptorOutOfRange(3))
405        ));
406        Ok(())
407    }
408
409    #[test]
410    #[serial]
411    fn parse_fd_fail_not_a_number() -> TestResult {
412        std::env::set_var("LISTEN_FDS", "3a");
413        assert!(matches!(
414            Binding::from_str("fd://"),
415            Err(Error::BadDescriptor(_))
416        ));
417        Ok(())
418    }
419
420    #[test]
421    #[serial]
422    fn parse_fd_fail() -> TestResult {
423        std::env::remove_var("LISTEN_FDS");
424        assert!(matches!(
425            Binding::from_str("fd://"),
426            Err(Error::DescriptorsMissing)
427        ));
428        Ok(())
429    }
430
431    #[test]
432    fn parse_unix() -> TestResult {
433        let binding = "unix:///tmp/test".try_into()?;
434        assert_eq!(Binding::FilePath("/tmp/test".into()), binding);
435
436        let result: Result<Listener, _> = binding.try_into();
437        // UnixListener is supported only on Unix platforms
438        if cfg!(unix) {
439            assert!(result.is_ok());
440        } else {
441            assert!(result.is_err());
442        }
443
444        Ok(())
445    }
446
447    #[test]
448    fn parse_tcp() -> TestResult {
449        let binding = "tcp://127.0.0.1:8081".try_into()?;
450        assert_eq!(
451            Binding::from(SocketAddr::from(([127, 0, 0, 1], 8081))),
452            binding
453        );
454        let _: Listener = binding.try_into()?;
455        Ok(())
456    }
457
458    #[test]
459    fn parse_tcp_localhost() -> TestResult {
460        let mut binding = "tcp://localhost:8081".try_into()?;
461
462        let Binding::Sockets(addrs) = &mut binding else {
463            panic!("Address should be parsed to Sockets");
464        };
465
466        let mut expected = vec![
467            SocketAddr::from(([127, 0, 0, 1], 8081)),
468            SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 1], 8081)),
469        ];
470
471        // Sort both vectors for testing equality as the ordering may be different
472        addrs.sort();
473        expected.sort();
474
475        assert_eq!(addrs, &expected);
476
477        let _: Listener = binding.try_into()?;
478        Ok(())
479    }
480
481    #[test]
482    fn parse_tcp_fail() -> TestResult {
483        assert!(matches!(
484            Binding::try_from("tcp://::8080"),
485            Err(Error::BadAddress(_))
486        ));
487
488        assert!(matches!(
489            Binding::try_from("tcp://an-unknown-hostname:8080"),
490            Err(Error::BadAddress(_))
491        ));
492
493        Ok(())
494    }
495
496    #[test]
497    fn parse_pipe() -> TestResult {
498        let binding = r"\\.\pipe\test".try_into()?;
499        assert_eq!(Binding::NamedPipe(r"\\.\pipe\test".into()), binding);
500        let _: Listener = binding.try_into()?;
501        Ok(())
502    }
503
504    #[test]
505    fn parse_pipe_short() -> TestResult {
506        let binding = r"npipe://test".try_into()?;
507        assert_eq!(Binding::NamedPipe(r"\\.\pipe\test".into()), binding);
508        let _: Listener = binding.try_into()?;
509        Ok(())
510    }
511
512    #[test]
513    fn parse_pipe_long() -> TestResult {
514        let binding = r"npipe:////./pipe/test".try_into()?;
515        assert_eq!(Binding::NamedPipe(r"\\.\pipe\test".into()), binding);
516        let _: Listener = binding.try_into()?;
517        Ok(())
518    }
519
520    #[test]
521    fn parse_pipe_fail() -> TestResult {
522        assert!(matches!(
523            Binding::try_from(r"\test"),
524            Err(Error::UnsupportedScheme)
525        ));
526        Ok(())
527    }
528
529    #[test]
530    fn parse_unknown_fail() -> TestResult {
531        assert!(matches!(
532            Binding::try_from("unknown://test"),
533            Err(Error::UnsupportedScheme)
534        ));
535        Ok(())
536    }
537
538    #[test]
539    #[cfg(unix)]
540    #[serial]
541    fn listen_on_socket_cleans_the_socket_file() -> TestResult {
542        let dir = std::env::temp_dir().join("temp-socket");
543        let binding = Binding::FilePath(dir);
544        let listener: Listener = binding.try_into().unwrap();
545        drop(listener);
546        // create a second listener from the same path
547        let dir = std::env::temp_dir().join("temp-socket");
548        let binding = Binding::FilePath(dir);
549        let listener: Listener = binding.try_into().unwrap();
550        drop(listener);
551        Ok(())
552    }
553
554    #[test]
555    #[cfg(unix)]
556    fn convert_from_pathbuf() {
557        let path = std::path::PathBuf::from("/tmp");
558        let binding: Binding = path.into();
559        assert!(matches!(binding, Binding::FilePath(_)));
560    }
561
562    #[test]
563    fn convert_from_socket() {
564        let socket: SocketAddr = ([127, 0, 0, 1], 8080).into();
565        let binding: Binding = socket.into();
566        assert!(matches!(binding, Binding::Sockets(_)));
567    }
568}