1
2use std::os::unix::net::UnixStream;
3use std::io::{self, Read, Write, BufReader};
4use std::error::Error;
5use bytes::{Bytes, BytesMut, BufMut};
6use uuid::Uuid;
7use std::time::Duration;
8
9const PROTOCOL_FIELDS: usize = 4;
10
11fn metadata_delim() -> &'static [u8] {
12 &[0x1E]
13}
14
15fn message_delim() -> u8 {
16 0x1F
17}
18
19struct Message {
20 request_id: String,
21 method_name: String,
22 body: Bytes,
23 error: String,
24}
25
26pub struct Client {
27 conn: UnixStream,
28 timeout: u64
29}
30
31impl Client {
32 pub fn new(address: &str, timeout: u64) -> Result<Self, Box<dyn Error>> {
33 let conn = UnixStream::connect(address)?;
34 Ok(Client { conn, timeout })
35 }
36
37 pub fn close(&self) -> io::Result<()> {
38 self.conn.shutdown(std::net::Shutdown::Both)
39 }
40
41 pub fn do_request(&mut self, method_name: &str, request_body: &[u8]) -> Result<Bytes, Box<dyn Error>> {
42 let request_id = Uuid::new_v4().to_string();
43 let request = Message {
44 request_id: request_id.clone(),
45 method_name: method_name.to_string(),
46 body: Bytes::from(request_body.to_vec()),
47 error: String::new(),
48 };
49
50 let raw_request = message_to_bytes(&request);
51 self.conn.write_all(&raw_request)?;
52 self.conn.set_read_timeout(Some(Duration::from_secs(self.timeout)))?;
53
54 let mut reader = BufReader::new(&self.conn);
55 let message = read_message(&mut reader)?;
56
57 if !message.error.is_empty() {
58 return Err(format!("client response error: {}", message.error).into());
59 }
60
61 if message.request_id != request_id {
62 return Err(format!("client wrong requestID error: {}", message.error).into());
63 }
64
65 Ok(message.body)
66 }
67}
68
69fn parse_message(body: &[u8]) -> Result<Message, Box<dyn Error>> {
70 let parts: Vec<&[u8]> = body.split(|&b| b == metadata_delim()[0]).collect();
71 if parts.len() != PROTOCOL_FIELDS {
72 return Err(format!("error protocol received message with {} parts, expected {}", parts.len(), PROTOCOL_FIELDS).into());
73 }
74
75 Ok(Message {
76 request_id: String::from_utf8(parts[0].to_vec())?,
77 method_name: String::from_utf8(parts[1].to_vec())?,
78 error: String::from_utf8(parts[2].to_vec())?,
79 body: Bytes::from(parts[3].to_vec()),
80 })
81}
82
83fn message_to_bytes(r: &Message) -> Bytes {
84 let mut buffer = BytesMut::new();
85
86 buffer.put(r.request_id.as_bytes());
87 buffer.put(metadata_delim());
88
89 buffer.put(r.method_name.as_bytes());
90 buffer.put(metadata_delim());
91
92 buffer.put(r.error.as_bytes());
93 buffer.put(metadata_delim());
94
95 buffer.put(&r.body[..]);
96 buffer.put_u8(message_delim());
97
98 buffer.freeze()
99}
100
101fn read_message<R: Read>(reader: &mut R) -> Result<Message, Box<dyn Error>> {
102 let mut message_body = Vec::new();
103 let mut byte = [0u8; 1];
104
105 loop {
106 reader.read_exact(&mut byte)?;
107 if byte[0] == message_delim() {
108 break;
109 }
110 message_body.push(byte[0]);
111 }
112
113 parse_message(&message_body)
114}
115
116#[cfg(test)]
117mod tests {
118 use super::*;
119
120 fn run_client() -> Result<(), Box<dyn Error>> {
121 let mut client = Client::new("/tmp/salt-ssd.sock", 10)?;
122 let method_name = "getnssusers";
123 let request_body = b"";
124
125 let response = client.do_request(method_name, request_body)?;
126 match std::str::from_utf8(&response) {
127 Ok(s) => println!("Received response: {}", s),
128 Err(e) => eprintln!("Response was not valid UTF-8: {}", e),
129 }
130 client.close()?;
131
132 Ok(())
133 }
134
135 #[test]
136 fn it_works() {
137 if let Err(e) = run_client() {
138 eprintln!("Error: {}", e.to_string());
139 std::process::exit(1);
140 }
141 }
142}