wrpc_transport/frame/
unix.rs

1//! Unix domain socket transport
2
3use std::path::{Path, PathBuf};
4
5use anyhow::{bail, Context as _};
6use bytes::Bytes;
7use tokio::net::unix::{OwnedReadHalf, OwnedWriteHalf, SocketAddr};
8use tokio::net::{UnixListener, UnixStream};
9use tracing::instrument;
10
11use crate::frame::{invoke, Accept, Incoming, Outgoing};
12use crate::Invoke;
13
14/// [Invoke] implementation in terms of a single [`UnixStream`]
15///
16/// [`Invoke::invoke`] can only be called once on [Invocation],
17/// repeated calls with return an error
18pub struct Invocation(std::sync::Mutex<Option<UnixStream>>);
19
20/// [Invoke] implementation of a Unix domain socket transport
21#[derive(Clone, Debug)]
22pub struct Client<T>(T);
23
24impl From<PathBuf> for Client<PathBuf> {
25    fn from(path: PathBuf) -> Self {
26        Self(path)
27    }
28}
29
30impl<'a> From<&'a Path> for Client<&'a Path> {
31    fn from(path: &'a Path) -> Self {
32        Self(path)
33    }
34}
35
36impl<'a> From<&'a std::os::unix::net::SocketAddr> for Client<&'a std::os::unix::net::SocketAddr> {
37    fn from(addr: &'a std::os::unix::net::SocketAddr) -> Self {
38        Self(addr)
39    }
40}
41
42impl From<std::os::unix::net::SocketAddr> for Client<std::os::unix::net::SocketAddr> {
43    fn from(addr: std::os::unix::net::SocketAddr) -> Self {
44        Self(addr)
45    }
46}
47
48impl Invoke for Client<PathBuf> {
49    type Context = ();
50    type Outgoing = Outgoing;
51    type Incoming = Incoming;
52
53    #[instrument(level = "trace", skip(self, paths, params), fields(params = format!("{params:02x?}")))]
54    async fn invoke<P>(
55        &self,
56        (): Self::Context,
57        instance: &str,
58        func: &str,
59        params: Bytes,
60        paths: impl AsRef<[P]> + Send,
61    ) -> anyhow::Result<(Self::Outgoing, Self::Incoming)>
62    where
63        P: AsRef<[Option<usize>]> + Send + Sync,
64    {
65        let stream = UnixStream::connect(&self.0).await?;
66        let (rx, tx) = stream.into_split();
67        invoke(tx, rx, instance, func, params, paths).await
68    }
69}
70
71impl Invoke for Client<&Path> {
72    type Context = ();
73    type Outgoing = Outgoing;
74    type Incoming = Incoming;
75
76    #[instrument(level = "trace", skip(self, paths, params), fields(params = format!("{params:02x?}")))]
77    async fn invoke<P>(
78        &self,
79        (): Self::Context,
80        instance: &str,
81        func: &str,
82        params: Bytes,
83        paths: impl AsRef<[P]> + Send,
84    ) -> anyhow::Result<(Self::Outgoing, Self::Incoming)>
85    where
86        P: AsRef<[Option<usize>]> + Send + Sync,
87    {
88        let stream = UnixStream::connect(self.0).await?;
89        let (rx, tx) = stream.into_split();
90        invoke(tx, rx, instance, func, params, paths).await
91    }
92}
93
94impl Invoke for Client<&std::os::unix::net::SocketAddr> {
95    type Context = ();
96    type Outgoing = Outgoing;
97    type Incoming = Incoming;
98
99    #[instrument(level = "trace", skip(self, paths, params), fields(params = format!("{params:02x?}")))]
100    async fn invoke<P>(
101        &self,
102        (): Self::Context,
103        instance: &str,
104        func: &str,
105        params: Bytes,
106        paths: impl AsRef<[P]> + Send,
107    ) -> anyhow::Result<(Self::Outgoing, Self::Incoming)>
108    where
109        P: AsRef<[Option<usize>]> + Send + Sync,
110    {
111        let stream = std::os::unix::net::UnixStream::connect_addr(self.0)?;
112        let stream = UnixStream::from_std(stream)?;
113        let (rx, tx) = stream.into_split();
114        invoke(tx, rx, instance, func, params, paths).await
115    }
116}
117
118impl Invoke for Client<std::os::unix::net::SocketAddr> {
119    type Context = ();
120    type Outgoing = Outgoing;
121    type Incoming = Incoming;
122
123    #[instrument(level = "trace", skip(self, paths, params), fields(params = format!("{params:02x?}")))]
124    async fn invoke<P>(
125        &self,
126        (): Self::Context,
127        instance: &str,
128        func: &str,
129        params: Bytes,
130        paths: impl AsRef<[P]> + Send,
131    ) -> anyhow::Result<(Self::Outgoing, Self::Incoming)>
132    where
133        P: AsRef<[Option<usize>]> + Send + Sync,
134    {
135        let stream = std::os::unix::net::UnixStream::connect_addr(&self.0)?;
136        let stream = UnixStream::from_std(stream)?;
137        let (rx, tx) = stream.into_split();
138        invoke(tx, rx, instance, func, params, paths).await
139    }
140}
141
142impl From<UnixStream> for Invocation {
143    fn from(stream: UnixStream) -> Self {
144        Self(std::sync::Mutex::new(Some(stream)))
145    }
146}
147
148impl Invoke for Invocation {
149    type Context = ();
150    type Outgoing = Outgoing;
151    type Incoming = Incoming;
152
153    #[instrument(level = "trace", skip(self, paths, params), fields(params = format!("{params:02x?}")))]
154    async fn invoke<P>(
155        &self,
156        (): Self::Context,
157        instance: &str,
158        func: &str,
159        params: Bytes,
160        paths: impl AsRef<[P]> + Send,
161    ) -> anyhow::Result<(Self::Outgoing, Self::Incoming)>
162    where
163        P: AsRef<[Option<usize>]> + Send + Sync,
164    {
165        let stream = match self.0.lock() {
166            Ok(mut stream) => stream
167                .take()
168                .context("stream was already used for an invocation")?,
169            Err(_) => bail!("stream lock poisoned"),
170        };
171        let (rx, tx) = stream.into_split();
172        invoke(tx, rx, instance, func, params, paths).await
173    }
174}
175
176impl Accept for UnixListener {
177    type Context = SocketAddr;
178    type Outgoing = OwnedWriteHalf;
179    type Incoming = OwnedReadHalf;
180
181    async fn accept(&self) -> std::io::Result<(Self::Context, Self::Outgoing, Self::Incoming)> {
182        (&self).accept().await
183    }
184}
185
186impl Accept for &UnixListener {
187    type Context = SocketAddr;
188    type Outgoing = OwnedWriteHalf;
189    type Incoming = OwnedReadHalf;
190
191    #[instrument(level = "trace")]
192    async fn accept(&self) -> std::io::Result<(Self::Context, Self::Outgoing, Self::Incoming)> {
193        let (stream, addr) = UnixListener::accept(self).await?;
194        let (rx, tx) = stream.into_split();
195        Ok((addr, tx, rx))
196    }
197}