use alloc::{vec, vec::Vec};
use nom::bytes::streaming::take;
use nom::combinator::verify;
use nom::error::{make_error, ErrorKind};
use nom::number::streaming::{be_u16, be_u8};
use nom::{Err, IResult};
use nom_derive::Parse;
use crate::tls_alert::*;
use crate::tls_handshake::*;
#[derive(Clone, Debug, PartialEq)]
pub enum TlsMessage<'a> {
Handshake(TlsMessageHandshake<'a>),
ChangeCipherSpec,
Alert(TlsMessageAlert),
ApplicationData(TlsMessageApplicationData<'a>),
Heartbeat(TlsMessageHeartbeat<'a>),
}
#[derive(Clone, Debug, PartialEq)]
pub struct TlsMessageApplicationData<'a> {
pub blob: &'a [u8],
}
#[derive(Clone, Debug, PartialEq)]
pub struct TlsMessageHeartbeat<'a> {
pub heartbeat_type: TlsHeartbeatMessageType,
pub payload_len: u16,
pub payload: &'a [u8],
}
pub fn parse_tls_message_changecipherspec(i: &[u8]) -> IResult<&[u8], TlsMessage> {
let (i, _) = verify(be_u8, |&tag| tag == 0x01)(i)?;
Ok((i, TlsMessage::ChangeCipherSpec))
}
pub fn parse_tls_message_alert(i: &[u8]) -> IResult<&[u8], TlsMessage> {
let (i, alert) = TlsMessageAlert::parse(i)?;
Ok((i, TlsMessage::Alert(alert)))
}
pub fn parse_tls_message_applicationdata(i: &[u8]) -> IResult<&[u8], TlsMessage> {
let msg = TlsMessage::ApplicationData(TlsMessageApplicationData { blob: i });
Ok((&[], msg))
}
pub fn parse_tls_message_heartbeat(
i: &[u8],
tls_plaintext_len: u16,
) -> IResult<&[u8], Vec<TlsMessage>> {
let (i, heartbeat_type) = TlsHeartbeatMessageType::parse(i)?;
let (i, payload_len) = be_u16(i)?;
if tls_plaintext_len < 3 {
return Err(Err::Error(make_error(i, ErrorKind::Verify)));
}
let (i, payload) = take(payload_len as usize)(i)?;
let v = vec![TlsMessage::Heartbeat(TlsMessageHeartbeat {
heartbeat_type,
payload_len,
payload,
})];
Ok((i, v))
}