shell_compose/
ipc.rs

1use crate::{get_user_name, Message};
2use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
3use interprocess::local_socket::{prelude::*, GenericFilePath, ListenerOptions};
4use log::debug;
5use std::io;
6use std::io::prelude::*;
7use thiserror::Error;
8
9#[derive(Error, Debug)]
10pub enum IpcServerError {
11    #[error("Failed to bind to socket: {0}")]
12    BindError(io::Error),
13    #[error("Failed to resolve socket name: {0}")]
14    SocketNameError(io::Error),
15}
16
17#[derive(Error, Debug)]
18pub enum IpcClientError {
19    #[error("Failed to connect to socket: {0}")]
20    ConnectError(#[from] io::Error),
21    #[error("Failed to resolve socket name: {0}")]
22    SocketNameError(io::Error),
23    #[error("Failed to read from socket: {0}")]
24    ReadError(#[from] IpcStreamReadError),
25    #[error("Failed to write to socket: {0}")]
26    WriteError(#[from] IpcStreamWriteError),
27}
28
29#[derive(Error, Debug)]
30pub enum IpcStreamReadError {
31    #[error("Failed to read from socket: {0}")]
32    ReadError(#[from] io::Error),
33    #[error("Failed to deserialize data from socket: {0}")]
34    DeserializeError(#[from] bincode::Error),
35}
36
37#[derive(Error, Debug)]
38pub enum IpcStreamWriteError {
39    #[error("Failed to write to socket: {0}")]
40    WriteError(#[from] io::Error),
41    #[error("Failed to serialize data for socket: {0}")]
42    SerializeError(#[from] bincode::Error),
43}
44
45/// Listen for incoming connections on the given socket.
46///
47/// # Arguments
48///
49/// * `socket` - The socket name to listen on.
50/// * `on_connection` - A function that will be invoked for each incoming connection.
51/// * `on_connection_error` - An optional function that will be invoked if there is an error accepting a connection.
52pub fn start_ipc_listener<F: FnMut(IpcStream) + Send + 'static>(
53    socket: &str,
54    mut on_connection: F,
55    on_connection_error: Option<fn(io::Error)>,
56) -> Result<(), IpcServerError> {
57    let name = socket
58        .to_fs_name::<GenericFilePath>()
59        .map_err(IpcServerError::SocketNameError)?;
60    let mut options = ListenerOptions::new().name(name.clone());
61    #[cfg(target_family = "unix")]
62    {
63        use interprocess::os::unix::local_socket::ListenerOptionsExt;
64        options = options.mode(0o600);
65    }
66    #[cfg(target_family = "windows")]
67    {
68        use interprocess::os::windows::{
69            local_socket::ListenerOptionsExt, security_descriptor::SecurityDescriptor,
70        };
71        options = options.security_descriptor(SecurityDescriptor::new().unwrap());
72    }
73    let listener = match options.create_sync() {
74        Err(e) => return Err(IpcServerError::BindError(e)),
75        Ok(listener) => listener,
76    };
77
78    let error_handler = move |inc: Result<LocalSocketStream, io::Error>| match inc {
79        Ok(conn) => Some(conn),
80        Err(e) => {
81            if let Some(on_connection_error) = on_connection_error {
82                on_connection_error(e);
83            }
84            None
85        }
86    };
87
88    for stream in listener.incoming().filter_map(error_handler) {
89        let logname = "listener".to_string();
90        let stream = IpcStream { logname, stream };
91        on_connection(stream);
92    }
93
94    Ok(())
95}
96
97/// Connect to the socket and return the stream.
98fn ipc_client_connect(socket_name: &str) -> Result<LocalSocketStream, IpcClientError> {
99    let name = socket_name
100        .to_fs_name::<GenericFilePath>()
101        .map_err(IpcClientError::SocketNameError)?;
102    LocalSocketStream::connect(name).map_err(IpcClientError::ConnectError)
103}
104
105trait SocketExt {
106    fn read_serde<T: serde::de::DeserializeOwned>(&mut self) -> Result<T, IpcStreamReadError>;
107    fn write_serde<T: serde::Serialize>(&mut self, data: &T) -> Result<(), IpcStreamWriteError>;
108}
109
110impl SocketExt for LocalSocketStream {
111    /// Read a serializable object from the socket.
112    ///
113    /// This reads a `u32` in little endian, then reads that many bytes from the socket, then deserializes the data using `bincode::deserialize`.
114    fn read_serde<T: serde::de::DeserializeOwned>(&mut self) -> Result<T, IpcStreamReadError> {
115        let size = self.read_u32::<LittleEndian>()?;
116
117        let bytes = {
118            let mut bytes = vec![0; size as usize];
119
120            self.read_exact(&mut bytes)?;
121
122            bytes
123        };
124
125        let result: T = bincode::deserialize(&bytes)?;
126
127        Ok(result)
128    }
129
130    /// Write a serializable object to the socket.
131    ///
132    /// This serializes the data using `bincode::serialize`, writes the length of the serialized data as a `u32` in little endian, then writes the serialized data.
133    fn write_serde<T: serde::Serialize>(&mut self, data: &T) -> Result<(), IpcStreamWriteError> {
134        let bytes = bincode::serialize(data)?;
135
136        self.write_u32::<LittleEndian>(bytes.len() as u32)?;
137        self.write_all(&bytes)?;
138
139        Ok(())
140    }
141}
142
143/// Communication stream
144pub struct IpcStream {
145    logname: String,
146    stream: LocalSocketStream,
147}
148
149impl IpcStream {
150    /// Connects to the socket and return the stream
151    pub fn connect(logname: &str) -> Result<Self, IpcClientError> {
152        let socket_name = IpcStream::user_socket_name();
153        let mut stream = ipc_client_connect(&socket_name)?;
154        stream.write_serde(&Message::Connect)?;
155        Ok(IpcStream {
156            logname: logname.to_string(),
157            stream,
158        })
159    }
160    /// Check socket connection
161    pub fn check_connection() -> Result<(), IpcClientError> {
162        IpcStream::connect("check_connection")?;
163        Ok(())
164    }
165    pub fn user_socket_name() -> String {
166        let user = get_user_name().unwrap_or("_".to_string());
167        IpcStream::socket_name(&user)
168    }
169    #[cfg(target_family = "unix")]
170    fn socket_name(user: &str) -> String {
171        let tmpdir = std::env::var("TMPDIR").ok();
172        format!(
173            "{}/shell-compose-{user}.sock",
174            tmpdir.as_deref().unwrap_or("/tmp")
175        )
176    }
177    #[cfg(target_family = "windows")]
178    fn socket_name(user: &str) -> String {
179        format!(r"\\.\pipe\shell-compose-{user}")
180    }
181    /// Check stream
182    pub fn alive(&mut self) -> Result<(), IpcClientError> {
183        self.stream.write_serde(&Message::Connect)?;
184        Ok(())
185    }
186    /// Send Message.
187    pub fn send_message(&mut self, message: &Message) -> Result<(), IpcClientError> {
188        debug!(target: &self.logname, "send_message {message:?}");
189        self.stream.write_serde(&message)?;
190        Ok(())
191    }
192    /// Receive Message.
193    pub fn receive_message(&mut self) -> Result<Message, IpcClientError> {
194        let message = self.stream.read_serde()?;
195        debug!(target: &self.logname, "receive_message {message:?}");
196        Ok(message)
197    }
198    /// Send a message and immediately read response message,
199    /// blocking until a response is received.
200    pub fn send_query(&mut self, request: &Message) -> Result<Message, IpcClientError> {
201        self.send_message(request)?;
202        let response = self.receive_message()?;
203        Ok(response)
204    }
205}