unixconn_rust/
lib.rs

1
2use std::os::unix::net::UnixStream;
3use std::io::{self, Read, Write, BufReader};
4use std::error::Error;
5use bytes::{Bytes, BytesMut, BufMut};
6use uuid::Uuid;
7use std::time::Duration;
8
9const PROTOCOL_FIELDS: usize = 4;
10
11fn metadata_delim() -> &'static [u8] {
12    &[0x1E]
13}
14
15fn message_delim() -> u8 {
16    0x1F
17}
18
19struct Message {
20    request_id: String,
21    method_name: String,
22    body: Bytes,
23    error: String,
24}
25
26pub struct Client {
27    conn: UnixStream,
28    timeout: u64
29}
30
31impl Client {
32    pub fn new(address: &str, timeout: u64) -> Result<Self, Box<dyn Error>> {
33        let conn = UnixStream::connect(address)?;
34        Ok(Client { conn, timeout })
35    }
36
37    pub fn close(&self) -> io::Result<()> {
38        self.conn.shutdown(std::net::Shutdown::Both)
39    }
40
41    pub fn do_request(&mut self, method_name: &str, request_body: &[u8]) -> Result<Bytes, Box<dyn Error>> {
42        let request_id = Uuid::new_v4().to_string();
43        let request = Message {
44            request_id: request_id.clone(),
45            method_name: method_name.to_string(),
46            body: Bytes::from(request_body.to_vec()),
47            error: String::new(),
48        };
49
50        let raw_request = message_to_bytes(&request);
51        self.conn.write_all(&raw_request)?;
52        self.conn.set_read_timeout(Some(Duration::from_secs(self.timeout)))?;
53
54        let mut reader = BufReader::new(&self.conn);
55        let message = read_message(&mut reader)?;
56
57        if !message.error.is_empty() {
58            return Err(format!("client response error: {}", message.error).into());
59        }
60
61        if message.request_id != request_id {
62            return Err(format!("client wrong requestID error: {}", message.error).into());
63        }
64
65        Ok(message.body)
66    }
67}
68
69fn parse_message(body: &[u8]) -> Result<Message, Box<dyn Error>> {
70    let parts: Vec<&[u8]> = body.split(|&b| b == metadata_delim()[0]).collect();
71    if parts.len() != PROTOCOL_FIELDS {
72        return Err(format!("error protocol received message with {} parts, expected {}", parts.len(), PROTOCOL_FIELDS).into());
73    }
74
75    Ok(Message {
76        request_id: String::from_utf8(parts[0].to_vec())?,
77        method_name: String::from_utf8(parts[1].to_vec())?,
78        error: String::from_utf8(parts[2].to_vec())?,
79        body: Bytes::from(parts[3].to_vec()),
80    })
81}
82
83fn message_to_bytes(r: &Message) -> Bytes {
84    let mut buffer = BytesMut::new();
85
86    buffer.put(r.request_id.as_bytes());
87    buffer.put(metadata_delim());
88
89    buffer.put(r.method_name.as_bytes());
90    buffer.put(metadata_delim());
91
92    buffer.put(r.error.as_bytes());
93    buffer.put(metadata_delim());
94
95    buffer.put(&r.body[..]);
96    buffer.put_u8(message_delim());
97
98    buffer.freeze()
99}
100
101fn read_message<R: Read>(reader: &mut R) -> Result<Message, Box<dyn Error>> {
102    let mut message_body = Vec::new();
103    let mut byte = [0u8; 1];
104
105    loop {
106        reader.read_exact(&mut byte)?;
107        if byte[0] == message_delim() {
108            break;
109        }
110        message_body.push(byte[0]);
111    }
112
113    parse_message(&message_body)
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119
120    fn run_client() -> Result<(), Box<dyn Error>> {
121        let mut client = Client::new("/tmp/salt-ssd.sock", 10)?;
122        let method_name = "getnssusers";
123        let request_body = b"";
124
125        let response = client.do_request(method_name, request_body)?;
126        match std::str::from_utf8(&response) {
127            Ok(s) => println!("Received response: {}", s),
128            Err(e) => eprintln!("Response was not valid UTF-8: {}", e),
129        }
130        client.close()?;
131
132        Ok(())
133    }
134
135    #[test]
136    fn it_works() {
137        if let Err(e) = run_client() {
138            eprintln!("Error: {}", e.to_string());
139            std::process::exit(1);
140        }
141    }
142}