rusty_tarantool/tarantool/
codec.rs

1use crate::tarantool::packets::{Code, Key, TarantoolRequest, TarantoolResponse};
2use crate::tarantool::tools::{
3    decode_serde, get_map_value, make_auth_digest, map_err_to_io, parse_msgpack_map,
4    serialize_to_buf_mut, write_u32_to_slice, SafeBytesMutWriter,
5};
6use bytes::{Buf, BufMut, Bytes, BytesMut};
7use rmp::encode;
8use rmpv::{self, decode, Value};
9use std::io;
10use std::io::Cursor;
11use std::str;
12use tokio_util::codec::{Decoder, Encoder};
13
14pub type RequestId = u64;
15pub type TarantoolFramedRequest = (RequestId, TarantoolRequest);
16
17const GREETINGS_HEADER_LENGTH: usize = 9;
18const GREETINGS_HEADER: &str = "Tarantool";
19
20///
21/// Tokio framed codec for tarantool
22/// use it if you want manually set it on your tokio framed transport
23///
24#[derive(Debug)]
25pub struct TarantoolCodec {
26    is_greetings_received: bool,
27    salt: Option<Vec<u8>>,
28}
29
30impl Default for TarantoolCodec {
31    fn default() -> Self {
32        TarantoolCodec {
33            is_greetings_received: false,
34            salt: None,
35        }
36    }
37}
38
39impl Decoder for TarantoolCodec {
40    type Item = (RequestId, io::Result<TarantoolResponse>);
41    type Error = io::Error;
42
43    fn decode(
44        &mut self,
45        buf: &mut BytesMut,
46    ) -> io::Result<Option<(RequestId, io::Result<TarantoolResponse>)>> {
47        if !self.is_greetings_received {
48            if buf.len() < 128 {
49                Ok(None)
50            } else {
51                self.is_greetings_received = true;
52                decode_greetings(self, buf)
53            }
54        } else if buf.len() < 5 {
55            Ok(None)
56        } else {
57            let size: usize = decode_serde(&buf[0..5])?;
58
59            if buf.len() - 5 < size {
60                Ok(None)
61            } else {
62                Ok(Some(parse_response(buf, size)?))
63            }
64        }
65    }
66}
67
68fn parse_response(
69    buf: &mut BytesMut,
70    size: usize,
71) -> io::Result<(RequestId, io::Result<TarantoolResponse>)> {
72    //    buf.split_to(5);
73    buf.advance(5);
74    let response_body = buf.split_to(size);
75    let mut r = Cursor::new(response_body);
76
77    let headers = decode::read_value(&mut r).map_err(map_err_to_io)?;
78    let (code, sync) = parse_headers(headers)?;
79    let mut response_fields = parse_msgpack_map(r)?;
80
81    match code {
82        0 => {
83            Ok((
84                sync,
85                Ok(TarantoolResponse::new_full_response(
86                    code,
87                    response_fields.remove(&(Key::DATA as u64)).unwrap_or_default(),
88                    response_fields.remove(&(Key::METADATA as u64)),
89                    response_fields.remove(&(Key::SQL_INFO as u64))
90                    // search_key_in_msgpack_map(r, Key::DATA as u64)?,
91                )),
92            ))
93        },
94        _ => {
95            let response_data =
96                TarantoolResponse::new_short_response(
97                    code,
98                    response_fields.remove(&(Key::ERROR as u64)).unwrap_or_default()
99                );
100            let s: String = response_data.decode()?;
101            error!("Tarantool ERROR >> {:?}", s);
102            Ok((sync, Err(io::Error::new(io::ErrorKind::Other, s))))
103        }
104    }
105}
106
107pub fn parse_headers(headers: Value) -> Result<(u64, u64), io::Error> {
108    match headers {
109        Value::Map(headers_vec) => Ok((
110            get_map_value(&headers_vec, Key::CODE as u64)?,
111            get_map_value(&headers_vec, Key::SYNC as u64)?,
112        )),
113        _ => Err(io::Error::new(
114            io::ErrorKind::Other,
115            "Incorrect headers msg pack type!",
116        )),
117    }
118}
119
120impl Encoder<(RequestId, TarantoolRequest)> for TarantoolCodec {
121    // type Item = ;
122    type Error = io::Error;
123
124    fn encode(
125        &mut self,
126        command: (RequestId, TarantoolRequest),
127        dst: &mut BytesMut,
128    ) -> Result<(), Self::Error> {
129        match command {
130            (sync_id, TarantoolRequest::Auth(packet)) => {
131                info!("send auth_packet={:?}", packet);
132                let digest =
133                    make_auth_digest(self.salt.clone().unwrap(), packet.password.as_bytes())?;
134
135                create_packet(
136                    dst,
137                    Code::AUTH,
138                    sync_id,
139                    None,
140                    vec![
141                        (Key::USER_NAME, Value::from(packet.login)),
142                        (
143                            Key::TUPLE,
144                            Value::Array(vec![
145                                Value::from("chap-sha1"),
146                                Value::from(&digest as &[u8]),
147                            ]),
148                        ),
149                    ],
150                    vec![],
151                )
152            }
153            (sync_id, TarantoolRequest::Command(packet)) => {
154                debug!("send normal packet={:?}", packet);
155                create_packet(
156                    dst,
157                    packet.code,
158                    sync_id,
159                    None,
160                    packet.internal_fields,
161                    packet.command_field,
162                )
163            }
164        }
165    }
166}
167
168fn decode_greetings(
169    codec: &mut TarantoolCodec,
170    buf: &mut BytesMut,
171) -> io::Result<Option<(RequestId, io::Result<TarantoolResponse>)>> {
172    let header = buf.split_to(GREETINGS_HEADER_LENGTH);
173    let test = str::from_utf8(&header).map_err(map_err_to_io)?;
174
175    let res = match test {
176        GREETINGS_HEADER => Ok(Some((0, Ok(TarantoolResponse::new_short_response(0, Bytes::new()))))),
177        _ => Err(io::Error::new(io::ErrorKind::Other, "Unknown header!")),
178    };
179    //    buf.split_to(64 - GREETINGS_HEADER_LENGTH);
180    buf.advance(64 - GREETINGS_HEADER_LENGTH);
181    let salt_buf = buf.split_to(64);
182    codec.salt = Some(salt_buf.to_vec());
183
184    res
185}
186
187fn create_packet(
188    buf: &mut BytesMut,
189    code: Code,
190    sync_id: u64,
191    schema_id: Option<u64>,
192    data: Vec<(Key, Value)>,
193    additional_data: Vec<(Key, Vec<u8>)>,
194) -> io::Result<()> {
195    let mut header_vec = vec![
196        (Value::from(Key::CODE as u8), Value::from(code as u8)),
197        (Value::from(Key::SYNC as u8), Value::from(sync_id)),
198    ];
199
200    if let Some(schema_id_v) = schema_id {
201        header_vec.push((Value::from(Key::SCHEMA_ID as u8), Value::from(schema_id_v)))
202    }
203
204    buf.reserve(5);
205    let start_position = buf.len() + 1;
206    buf.put_slice(&[0xce, 0x00, 0x00, 0x00, 0x00]);
207    {
208        let mut writer = SafeBytesMutWriter::writer(buf);
209
210        serialize_to_buf_mut(&mut writer, &Value::Map(header_vec))?;
211        encode::write_map_len(
212            &mut writer,
213            data.len() as u32 + (additional_data.len() as u32),
214        )?;
215        for (ref key, ref val) in data {
216            rmpv::encode::write_value(&mut writer, &Value::from((*key) as u8))?;
217            rmpv::encode::write_value(&mut writer, val)?;
218        }
219        for (ref key, ref val) in additional_data {
220            rmpv::encode::write_value(&mut writer, &Value::from((*key) as u8))?;
221            io::Write::write(&mut writer, val)?;
222        }
223    }
224
225    let len = (buf.len() - start_position - 4) as u32;
226    write_u32_to_slice(&mut buf[start_position..start_position + 4], len);
227
228    Ok(())
229}