use std::sync::Arc;
use byteorder::{ByteOrder, NetworkEndian};
use memberlist_types::CheapClone;
use super::{
DelegateVersion, MemberlistDelegateVersion, MemberlistProtocolVersion, Node, NodeTransformError,
ProtocolVersion, Tags, TagsTransformError, Transformable, UnknownDelegateVersion,
UnknownMemberlistDelegateVersion, UnknownMemberlistProtocolVersion, UnknownProtocolVersion,
};
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, bytemuck::NoUninit)]
#[repr(u8)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum MemberStatus {
None = 0,
Alive = 1,
Leaving = 2,
Left = 3,
Failed = 4,
}
impl core::fmt::Display for MemberStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
impl TryFrom<u8> for MemberStatus {
type Error = UnknownMemberStatus;
fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
0 => Ok(Self::None),
1 => Ok(Self::Alive),
2 => Ok(Self::Leaving),
3 => Ok(Self::Left),
4 => Ok(Self::Failed),
_ => Err(UnknownMemberStatus(value)),
}
}
}
impl MemberStatus {
#[inline]
pub const fn as_str(&self) -> &'static str {
match self {
Self::None => "none",
Self::Alive => "alive",
Self::Leaving => "leaving",
Self::Left => "left",
Self::Failed => "failed",
}
}
}
#[derive(Debug, Copy, Clone, thiserror::Error)]
#[error("Unknown member status: {0}")]
pub struct UnknownMemberStatus(u8);
#[viewit::viewit(setters(prefix = "with"))]
#[derive(Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Member<I, A> {
#[viewit(
getter(const, style = "ref", attrs(doc = "Returns the node")),
setter(attrs(doc = "Sets the node (Builder pattern)"))
)]
node: Node<I, A>,
#[viewit(
getter(const, style = "ref", attrs(doc = "Returns the tags")),
setter(attrs(doc = "Sets the tags (Builder pattern)"))
)]
tags: Arc<Tags>,
#[viewit(
getter(const, style = "ref", attrs(doc = "Returns the status")),
setter(attrs(doc = "Sets the status (Builder pattern)"))
)]
status: MemberStatus,
#[viewit(
getter(const, attrs(doc = "Returns the memberlist protocol version")),
setter(
const,
attrs(doc = "Sets the memberlist protocol version (Builder pattern)")
)
)]
memberlist_protocol_version: MemberlistProtocolVersion,
#[viewit(
getter(const, attrs(doc = "Returns the memberlist delegate version")),
setter(
const,
attrs(doc = "Sets the memberlist delegate version (Builder pattern)")
)
)]
memberlist_delegate_version: MemberlistDelegateVersion,
#[viewit(
getter(const, attrs(doc = "Returns the ruserf protocol version")),
setter(
const,
attrs(doc = "Sets the ruserf protocol version (Builder pattern)")
)
)]
protocol_version: ProtocolVersion,
#[viewit(
getter(const, attrs(doc = "Returns the ruserf delegate version")),
setter(
const,
attrs(doc = "Sets the ruserf delegate version (Builder pattern)")
)
)]
delegate_version: DelegateVersion,
}
impl<I, A> Member<I, A> {
#[inline]
pub fn new(node: Node<I, A>, tags: Tags, status: MemberStatus) -> Self {
Self {
node,
tags: Arc::new(tags),
status,
memberlist_protocol_version: MemberlistProtocolVersion::V1,
memberlist_delegate_version: MemberlistDelegateVersion::V1,
protocol_version: ProtocolVersion::V1,
delegate_version: DelegateVersion::V1,
}
}
}
impl<I: Clone, A: Clone> Clone for Member<I, A> {
fn clone(&self) -> Self {
Self {
node: self.node.clone(),
tags: self.tags.clone(),
status: self.status,
memberlist_protocol_version: self.memberlist_protocol_version,
memberlist_delegate_version: self.memberlist_delegate_version,
protocol_version: self.protocol_version,
delegate_version: self.delegate_version,
}
}
}
impl<I: CheapClone, A: CheapClone> CheapClone for Member<I, A> {
fn cheap_clone(&self) -> Self {
Self {
node: self.node.cheap_clone(),
tags: self.tags.cheap_clone(),
status: self.status,
memberlist_protocol_version: self.memberlist_protocol_version,
memberlist_delegate_version: self.memberlist_delegate_version,
protocol_version: self.protocol_version,
delegate_version: self.delegate_version,
}
}
}
#[derive(thiserror::Error)]
pub enum MemberTransformError<I, A>
where
I: Transformable,
A: Transformable,
{
#[error(transparent)]
Node(#[from] NodeTransformError<I, A>),
#[error(transparent)]
Tags(#[from] TagsTransformError),
#[error(transparent)]
MemberStatus(#[from] UnknownMemberStatus),
#[error("encode buffer too small")]
BufferTooSmall,
#[error("not enough bytes to decode `Member`")]
NotEnoughBytes,
#[error(transparent)]
MemberlistProtocolVersion(#[from] UnknownMemberlistProtocolVersion),
#[error(transparent)]
MemberlistDelegateVersion(#[from] UnknownMemberlistDelegateVersion),
#[error(transparent)]
ProtocolVersion(#[from] UnknownProtocolVersion),
#[error(transparent)]
DelegateVersion(#[from] UnknownDelegateVersion),
}
impl<I, A> core::fmt::Debug for MemberTransformError<I, A>
where
I: Transformable,
A: Transformable,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "{}", self)
}
}
impl<I, A> Transformable for Member<I, A>
where
I: Transformable,
A: Transformable,
{
type Error = MemberTransformError<I, A>;
fn encode(&self, dst: &mut [u8]) -> Result<usize, Self::Error> {
let encoded_len = self.encoded_len();
if dst.len() < encoded_len {
return Err(Self::Error::BufferTooSmall);
}
let mut offset = 0;
NetworkEndian::write_u32(&mut dst[offset..], encoded_len as u32);
offset += 4;
offset += self.node.encode(&mut dst[offset..])?;
offset += self.tags.encode(&mut dst[offset..])?;
dst[offset] = self.status as u8;
offset += 1;
dst[offset] = self.memberlist_protocol_version as u8;
offset += 1;
dst[offset] = self.memberlist_delegate_version as u8;
offset += 1;
dst[offset] = self.protocol_version as u8;
offset += 1;
dst[offset] = self.delegate_version as u8;
offset += 1;
debug_assert_eq!(
offset, encoded_len,
"expect write {} bytes, but actually write {} bytes",
offset, encoded_len
);
Ok(offset)
}
fn encoded_len(&self) -> usize {
4 + self.node.encoded_len()
+ self.tags.encoded_len()
+ 1 + 1 + 1 + 1 + 1 }
fn decode(src: &[u8]) -> Result<(usize, Self), Self::Error>
where
Self: Sized,
{
let src_len = src.len();
if src_len < 4 {
return Err(Self::Error::NotEnoughBytes);
}
let encoded_len = NetworkEndian::read_u32(&src[0..4]) as usize;
if src_len < encoded_len {
return Err(Self::Error::NotEnoughBytes);
}
let mut offset = 4;
let (node_len, node) = Node::decode(&src[offset..])?;
offset += node_len;
let (tags_len, tags) = Tags::decode(&src[offset..])?;
offset += tags_len;
if src_len < offset + 5 {
return Err(Self::Error::NotEnoughBytes);
}
let status = MemberStatus::try_from(src[offset])?;
offset += 1;
let memberlist_protocol_version = MemberlistProtocolVersion::try_from(src[offset])?;
offset += 1;
let memberlist_delegate_version = MemberlistDelegateVersion::try_from(src[offset])?;
offset += 1;
let protocol_version = ProtocolVersion::try_from(src[offset])?;
offset += 1;
let delegate_version = DelegateVersion::try_from(src[offset])?;
offset += 1;
debug_assert_eq!(
offset, encoded_len,
"expect read {} bytes, but actually read {} bytes",
offset, encoded_len
);
Ok((
encoded_len,
Self {
node,
tags: Arc::new(tags),
status,
memberlist_protocol_version,
memberlist_delegate_version,
protocol_version,
delegate_version,
},
))
}
}
#[cfg(test)]
mod tests {
use std::net::SocketAddr;
use rand::{distributions::Alphanumeric, random, thread_rng, Rng};
use smol_str::SmolStr;
use super::*;
impl Member<SmolStr, SocketAddr> {
fn random(num_tags: usize, size: usize) -> Self {
let id = thread_rng()
.sample_iter(Alphanumeric)
.take(size)
.collect::<Vec<u8>>();
let id = String::from_utf8(id).unwrap().into();
let addr = SocketAddr::from(([127, 0, 0, 1], random::<u16>()));
let node = Node::new(id, addr);
let tags = Tags::random(num_tags, size);
Self {
node,
tags: Arc::new(tags),
status: MemberStatus::Alive,
memberlist_protocol_version: MemberlistProtocolVersion::V1,
memberlist_delegate_version: MemberlistDelegateVersion::V1,
protocol_version: ProtocolVersion::V1,
delegate_version: DelegateVersion::V1,
}
}
}
#[test]
fn member_encode_decode() {
futures::executor::block_on(async {
for i in 0..100 {
let filter = Member::random(i % 10, i);
let mut buf = vec![0; filter.encoded_len()];
let encoded_len = filter.encode(&mut buf).unwrap();
assert_eq!(encoded_len, filter.encoded_len());
let (decoded_len, decoded) = Member::<SmolStr, SocketAddr>::decode(&buf).unwrap();
assert_eq!(decoded_len, encoded_len);
assert_eq!(decoded, filter);
let (decoded_len, decoded) =
Member::<SmolStr, SocketAddr>::decode_from_reader(&mut std::io::Cursor::new(&buf))
.unwrap();
assert_eq!(decoded_len, encoded_len);
assert_eq!(decoded, filter);
let (decoded_len, decoded) = Member::<SmolStr, SocketAddr>::decode_from_async_reader(
&mut futures::io::Cursor::new(&buf),
)
.await
.unwrap();
assert_eq!(decoded_len, encoded_len);
assert_eq!(decoded, filter);
}
});
}
}