serf_types/
push_pull.rs

1use byteorder::{ByteOrder, NetworkEndian};
2use indexmap::{IndexMap, IndexSet};
3use memberlist_types::TinyVec;
4use transformable::Transformable;
5
6use super::{LamportTime, LamportTimeTransformError, UserEvents, UserEventsTransformError};
7
8/// Used when doing a state exchange. This
9/// is a relatively large message, but is sent infrequently
10#[viewit::viewit(setters(prefix = "with"))]
11#[derive(Debug, Clone)]
12#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
13#[cfg_attr(
14  feature = "serde",
15  serde(bound(
16    serialize = "I: core::cmp::Eq + core::hash::Hash + serde::Serialize",
17    deserialize = "I: core::cmp::Eq + core::hash::Hash + serde::Deserialize<'de>"
18  ))
19)]
20pub struct PushPullMessage<I> {
21  /// Current node lamport time
22  #[viewit(
23    getter(const, style = "move", attrs(doc = "Returns the lamport time")),
24    setter(const, attrs(doc = "Sets the lamport time (Builder pattern)"))
25  )]
26  ltime: LamportTime,
27  /// Maps the node to its status time
28  #[viewit(
29    getter(
30      const,
31      style = "ref",
32      attrs(doc = "Returns the maps the node to its status time")
33    ),
34    setter(attrs(doc = "Sets the maps the node to its status time (Builder pattern)"))
35  )]
36  status_ltimes: IndexMap<I, LamportTime>,
37  /// List of left nodes
38  #[viewit(
39    getter(const, style = "ref", attrs(doc = "Returns the list of left nodes")),
40    setter(attrs(doc = "Sets the list of left nodes (Builder pattern)"))
41  )]
42  left_members: IndexSet<I>,
43  /// Lamport time for event clock
44  #[viewit(
45    getter(
46      const,
47      style = "move",
48      attrs(doc = "Returns the lamport time for event clock")
49    ),
50    setter(
51      const,
52      attrs(doc = "Sets the lamport time for event clock (Builder pattern)")
53    )
54  )]
55  event_ltime: LamportTime,
56  /// Recent events
57  #[viewit(
58    getter(const, style = "ref", attrs(doc = "Returns the recent events")),
59    setter(attrs(doc = "Sets the recent events (Builder pattern)"))
60  )]
61  events: TinyVec<Option<UserEvents>>,
62  /// Lamport time for query clock
63  #[viewit(
64    getter(
65      const,
66      style = "move",
67      attrs(doc = "Returns the lamport time for query clock")
68    ),
69    setter(
70      const,
71      attrs(doc = "Sets the lamport time for query clock (Builder pattern)")
72    )
73  )]
74  query_ltime: LamportTime,
75}
76
77impl<I> PartialEq for PushPullMessage<I>
78where
79  I: core::hash::Hash + Eq,
80{
81  fn eq(&self, other: &Self) -> bool {
82    self.ltime == other.ltime
83      && self.status_ltimes == other.status_ltimes
84      && self.left_members == other.left_members
85      && self.event_ltime == other.event_ltime
86      && self.events == other.events
87      && self.query_ltime == other.query_ltime
88  }
89}
90
91/// Used when doing a state exchange. This
92/// is a relatively large message, but is sent infrequently
93#[viewit::viewit(getters(skip), setters(skip))]
94#[derive(Debug)]
95#[cfg_attr(feature = "serde", derive(serde::Serialize))]
96pub struct PushPullMessageRef<'a, I> {
97  /// Current node lamport time
98  ltime: LamportTime,
99  /// Maps the node to its status time
100  status_ltimes: &'a IndexMap<I, LamportTime>,
101  /// List of left nodes
102  left_members: &'a IndexSet<I>,
103  /// Lamport time for event clock
104  event_ltime: LamportTime,
105  /// Recent events
106  events: &'a [Option<UserEvents>],
107  /// Lamport time for query clock
108  query_ltime: LamportTime,
109}
110
111impl<I> Clone for PushPullMessageRef<'_, I> {
112  fn clone(&self) -> Self {
113    *self
114  }
115}
116
117impl<I> Copy for PushPullMessageRef<'_, I> {}
118
119impl<'a, I> From<&'a PushPullMessage<I>> for PushPullMessageRef<'a, I> {
120  #[inline]
121  fn from(msg: &'a PushPullMessage<I>) -> Self {
122    Self {
123      ltime: msg.ltime,
124      status_ltimes: &msg.status_ltimes,
125      left_members: &msg.left_members,
126      event_ltime: msg.event_ltime,
127      events: &msg.events,
128      query_ltime: msg.query_ltime,
129    }
130  }
131}
132
133impl<'a, I> From<&'a mut PushPullMessage<I>> for PushPullMessageRef<'a, I> {
134  #[inline]
135  fn from(msg: &'a mut PushPullMessage<I>) -> Self {
136    Self {
137      ltime: msg.ltime,
138      status_ltimes: &msg.status_ltimes,
139      left_members: &msg.left_members,
140      event_ltime: msg.event_ltime,
141      events: &msg.events,
142      query_ltime: msg.query_ltime,
143    }
144  }
145}
146
147impl<I> super::Encodable for PushPullMessageRef<'_, I>
148where
149  I: Transformable,
150{
151  type Error = PushPullMessageTransformError<I>;
152
153  /// Returns the encoded length of the message
154  fn encoded_len(&self) -> usize {
155    4 + Transformable::encoded_len(&self.ltime)
156      + 4
157      + self
158        .status_ltimes
159        .iter()
160        .map(|(k, v)| Transformable::encoded_len(k) + Transformable::encoded_len(v))
161        .sum::<usize>()
162      + 4
163      + self
164        .left_members
165        .iter()
166        .map(Transformable::encoded_len)
167        .sum::<usize>()
168      + Transformable::encoded_len(&self.event_ltime)
169      + 4
170      + self
171        .events
172        .iter()
173        .map(|e| match e {
174          Some(e) => 1 + Transformable::encoded_len(e),
175          None => 1,
176        })
177        .sum::<usize>()
178      + Transformable::encoded_len(&self.query_ltime)
179  }
180
181  /// Encodes the message into the given buffer
182  fn encode(&self, dst: &mut [u8]) -> Result<usize, PushPullMessageTransformError<I>> {
183    let encoded_len = self.encoded_len();
184    if dst.len() < encoded_len {
185      return Err(PushPullMessageTransformError::BufferTooSmall);
186    }
187
188    let mut offset = 0;
189    NetworkEndian::write_u32(&mut dst[offset..offset + 4], encoded_len as u32);
190    offset += 4;
191
192    offset += Transformable::encode(&self.ltime, &mut dst[offset..])?;
193    let len = self.status_ltimes.len() as u32;
194    NetworkEndian::write_u32(&mut dst[offset..offset + 4], len);
195    offset += 4;
196    for (node, ltime) in self.status_ltimes.iter() {
197      offset += Transformable::encode(node, &mut dst[offset..]).map_err(Self::Error::Id)?;
198      offset += Transformable::encode(ltime, &mut dst[offset..])?;
199    }
200
201    let len = self.left_members.len() as u32;
202    NetworkEndian::write_u32(&mut dst[offset..offset + 4], len);
203    offset += 4;
204    for node in self.left_members.iter() {
205      offset += Transformable::encode(node, &mut dst[offset..]).map_err(Self::Error::Id)?;
206    }
207
208    offset += Transformable::encode(&self.event_ltime, &mut dst[offset..])?;
209    let len = self.events.len() as u32;
210    NetworkEndian::write_u32(&mut dst[offset..offset + 4], len);
211    offset += 4;
212    for e in self.events.iter() {
213      match e {
214        Some(e) => {
215          dst[offset] = 1;
216          offset += 1;
217          offset += Transformable::encode(e, &mut dst[offset..])?;
218        }
219        None => {
220          dst[offset] = 0;
221          offset += 1;
222        }
223      }
224    }
225
226    offset += Transformable::encode(&self.query_ltime, &mut dst[offset..])?;
227
228    debug_assert_eq!(
229      offset, encoded_len,
230      "expect write {} bytes, but actual write {} bytes",
231      encoded_len, offset
232    );
233
234    Ok(offset)
235  }
236}
237
238/// Error that can occur when transforming a [`PushPullMessage`] or [`PushPullMessageRef`].
239#[derive(thiserror::Error)]
240pub enum PushPullMessageTransformError<I>
241where
242  I: Transformable,
243{
244  /// Not enough bytes to decode [`PushPullMessage`]
245  #[error("not enough bytes to decode PushPullMessage")]
246  NotEnoughBytes,
247  /// Encode buffer too small
248  #[error("encode buffer too small")]
249  BufferTooSmall,
250  /// Error transforming [`I`]
251  #[error(transparent)]
252  Id(I::Error),
253  /// Error when we do not have enough nodes
254  #[error("expect {expect} nodes, but actual decode {got} nodes")]
255  MissingLeftMember {
256    /// Expect
257    expect: usize,
258    /// Actual
259    got: usize,
260  },
261  /// Error when we do not have enough status time
262  #[error("expect {expect} status time, but actual decode {got} status time")]
263  MissingNodeStatusTime {
264    /// Expect
265    expect: usize,
266    /// Actual
267    got: usize,
268  },
269  /// Error transforming [`LamportTime`]
270  #[error(transparent)]
271  LamportTime(#[from] LamportTimeTransformError),
272  /// Error transforming [`UserEvents`]
273  #[error(transparent)]
274  UserEvents(#[from] UserEventsTransformError),
275  /// Error when we do not have enough events
276  #[error("expect {expect} events, but actual decode {got} events")]
277  MissingEvents {
278    /// Expect
279    expect: usize,
280    /// Actual
281    got: usize,
282  },
283}
284
285impl<I> core::fmt::Debug for PushPullMessageTransformError<I>
286where
287  I: Transformable,
288{
289  fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
290    write!(f, "{}", self)
291  }
292}
293
294impl<I> Transformable for PushPullMessage<I>
295where
296  I: Transformable + core::hash::Hash + Eq,
297{
298  type Error = PushPullMessageTransformError<I>;
299
300  fn encode(&self, dst: &mut [u8]) -> Result<usize, Self::Error> {
301    super::Encodable::encode(&PushPullMessageRef::from(self), dst)
302  }
303
304  fn encoded_len(&self) -> usize {
305    super::Encodable::encoded_len(&PushPullMessageRef::from(self))
306  }
307
308  fn decode(src: &[u8]) -> Result<(usize, Self), Self::Error>
309  where
310    Self: Sized,
311  {
312    let src_len = src.len();
313    if src_len < 4 {
314      return Err(PushPullMessageTransformError::NotEnoughBytes);
315    }
316
317    let encoded_len = NetworkEndian::read_u32(&src[..4]) as usize;
318    if src_len < encoded_len {
319      return Err(PushPullMessageTransformError::NotEnoughBytes);
320    }
321
322    let mut offset = 4;
323    let (n, ltime) = LamportTime::decode(&src[offset..])?;
324    offset += n;
325
326    let len = NetworkEndian::read_u32(&src[offset..offset + 4]) as usize;
327    offset += 4;
328
329    let mut status_ltimes = IndexMap::with_capacity(len);
330    for _ in 0..len {
331      let (n, node) = I::decode(&src[offset..]).map_err(Self::Error::Id)?;
332      offset += n;
333      let (n, ltime) = LamportTime::decode(&src[offset..])?;
334      offset += n;
335      status_ltimes.insert(node, ltime);
336    }
337
338    let len = NetworkEndian::read_u32(&src[offset..offset + 4]) as usize;
339    offset += 4;
340
341    let mut left_members = IndexSet::with_capacity(len);
342    for _ in 0..len {
343      let (n, node) = I::decode(&src[offset..]).map_err(Self::Error::Id)?;
344      offset += n;
345      left_members.insert(node);
346    }
347
348    let (n, event_ltime) = LamportTime::decode(&src[offset..])?;
349    offset += n;
350
351    let len = NetworkEndian::read_u32(&src[offset..offset + 4]) as usize;
352    offset += 4;
353
354    let mut events = TinyVec::with_capacity(len);
355    for _ in 0..len {
356      let has_event = src[offset];
357      offset += 1;
358      if has_event == 1 {
359        let (n, event) = UserEvents::decode(&src[offset..])?;
360        offset += n;
361        events.push(Some(event));
362      } else {
363        events.push(None);
364      }
365    }
366
367    let (n, query_ltime) = LamportTime::decode(&src[offset..])?;
368    offset += n;
369
370    debug_assert_eq!(
371      offset, encoded_len,
372      "expect read {} bytes, but actual read {} bytes",
373      encoded_len, offset
374    );
375
376    Ok((
377      encoded_len,
378      PushPullMessage {
379        ltime,
380        status_ltimes,
381        left_members,
382        event_ltime,
383        events,
384        query_ltime,
385      },
386    ))
387  }
388}
389
390#[cfg(test)]
391mod tests {
392  use rand::{distributions::Alphanumeric, thread_rng, Rng};
393  use smol_str::SmolStr;
394
395  use super::*;
396
397  impl PushPullMessage<SmolStr> {
398    fn random(size: usize) -> Self {
399      let mut status_ltimes = IndexMap::new();
400      for _ in 0..size {
401        let id = thread_rng()
402          .sample_iter(Alphanumeric)
403          .take(size)
404          .collect::<Vec<u8>>();
405        let id = String::from_utf8(id).unwrap().into();
406
407        status_ltimes.insert(id, LamportTime::random());
408      }
409
410      let mut left_members = IndexSet::new();
411      for _ in 0..size {
412        let id = thread_rng()
413          .sample_iter(Alphanumeric)
414          .take(size)
415          .collect::<Vec<u8>>();
416        let id = String::from_utf8(id).unwrap().into();
417        left_members.insert(id);
418      }
419
420      let mut events = TinyVec::new();
421      for i in 0..size {
422        if i % 2 == 0 {
423          events.push(None);
424        } else {
425          events.push(Some(UserEvents::random(size, size % 10)));
426        }
427      }
428
429      Self {
430        ltime: LamportTime::random(),
431        status_ltimes,
432        left_members,
433        event_ltime: LamportTime::random(),
434        events,
435        query_ltime: LamportTime::random(),
436      }
437    }
438  }
439
440  #[test]
441  fn test_push_pull_message_transform() {
442    futures::executor::block_on(async {
443      for i in 0..100 {
444        let msg = PushPullMessage::random(i);
445        let mut buf = vec![0; msg.encoded_len()];
446        let encoded_len = msg.encode(&mut buf).unwrap();
447        assert_eq!(encoded_len, msg.encoded_len());
448
449        let (decoded_len, decoded) = PushPullMessage::<SmolStr>::decode(&buf).unwrap();
450        assert_eq!(decoded_len, encoded_len);
451        assert_eq!(decoded, msg);
452
453        let (decoded_len, decoded) =
454          PushPullMessage::<SmolStr>::decode_from_reader(&mut std::io::Cursor::new(&buf)).unwrap();
455        assert_eq!(decoded_len, encoded_len);
456        assert_eq!(decoded, msg);
457
458        let (decoded_len, decoded) =
459          PushPullMessage::<SmolStr>::decode_from_async_reader(&mut futures::io::Cursor::new(&buf))
460            .await
461            .unwrap();
462        assert_eq!(decoded_len, encoded_len);
463        assert_eq!(decoded, msg);
464      }
465    });
466  }
467}