1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
//! Tarantool Binary Protocol implementation without actual transport layer.
//!
//! According to Sans-I/O pattern this implementation processes requests and responses
//! by outputing the corresponding bytes into incoming and outgoing buffers.
//! And it provides an API for the upper client layer to get data from these buffers.
//!
//! See [`super::client`] if you need a fully functional Tarantool client.

pub mod api;
pub mod codec;

use std::cmp;
use std::collections::HashMap;
use std::io::{self, Cursor, Read, Seek};
use std::str::Utf8Error;
use std::vec::Drain;

use api::Request;

/// Error returned by [`Protocol`].
#[derive(thiserror::Error, Debug)]
pub enum Error {
    #[error("utf8 error: {0}")]
    Utf8(#[from] Utf8Error),
    #[error("failed to encode: {0}")]
    Encode(#[from] rmp::encode::ValueWriteError),
    #[error("failed to decode: {0}")]
    Decode(#[from] rmp::decode::ValueReadError),
    #[error("failed to decode: {0}")]
    DecodeNum(#[from] rmp::decode::NumValueReadError),
    #[error("service responded with error: {0}")]
    Response(#[from] ResponseError),
    #[error("io error: {0}")]
    Io(#[from] io::Error),
    #[error("message size hint is 0")]
    ZeroSizeHint,
    #[error("DATA not found in response body but is required for call/eval")]
    ResponseDataNotFound,
    /// This type of error is only present in codec
    /// due to ToTupleBuffer implementation using generic Tarantool Error.
    /// So should be related to encode/decode.
    // TODO: remove this variant when ToTupleBuffer switches to more specific errors
    #[error("encode/decode error: {0}")]
    Tarantool(Box<crate::error::Error>),
}

// TODO: remove this conversion when ToTupleBuffer switches to more specific errors
impl From<crate::error::Error> for Error {
    fn from(err: crate::error::Error) -> Self {
        Error::Tarantool(Box::new(err))
    }
}

/// Unique identifier of the sent message on this connection.
/// It is used to retrieve response for the corresponding request.
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct SyncIndex(u64);

impl SyncIndex {
    pub fn next_index(&mut self) -> Self {
        let sync = self.0;
        self.0 += 1;
        Self(sync)
    }
}

/// Error returned from the Tarantool server.
///
/// It represents an error with which Tarantool server
/// answers to the client in case of faulty request or an error
/// during request execution on the server side.
#[derive(Debug, thiserror::Error, Clone)]
#[error("{message}")]
pub struct ResponseError {
    pub(crate) message: String,
}

#[derive(Debug, Clone, Copy, Eq, PartialEq)]
enum State {
    /// Awaits greeting
    Init,
    /// Awaits auth
    Auth,
    /// Ready to accept new messages
    Ready,
}

/// Configuration of [`Protocol`].
#[derive(Debug, Clone, Default, Eq, PartialEq)]
pub struct Config {
    /// (user, password)
    pub creds: Option<(String, String)>,
    // TODO: add buffer limits here
}

/// A sans-io connection handler.
///
/// Buffers incoming and outgoing bytes and provides an API for
/// a client implementation to:
/// - Input requests
/// - Get processed responses
/// - Retrieve outgoing bytes
/// - Input incoming bytes
#[derive(Debug)]
pub struct Protocol {
    state: State,
    msg_size_hint: Option<usize>,
    outgoing: Vec<u8>,
    pending_outgoing: Vec<u8>,
    sync: SyncIndex,
    // TODO: limit incoming size
    incoming: HashMap<SyncIndex, Result<Vec<u8>, ResponseError>>,
    /// (user, password)
    creds: Option<(String, String)>,
}

impl Default for Protocol {
    fn default() -> Self {
        Self::new()
    }
}

impl Protocol {
    /// Construct [`Protocol`] with default values for [`Config`].
    pub fn new() -> Self {
        Self {
            state: State::Init,
            sync: SyncIndex(0),
            pending_outgoing: Vec::new(),
            creds: None,
            outgoing: Vec::new(),
            incoming: HashMap::new(),
            // Greeting is exactly 128 bytes
            msg_size_hint: Some(128),
        }
    }

    /// Construct [`Protocol`] with custom values for [`Config`].
    pub fn with_config(config: Config) -> Self {
        let mut protocol = Self::new();
        protocol.creds = config.creds;
        protocol
    }

    /// Returns `true` if the [`Protocol`] has passed initialization and authorization
    /// stages.
    ///
    /// Data can be sent independently of whether the protocol [`Self::is_ready`].
    /// If the protocol is not ready data will be queued and eventually processed
    /// after auth is done.
    pub fn is_ready(&self) -> bool {
        matches!(self.state, State::Ready)
    }

    /// Processes incoming request and buffers generated outgoing bytes.
    /// Outgoing bytes can be retrieved with [`Protocol::drain_outgoing_data`]
    ///
    /// Data can be sent independently of whether the protocol [`Self::is_ready`].
    /// If the protocol is not ready data will be queued and eventually processed
    /// after auth is done.
    pub fn send_request(&mut self, request: &impl Request) -> Result<SyncIndex, Error> {
        let end = self.pending_outgoing.len();
        let mut buf = Cursor::new(&mut self.pending_outgoing);
        buf.set_position(end as u64);
        // TODO: limit the pending vec size
        write_to_buffer(&mut buf, self.sync, request)?;
        self.process_pending_data();
        Ok(self.sync.next_index())
    }

    /// Take existing response by [`SyncIndex`].
    pub fn take_response<R: Request>(
        &mut self,
        sync: SyncIndex,
        request: &R,
    ) -> Option<Result<R::Response, Error>> {
        let response = match self.incoming.remove(&sync)? {
            Ok(response) => response,
            Err(err) => return Some(Err(err.into())),
        };
        Some(request.decode_body(&mut Cursor::new(response)))
    }

    /// Drop response by [`SyncIndex`] if it exists. If not - does nothing.
    pub fn drop_response(&mut self, sync: SyncIndex) {
        self.incoming.remove(&sync);
    }

    /// See [`Protocol::process_incoming`].
    pub fn read_size_hint(&self) -> usize {
        if let Some(hint) = self.msg_size_hint {
            hint
        } else {
            // Reading the U32 message size hint
            // Read 5 bytes, 1st is a marker
            5
        }
    }

    /// Processes incoming bytes received over transport layer.
    ///
    /// Should be used together with [`Protocol::read_size_hint`] e.g:
    /// 1. Call `read_size_hint` and get its value.
    ///    It is the number of bytes a client implementation should read from transport.
    /// 2. Read the required number of bytes from transport
    /// 3. Call [`Protocol::process_incoming`] with these bytes.
    ///
    /// Returns a [`SyncIndex`] if non-technical message was received.
    /// This message can be retreived by this index with [`Protocol::take_response`].
    pub fn process_incoming<R: Read + Seek>(
        &mut self,
        chunk: &mut R,
    ) -> Result<Option<SyncIndex>, Error> {
        if self.msg_size_hint.is_some() {
            // Message size hint was already read at previous call - now processing message
            self.msg_size_hint = None;
            self.process_message(chunk)
        } else {
            // Message was read at previous call - now reading size hint
            let hint = rmp::decode::read_u32(chunk)?;
            if hint > 0 {
                self.msg_size_hint = Some(hint as usize);
                Ok(None)
            } else {
                Err(Error::ZeroSizeHint)
            }
        }
    }

    fn process_message<R: Read + Seek>(
        &mut self,
        message: &mut R,
    ) -> Result<Option<SyncIndex>, Error> {
        let sync = match self.state {
            State::Init => {
                let salt = codec::decode_greeting(message)?;
                if let Some((user, pass)) = self.creds.as_ref() {
                    // Auth
                    self.state = State::Auth;
                    // Write straight to outgoing, it should be empty
                    debug_assert!(self.outgoing.is_empty());
                    let mut buf = Cursor::new(&mut self.outgoing);
                    let sync = self.sync.next_index();
                    write_to_buffer(
                        &mut buf,
                        sync,
                        &api::Auth {
                            user,
                            pass,
                            salt: &salt,
                        },
                    )?;
                } else {
                    // No auth
                    self.state = State::Ready;
                }
                None
            }
            State::Auth => {
                let header = codec::decode_header(message)?;
                if header.status_code != 0 {
                    return Err(codec::decode_error(message)?.into());
                }
                self.state = State::Ready;
                None
            }
            State::Ready => {
                let header = codec::decode_header(message)?;
                let response = if header.status_code != 0 {
                    Err(codec::decode_error(message)?)
                } else {
                    let mut buf = Vec::new();
                    message.read_to_end(&mut buf)?;
                    Ok(buf)
                };
                self.incoming.insert(header.sync, response);
                Some(header.sync)
            }
        };
        self.process_pending_data();
        Ok(sync)
    }

    /// Returns a number of outgoing data bytes.
    pub fn ready_outgoing_len(&self) -> usize {
        self.outgoing.len()
    }

    /// Drains and returns buffered outgoing data.
    ///
    /// The returned bytes can then be sent through a
    /// transport layer to a Tarantool server.
    pub fn drain_outgoing_data(&mut self, max: Option<usize>) -> Drain<u8> {
        let bound = if let Some(max) = max {
            cmp::min(self.ready_outgoing_len(), max)
        } else {
            self.ready_outgoing_len()
        };
        self.outgoing.drain(..bound)
    }

    fn process_pending_data(&mut self) {
        if self.is_ready() {
            let pending_data = self.pending_outgoing.drain(..);
            // TODO: limit the ready vec size
            self.outgoing.extend(pending_data);
        }
    }
}

fn write_to_buffer(
    buffer: &mut Cursor<&mut Vec<u8>>,
    sync: SyncIndex,
    request: &impl Request,
) -> Result<(), Error> {
    // write MSG_SIZE placeholder
    let msg_start_offset = buffer.position();
    rmp::encode::write_u32(buffer, 0)?;

    // write message payload
    let payload_start_offset = buffer.position();
    request.encode(buffer, sync)?;
    let payload_end_offset = buffer.position();

    // calculate and write MSG_SIZE
    buffer.set_position(msg_start_offset);
    rmp::encode::write_u32(buffer, (payload_end_offset - payload_start_offset) as u32)?;
    buffer.set_position(payload_end_offset);

    Ok(())
}

// Tests have to be run in Tarantool environment due to `ToTupleBuffer` using `crate::Error` which contains `LuaError`
// and therefore lua symbols
#[cfg(feature = "internal_test")]
mod tests {
    use super::*;

    /// See [tarantool docs](https://www.tarantool.io/en/doc/latest/dev_guide/internals/iproto/authentication/#greeting-message).
    fn fake_greeting() -> Vec<u8> {
        let mut greeting = Vec::new();
        greeting.extend([0; 63].iter());
        greeting.push(b'\n');
        greeting.extend(b"QK2HoFZGXTXBq2vFj7soCsHqTo6PGTF575ssUBAJLAI=".iter());
        while greeting.len() < 127 {
            greeting.push(0);
        }
        greeting.push(b'\n');
        greeting
    }

    #[crate::test(tarantool = "crate")]
    fn connection_established() {
        let mut conn = Protocol::new();
        assert!(!conn.is_ready());
        assert_eq!(conn.msg_size_hint, Some(128));
        assert_eq!(conn.read_size_hint(), 128);
        conn.process_incoming(&mut Cursor::new(fake_greeting()))
            .unwrap();
        assert_eq!(conn.msg_size_hint, None);
        assert_eq!(conn.read_size_hint(), 5);
        assert!(conn.is_ready())
    }

    #[crate::test(tarantool = "crate")]
    fn send_bytes_generated() {
        let mut conn = Protocol::new();
        conn.process_incoming(&mut Cursor::new(fake_greeting()))
            .unwrap();
        conn.send_request(&api::Ping).unwrap();
        assert!(conn.ready_outgoing_len() > 0);
    }
}