rusty_tarantool/tarantool/
dispatch.rs1use core::pin::Pin;
2use std::collections::HashMap;
3use std::io;
4use std::string::ToString;
5use std::sync::{Arc, Mutex, RwLock};
6
7use futures::select;
8use futures::stream::StreamExt;
9use futures::SinkExt;
10use futures_channel::mpsc;
11use futures_channel::oneshot;
12use futures_util::FutureExt;
13
14use tokio::net::TcpStream;
15use tokio::time::{sleep_until, Duration, Instant};
16use tokio_util::time::{delay_queue, DelayQueue};
17use tokio_util::codec::{Decoder, Framed};
18
19use crate::tarantool::codec::{RequestId, TarantoolCodec, TarantoolFramedRequest};
20use crate::tarantool::packets::{AuthPacket, CommandPacket, TarantoolRequest, TarantoolResponse};
21
22pub type TarantoolFramed = Framed<TcpStream, TarantoolCodec>;
23pub type CallbackSender = oneshot::Sender<io::Result<TarantoolResponse>>;
24pub type ReconnectNotifySender = mpsc::UnboundedSender<ClientStatus>;
25
26pub static ERROR_DISPATCH_THREAD_IS_DEAD: &str = "DISPATCH THREAD IS DEAD!";
28pub static ERROR_CLIENT_DISCONNECTED: &str = "CLIENT DISCONNECTED!";
29#[derive(Clone, PartialEq, Eq, Hash, Debug)]
42pub struct ClientConfig {
43 addr: String,
44 login: String,
45 password: String,
46 reconnect_time_ms: u64,
47 timeout_time_ms: Option<u64>,
48}
49
50impl ClientConfig {
51 pub fn new<S0, S, S1>(addr: S0, login: S, password: S1) -> ClientConfig
52 where
53 S0: Into<String>,
54 S: Into<String>,
55 S1: Into<String>,
56 {
57 ClientConfig {
58 addr: addr.into(),
59 login: login.into(),
60 password: password.into(),
61 reconnect_time_ms: 10000,
62 timeout_time_ms: None,
63 }
64 }
65
66 pub fn set_timeout_time_ms(mut self, timeout_time_ms: u64) -> ClientConfig {
67 self.timeout_time_ms = Some(timeout_time_ms);
68 self
69 }
70
71 pub fn set_reconnect_time_ms(mut self, reconnect_time_ms: u64) -> ClientConfig {
72 self.reconnect_time_ms = reconnect_time_ms;
73 self
74 }
75}
76
77#[derive(Clone, Debug)]
78pub enum ClientStatus {
79 Init,
80 Connecting,
81 Handshaking,
82 Connected,
83 Disconnecting(String),
84 Disconnected(String),
85 Closed,
86}
87
88pub struct Dispatch {
89 config: ClientConfig,
90 command_receiver: mpsc::UnboundedReceiver<(CommandPacket, CallbackSender)>,
92 is_command_receiver_closed: bool,
93 awaiting_callbacks: HashMap<RequestId, CallbackSender>,
94 notify_callbacks: Arc<Mutex<Vec<ReconnectNotifySender>>>,
95
96 buffered_command: Option<TarantoolFramedRequest>,
97 command_counter: RequestId,
98
99 timeout_time_ms: Option<u64>,
100 timeout_queue: DelayQueue<RequestId>,
101 timeout_id_to_key: HashMap<RequestId, delay_queue::Key>,
102
103 status: Arc<RwLock<ClientStatus>>,
104}
105
106impl Dispatch {
107 pub fn new(
108 config: ClientConfig,
109 command_receiver: mpsc::UnboundedReceiver<(CommandPacket, CallbackSender)>,
110 status: Arc<RwLock<ClientStatus>>,
111 notify_callbacks: Arc<Mutex<Vec<ReconnectNotifySender>>>,
112 ) -> Dispatch {
113 let timeout_time_ms = config.timeout_time_ms;
114 Dispatch {
115 config,
116 command_receiver,
117 is_command_receiver_closed:false,
118 buffered_command: None,
119 awaiting_callbacks: HashMap::new(),
120 notify_callbacks,
121 command_counter: 3,
122 timeout_time_ms,
123 timeout_queue: DelayQueue::new(),
124 timeout_id_to_key: HashMap::new(),
125 status,
126 }
127 }
128
129 async fn send_notify(&mut self, status: &ClientStatus) {
131 let callbacks: Vec<ReconnectNotifySender> =
132 self.notify_callbacks.lock().unwrap().split_off(0);
133 let mut filtered_callbacks: Vec<ReconnectNotifySender> = Vec::new();
134 for mut callback in callbacks {
135 if callback.send(status.clone()).await.is_ok() {
136 filtered_callbacks.push(callback);
137 }
138 }
139
140 self.notify_callbacks
141 .lock()
142 .unwrap()
143 .extend(filtered_callbacks.iter().cloned());
144 }
145
146 async fn try_send_buffered_command(&mut self, sink: &mut TarantoolFramed) -> io::Result<()> {
148 if let Some(command) = self.buffered_command.take() {
149 if let Err(e) = Pin::new(sink).send(command.clone()).await {
150 self.buffered_command = Some(command);
152 return Err(io::Error::new(
153 io::ErrorKind::ConnectionAborted,
154 e.to_string(),
155 ));
156 }
157 }
158 Ok(())
159 }
160
161 fn send_error_to_all(&mut self, error_description: String) {
163 for (_, callback_sender) in self.awaiting_callbacks.drain() {
164 let _res = callback_sender.send(Err(io::Error::new(
165 io::ErrorKind::Other,
166 error_description.clone(),
167 )));
168 }
169 self.buffered_command = None;
170
171 if self.timeout_time_ms.is_some() {
172 self.timeout_id_to_key.clear();
173 self.timeout_queue.clear();
174 }
175
176 if !self.is_command_receiver_closed {
177 while let Ok(Some((_, callback_sender))) = self.command_receiver.try_next() {
178 let _res = callback_sender.send(Err(io::Error::new(
179 io::ErrorKind::Other,
180 error_description.clone(),
181 )));
182 }
183 }
184 }
185
186 async fn process_command(
188 &mut self,
189 command: Option<(CommandPacket, CallbackSender)>,
190 sink: &mut TarantoolFramed,
191 ) -> io::Result<()> {
192 self.try_send_buffered_command(sink).await?;
193
194 match command {
195 Some((command_packet, callback_sender)) => {
196 let request_id = self.increment_command_counter();
197 self.awaiting_callbacks.insert(request_id, callback_sender);
198 self.buffered_command =
199 Some((request_id, TarantoolRequest::Command(command_packet)));
200 if let Some(timeout_time_ms) = self.timeout_time_ms {
201 let delay_key = self.timeout_queue.insert_at(
202 request_id,
203 Instant::now() + Duration::from_millis(timeout_time_ms),
204 );
205 self.timeout_id_to_key.insert(request_id, delay_key);
206 }
207 self.try_send_buffered_command(sink).await
209 }
210 None => {
211 self.is_command_receiver_closed = true;
212 Err(io::Error::new(
214 io::ErrorKind::InvalidInput,
215 "inbound commands queue is over",
216 ))
217 }
218 }
219 }
220
221 async fn process_tarantool_response(
223 &mut self,
224 response: Option<io::Result<(RequestId, io::Result<TarantoolResponse>)>>,
225 ) -> io::Result<()> {
226 debug!("receive command! {:?} ", response);
227 match response {
228 Some(Ok((request_id, Ok(command_packet)))) => {
229 debug!("receive command! {} {:?} ", request_id, command_packet);
230 if self.timeout_time_ms.is_some() {
231 if let Some(delay_key) = self.timeout_id_to_key.remove(&request_id) {
232 self.timeout_queue.remove(&delay_key);
233 }
234 }
235 if let Some(callback) = self.awaiting_callbacks.remove(&request_id) {
236 let _send_res = callback.send(Ok(command_packet));
237 }
238
239 Ok(())
240 },
241 Some(Ok((request_id, Err(e)))) => {
242 debug!("receive command! {} {:?} ", request_id, e);
243 if self.timeout_time_ms.is_some() {
244 if let Some(delay_key) = self.timeout_id_to_key.remove(&request_id) {
245 self.timeout_queue.remove(&delay_key);
246 }
247 }
248 if let Some(callback) = self.awaiting_callbacks.remove(&request_id) {
249 let _send_res = callback.send(Err(e));
250 }
251
252 Ok(())
253 },
254 None => Err(io::Error::new(
255 io::ErrorKind::ConnectionAborted,
256 "return none from stream!",
257 )),
258 _ => Ok(()),
259 }
260 }
261
262 fn increment_command_counter(&mut self) -> RequestId {
263 self.command_counter += 1;
264 self.command_counter
265 }
266
267 fn clean_command_counter(&mut self) {
268 self.command_counter = 3;
269 }
270
271 async fn set_status(&mut self, client_status: ClientStatus) {
272 self.send_notify(&client_status).await;
273 *(self.status.write().unwrap()) = client_status;
274 }
275
276 pub async fn run(&mut self) {
278 self.set_status(ClientStatus::Connecting).await;
279 loop {
280 match self.connect_and_process_commands().await {
281 Ok(()) => {
282 return;
284 }
285 Err(e) => {
286 self.set_status(ClientStatus::Disconnected(e.to_string()))
287 .await;
288 self.send_error_to_all(e.to_string());
289 sleep_until(Instant::now() + Duration::from_millis(self.config.reconnect_time_ms)).await;
290 }
291 }
292
293 if self.is_command_receiver_closed {
294 self.set_status(ClientStatus::Closed).await;
295 return;
296 }
297 }
298 }
299
300 async fn connect_and_process_commands(&mut self) -> io::Result<()> {
301 let tcp_stream = TcpStream::connect(self.config.addr.clone()).await?;
302 let mut framed_io = self.auth(tcp_stream).await?;
303 self.set_status(ClientStatus::Connected).await;
304 loop {
305 select! {
306 tarantool_response = framed_io.next().fuse() => {
307 self.process_tarantool_response(tarantool_response).await?
308 }
309 command = self.command_receiver.next() => {
310 self.process_command(command, &mut framed_io).await?
311 }
312 }
313 }
314 }
315
316 async fn auth(&mut self, tcp_stream: TcpStream) -> io::Result<TarantoolFramed> {
317 let codec: TarantoolCodec = Default::default();
318 let mut framed_io = codec.framed(tcp_stream);
319 let _first_response = framed_io.next().await;
320 framed_io
322 .send((
323 2,
324 TarantoolRequest::Auth(AuthPacket {
325 login: self.config.login.clone(),
326 password: self.config.password.clone(),
327 }),
328 ))
329 .await?;
330 let auth_response = framed_io.next().await;
331 match auth_response {
332 Some(Ok((_, Err(e)))) => Err(io::Error::new(
333 io::ErrorKind::PermissionDenied,
334 e.to_string(),
335 )),
336 _ => {
337 self.clean_command_counter();
338 Ok(framed_io)
339 }
340 }
341 }
342}