serf_types/
join.rs

1use byteorder::{ByteOrder, NetworkEndian};
2use transformable::utils::encoded_u64_varint_len;
3
4use crate::LamportTimeTransformError;
5
6use super::{LamportTime, Transformable};
7
8/// The message broadcasted after we join to
9/// associated the node with a lamport clock
10#[viewit::viewit(setters(prefix = "with"))]
11#[derive(Debug, Clone, Eq, PartialEq)]
12#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
13pub struct JoinMessage<I> {
14  /// The lamport time
15  #[viewit(
16    getter(const, attrs(doc = "Returns the lamport time for this message")),
17    setter(
18      const,
19      attrs(doc = "Sets the lamport time for this message (Builder pattern)")
20    )
21  )]
22  ltime: LamportTime,
23  /// The id of the node
24  #[viewit(
25    getter(const, style = "ref", attrs(doc = "Returns the node")),
26    setter(attrs(doc = "Sets the node (Builder pattern)"))
27  )]
28  id: I,
29}
30
31impl<I> JoinMessage<I> {
32  /// Create a new join message
33  pub fn new(ltime: LamportTime, id: I) -> Self {
34    Self { ltime, id }
35  }
36
37  /// Set the lamport time
38  #[inline]
39  pub fn set_ltime(&mut self, ltime: LamportTime) -> &mut Self {
40    self.ltime = ltime;
41    self
42  }
43
44  /// Set the id of the node
45  #[inline]
46  pub fn set_id(&mut self, id: I) -> &mut Self {
47    self.id = id;
48    self
49  }
50}
51
52/// Error that can occur when transforming a JoinMessage
53#[derive(thiserror::Error)]
54pub enum JoinMessageTransformError<I: Transformable> {
55  /// Not enough bytes to decode JoinMessage
56  #[error("not enough bytes to decode JoinMessage")]
57  NotEnoughBytes,
58  /// Encode buffer too small
59  #[error("encode buffer too small")]
60  EncodeBufferTooSmall,
61  /// Error transforming Id
62  #[error(transparent)]
63  Id(I::Error),
64
65  /// Error transforming LamportTime
66  #[error(transparent)]
67  LamportTime(#[from] LamportTimeTransformError),
68}
69
70impl<I: Transformable> core::fmt::Debug for JoinMessageTransformError<I> {
71  fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
72    write!(f, "{}", self)
73  }
74}
75
76impl<I> Transformable for JoinMessage<I>
77where
78  I: Transformable,
79{
80  type Error = JoinMessageTransformError<I>;
81
82  fn encode(&self, dst: &mut [u8]) -> Result<usize, Self::Error> {
83    let encoded_len = self.encoded_len();
84    if dst.len() < encoded_len {
85      return Err(Self::Error::EncodeBufferTooSmall);
86    }
87
88    let mut offset = 0;
89    NetworkEndian::write_u32(&mut dst[offset..offset + 4], encoded_len as u32);
90    offset += 4;
91
92    offset += self.ltime.encode(&mut dst[offset..])?;
93    offset += self
94      .id
95      .encode(&mut dst[offset..])
96      .map_err(Self::Error::Id)?;
97
98    debug_assert_eq!(
99      offset, encoded_len,
100      "expect write {} bytes, but actual write {} bytes",
101      encoded_len, offset
102    );
103    Ok(offset)
104  }
105
106  fn encoded_len(&self) -> usize {
107    4 + encoded_u64_varint_len(self.ltime.0) + self.id.encoded_len()
108  }
109
110  fn decode(src: &[u8]) -> Result<(usize, Self), Self::Error>
111  where
112    Self: Sized,
113  {
114    if src.len() < 4 {
115      return Err(Self::Error::NotEnoughBytes);
116    }
117
118    let encoded_len = NetworkEndian::read_u32(&src[..4]) as usize;
119    if src.len() < encoded_len {
120      return Err(Self::Error::NotEnoughBytes);
121    }
122
123    let mut offset = 4;
124    let (n, ltime) = LamportTime::decode(&src[offset..])?;
125    offset += n;
126
127    let (n, id) = I::decode(&src[offset..]).map_err(Self::Error::Id)?;
128    offset += n;
129
130    debug_assert_eq!(
131      offset, encoded_len,
132      "expect read {} bytes, but actual read {} bytes",
133      encoded_len, offset
134    );
135    Ok((encoded_len, Self { ltime, id }))
136  }
137}
138
139#[cfg(test)]
140mod tests {
141  use rand::{distributions::Alphanumeric, thread_rng, Rng};
142  use smol_str::SmolStr;
143
144  use super::*;
145
146  impl JoinMessage<SmolStr> {
147    fn random(size: usize) -> Self {
148      let id = thread_rng()
149        .sample_iter(Alphanumeric)
150        .take(size)
151        .collect::<Vec<u8>>();
152      let id = String::from_utf8(id).unwrap().into();
153
154      Self {
155        ltime: LamportTime::random(),
156        id,
157      }
158    }
159  }
160
161  #[test]
162  fn test_transfrom_encode_decode() {
163    futures::executor::block_on(async {
164      for i in 0..100 {
165        let filter = JoinMessage::random(i);
166        let mut buf = vec![0; filter.encoded_len()];
167        let encoded_len = filter.encode(&mut buf).unwrap();
168        assert_eq!(encoded_len, filter.encoded_len());
169
170        let (decoded_len, decoded) = JoinMessage::<SmolStr>::decode(&buf).unwrap();
171        assert_eq!(decoded_len, encoded_len);
172        assert_eq!(decoded, filter);
173
174        let (decoded_len, decoded) =
175          JoinMessage::<SmolStr>::decode_from_reader(&mut std::io::Cursor::new(&buf)).unwrap();
176        assert_eq!(decoded_len, encoded_len);
177        assert_eq!(decoded, filter);
178
179        let (decoded_len, decoded) =
180          JoinMessage::<SmolStr>::decode_from_async_reader(&mut futures::io::Cursor::new(&buf))
181            .await
182            .unwrap();
183        assert_eq!(decoded_len, encoded_len);
184        assert_eq!(decoded, filter);
185      }
186    });
187  }
188}