Skip to main content

we_trust_sqlserver/
codec.rs

1use bytes::{Buf, BufMut, BytesMut};
2use tokio_util::codec::{Decoder, Encoder};
3use yykv_types::DsError;
4
5/// TDS Packet Header
6#[derive(Debug, Clone)]
7pub struct TdsPacket {
8    pub packet_type: u8,
9    pub status: u8,
10    pub length: u16,
11    pub spid: u16,
12    pub packet_id: u8,
13    pub window: u8,
14    pub payload: Vec<u8>,
15}
16
17pub struct TdsCodec;
18
19impl Decoder for TdsCodec {
20    type Item = TdsPacket;
21    type Error = DsError;
22
23    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
24        if src.len() < 8 {
25            return Ok(None);
26        }
27
28        let length = u16::from_be_bytes([src[2], src[3]]) as usize;
29        if src.len() < length {
30            src.reserve(length);
31            return Ok(None);
32        }
33
34        let packet_type = src[0];
35        let status = src[1];
36        let spid = u16::from_be_bytes([src[4], src[5]]);
37        let packet_id = src[6];
38        let window = src[7];
39
40        src.advance(8);
41        let payload = src.split_to(length - 8).to_vec();
42
43        Ok(Some(TdsPacket {
44            packet_type,
45            status,
46            length: length as u16,
47            spid,
48            packet_id,
49            window,
50            payload,
51        }))
52    }
53}
54
55impl Encoder<TdsPacket> for TdsCodec {
56    type Error = DsError;
57
58    fn encode(&mut self, item: TdsPacket, dst: &mut BytesMut) -> Result<(), Self::Error> {
59        dst.put_u8(item.packet_type);
60        dst.put_u8(item.status);
61        dst.put_u16(8 + item.payload.len() as u16);
62        dst.put_u16(item.spid);
63        dst.put_u8(item.packet_id);
64        dst.put_u8(item.window);
65        dst.put_slice(&item.payload);
66        Ok(())
67    }
68}