use std::default;
use std::fmt::{self, Display, Formatter};
use bytes::{Buf, Bytes};
use uuid::Uuid;
use crate::error::Error;
use crate::io::{Decode, Encode};
#[derive(Debug, Default)]
pub(crate) struct PreLogin {
pub(crate) version: Version,
pub(crate) encryption: Encrypt,
pub(crate) instance: Option<String>,
pub(crate) thread_id: Option<u32>,
pub(crate) trace_id: Option<TraceId>,
pub(crate) multiple_active_result_sets: Option<bool>,
}
impl<'de> Decode<'de> for PreLogin {
fn decode_with(buf: Bytes, _: ()) -> Result<Self, Error> {
let mut version = None;
let mut encryption = None;
let mut instance = None;
let mut thread_id = None;
let mut trace_id = None;
let mut multiple_active_result_sets = None;
let mut offsets = buf.clone();
loop {
let token = offsets.get_u8();
match PreLoginOptionToken::get(token) {
Some(token) => {
let offset = offsets.get_u16() as usize;
let size = offsets.get_u16() as usize;
let mut data = &buf[offset..offset + size];
match token {
PreLoginOptionToken::Version => {
let major = data.get_u8();
let minor = data.get_u8();
let build = data.get_u16();
let sub_build = data.get_u16();
version = Some(Version {
major,
minor,
build,
sub_build,
});
}
PreLoginOptionToken::Encryption => {
encryption = Some(Encrypt::from(data.get_u8()));
}
PreLoginOptionToken::Instance => {
instance = Some(String::from_utf8_lossy(&data[..size - 1]).to_string());
}
PreLoginOptionToken::ThreadId => {
thread_id = Some(data.get_u32_le());
}
PreLoginOptionToken::MultipleActiveResultSets => {
multiple_active_result_sets = Some(data.get_u8() != 0);
}
PreLoginOptionToken::TraceId => {
let connection_id = Uuid::from_u128(data.get_u128());
let activity_id = Uuid::from_u128(data.get_u128());
let activity_seq = data.get_u32();
trace_id = Some(TraceId {
connection_id,
activity_id,
activity_seq,
});
}
}
}
None if token == 0xff => {
break;
}
None => {
return Err(err_protocol!(
"PRELOGIN: unexpected login option token: 0x{:02?}",
token
)
.into());
}
}
}
let version =
version.ok_or(err_protocol!("PRELOGIN: missing required `version` option"))?;
let encryption = encryption.ok_or(err_protocol!(
"PRELOGIN: missing required `encryption` option"
))?;
Ok(Self {
version,
encryption,
instance,
thread_id,
trace_id,
multiple_active_result_sets,
})
}
}
impl Encode<'_> for PreLogin {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
use PreLoginOptionToken::*;
let num_options = 2
+ self.instance.as_ref().map_or(0, |_| 1)
+ self.thread_id.map_or(0, |_| 1)
+ self.trace_id.as_ref().map_or(0, |_| 1)
+ self.multiple_active_result_sets.map_or(0, |_| 1);
let len_offsets = (num_options * 5) + 1;
let mut offsets = buf.len();
let mut offset = u16::try_from(len_offsets).unwrap();
buf.resize(buf.len() + len_offsets, 0);
let end_offsets = buf.len() - 1;
buf[end_offsets] = 0xff;
Version.put(buf, &mut offsets, &mut offset, 6);
self.version.encode(buf);
Encryption.put(buf, &mut offsets, &mut offset, 1);
buf.push(u8::from(self.encryption));
if let Some(name) = &self.instance {
Instance.put(
buf,
&mut offsets,
&mut offset,
u16::try_from(name.len() + 1).unwrap(),
);
buf.extend_from_slice(name.as_bytes());
buf.push(b'\0');
}
if let Some(id) = self.thread_id {
ThreadId.put(buf, &mut offsets, &mut offset, 4);
buf.extend_from_slice(&id.to_le_bytes());
}
if let Some(trace) = &self.trace_id {
ThreadId.put(buf, &mut offsets, &mut offset, 36);
buf.extend_from_slice(trace.connection_id.as_bytes());
buf.extend_from_slice(trace.activity_id.as_bytes());
buf.extend_from_slice(&trace.activity_seq.to_be_bytes());
}
if let Some(mars) = &self.multiple_active_result_sets {
MultipleActiveResultSets.put(buf, &mut offsets, &mut offset, 1);
buf.push(*mars as u8);
}
}
}
#[derive(Debug, Copy, Clone)]
#[repr(u8)]
enum PreLoginOptionToken {
Version = 0x00,
Encryption = 0x01,
Instance = 0x02,
ThreadId = 0x03,
MultipleActiveResultSets = 0x04,
TraceId = 0x05,
}
impl PreLoginOptionToken {
fn put(self, buf: &mut Vec<u8>, pos: &mut usize, offset: &mut u16, len: u16) {
buf[*pos] = self as u8;
*pos += 1;
buf[*pos..(*pos + 2)].copy_from_slice(&offset.to_be_bytes());
*pos += 2;
buf[*pos..(*pos + 2)].copy_from_slice(&len.to_be_bytes());
*pos += 2;
*offset += len;
}
fn get(b: u8) -> Option<Self> {
Some(match b {
0x00 => PreLoginOptionToken::Version,
0x01 => PreLoginOptionToken::Encryption,
0x02 => PreLoginOptionToken::Instance,
0x03 => PreLoginOptionToken::ThreadId,
0x04 => PreLoginOptionToken::MultipleActiveResultSets,
0x05 => PreLoginOptionToken::TraceId,
_ => {
return None;
}
})
}
}
#[derive(Debug)]
pub(crate) struct TraceId {
pub(crate) connection_id: Uuid,
pub(crate) activity_id: Uuid,
pub(crate) activity_seq: u32,
}
#[derive(Debug)]
pub(crate) struct Version {
pub(crate) major: u8,
pub(crate) minor: u8,
pub(crate) build: u16,
pub(crate) sub_build: u16,
}
impl Default for Version {
fn default() -> Self {
Self {
major: env!("CARGO_PKG_VERSION_MAJOR").parse().unwrap(),
minor: env!("CARGO_PKG_VERSION_MINOR").parse().unwrap(),
build: env!("CARGO_PKG_VERSION_PATCH").parse().unwrap(),
sub_build: 0,
}
}
}
impl Version {
fn encode(&self, buf: &mut Vec<u8>) {
buf.push(self.major);
buf.push(self.minor);
buf.extend(&self.build.to_be_bytes());
buf.extend(&self.sub_build.to_be_bytes());
}
}
impl Display for Version {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "v{}.{}.{}", self.major, self.minor, self.build)
}
}
#[derive(Default, Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum Encrypt {
Off = 0x00,
#[default]
On = 0x01,
NotSupported = 0x02,
Required = 0x03,
}
impl From<u8> for Encrypt {
fn from(value: u8) -> Self {
match value {
0x00 => Encrypt::Off,
0x01 => Encrypt::On,
0x02 => Encrypt::NotSupported,
0x03 => Encrypt::Required,
_ => Encrypt::Off,
}
}
}
impl From<Encrypt> for u8 {
fn from(value: Encrypt) -> Self {
value as u8
}
}
#[test]
fn test_encode_pre_login() {
let mut buf = Vec::new();
let pre_login = PreLogin {
version: Version {
major: 9,
minor: 0,
build: 0,
sub_build: 0,
},
encryption: Encrypt::On,
instance: Some("".to_string()),
thread_id: Some(0x00000DB8),
multiple_active_result_sets: Some(true),
..Default::default()
};
#[rustfmt::skip]
let expected = vec![
0x00, 0x00, 0x1A, 0x00, 0x06, 0x01, 0x00, 0x20, 0x00, 0x01, 0x02, 0x00, 0x21, 0x00,
0x01, 0x03, 0x00, 0x22, 0x00, 0x04, 0x04, 0x00, 0x26, 0x00, 0x01, 0xFF, 0x09, 0x00,
0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0xB8, 0x0D, 0x00, 0x00, 0x01
];
pre_login.encode(&mut buf);
assert_eq!(expected, buf);
}
#[test]
fn test_decode_pre_login() {
#[rustfmt::skip]
let buffer = Bytes::from_static(&[
0, 0, 11, 0, 6, 1, 0, 17, 0, 1, 255,
14, 0, 12, 209, 0, 0, 0,
]);
let pre_login = PreLogin::decode(buffer).unwrap();
assert_eq!(pre_login.version.major, 14);
assert_eq!(pre_login.version.minor, 0);
assert_eq!(pre_login.version.build, 3281);
assert_eq!(pre_login.version.sub_build, 0);
assert_eq!(u8::from(pre_login.encryption), 0);
}