1use crate::{get_user_name, Message};
2use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
3use interprocess::local_socket::{prelude::*, GenericFilePath, ListenerOptions};
4use log::debug;
5use std::io;
6use std::io::prelude::*;
7use thiserror::Error;
8
9#[derive(Error, Debug)]
10pub enum IpcServerError {
11 #[error("Failed to bind to socket: {0}")]
12 BindError(io::Error),
13 #[error("Failed to resolve socket name: {0}")]
14 SocketNameError(io::Error),
15}
16
17#[derive(Error, Debug)]
18pub enum IpcClientError {
19 #[error("Failed to connect to socket: {0}")]
20 ConnectError(#[from] io::Error),
21 #[error("Failed to resolve socket name: {0}")]
22 SocketNameError(io::Error),
23 #[error("Failed to read from socket: {0}")]
24 ReadError(#[from] IpcStreamReadError),
25 #[error("Failed to write to socket: {0}")]
26 WriteError(#[from] IpcStreamWriteError),
27}
28
29#[derive(Error, Debug)]
30pub enum IpcStreamReadError {
31 #[error("Failed to read from socket: {0}")]
32 ReadError(#[from] io::Error),
33 #[error("Failed to deserialize data from socket: {0}")]
34 DeserializeError(#[from] bincode::Error),
35}
36
37#[derive(Error, Debug)]
38pub enum IpcStreamWriteError {
39 #[error("Failed to write to socket: {0}")]
40 WriteError(#[from] io::Error),
41 #[error("Failed to serialize data for socket: {0}")]
42 SerializeError(#[from] bincode::Error),
43}
44
45pub fn start_ipc_listener<F: FnMut(IpcStream) + Send + 'static>(
53 socket: &str,
54 mut on_connection: F,
55 on_connection_error: Option<fn(io::Error)>,
56) -> Result<(), IpcServerError> {
57 let name = socket
58 .to_fs_name::<GenericFilePath>()
59 .map_err(IpcServerError::SocketNameError)?;
60 let mut options = ListenerOptions::new().name(name.clone());
61 #[cfg(target_family = "unix")]
62 {
63 use interprocess::os::unix::local_socket::ListenerOptionsExt;
64 options = options.mode(0o600);
65 }
66 #[cfg(target_family = "windows")]
67 {
68 use interprocess::os::windows::{
69 local_socket::ListenerOptionsExt, security_descriptor::SecurityDescriptor,
70 };
71 options = options.security_descriptor(SecurityDescriptor::new().unwrap());
72 }
73 let listener = match options.create_sync() {
74 Err(e) => return Err(IpcServerError::BindError(e)),
75 Ok(listener) => listener,
76 };
77
78 let error_handler = move |inc: Result<LocalSocketStream, io::Error>| match inc {
79 Ok(conn) => Some(conn),
80 Err(e) => {
81 if let Some(on_connection_error) = on_connection_error {
82 on_connection_error(e);
83 }
84 None
85 }
86 };
87
88 for stream in listener.incoming().filter_map(error_handler) {
89 let logname = "listener".to_string();
90 let stream = IpcStream { logname, stream };
91 on_connection(stream);
92 }
93
94 Ok(())
95}
96
97fn ipc_client_connect(socket_name: &str) -> Result<LocalSocketStream, IpcClientError> {
99 let name = socket_name
100 .to_fs_name::<GenericFilePath>()
101 .map_err(IpcClientError::SocketNameError)?;
102 LocalSocketStream::connect(name).map_err(IpcClientError::ConnectError)
103}
104
105trait SocketExt {
106 fn read_serde<T: serde::de::DeserializeOwned>(&mut self) -> Result<T, IpcStreamReadError>;
107 fn write_serde<T: serde::Serialize>(&mut self, data: &T) -> Result<(), IpcStreamWriteError>;
108}
109
110impl SocketExt for LocalSocketStream {
111 fn read_serde<T: serde::de::DeserializeOwned>(&mut self) -> Result<T, IpcStreamReadError> {
115 let size = self.read_u32::<LittleEndian>()?;
116
117 let bytes = {
118 let mut bytes = vec![0; size as usize];
119
120 self.read_exact(&mut bytes)?;
121
122 bytes
123 };
124
125 let result: T = bincode::deserialize(&bytes)?;
126
127 Ok(result)
128 }
129
130 fn write_serde<T: serde::Serialize>(&mut self, data: &T) -> Result<(), IpcStreamWriteError> {
134 let bytes = bincode::serialize(data)?;
135
136 self.write_u32::<LittleEndian>(bytes.len() as u32)?;
137 self.write_all(&bytes)?;
138
139 Ok(())
140 }
141}
142
143pub struct IpcStream {
145 logname: String,
146 stream: LocalSocketStream,
147}
148
149impl IpcStream {
150 pub fn connect(logname: &str) -> Result<Self, IpcClientError> {
152 let socket_name = IpcStream::user_socket_name();
153 let mut stream = ipc_client_connect(&socket_name)?;
154 stream.write_serde(&Message::Connect)?;
155 Ok(IpcStream {
156 logname: logname.to_string(),
157 stream,
158 })
159 }
160 pub fn check_connection() -> Result<(), IpcClientError> {
162 IpcStream::connect("check_connection")?;
163 Ok(())
164 }
165 pub fn user_socket_name() -> String {
166 let user = get_user_name().unwrap_or("_".to_string());
167 IpcStream::socket_name(&user)
168 }
169 #[cfg(target_family = "unix")]
170 fn socket_name(user: &str) -> String {
171 let tmpdir = std::env::var("TMPDIR").ok();
172 format!(
173 "{}/shell-compose-{user}.sock",
174 tmpdir.as_deref().unwrap_or("/tmp")
175 )
176 }
177 #[cfg(target_family = "windows")]
178 fn socket_name(user: &str) -> String {
179 format!(r"\\.\pipe\shell-compose-{user}")
180 }
181 pub fn alive(&mut self) -> Result<(), IpcClientError> {
183 self.stream.write_serde(&Message::Connect)?;
184 Ok(())
185 }
186 pub fn send_message(&mut self, message: &Message) -> Result<(), IpcClientError> {
188 debug!(target: &self.logname, "send_message {message:?}");
189 self.stream.write_serde(&message)?;
190 Ok(())
191 }
192 pub fn receive_message(&mut self) -> Result<Message, IpcClientError> {
194 let message = self.stream.read_serde()?;
195 debug!(target: &self.logname, "receive_message {message:?}");
196 Ok(message)
197 }
198 pub fn send_query(&mut self, request: &Message) -> Result<Message, IpcClientError> {
201 self.send_message(request)?;
202 let response = self.receive_message()?;
203 Ok(response)
204 }
205}