1use std::{
2 collections::HashMap,
3 fmt::Debug,
4 net::{SocketAddr, TcpStream},
5};
6
7use std::io;
8use std::sync::{Arc, Mutex};
9
10use io::{BufReader, Write};
11use protocol::RequestHeader;
12use serde::de::DeserializeOwned;
13
14const MAX_IPC_VERSION: u32 = 1;
15
16mod coordinates;
17mod members;
18mod request;
19mod stream;
20
21pub mod protocol;
22
23pub use request::RPCRequest;
24pub use stream::RPCStream;
25
26#[doc(hidden)]
30pub struct SeqRead<'a>(&'a mut BufReader<TcpStream>);
31impl<'a> SeqRead<'a> {
32 fn read_msg<T: DeserializeOwned + Debug>(mut self) -> T {
33 rmp_serde::from_read(&mut self.0).unwrap()
35 }
36}
37
38trait SeqHandler: 'static + Send + Sync {
39 fn handle(&self, res: RPCResult<SeqRead>);
40 fn streaming(&self) -> bool {
42 false
43 }
44}
45
46type RPCResult<T = ()> = Result<T, String>;
47
48#[derive(Clone)]
49pub struct Client {
50 dispatch: Arc<Mutex<DispatchMap>>,
51 tx: std::sync::mpsc::Sender<Vec<u8>>,
52}
53
54struct DispatchMap {
55 map: HashMap<u64, Arc<dyn SeqHandler>>,
56 next_seq: u64,
57}
58
59impl Client {
60 pub async fn connect(rpc_addr: SocketAddr, auth_key: Option<&str>) -> RPCResult<Self> {
64 let (tx, rx) = std::sync::mpsc::channel();
65
66 let dispatch = Arc::new(Mutex::new(DispatchMap {
67 map: HashMap::new(),
68 next_seq: 0,
69 }));
70
71 let client = Client { dispatch, tx };
72
73 let dispatch = Arc::downgrade(&client.dispatch);
74
75 std::thread::spawn(move || {
76 let mut stream = TcpStream::connect(rpc_addr).unwrap();
77
78 let mut reader = BufReader::new(stream.try_clone().unwrap());
80
81 std::thread::spawn(move || {
83 while let Ok(buf) = rx.recv() {
84 stream.write_all(&buf).unwrap();
85 }
86 });
87
88 while let Some(dispatch) = dispatch.upgrade() {
90 let protocol::ResponseHeader { seq, error } =
91 rmp_serde::from_read(&mut reader).unwrap();
92
93 let seq_handler = {
94 let mut dispatch = dispatch.lock().unwrap();
95 match dispatch.map.get(&seq) {
96 Some(v) => {
97 if v.streaming() {
98 v.clone()
99 } else {
100 dispatch.map.remove(&seq).unwrap()
101 }
102 }
103 None => {
104 continue;
106 }
107 }
108 };
109
110 let res = if error.is_empty() {
111 Ok(SeqRead(&mut reader))
112 } else {
113 Err(error)
114 };
115
116 seq_handler.handle(res);
117 }
118 });
119
120 client.handshake(MAX_IPC_VERSION).await?;
121
122 if let Some(auth_key) = auth_key {
123 client.auth(auth_key).await?;
124 }
125
126 return Ok(client);
127 }
128
129 fn deregister_seq_handler(&self, seq: u64) -> Option<Arc<dyn SeqHandler>> {
130 self.dispatch.lock().unwrap().map.remove(&seq)
131 }
132
133 fn send_command(&self, cmd: SerializedCommand, handler: Option<Arc<dyn SeqHandler>>) -> u64 {
137 let seq = {
138 let mut dispatch = self.dispatch.lock().unwrap();
139
140 let seq = dispatch.next_seq;
141 dispatch.next_seq += 1;
142
143 if let Some(handler) = handler {
144 dispatch.map.insert(seq, handler);
145 }
146
147 seq
148 };
149
150 let mut buf = rmp_serde::encode::to_vec_named(&RequestHeader {
151 command: cmd.name,
152 seq,
153 })
154 .unwrap();
155 buf.extend_from_slice(&cmd.body);
156
157 self.tx.send(buf).unwrap();
158
159 seq
160 }
161
162 pub async fn current_node_name(&self) -> RPCResult<String> {
163 Ok(self.stats().await?.agent.name)
164 }
165}
166
167struct SerializedCommand {
168 name: &'static str,
169 body: Vec<u8>,
170}
171
172#[doc(hidden)]
176pub trait RPCResponse: Sized + Send + 'static {
177 fn read_from(read: SeqRead<'_>) -> RPCResult<Self>;
178}
179
180impl RPCResponse for () {
181 fn read_from(_: SeqRead<'_>) -> RPCResult<Self> {
182 Ok(())
183 }
184}