1use std::io::{Read, Write};
2use std::net::Shutdown;
3use std::os::fd::AsRawFd;
4use std::os::unix::net::{SocketAddr, UnixListener, UnixStream};
5use std::sync::atomic::{AtomicBool, Ordering};
6use std::sync::Arc;
7
8use libc::c_int;
9use utf8path::Path;
10
11#[derive(Debug)]
14struct ContextState {
15 cancel: AtomicBool,
16 rx: c_int,
17 tx: c_int,
18}
19
20impl ContextState {
21 fn new() -> Result<Self, std::io::Error> {
22 let cancel = AtomicBool::new(false);
23 let mut fds: [c_int; 2] = [-1; 2];
24 unsafe {
25 if libc::pipe(&mut fds as *mut c_int) < 0 {
26 return Err(std::io::Error::last_os_error());
27 }
28 }
29 let rx = fds[0];
30 let tx = fds[1];
31 Ok(ContextState { cancel, rx, tx })
32 }
33
34 fn cancel(&self) {
35 if self.cancel.swap(true, Ordering::AcqRel) {
36 return;
37 }
38 unsafe {
39 libc::close(self.tx);
40 }
41 }
42
43 fn canceled(&self) -> bool {
44 self.cancel.load(Ordering::Relaxed)
45 }
46
47 fn wait(&self, other: &impl AsRawFd) -> Result<(), std::io::Error> {
48 let mut pfd = [
49 libc::pollfd {
50 fd: self.rx,
51 events: libc::POLLERR,
52 revents: 0,
53 },
54 libc::pollfd {
55 fd: other.as_raw_fd(),
56 events: libc::POLLIN | libc::POLLHUP | libc::POLLERR,
57 revents: 0,
58 },
59 ];
60 unsafe {
61 if libc::poll(pfd.as_mut_ptr(), 2, -1) < 0 {
62 return Err(std::io::Error::last_os_error());
63 }
64 }
65 Ok(())
66 }
67}
68
69impl Drop for ContextState {
70 fn drop(&mut self) {
71 self.cancel();
72 unsafe {
73 libc::close(self.rx);
74 }
75 }
76}
77
78#[derive(Clone, Debug)]
81pub struct Context {
82 state: Arc<ContextState>,
83}
84
85impl Context {
86 pub fn new() -> Result<Self, std::io::Error> {
87 let state = Arc::new(ContextState::new()?);
88 Ok(Self { state })
89 }
90
91 pub fn cancel(&self) {
92 self.state.cancel();
93 }
94
95 pub fn canceled(&self) -> bool {
96 self.state.canceled()
97 }
98
99 pub fn wait(&self, other: &impl AsRawFd) -> Result<(), std::io::Error> {
100 self.state.wait(other)
101 }
102}
103
104pub struct Client {
107 path: Path<'static>,
108}
109
110impl Client {
111 pub fn new<'a>(path: impl Into<Path<'a>>) -> Result<Self, std::io::Error> {
112 let path = path.into().into_owned();
113 Ok(Client { path })
114 }
115
116 pub fn invoke(&mut self, command: &str) -> Result<String, std::io::Error> {
117 let mut stream = UnixStream::connect(self.path.as_str())?;
118 stream.write_all(command.as_ref())?;
119 stream.shutdown(Shutdown::Write)?;
120 let mut response = vec![];
121 loop {
122 let mut buf = [0u8; 4096];
123 let amt = stream.read(&mut buf)?;
124 if amt == 0 {
125 break;
126 }
127 response.extend(buf[..amt].iter());
128 }
129 String::from_utf8(response).map_err(|_| std::io::Error::other("expected utf8 in response"))
130 }
131}
132
133pub trait Invokable: Send + Sync {
136 fn invoke(&self, command: &str) -> String;
137}
138
139pub struct Server {
142 path: Path<'static>,
143 listener: UnixListener,
144 invokable: Arc<dyn Invokable>,
145}
146
147impl Server {
148 pub fn new<'a, I: Invokable + 'static>(
149 path: impl Into<Path<'a>>,
150 invoke: I,
151 ) -> Result<Self, std::io::Error> {
152 let path = path.into().into_owned();
153 let listener = UnixListener::bind(path.as_str())?;
154 let invokable = Arc::new(invoke);
155 Ok(Server {
156 path,
157 listener,
158 invokable,
159 })
160 }
161
162 pub fn serve(&mut self, context: &Context) -> Result<(), std::io::Error> {
163 loop {
164 context.wait(&self.listener)?;
165 if context.canceled() {
166 break;
167 }
168 let context = context.clone();
169 let invokable = self.invokable.clone();
170 let (socket, addr) = self.listener.accept()?;
171 let context = context.clone();
172 let _handle = std::thread::spawn(move || {
173 Self::serve_one(&context, invokable.as_ref(), socket, addr);
174 });
175 }
180 Ok(())
181 }
182
183 fn serve_one(
184 context: &Context,
185 invokable: &dyn Invokable,
186 mut socket: UnixStream,
187 _: SocketAddr,
188 ) {
189 let mut request = vec![];
190 loop {
191 if let Err(err) = context.wait(&socket) {
192 _ = socket.write_all(format!("error: {err:?}").as_ref());
193 return;
194 }
195 if context.canceled() {
196 _ = socket.write_all("error: server shut down".as_ref());
197 return;
198 }
199 let mut buf = [0u8; 4096];
200 let amt = match socket.read(&mut buf) {
201 Ok(amt) => amt,
202 Err(err) => {
203 _ = socket.write_all(format!("error: could not read: {err:?}").as_ref());
204 return;
205 }
206 };
207 if amt == 0 {
208 break;
209 }
210 request.extend(buf[..amt].iter());
211 if request.len() >= 65536 {
212 _ = socket.write_all("error: request exceeds 65536 bytes".as_ref());
213 return;
214 }
215 }
216 let request = match String::from_utf8(request) {
217 Ok(request) => request,
218 Err(err) => {
219 _ = socket
220 .write_all(format!("error: could not interpret as utf8: {err:?}").as_ref());
221 return;
222 }
223 };
224 let response = invokable.invoke(&request);
225 _ = socket.write_all(response.as_ref());
226 }
227}
228
229impl std::fmt::Debug for Server {
230 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> {
231 write!(fmt, "Server({:?})", self.path)
232 }
233}