unix_sock/
lib.rs

1use std::io::{Read, Write};
2use std::net::Shutdown;
3use std::os::fd::AsRawFd;
4use std::os::unix::net::{SocketAddr, UnixListener, UnixStream};
5use std::sync::atomic::{AtomicBool, Ordering};
6use std::sync::Arc;
7
8use libc::c_int;
9use utf8path::Path;
10
11/////////////////////////////////////////// ContextState ///////////////////////////////////////////
12
13#[derive(Debug)]
14struct ContextState {
15    cancel: AtomicBool,
16    rx: c_int,
17    tx: c_int,
18}
19
20impl ContextState {
21    fn new() -> Result<Self, std::io::Error> {
22        let cancel = AtomicBool::new(false);
23        let mut fds: [c_int; 2] = [-1; 2];
24        unsafe {
25            if libc::pipe(&mut fds as *mut c_int) < 0 {
26                return Err(std::io::Error::last_os_error());
27            }
28        }
29        let rx = fds[0];
30        let tx = fds[1];
31        Ok(ContextState { cancel, rx, tx })
32    }
33
34    fn cancel(&self) {
35        if self.cancel.swap(true, Ordering::AcqRel) {
36            return;
37        }
38        unsafe {
39            libc::close(self.tx);
40        }
41    }
42
43    fn canceled(&self) -> bool {
44        self.cancel.load(Ordering::Relaxed)
45    }
46
47    fn wait(&self, other: &impl AsRawFd) -> Result<(), std::io::Error> {
48        let mut pfd = [
49            libc::pollfd {
50                fd: self.rx,
51                events: libc::POLLERR,
52                revents: 0,
53            },
54            libc::pollfd {
55                fd: other.as_raw_fd(),
56                events: libc::POLLIN | libc::POLLHUP | libc::POLLERR,
57                revents: 0,
58            },
59        ];
60        unsafe {
61            if libc::poll(pfd.as_mut_ptr(), 2, -1) < 0 {
62                return Err(std::io::Error::last_os_error());
63            }
64        }
65        Ok(())
66    }
67}
68
69impl Drop for ContextState {
70    fn drop(&mut self) {
71        self.cancel();
72        unsafe {
73            libc::close(self.rx);
74        }
75    }
76}
77
78////////////////////////////////////////////// Context /////////////////////////////////////////////
79
80#[derive(Clone, Debug)]
81pub struct Context {
82    state: Arc<ContextState>,
83}
84
85impl Context {
86    pub fn new() -> Result<Self, std::io::Error> {
87        let state = Arc::new(ContextState::new()?);
88        Ok(Self { state })
89    }
90
91    pub fn cancel(&self) {
92        self.state.cancel();
93    }
94
95    pub fn canceled(&self) -> bool {
96        self.state.canceled()
97    }
98
99    pub fn wait(&self, other: &impl AsRawFd) -> Result<(), std::io::Error> {
100        self.state.wait(other)
101    }
102}
103
104////////////////////////////////////////////// Client //////////////////////////////////////////////
105
106pub struct Client {
107    path: Path<'static>,
108}
109
110impl Client {
111    pub fn new<'a>(path: impl Into<Path<'a>>) -> Result<Self, std::io::Error> {
112        let path = path.into().into_owned();
113        Ok(Client { path })
114    }
115
116    pub fn invoke(&mut self, command: &str) -> Result<String, std::io::Error> {
117        let mut stream = UnixStream::connect(self.path.as_str())?;
118        stream.write_all(command.as_ref())?;
119        stream.shutdown(Shutdown::Write)?;
120        let mut response = vec![];
121        loop {
122            let mut buf = [0u8; 4096];
123            let amt = stream.read(&mut buf)?;
124            if amt == 0 {
125                break;
126            }
127            response.extend(buf[..amt].iter());
128        }
129        String::from_utf8(response).map_err(|_| std::io::Error::other("expected utf8 in response"))
130    }
131}
132
133///////////////////////////////////////////// Invokable ////////////////////////////////////////////
134
135pub trait Invokable: Send + Sync {
136    fn invoke(&self, command: &str) -> String;
137}
138
139////////////////////////////////////////////// Server //////////////////////////////////////////////
140
141pub struct Server {
142    path: Path<'static>,
143    listener: UnixListener,
144    invokable: Arc<dyn Invokable>,
145}
146
147impl Server {
148    pub fn new<'a, I: Invokable + 'static>(
149        path: impl Into<Path<'a>>,
150        invoke: I,
151    ) -> Result<Self, std::io::Error> {
152        let path = path.into().into_owned();
153        let listener = UnixListener::bind(path.as_str())?;
154        let invokable = Arc::new(invoke);
155        Ok(Server {
156            path,
157            listener,
158            invokable,
159        })
160    }
161
162    pub fn serve(&mut self, context: &Context) -> Result<(), std::io::Error> {
163        loop {
164            context.wait(&self.listener)?;
165            if context.canceled() {
166                break;
167            }
168            let context = context.clone();
169            let invokable = self.invokable.clone();
170            let (socket, addr) = self.listener.accept()?;
171            let context = context.clone();
172            let _handle = std::thread::spawn(move || {
173                Self::serve_one(&context, invokable.as_ref(), socket, addr);
174            });
175            // NOTE(rescrv):  We leak handle here and rely upon context being canceled and the
176            // thread exiting quickly to clean things up.  If it takes time, that's not a problem.
177            // If it doesn't happen, that's not a problem.  The lingering thread can only return an
178            // error on its socket---which is what we want.
179        }
180        Ok(())
181    }
182
183    fn serve_one(
184        context: &Context,
185        invokable: &dyn Invokable,
186        mut socket: UnixStream,
187        _: SocketAddr,
188    ) {
189        let mut request = vec![];
190        loop {
191            if let Err(err) = context.wait(&socket) {
192                _ = socket.write_all(format!("error: {err:?}").as_ref());
193                return;
194            }
195            if context.canceled() {
196                _ = socket.write_all("error: server shut down".as_ref());
197                return;
198            }
199            let mut buf = [0u8; 4096];
200            let amt = match socket.read(&mut buf) {
201                Ok(amt) => amt,
202                Err(err) => {
203                    _ = socket.write_all(format!("error: could not read: {err:?}").as_ref());
204                    return;
205                }
206            };
207            if amt == 0 {
208                break;
209            }
210            request.extend(buf[..amt].iter());
211            if request.len() >= 65536 {
212                _ = socket.write_all("error: request exceeds 65536 bytes".as_ref());
213                return;
214            }
215        }
216        let request = match String::from_utf8(request) {
217            Ok(request) => request,
218            Err(err) => {
219                _ = socket
220                    .write_all(format!("error: could not interpret as utf8: {err:?}").as_ref());
221                return;
222            }
223        };
224        let response = invokable.invoke(&request);
225        _ = socket.write_all(response.as_ref());
226    }
227}
228
229impl std::fmt::Debug for Server {
230    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> {
231        write!(fmt, "Server({:?})", self.path)
232    }
233}