#[cfg(test)]
mod message_test;
use crate::agent::*;
use crate::attributes::*;
use crate::errors::*;
use util::Error;
use rand::Rng;
use std::fmt;
use std::io::{Read, Write};
pub(crate) const MAGIC_COOKIE: u32 = 0x2112A442;
pub(crate) const ATTRIBUTE_HEADER_SIZE: usize = 4;
pub(crate) const MESSAGE_HEADER_SIZE: usize = 20;
pub const TRANSACTION_ID_SIZE: usize = 12;
pub trait Setter {
fn add_to(&self, m: &mut Message) -> Result<(), Error>;
}
pub trait Getter {
fn get_from(&mut self, m: &Message) -> Result<(), Error>;
}
pub trait Checker {
fn check(&self, m: &Message) -> Result<(), Error>;
}
pub fn is_message(b: &[u8]) -> bool {
b.len() >= MESSAGE_HEADER_SIZE && u32::from_be_bytes([b[4], b[5], b[6], b[7]]) == MAGIC_COOKIE
}
#[derive(Default, Debug, Clone)]
pub struct Message {
pub typ: MessageType,
pub length: u32,
pub transaction_id: TransactionId,
pub attributes: Attributes,
pub raw: Vec<u8>,
}
impl fmt::Display for Message {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let t_id = base64::encode(&self.transaction_id.0);
write!(
f,
"{} l={} attrs={} id={}",
self.typ,
self.length,
self.attributes.0.len(),
t_id
)
}
}
impl PartialEq for Message {
fn eq(&self, other: &Self) -> bool {
if self.typ != other.typ {
return false;
}
if self.transaction_id != other.transaction_id {
return false;
}
if self.length != other.length {
return false;
}
if self.attributes != other.attributes {
return false;
}
true
}
}
const DEFAULT_RAW_CAPACITY: usize = 120;
impl Setter for Message {
fn add_to(&self, b: &mut Message) -> Result<(), Error> {
b.transaction_id = self.transaction_id;
b.write_transaction_id();
Ok(())
}
}
impl Message {
pub fn new() -> Self {
Message {
raw: {
let mut raw = Vec::with_capacity(DEFAULT_RAW_CAPACITY);
raw.extend_from_slice(&[0; MESSAGE_HEADER_SIZE]);
raw
},
..Default::default()
}
}
pub fn marshal_binary(&self) -> Result<Vec<u8>, Error> {
Ok(self.raw.clone())
}
pub fn unmarshal_binary(&mut self, data: &[u8]) -> Result<(), Error> {
self.raw.clear();
self.raw.extend_from_slice(data);
self.decode()
}
pub fn new_transaction_id(&mut self) -> Result<(), Error> {
rand::thread_rng().fill(&mut self.transaction_id.0);
self.write_transaction_id();
Ok(())
}
pub fn reset(&mut self) {
self.raw.clear();
self.length = 0;
self.attributes.0.clear();
}
fn grow(&mut self, n: usize, resize: bool) {
if self.raw.len() >= n {
if resize {
self.raw.resize(n, 0);
}
return;
}
self.raw.extend_from_slice(&vec![0; n - self.raw.len()]);
}
pub fn add(&mut self, t: AttrType, v: &[u8]) {
let alloc_size = ATTRIBUTE_HEADER_SIZE + v.len();
let first = MESSAGE_HEADER_SIZE + self.length as usize;
let mut last = first + alloc_size;
self.grow(last, true);
self.length += alloc_size as u32;
let buf = &mut self.raw[first..last];
buf[0..2].copy_from_slice(&t.value().to_be_bytes());
buf[2..4].copy_from_slice(&(v.len() as u16).to_be_bytes());
let value = &mut buf[ATTRIBUTE_HEADER_SIZE..];
value.copy_from_slice(v);
let attr = RawAttribute {
typ: t,
length: v.len() as u16,
value: value.to_vec(),
};
if attr.length as usize % PADDING != 0 {
let bytes_to_add = nearest_padded_value_length(v.len()) - v.len();
last += bytes_to_add;
self.grow(last, true);
let buf = &mut self.raw[last - bytes_to_add..last];
for b in buf {
*b = 0;
}
self.length += bytes_to_add as u32;
}
self.attributes.0.push(attr);
self.write_length();
}
pub fn write_length(&mut self) {
self.grow(4, false);
self.raw[2..4].copy_from_slice(&(self.length as u16).to_be_bytes());
}
pub fn write_header(&mut self) {
self.grow(MESSAGE_HEADER_SIZE, false);
self.write_type();
self.write_length();
self.raw[4..8].copy_from_slice(&MAGIC_COOKIE.to_be_bytes());
self.raw[8..MESSAGE_HEADER_SIZE].copy_from_slice(&self.transaction_id.0);
}
pub fn write_transaction_id(&mut self) {
self.raw[8..MESSAGE_HEADER_SIZE].copy_from_slice(&self.transaction_id.0);
}
pub fn write_attributes(&mut self) {
let attributes: Vec<RawAttribute> = self.attributes.0.drain(..).collect();
for a in &attributes {
self.add(a.typ, &a.value);
}
self.attributes = Attributes(attributes);
}
pub fn write_type(&mut self) {
self.grow(2, false);
self.raw[..2].copy_from_slice(&self.typ.value().to_be_bytes());
}
pub fn set_type(&mut self, t: MessageType) {
self.typ = t;
self.write_type();
}
pub fn encode(&mut self) {
self.raw.clear();
self.write_header();
self.length = 0;
self.write_attributes();
}
pub fn decode(&mut self) -> Result<(), Error> {
let buf = &self.raw;
if buf.len() < MESSAGE_HEADER_SIZE {
return Err(ERR_UNEXPECTED_HEADER_EOF.clone());
}
let t = u16::from_be_bytes([buf[0], buf[1]]);
let size = u16::from_be_bytes([buf[2], buf[3]]) as usize;
let cookie = u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]]);
let full_size = MESSAGE_HEADER_SIZE + size;
if cookie != MAGIC_COOKIE {
return Err(Error::new(format!(
"{:x} is invalid magic cookie (should be {:x})",
cookie, MAGIC_COOKIE
)));
}
if buf.len() < full_size {
return Err(Error::new(format!(
"buffer length {} is less than {} (expected message size)",
buf.len(),
full_size
)));
}
self.typ.read_value(t);
self.length = size as u32;
self.transaction_id
.0
.copy_from_slice(&buf[8..MESSAGE_HEADER_SIZE]);
self.attributes.0.clear();
let mut offset = 0;
let mut b = &buf[MESSAGE_HEADER_SIZE..full_size];
while offset < size {
if b.len() < ATTRIBUTE_HEADER_SIZE {
return Err(Error::new(format!(
"buffer length {} is less than {} (expected header size)",
b.len(),
ATTRIBUTE_HEADER_SIZE
)));
}
let mut a = RawAttribute {
typ: compat_attr_type(u16::from_be_bytes([b[0], b[1]])),
length: u16::from_be_bytes([b[2], b[3]]),
..Default::default()
};
let a_l = a.length as usize;
let a_buff_l = nearest_padded_value_length(a_l);
b = &b[ATTRIBUTE_HEADER_SIZE..];
offset += ATTRIBUTE_HEADER_SIZE;
if b.len() < a_buff_l {
return Err(Error::new(format!(
"buffer length {} is less than {} (expected value size for {})",
b.len(),
a_buff_l,
a.typ
)));
}
a.value = b[..a_l].to_vec();
offset += a_buff_l;
b = &b[a_buff_l..];
self.attributes.0.push(a);
}
Ok(())
}
pub fn write_to<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
let n = writer.write(&self.raw)?;
Ok(n)
}
pub fn read_from<R: Read>(&mut self, reader: &mut R) -> Result<usize, Error> {
let mut t_buf = vec![0; DEFAULT_RAW_CAPACITY];
let n = reader.read(&mut t_buf)?;
self.raw = t_buf[..n].to_vec();
self.decode()?;
Ok(n)
}
pub fn write(&mut self, t_buf: &[u8]) -> Result<usize, Error> {
self.raw.clear();
self.raw.extend_from_slice(t_buf);
self.decode()?;
Ok(t_buf.len())
}
pub fn clone_to(&self, b: &mut Message) -> Result<(), Error> {
b.raw.clear();
b.raw.extend_from_slice(&self.raw);
b.decode()
}
pub fn contains(&self, t: AttrType) -> bool {
for a in &self.attributes.0 {
if a.typ == t {
return true;
}
}
false
}
pub fn get(&self, t: AttrType) -> Result<Vec<u8>, Error> {
let (v, ok) = self.attributes.get(t);
if ok {
Ok(v.value)
} else {
Err(ERR_ATTRIBUTE_NOT_FOUND.clone())
}
}
pub fn build(&mut self, setters: &[Box<dyn Setter>]) -> Result<(), Error> {
self.reset();
self.write_header();
for s in setters {
s.add_to(self)?;
}
Ok(())
}
pub fn check<C: Checker>(&self, checkers: &[C]) -> Result<(), Error> {
for c in checkers {
c.check(self)?;
}
Ok(())
}
pub fn parse<G: Getter>(&self, getters: &mut [G]) -> Result<(), Error> {
for c in getters {
c.get_from(self)?;
}
Ok(())
}
}
#[derive(Default, PartialEq, Eq, Debug, Copy, Clone)]
pub struct MessageClass(u8);
pub const CLASS_REQUEST: MessageClass = MessageClass(0x00);
pub const CLASS_INDICATION: MessageClass = MessageClass(0x01);
pub const CLASS_SUCCESS_RESPONSE: MessageClass = MessageClass(0x02);
pub const CLASS_ERROR_RESPONSE: MessageClass = MessageClass(0x03);
impl fmt::Display for MessageClass {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match *self {
CLASS_REQUEST => "request",
CLASS_INDICATION => "indication",
CLASS_SUCCESS_RESPONSE => "success response",
CLASS_ERROR_RESPONSE => "error response",
_ => "unknown message class",
};
write!(f, "{}", s)
}
}
#[derive(Default, PartialEq, Eq, Debug, Copy, Clone)]
pub struct Method(u16);
pub const METHOD_BINDING: Method = Method(0x001);
pub const METHOD_ALLOCATE: Method = Method(0x003);
pub const METHOD_REFRESH: Method = Method(0x004);
pub const METHOD_SEND: Method = Method(0x006);
pub const METHOD_DATA: Method = Method(0x007);
pub const METHOD_CREATE_PERMISSION: Method = Method(0x008);
pub const METHOD_CHANNEL_BIND: Method = Method(0x009);
pub const METHOD_CONNECT: Method = Method(0x000a);
pub const METHOD_CONNECTION_BIND: Method = Method(0x000b);
pub const METHOD_CONNECTION_ATTEMPT: Method = Method(0x000c);
impl fmt::Display for Method {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let unknown = format!("0x{:x}", self.0);
let s = match *self {
METHOD_BINDING => "Binding",
METHOD_ALLOCATE => "Allocate",
METHOD_REFRESH => "Refresh",
METHOD_SEND => "Send",
METHOD_DATA => "Data",
METHOD_CREATE_PERMISSION => "CreatePermission",
METHOD_CHANNEL_BIND => "ChannelBind",
METHOD_CONNECT => "Connect",
METHOD_CONNECTION_BIND => "ConnectionBind",
METHOD_CONNECTION_ATTEMPT => "ConnectionAttempt",
_ => unknown.as_str(),
};
write!(f, "{}", s)
}
}
#[derive(Default, Debug, PartialEq, Clone, Copy)]
pub struct MessageType {
pub method: Method,
pub class: MessageClass,
}
pub const BINDING_REQUEST: MessageType = MessageType {
method: METHOD_BINDING,
class: CLASS_REQUEST,
};
pub const BINDING_SUCCESS: MessageType = MessageType {
method: METHOD_BINDING,
class: CLASS_SUCCESS_RESPONSE,
};
pub const BINDING_ERROR: MessageType = MessageType {
method: METHOD_BINDING,
class: CLASS_ERROR_RESPONSE,
};
impl fmt::Display for MessageType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{} {}", self.method, self.class)
}
}
const METHOD_ABITS: u16 = 0xf;
const METHOD_BBITS: u16 = 0x70;
const METHOD_DBITS: u16 = 0xf80;
const METHOD_BSHIFT: u16 = 1;
const METHOD_DSHIFT: u16 = 2;
const FIRST_BIT: u16 = 0x1;
const SECOND_BIT: u16 = 0x2;
const C0BIT: u16 = FIRST_BIT;
const C1BIT: u16 = SECOND_BIT;
const CLASS_C0SHIFT: u16 = 4;
const CLASS_C1SHIFT: u16 = 7;
impl Setter for MessageType {
fn add_to(&self, m: &mut Message) -> Result<(), Error> {
m.set_type(*self);
Ok(())
}
}
impl MessageType {
pub fn new(method: Method, class: MessageClass) -> Self {
MessageType { method, class }
}
pub fn value(&self) -> u16 {
let method = self.method.0;
let a = method & METHOD_ABITS;
let b = method & METHOD_BBITS;
let d = method & METHOD_DBITS;
let method = a + (b << METHOD_BSHIFT) + (d << METHOD_DSHIFT);
let c = self.class.0 as u16;
let c0 = (c & C0BIT) << CLASS_C0SHIFT;
let c1 = (c & C1BIT) << CLASS_C1SHIFT;
let class = c0 + c1;
method + class
}
pub fn read_value(&mut self, value: u16) {
let c0 = (value >> CLASS_C0SHIFT) & C0BIT;
let c1 = (value >> CLASS_C1SHIFT) & C1BIT;
let class = c0 + c1;
self.class = MessageClass(class as u8);
let a = value & METHOD_ABITS;
let b = (value >> METHOD_BSHIFT) & METHOD_BBITS;
let d = (value >> METHOD_DSHIFT) & METHOD_DBITS;
let m = a + b + d;
self.method = Method(m);
}
}