shell_compose/
ipc.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
use crate::Message;
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use interprocess::local_socket::{prelude::*, GenericNamespaced, ListenerOptions};
use log::debug;
use std::io;
use std::io::prelude::*;
use thiserror::Error;

#[derive(Error, Debug)]
pub enum IpcServerError {
    #[error("Failed to bind to socket: {0}")]
    BindError(io::Error),
    #[error("Failed to resolve socket name: {0}")]
    SocketNameError(io::Error),
}

#[derive(Error, Debug)]
pub enum IpcClientError {
    #[error("Failed to connect to socket: {0}")]
    ConnectError(#[from] io::Error),
    #[error("Failed to resolve socket name: {0}")]
    SocketNameError(io::Error),
    #[error("Failed to read from socket: {0}")]
    ReadError(#[from] IpcStreamReadError),
    #[error("Failed to write to socket: {0}")]
    WriteError(#[from] IpcStreamWriteError),
}

#[derive(Error, Debug)]
pub enum IpcStreamReadError {
    #[error("Failed to read from socket: {0}")]
    ReadError(#[from] io::Error),
    #[error("Failed to deserialize data from socket: {0}")]
    DeserializeError(#[from] bincode::Error),
}

#[derive(Error, Debug)]
pub enum IpcStreamWriteError {
    #[error("Failed to write to socket: {0}")]
    WriteError(#[from] io::Error),
    #[error("Failed to serialize data for socket: {0}")]
    SerializeError(#[from] bincode::Error),
}

/// Listen for incoming connections on the given socket.
///
/// # Arguments
///
/// * `socket` - The socket name to listen on.
/// * `on_connection` - A function that will be invoked for each incoming connection.
/// * `on_connection_error` - An optional function that will be invoked if there is an error accepting a connection.
pub fn start_ipc_listener<F: FnMut(IpcStream) + Send + 'static>(
    socket: &str,
    mut on_connection: F,
    on_connection_error: Option<fn(io::Error)>,
) -> Result<(), IpcServerError> {
    let name = socket
        .to_ns_name::<GenericNamespaced>()
        .map_err(IpcServerError::SocketNameError)?;
    let listener = match ListenerOptions::new().name(name.clone()).create_sync() {
        Err(e) => return Err(IpcServerError::BindError(e)),
        Ok(listener) => listener,
    };

    let error_handler = move |inc: Result<LocalSocketStream, io::Error>| match inc {
        Ok(conn) => Some(conn),
        Err(e) => {
            if let Some(on_connection_error) = on_connection_error {
                on_connection_error(e);
            }
            None
        }
    };

    for stream in listener.incoming().filter_map(error_handler) {
        let logname = "listener".to_string();
        let stream = IpcStream { logname, stream };
        on_connection(stream);
    }

    Ok(())
}

/// Connect to the socket and return the stream.
fn ipc_client_connect(socket_name: &str) -> Result<LocalSocketStream, IpcClientError> {
    let name = socket_name
        .to_ns_name::<GenericNamespaced>()
        .map_err(IpcClientError::SocketNameError)?;
    LocalSocketStream::connect(name).map_err(IpcClientError::ConnectError)
}

pub trait SocketExt {
    fn read_serde<T: serde::de::DeserializeOwned>(&mut self) -> Result<T, IpcStreamReadError>;
    fn write_serde<T: serde::Serialize>(&mut self, data: &T) -> Result<(), IpcStreamWriteError>;
}

impl SocketExt for LocalSocketStream {
    /// Read a serializable object from the socket.
    ///
    /// This reads a `u32` in little endian, then reads that many bytes from the socket, then deserializes the data using `bincode::deserialize`.
    fn read_serde<T: serde::de::DeserializeOwned>(&mut self) -> Result<T, IpcStreamReadError> {
        let size = self.read_u32::<LittleEndian>()?;

        let bytes = {
            let mut bytes = vec![0; size as usize];

            self.read_exact(&mut bytes)?;

            bytes
        };

        let result: T = bincode::deserialize(&bytes)?;

        Ok(result)
    }

    /// Write a serializable object to the socket.
    ///
    /// This serializes the data using `bincode::serialize`, writes the length of the serialized data as a `u32` in little endian, then writes the serialized data.
    fn write_serde<T: serde::Serialize>(&mut self, data: &T) -> Result<(), IpcStreamWriteError> {
        let bytes = bincode::serialize(data)?;

        self.write_u32::<LittleEndian>(bytes.len() as u32)?;
        self.write_all(&bytes)?;

        Ok(())
    }
}

/// Communication stream
pub struct IpcStream {
    logname: String,
    stream: LocalSocketStream,
}

impl IpcStream {
    /// Connects to the socket and return the stream
    pub fn connect(logname: &str, socket_name: &str) -> Result<Self, IpcClientError> {
        let mut stream = ipc_client_connect(socket_name)?;
        stream.write_serde(&Message::Connect)?;
        Ok(IpcStream {
            logname: logname.to_string(),
            stream,
        })
    }
    /// Check socket connection
    pub fn check_connection(socket_name: &str) -> Result<(), IpcClientError> {
        IpcStream::connect("check_connection", socket_name)?;
        Ok(())
    }
    /// Send Message.
    pub fn send_message(&mut self, message: &Message) -> Result<(), IpcClientError> {
        debug!(target: &self.logname, "send_message {message:?}");
        self.stream.write_serde(&message)?;
        Ok(())
    }
    /// Receive Message.
    pub fn receive_message(&mut self) -> Result<Message, IpcClientError> {
        let message = self.stream.read_serde()?;
        debug!(target: &self.logname, "receive_message {message:?}");
        Ok(message)
    }
    /// Send a message and immediately read response message,
    /// blocking until a response is received.
    pub fn send_query(&mut self, request: &Message) -> Result<Message, IpcClientError> {
        self.send_message(request)?;
        let response = self.receive_message()?;
        Ok(response)
    }
}