Skip to main content

rusty_tarantool/tarantool/
dispatch.rs

1use core::pin::Pin;
2use std::collections::HashMap;
3use std::io;
4use std::string::ToString;
5use std::sync::{Arc, Mutex, RwLock};
6
7use futures::select;
8use futures::stream::StreamExt;
9use futures::SinkExt;
10use futures_channel::mpsc;
11use futures_channel::oneshot;
12use futures_util::FutureExt;
13
14use tokio::net::TcpStream;
15use tokio::time::{sleep_until, Duration, Instant};
16use tokio_util::time::{delay_queue, DelayQueue};
17use tokio_util::codec::{Decoder, Framed};
18
19use crate::tarantool::codec::{RequestId, TarantoolCodec, TarantoolFramedRequest};
20use crate::tarantool::packets::{AuthPacket, CommandPacket, TarantoolRequest, TarantoolResponse};
21
22pub type TarantoolFramed = Framed<TcpStream, TarantoolCodec>;
23pub type CallbackSender = oneshot::Sender<io::Result<TarantoolResponse>>;
24pub type ReconnectNotifySender = mpsc::UnboundedSender<ClientStatus>;
25
26//static ERROR_SERVER_DISCONNECT: &str = "SERVER DISCONNECTED!";
27pub static ERROR_DISPATCH_THREAD_IS_DEAD: &str = "DISPATCH THREAD IS DEAD!";
28pub static ERROR_CLIENT_DISCONNECTED: &str = "CLIENT DISCONNECTED!";
29//static ERROR_TIMEOUT: &str = "TIMEOUT!";
30
31///
32/// Tarantool client config
33///
34/// # Examples
35/// ```text
36/// let client = ClientConfig::new(addr, "rust", "rust")
37///            .set_timeout_time_ms(1000)
38///            .set_reconnect_time_ms(10000)
39///            .build();
40///
41#[derive(Clone, PartialEq, Eq, Hash, Debug)]
42pub struct ClientConfig {
43    addr: String,
44    login: String,
45    password: String,
46    reconnect_time_ms: u64,
47    timeout_time_ms: Option<u64>,
48}
49
50impl ClientConfig {
51    pub fn new<S0, S, S1>(addr: S0, login: S, password: S1) -> ClientConfig
52    where
53        S0: Into<String>,
54        S: Into<String>,
55        S1: Into<String>,
56    {
57        ClientConfig {
58            addr: addr.into(),
59            login: login.into(),
60            password: password.into(),
61            reconnect_time_ms: 10000,
62            timeout_time_ms: None,
63        }
64    }
65
66    pub fn set_timeout_time_ms(mut self, timeout_time_ms: u64) -> ClientConfig {
67        self.timeout_time_ms = Some(timeout_time_ms);
68        self
69    }
70
71    pub fn set_reconnect_time_ms(mut self, reconnect_time_ms: u64) -> ClientConfig {
72        self.reconnect_time_ms = reconnect_time_ms;
73        self
74    }
75}
76
77#[derive(Clone, Debug)]
78pub enum ClientStatus {
79    Init,
80    Connecting,
81    Handshaking,
82    Connected,
83    Disconnecting(String),
84    Disconnected(String),
85    Closed,
86}
87
88pub struct Dispatch {
89    config: ClientConfig,
90    // engine: DispatchEngine,
91    command_receiver: mpsc::UnboundedReceiver<(CommandPacket, CallbackSender)>,
92    is_command_receiver_closed: bool,
93    awaiting_callbacks: HashMap<RequestId, CallbackSender>,
94    notify_callbacks: Arc<Mutex<Vec<ReconnectNotifySender>>>,
95
96    buffered_command: Option<TarantoolFramedRequest>,
97    command_counter: RequestId,
98
99    timeout_time_ms: Option<u64>,
100    timeout_queue: DelayQueue<RequestId>,
101    timeout_id_to_key: HashMap<RequestId, delay_queue::Key>,
102
103    status: Arc<RwLock<ClientStatus>>,
104}
105
106impl Dispatch {
107    pub fn new(
108        config: ClientConfig,
109        command_receiver: mpsc::UnboundedReceiver<(CommandPacket, CallbackSender)>,
110        status: Arc<RwLock<ClientStatus>>,
111        notify_callbacks: Arc<Mutex<Vec<ReconnectNotifySender>>>,
112    ) -> Dispatch {
113        let timeout_time_ms = config.timeout_time_ms;
114        Dispatch {
115            config,
116            command_receiver,
117            is_command_receiver_closed:false,
118            buffered_command: None,
119            awaiting_callbacks: HashMap::new(),
120            notify_callbacks,
121            command_counter: 3,
122            timeout_time_ms,
123            timeout_queue: DelayQueue::new(),
124            timeout_id_to_key: HashMap::new(),
125            status,
126        }
127    }
128
129    ///send status notification to all subscribers
130    async fn send_notify(&mut self, status: &ClientStatus) {
131        let callbacks: Vec<ReconnectNotifySender> =
132            self.notify_callbacks.lock().unwrap().split_off(0);
133        let mut filtered_callbacks: Vec<ReconnectNotifySender> = Vec::new();
134        for mut callback in callbacks {
135            if callback.send(status.clone()).await.is_ok() {
136                filtered_callbacks.push(callback);
137            }
138        }
139
140        self.notify_callbacks
141            .lock()
142            .unwrap()
143            .extend(filtered_callbacks.iter().cloned());
144    }
145
146    ///send command from buffer. if not success, return command to buffer and initiate reconnect
147    async fn try_send_buffered_command(&mut self, sink: &mut TarantoolFramed) -> io::Result<()> {
148        if let Some(command) = self.buffered_command.take() {
149            if let Err(e) = Pin::new(sink).send(command.clone()).await {
150                //return command to buffer
151                self.buffered_command = Some(command);
152                return Err(io::Error::new(
153                    io::ErrorKind::ConnectionAborted,
154                    e.to_string(),
155                ));
156            }
157        }
158        Ok(())
159    }
160
161    ///send error to all awaiting callbacks
162    fn send_error_to_all(&mut self, error_description: String) {
163        for (_, callback_sender) in self.awaiting_callbacks.drain() {
164            let _res = callback_sender.send(Err(io::Error::new(
165                io::ErrorKind::Other,
166                error_description.clone(),
167            )));
168        }
169        self.buffered_command = None;
170
171        if self.timeout_time_ms.is_some() {
172            self.timeout_id_to_key.clear();
173            self.timeout_queue.clear();
174        }
175
176        if !self.is_command_receiver_closed {
177            while let Ok(Some((_, callback_sender))) = self.command_receiver.try_next() {
178                let _res = callback_sender.send(Err(io::Error::new(
179                    io::ErrorKind::Other,
180                    error_description.clone(),
181                )));
182            }
183        }
184    }
185
186    ///process command - send to tarantool, store callback
187    async fn process_command(
188        &mut self,
189        command: Option<(CommandPacket, CallbackSender)>,
190        sink: &mut TarantoolFramed,
191    ) -> io::Result<()> {
192        self.try_send_buffered_command(sink).await?;
193
194        match command {
195            Some((command_packet, callback_sender)) => {
196                let request_id = self.increment_command_counter();
197                self.awaiting_callbacks.insert(request_id, callback_sender);
198                self.buffered_command =
199                    Some((request_id, TarantoolRequest::Command(command_packet)));
200                if let Some(timeout_time_ms) = self.timeout_time_ms {
201                    let delay_key = self.timeout_queue.insert_at(
202                        request_id,
203                        Instant::now() + Duration::from_millis(timeout_time_ms),
204                    );
205                    self.timeout_id_to_key.insert(request_id, delay_key);
206                }
207                //if return disconnected - retry
208                self.try_send_buffered_command(sink).await
209            }
210            None => {
211                self.is_command_receiver_closed = true;
212                //inbound sink is finished. close coroutine
213                Err(io::Error::new(
214                    io::ErrorKind::InvalidInput,
215                    "inbound commands queue is over",
216                ))
217            }
218        }
219    }
220
221    ///process tarantool response
222    async fn process_tarantool_response(
223        &mut self,
224        response: Option<io::Result<(RequestId, io::Result<TarantoolResponse>)>>,
225    ) -> io::Result<()> {
226        debug!("receive command! {:?} ", response);
227        match response {
228            Some(Ok((request_id, Ok(command_packet)))) => {
229                debug!("receive command! {} {:?} ", request_id, command_packet);
230                if self.timeout_time_ms.is_some() {
231                    if let Some(delay_key) = self.timeout_id_to_key.remove(&request_id) {
232                        self.timeout_queue.remove(&delay_key);
233                    }
234                }
235                if let Some(callback) = self.awaiting_callbacks.remove(&request_id) {
236                    let _send_res = callback.send(Ok(command_packet));
237                }
238
239                Ok(())
240            },
241            Some(Ok((request_id, Err(e)))) => {
242                debug!("receive command! {} {:?} ", request_id, e);
243                if self.timeout_time_ms.is_some() {
244                    if let Some(delay_key) = self.timeout_id_to_key.remove(&request_id) {
245                        self.timeout_queue.remove(&delay_key);
246                    }
247                }
248                if let Some(callback) = self.awaiting_callbacks.remove(&request_id) {
249                    let _send_res = callback.send(Err(e));
250                }
251
252                Ok(())
253            },
254            None => Err(io::Error::new(
255                io::ErrorKind::ConnectionAborted,
256                "return none from stream!",
257            )),
258            _ => Ok(()),
259        }
260    }
261
262    fn increment_command_counter(&mut self) -> RequestId {
263        self.command_counter += 1;
264        self.command_counter
265    }
266
267    fn clean_command_counter(&mut self) {
268        self.command_counter = 3;
269    }
270
271    async fn set_status(&mut self, client_status: ClientStatus) {
272        self.send_notify(&client_status).await;
273        *(self.status.write().unwrap()) = client_status;
274    }
275
276    ///main dispatch look function
277    pub async fn run(&mut self) {
278        self.set_status(ClientStatus::Connecting).await;
279        loop {
280            match self.connect_and_process_commands().await {
281                Ok(()) => {
282                    //finish
283                    return;
284                }
285                Err(e) => {
286                    self.set_status(ClientStatus::Disconnected(e.to_string()))
287                        .await;
288                    self.send_error_to_all(e.to_string());
289                    sleep_until(Instant::now() + Duration::from_millis(self.config.reconnect_time_ms)).await;
290                }
291            }
292
293            if self.is_command_receiver_closed {
294                self.set_status(ClientStatus::Closed).await;
295                return;
296            }
297        }
298    }
299
300    async fn connect_and_process_commands(&mut self) -> io::Result<()> {
301        let tcp_stream = TcpStream::connect(self.config.addr.clone()).await?;
302        let mut framed_io = self.auth(tcp_stream).await?;
303        self.set_status(ClientStatus::Connected).await;
304        loop {
305            select! {
306                tarantool_response = framed_io.next().fuse() => {
307                    self.process_tarantool_response(tarantool_response).await?
308                }
309                command = self.command_receiver.next() => {
310                    self.process_command(command, &mut framed_io).await?
311                }
312            }
313        }
314    }
315
316    async fn auth(&mut self, tcp_stream: TcpStream) -> io::Result<TarantoolFramed> {
317        let codec: TarantoolCodec = Default::default();
318        let mut framed_io = codec.framed(tcp_stream);
319        let _first_response = framed_io.next().await;
320        // println!("Received first packet {:?}", first_response);
321        framed_io
322            .send((
323                2,
324                TarantoolRequest::Auth(AuthPacket {
325                    login: self.config.login.clone(),
326                    password: self.config.password.clone(),
327                }),
328            ))
329            .await?;
330        let auth_response = framed_io.next().await;
331        match auth_response {
332            Some(Ok((_, Err(e)))) => Err(io::Error::new(
333                io::ErrorKind::PermissionDenied,
334                e.to_string(),
335            )),
336            _ => {
337                self.clean_command_counter();
338                Ok(framed_io)
339            }
340        }
341    }
342}