use byteorder::{ByteOrder, NetworkEndian};
use smol_str::SmolStr;
use transformable::{
  BytesTransformError, DurationTransformError, StringTransformError, Transformable,
};
use std::time::Duration;
use memberlist_types::{bytes::Bytes, Node, NodeTransformError, TinyVec};
use super::{LamportTime, LamportTimeTransformError};
bitflags::bitflags! {
  #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
  #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
  #[cfg_attr(feature = "serde", serde(transparent))]
  pub struct QueryFlag: u32 {
    const ACK = 1 << 0;
    const NO_BROADCAST = 1 << 1;
  }
}
#[viewit::viewit(getters(style = "ref"), setters(prefix = "with"))]
#[derive(Debug, Clone, Eq, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct QueryMessage<I, A> {
  #[viewit(
    getter(const, style = "move", attrs(doc = "Returns the event lamport time")),
    setter(const, attrs(doc = "Sets the event lamport time (Builder pattern)"))
  )]
  ltime: LamportTime,
  #[viewit(
    getter(const, style = "move", attrs(doc = "Returns the query id")),
    setter(attrs(doc = "Sets the query id (Builder pattern)"))
  )]
  id: u32,
  #[viewit(
    getter(const, attrs(doc = "Returns the from node")),
    setter(attrs(doc = "Sets the from node (Builder pattern)"))
  )]
  from: Node<I, A>,
  #[viewit(
    getter(const, attrs(doc = "Returns the potential query filters")),
    setter(attrs(doc = "Sets the potential query filters (Builder pattern)"))
  )]
  filters: TinyVec<Bytes>,
  #[viewit(
    getter(const, style = "move", attrs(doc = "Returns the flags")),
    setter(attrs(doc = "Sets the flags (Builder pattern)"))
  )]
  flags: QueryFlag,
  #[viewit(
    getter(
      const,
      style = "move",
      attrs(doc = "Returns the number of duplicate relayed responses")
    ),
    setter(attrs(doc = "Sets the number of duplicate relayed responses (Builder pattern)"))
  )]
  relay_factor: u8,
  #[viewit(
    getter(
      const,
      style = "move",
      attrs(doc = "Returns the maximum time between delivery and response")
    ),
    setter(attrs(doc = "Sets the maximum time between delivery and response (Builder pattern)"))
  )]
  timeout: Duration,
  #[viewit(
    getter(const, style = "ref", attrs(doc = "Returns the name of the query")),
    setter(attrs(doc = "Sets the name of the query (Builder pattern)"))
  )]
  name: SmolStr,
  #[viewit(
    getter(const, style = "ref", attrs(doc = "Returns the payload")),
    setter(attrs(doc = "Sets the payload (Builder pattern)"))
  )]
  payload: Bytes,
}
impl<I, A> QueryMessage<I, A> {
  #[inline]
  pub fn ack(&self) -> bool {
    self.flags.contains(QueryFlag::ACK)
  }
  #[inline]
  pub fn no_broadcast(&self) -> bool {
    self.flags.contains(QueryFlag::NO_BROADCAST)
  }
}
#[derive(thiserror::Error)]
pub enum QueryMessageTransformError<I, A>
where
  I: Transformable,
  A: Transformable,
{
  #[error("not enough bytes to decode QueryMessage")]
  NotEnoughBytes,
  #[error("encode buffer too small")]
  BufferTooSmall,
  #[error(transparent)]
  From(#[from] NodeTransformError<I, A>),
  #[error(transparent)]
  LamportTime(#[from] LamportTimeTransformError),
  #[error(transparent)]
  Payload(BytesTransformError),
  #[error(transparent)]
  Filters(BytesTransformError),
  #[error(transparent)]
  Name(#[from] StringTransformError),
  #[error(transparent)]
  Timeout(#[from] DurationTransformError),
}
impl<I, A> core::fmt::Debug for QueryMessageTransformError<I, A>
where
  I: Transformable,
  A: Transformable,
{
  fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
    write!(f, "{}", self)
  }
}
impl<I, A> Transformable for QueryMessage<I, A>
where
  I: Transformable,
  A: Transformable,
{
  type Error = QueryMessageTransformError<I, A>;
  fn encode(&self, dst: &mut [u8]) -> Result<usize, Self::Error> {
    let encoded_len = self.encoded_len();
    if dst.len() < encoded_len {
      return Err(Self::Error::BufferTooSmall);
    }
    let mut offset = 0;
    NetworkEndian::write_u32(&mut dst[offset..], encoded_len as u32);
    offset += 4;
    offset += self.ltime.encode(&mut dst[offset..])?;
    NetworkEndian::write_u32(&mut dst[offset..], self.id);
    offset += 4;
    offset += self.from.encode(&mut dst[offset..])?;
    NetworkEndian::write_u32(&mut dst[offset..], self.filters.len() as u32);
    offset += 4;
    for filter in self.filters.iter() {
      offset += filter
        .encode(&mut dst[offset..])
        .map_err(Self::Error::Filters)?;
    }
    NetworkEndian::write_u32(&mut dst[offset..], self.flags.bits());
    offset += 4;
    dst[offset] = self.relay_factor;
    offset += 1;
    offset += self.timeout.encode(&mut dst[offset..])?;
    offset += self.name.encode(&mut dst[offset..])?;
    offset += self
      .payload
      .encode(&mut dst[offset..])
      .map_err(Self::Error::Payload)?;
    debug_assert_eq!(
      offset, encoded_len,
      "expect write {} bytes, but actual write {} bytes",
      encoded_len, offset
    );
    Ok(offset)
  }
  fn encoded_len(&self) -> usize {
    4 + self.ltime.encoded_len()
      + 4 + self.from.encoded_len()
      + 4 + self.filters.iter().map(|f| f.encoded_len()).sum::<usize>()
      + 4 + 1 + self.timeout.encoded_len()
      + self.name.encoded_len()
      + self.payload.encoded_len()
  }
  fn decode(src: &[u8]) -> Result<(usize, Self), Self::Error>
  where
    Self: Sized,
  {
    let src_len = src.len();
    if src.len() < 4 {
      return Err(Self::Error::NotEnoughBytes);
    }
    let mut offset = 0;
    let len = NetworkEndian::read_u32(&src[offset..]) as usize;
    if src.len() < len {
      return Err(Self::Error::NotEnoughBytes);
    }
    offset += 4;
    let (n, ltime) = LamportTime::decode(&src[offset..])?;
    offset += n;
    if offset + 4 > src_len {
      return Err(Self::Error::NotEnoughBytes);
    }
    let id = NetworkEndian::read_u32(&src[offset..]);
    offset += 4;
    let (n, from) = Node::decode(&src[offset..])?;
    offset += n;
    if offset + 4 > src_len {
      return Err(Self::Error::NotEnoughBytes);
    }
    let num_filters = NetworkEndian::read_u32(&src[offset..]) as usize;
    offset += 4;
    let mut filters = TinyVec::with_capacity(num_filters);
    for _ in 0..num_filters {
      let (n, filter) = Bytes::decode(&src[offset..]).map_err(Self::Error::Filters)?;
      filters.push(filter);
      offset += n;
    }
    if offset + 4 > src_len {
      return Err(Self::Error::NotEnoughBytes);
    }
    let flags = QueryFlag::from_bits_retain(NetworkEndian::read_u32(&src[offset..]));
    offset += 4;
    if offset + 1 > src_len {
      return Err(Self::Error::NotEnoughBytes);
    }
    let relay_factor = src[offset];
    offset += 1;
    let (n, timeout) = Duration::decode(&src[offset..])?;
    offset += n;
    let (n, name) = SmolStr::decode(&src[offset..])?;
    offset += n;
    let (n, payload) = Bytes::decode(&src[offset..]).map_err(Self::Error::Payload)?;
    offset += n;
    debug_assert_eq!(
      offset, len,
      "expect read {} bytes, but actual read {} bytes",
      len, offset
    );
    Ok((
      offset,
      Self {
        ltime,
        id,
        from,
        filters,
        flags,
        relay_factor,
        timeout,
        name,
        payload,
      },
    ))
  }
}
#[viewit::viewit(getters(style = "ref"), setters(prefix = "with"))]
#[derive(Debug, Clone, Eq, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct QueryResponseMessage<I, A> {
  #[viewit(
    getter(const, attrs(doc = "Returns the lamport time for this message")),
    setter(
      const,
      attrs(doc = "Sets the lamport time for this message (Builder pattern)")
    )
  )]
  ltime: LamportTime,
  #[viewit(
    getter(const, attrs(doc = "Returns the query id")),
    setter(attrs(doc = "Sets the query id (Builder pattern)"))
  )]
  id: u32,
  #[viewit(
    getter(const, attrs(doc = "Returns the from node")),
    setter(attrs(doc = "Sets the from node (Builder pattern)"))
  )]
  from: Node<I, A>,
  #[viewit(
    getter(const, style = "ref", attrs(doc = "Returns the flags")),
    setter(attrs(doc = "Sets the flags (Builder pattern)"))
  )]
  flags: QueryFlag,
  #[viewit(
    getter(const, style = "ref", attrs(doc = "Returns the payload")),
    setter(attrs(doc = "Sets the payload (Builder pattern)"))
  )]
  payload: Bytes,
}
impl<I, A> QueryResponseMessage<I, A> {
  #[inline]
  pub fn ack(&self) -> bool {
    self.flags.contains(QueryFlag::ACK)
  }
  #[inline]
  pub fn no_broadcast(&self) -> bool {
    self.flags.contains(QueryFlag::NO_BROADCAST)
  }
}
#[derive(thiserror::Error)]
pub enum QueryResponseMessageTransformError<I, A>
where
  I: Transformable,
  A: Transformable,
{
  #[error("not enough bytes to decode QueryResponseMessage")]
  NotEnoughBytes,
  #[error("encode buffer too small")]
  BufferTooSmall,
  #[error(transparent)]
  Node(#[from] NodeTransformError<I, A>),
  #[error(transparent)]
  LamportTime(#[from] LamportTimeTransformError),
  #[error(transparent)]
  Payload(#[from] BytesTransformError),
}
impl<I, A> core::fmt::Debug for QueryResponseMessageTransformError<I, A>
where
  I: Transformable,
  A: Transformable,
{
  fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
    write!(f, "{}", self)
  }
}
impl<I, A> Transformable for QueryResponseMessage<I, A>
where
  I: Transformable,
  A: Transformable,
{
  type Error = QueryResponseMessageTransformError<I, A>;
  fn encode(&self, dst: &mut [u8]) -> Result<usize, Self::Error> {
    let encoded_len = self.encoded_len();
    if dst.len() < encoded_len {
      return Err(Self::Error::BufferTooSmall);
    }
    let mut offset = 0;
    NetworkEndian::write_u32(&mut dst[offset..], encoded_len as u32);
    offset += 4;
    offset += self.ltime.encode(&mut dst[offset..])?;
    NetworkEndian::write_u32(&mut dst[offset..], self.id);
    offset += 4;
    offset += self.from.encode(&mut dst[offset..])?;
    NetworkEndian::write_u32(&mut dst[offset..], self.flags.bits());
    offset += 4;
    offset += self.payload.encode(&mut dst[offset..])?;
    debug_assert_eq!(
      offset, encoded_len,
      "expect write {} bytes, but actual write {} bytes",
      encoded_len, offset
    );
    Ok(offset)
  }
  fn encoded_len(&self) -> usize {
    4 + self.ltime.encoded_len() + 4 + self.from.encoded_len() + 4 + self.payload.encoded_len()
  }
  fn decode(src: &[u8]) -> Result<(usize, Self), Self::Error>
  where
    Self: Sized,
  {
    let src_len = src.len();
    if src.len() < 4 {
      return Err(Self::Error::NotEnoughBytes);
    }
    let mut offset = 0;
    let len = NetworkEndian::read_u32(&src[offset..]) as usize;
    if src.len() < len {
      return Err(Self::Error::NotEnoughBytes);
    }
    offset += 4;
    let (n, ltime) = LamportTime::decode(&src[offset..])?;
    offset += n;
    if offset + 4 > src_len {
      return Err(Self::Error::NotEnoughBytes);
    }
    let id = NetworkEndian::read_u32(&src[offset..]);
    offset += 4;
    let (n, from) = Node::decode(&src[offset..])?;
    offset += n;
    if offset + 4 > src_len {
      return Err(Self::Error::NotEnoughBytes);
    }
    let flags = QueryFlag::from_bits_retain(NetworkEndian::read_u32(&src[offset..]));
    offset += 4;
    let (n, payload) = Bytes::decode(&src[offset..])?;
    offset += n;
    debug_assert_eq!(
      offset, len,
      "expect read {} bytes, but actual read {} bytes",
      len, offset
    );
    Ok((
      offset,
      Self {
        ltime,
        id,
        from,
        flags,
        payload,
      },
    ))
  }
}
#[cfg(test)]
mod tests {
  use std::net::SocketAddr;
  use rand::{distributions::Alphanumeric, random, thread_rng, Rng};
  use super::*;
  impl QueryMessage<SmolStr, SocketAddr> {
    fn random(size: usize, num_filters: usize) -> Self {
      let ltime = LamportTime::random();
      let id = random();
      let from_id = thread_rng()
        .sample_iter(Alphanumeric)
        .take(size)
        .collect::<Vec<u8>>();
      let from_id = String::from_utf8(from_id).unwrap().into();
      let addr = SocketAddr::from(([127, 0, 0, 1], random::<u16>()));
      let from = Node::new(from_id, addr);
      let filters = (0..num_filters)
        .map(|_| {
          let payload = thread_rng()
            .sample_iter(Alphanumeric)
            .take(size)
            .collect::<Vec<u8>>();
          payload.into()
        })
        .collect();
      let flags = QueryFlag::empty();
      let relay_factor = random();
      let timeout = Duration::from_secs(random::<u64>());
      let name = thread_rng()
        .sample_iter(Alphanumeric)
        .take(size)
        .collect::<Vec<u8>>();
      let name = SmolStr::from(String::from_utf8(name).unwrap());
      let payload = thread_rng()
        .sample_iter(Alphanumeric)
        .take(size)
        .collect::<Vec<u8>>();
      let payload = Bytes::from(payload);
      Self {
        ltime,
        id,
        from,
        filters,
        flags,
        relay_factor,
        timeout,
        name,
        payload,
      }
    }
  }
  impl QueryResponseMessage<SmolStr, SocketAddr> {
    fn random(size: usize) -> Self {
      let id = rand::random();
      let from_id = thread_rng()
        .sample_iter(Alphanumeric)
        .take(size)
        .collect::<Vec<u8>>();
      let from_id = String::from_utf8(from_id).unwrap().into();
      let addr = SocketAddr::from(([127, 0, 0, 1], random::<u16>()));
      let from = Node::new(from_id, addr);
      let flags = QueryFlag::empty();
      let payload = thread_rng()
        .sample_iter(Alphanumeric)
        .take(size)
        .collect::<Vec<u8>>();
      Self {
        ltime: LamportTime::random(),
        id,
        from,
        flags,
        payload: payload.into(),
      }
    }
  }
  #[test]
  fn test_query_response_transform() {
    futures::executor::block_on(async {
      for i in 0..100 {
        let filter = QueryResponseMessage::random(i);
        let mut buf = vec![0; filter.encoded_len()];
        let encoded_len = filter.encode(&mut buf).unwrap();
        assert_eq!(encoded_len, filter.encoded_len());
        let (decoded_len, decoded) =
          QueryResponseMessage::<SmolStr, SocketAddr>::decode(&buf).unwrap();
        assert_eq!(decoded_len, encoded_len);
        assert_eq!(decoded, filter);
        let (decoded_len, decoded) =
          QueryResponseMessage::<SmolStr, SocketAddr>::decode_from_reader(
            &mut std::io::Cursor::new(&buf),
          )
          .unwrap();
        assert_eq!(decoded_len, encoded_len);
        assert_eq!(decoded, filter);
        let (decoded_len, decoded) =
          QueryResponseMessage::<SmolStr, SocketAddr>::decode_from_async_reader(
            &mut futures::io::Cursor::new(&buf),
          )
          .await
          .unwrap();
        assert_eq!(decoded_len, encoded_len);
        assert_eq!(decoded, filter);
      }
    });
  }
  #[test]
  fn test_query_message_transform() {
    futures::executor::block_on(async {
      for i in 0..100 {
        let filter = QueryMessage::random(i, i % 10);
        let mut buf = vec![0; filter.encoded_len()];
        let encoded_len = filter.encode(&mut buf).unwrap();
        assert_eq!(encoded_len, filter.encoded_len());
        let (decoded_len, decoded) = QueryMessage::<SmolStr, SocketAddr>::decode(&buf).unwrap();
        assert_eq!(decoded_len, encoded_len);
        assert_eq!(decoded, filter);
        let (decoded_len, decoded) =
          QueryMessage::<SmolStr, SocketAddr>::decode_from_reader(&mut std::io::Cursor::new(&buf))
            .unwrap();
        assert_eq!(decoded_len, encoded_len);
        assert_eq!(decoded, filter);
        let (decoded_len, decoded) = QueryMessage::<SmolStr, SocketAddr>::decode_from_async_reader(
          &mut futures::io::Cursor::new(&buf),
        )
        .await
        .unwrap();
        assert_eq!(decoded_len, encoded_len);
        assert_eq!(decoded, filter);
      }
    });
  }
}