serf_rpc/
lib.rs

1use std::{
2    collections::HashMap,
3    fmt::Debug,
4    net::{SocketAddr, TcpStream},
5};
6
7use std::io;
8use std::sync::{Arc, Mutex};
9
10use io::{BufReader, Write};
11use protocol::RequestHeader;
12use serde::de::DeserializeOwned;
13
14const MAX_IPC_VERSION: u32 = 1;
15
16mod coordinates;
17mod members;
18mod request;
19mod stream;
20
21pub mod protocol;
22
23pub use request::RPCRequest;
24pub use stream::RPCStream;
25
26/// A wrapper allowing reading a Seq response.
27///
28/// This is an internal implementation detail, but public because it is exposed in traits.
29#[doc(hidden)]
30pub struct SeqRead<'a>(&'a mut BufReader<TcpStream>);
31impl<'a> SeqRead<'a> {
32    fn read_msg<T: DeserializeOwned + Debug>(mut self) -> T {
33        // annoyingly, we pretty much have to panic, because otherwise the reader is left in an invalid state
34        rmp_serde::from_read(&mut self.0).unwrap()
35    }
36}
37
38trait SeqHandler: 'static + Send + Sync {
39    fn handle(&self, res: RPCResult<SeqRead>);
40    /// are we expecting more than one response?
41    fn streaming(&self) -> bool {
42        false
43    }
44}
45
46type RPCResult<T = ()> = Result<T, String>;
47
48#[derive(Clone)]
49pub struct Client {
50    dispatch: Arc<Mutex<DispatchMap>>,
51    tx: std::sync::mpsc::Sender<Vec<u8>>,
52}
53
54struct DispatchMap {
55    map: HashMap<u64, Arc<dyn SeqHandler>>,
56    next_seq: u64,
57}
58
59impl Client {
60    /// Connect to hub.
61    ///
62    /// Waits for handshake, and optionally for authentication if an auth key is provided.
63    pub async fn connect(rpc_addr: SocketAddr, auth_key: Option<&str>) -> RPCResult<Self> {
64        let (tx, rx) = std::sync::mpsc::channel();
65
66        let dispatch = Arc::new(Mutex::new(DispatchMap {
67            map: HashMap::new(),
68            next_seq: 0,
69        }));
70
71        let client = Client { dispatch, tx };
72
73        let dispatch = Arc::downgrade(&client.dispatch);
74
75        std::thread::spawn(move || {
76            let mut stream = TcpStream::connect(rpc_addr).unwrap();
77
78            // clone the stream to create a reader
79            let mut reader = BufReader::new(stream.try_clone().unwrap());
80
81            // write loop
82            std::thread::spawn(move || {
83                while let Ok(buf) = rx.recv() {
84                    stream.write_all(&buf).unwrap();
85                }
86            });
87
88            // read loop
89            while let Some(dispatch) = dispatch.upgrade() {
90                let protocol::ResponseHeader { seq, error } =
91                    rmp_serde::from_read(&mut reader).unwrap();
92
93                let seq_handler = {
94                    let mut dispatch = dispatch.lock().unwrap();
95                    match dispatch.map.get(&seq) {
96                        Some(v) => {
97                            if v.streaming() {
98                                v.clone()
99                            } else {
100                                dispatch.map.remove(&seq).unwrap()
101                            }
102                        }
103                        None => {
104                            // response with no handler, ignore
105                            continue;
106                        }
107                    }
108                };
109
110                let res = if error.is_empty() {
111                    Ok(SeqRead(&mut reader))
112                } else {
113                    Err(error)
114                };
115
116                seq_handler.handle(res);
117            }
118        });
119
120        client.handshake(MAX_IPC_VERSION).await?;
121
122        if let Some(auth_key) = auth_key {
123            client.auth(auth_key).await?;
124        }
125
126        return Ok(client);
127    }
128
129    fn deregister_seq_handler(&self, seq: u64) -> Option<Arc<dyn SeqHandler>> {
130        self.dispatch.lock().unwrap().map.remove(&seq)
131    }
132
133    /// Send a command, optionally registering a handler for responses.
134    ///
135    /// Returns the sequence number.
136    fn send_command(&self, cmd: SerializedCommand, handler: Option<Arc<dyn SeqHandler>>) -> u64 {
137        let seq = {
138            let mut dispatch = self.dispatch.lock().unwrap();
139
140            let seq = dispatch.next_seq;
141            dispatch.next_seq += 1;
142
143            if let Some(handler) = handler {
144                dispatch.map.insert(seq, handler);
145            }
146
147            seq
148        };
149
150        let mut buf = rmp_serde::encode::to_vec_named(&RequestHeader {
151            command: cmd.name,
152            seq,
153        })
154        .unwrap();
155        buf.extend_from_slice(&cmd.body);
156
157        self.tx.send(buf).unwrap();
158
159        seq
160    }
161
162    pub async fn current_node_name(&self) -> RPCResult<String> {
163        Ok(self.stats().await?.agent.name)
164    }
165}
166
167struct SerializedCommand {
168    name: &'static str,
169    body: Vec<u8>,
170}
171
172/// A trait for types that can be deserialized as the response to a command
173///
174/// This is an internal implementation detail, but public because it is exposed in traits.
175#[doc(hidden)]
176pub trait RPCResponse: Sized + Send + 'static {
177    fn read_from(read: SeqRead<'_>) -> RPCResult<Self>;
178}
179
180impl RPCResponse for () {
181    fn read_from(_: SeqRead<'_>) -> RPCResult<Self> {
182        Ok(())
183    }
184}