1use byteorder::{ByteOrder, NetworkEndian};
2use indexmap::{IndexMap, IndexSet};
3use memberlist_types::TinyVec;
4use transformable::Transformable;
5
6use super::{LamportTime, LamportTimeTransformError, UserEvents, UserEventsTransformError};
7
8#[viewit::viewit(setters(prefix = "with"))]
11#[derive(Debug, Clone)]
12#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
13#[cfg_attr(
14 feature = "serde",
15 serde(bound(
16 serialize = "I: core::cmp::Eq + core::hash::Hash + serde::Serialize",
17 deserialize = "I: core::cmp::Eq + core::hash::Hash + serde::Deserialize<'de>"
18 ))
19)]
20pub struct PushPullMessage<I> {
21 #[viewit(
23 getter(const, style = "move", attrs(doc = "Returns the lamport time")),
24 setter(const, attrs(doc = "Sets the lamport time (Builder pattern)"))
25 )]
26 ltime: LamportTime,
27 #[viewit(
29 getter(
30 const,
31 style = "ref",
32 attrs(doc = "Returns the maps the node to its status time")
33 ),
34 setter(attrs(doc = "Sets the maps the node to its status time (Builder pattern)"))
35 )]
36 status_ltimes: IndexMap<I, LamportTime>,
37 #[viewit(
39 getter(const, style = "ref", attrs(doc = "Returns the list of left nodes")),
40 setter(attrs(doc = "Sets the list of left nodes (Builder pattern)"))
41 )]
42 left_members: IndexSet<I>,
43 #[viewit(
45 getter(
46 const,
47 style = "move",
48 attrs(doc = "Returns the lamport time for event clock")
49 ),
50 setter(
51 const,
52 attrs(doc = "Sets the lamport time for event clock (Builder pattern)")
53 )
54 )]
55 event_ltime: LamportTime,
56 #[viewit(
58 getter(const, style = "ref", attrs(doc = "Returns the recent events")),
59 setter(attrs(doc = "Sets the recent events (Builder pattern)"))
60 )]
61 events: TinyVec<Option<UserEvents>>,
62 #[viewit(
64 getter(
65 const,
66 style = "move",
67 attrs(doc = "Returns the lamport time for query clock")
68 ),
69 setter(
70 const,
71 attrs(doc = "Sets the lamport time for query clock (Builder pattern)")
72 )
73 )]
74 query_ltime: LamportTime,
75}
76
77impl<I> PartialEq for PushPullMessage<I>
78where
79 I: core::hash::Hash + Eq,
80{
81 fn eq(&self, other: &Self) -> bool {
82 self.ltime == other.ltime
83 && self.status_ltimes == other.status_ltimes
84 && self.left_members == other.left_members
85 && self.event_ltime == other.event_ltime
86 && self.events == other.events
87 && self.query_ltime == other.query_ltime
88 }
89}
90
91#[viewit::viewit(getters(skip), setters(skip))]
94#[derive(Debug)]
95#[cfg_attr(feature = "serde", derive(serde::Serialize))]
96pub struct PushPullMessageRef<'a, I> {
97 ltime: LamportTime,
99 status_ltimes: &'a IndexMap<I, LamportTime>,
101 left_members: &'a IndexSet<I>,
103 event_ltime: LamportTime,
105 events: &'a [Option<UserEvents>],
107 query_ltime: LamportTime,
109}
110
111impl<I> Clone for PushPullMessageRef<'_, I> {
112 fn clone(&self) -> Self {
113 *self
114 }
115}
116
117impl<I> Copy for PushPullMessageRef<'_, I> {}
118
119impl<'a, I> From<&'a PushPullMessage<I>> for PushPullMessageRef<'a, I> {
120 #[inline]
121 fn from(msg: &'a PushPullMessage<I>) -> Self {
122 Self {
123 ltime: msg.ltime,
124 status_ltimes: &msg.status_ltimes,
125 left_members: &msg.left_members,
126 event_ltime: msg.event_ltime,
127 events: &msg.events,
128 query_ltime: msg.query_ltime,
129 }
130 }
131}
132
133impl<'a, I> From<&'a mut PushPullMessage<I>> for PushPullMessageRef<'a, I> {
134 #[inline]
135 fn from(msg: &'a mut PushPullMessage<I>) -> Self {
136 Self {
137 ltime: msg.ltime,
138 status_ltimes: &msg.status_ltimes,
139 left_members: &msg.left_members,
140 event_ltime: msg.event_ltime,
141 events: &msg.events,
142 query_ltime: msg.query_ltime,
143 }
144 }
145}
146
147impl<I> super::Encodable for PushPullMessageRef<'_, I>
148where
149 I: Transformable,
150{
151 type Error = PushPullMessageTransformError<I>;
152
153 fn encoded_len(&self) -> usize {
155 4 + Transformable::encoded_len(&self.ltime)
156 + 4
157 + self
158 .status_ltimes
159 .iter()
160 .map(|(k, v)| Transformable::encoded_len(k) + Transformable::encoded_len(v))
161 .sum::<usize>()
162 + 4
163 + self
164 .left_members
165 .iter()
166 .map(Transformable::encoded_len)
167 .sum::<usize>()
168 + Transformable::encoded_len(&self.event_ltime)
169 + 4
170 + self
171 .events
172 .iter()
173 .map(|e| match e {
174 Some(e) => 1 + Transformable::encoded_len(e),
175 None => 1,
176 })
177 .sum::<usize>()
178 + Transformable::encoded_len(&self.query_ltime)
179 }
180
181 fn encode(&self, dst: &mut [u8]) -> Result<usize, PushPullMessageTransformError<I>> {
183 let encoded_len = self.encoded_len();
184 if dst.len() < encoded_len {
185 return Err(PushPullMessageTransformError::BufferTooSmall);
186 }
187
188 let mut offset = 0;
189 NetworkEndian::write_u32(&mut dst[offset..offset + 4], encoded_len as u32);
190 offset += 4;
191
192 offset += Transformable::encode(&self.ltime, &mut dst[offset..])?;
193 let len = self.status_ltimes.len() as u32;
194 NetworkEndian::write_u32(&mut dst[offset..offset + 4], len);
195 offset += 4;
196 for (node, ltime) in self.status_ltimes.iter() {
197 offset += Transformable::encode(node, &mut dst[offset..]).map_err(Self::Error::Id)?;
198 offset += Transformable::encode(ltime, &mut dst[offset..])?;
199 }
200
201 let len = self.left_members.len() as u32;
202 NetworkEndian::write_u32(&mut dst[offset..offset + 4], len);
203 offset += 4;
204 for node in self.left_members.iter() {
205 offset += Transformable::encode(node, &mut dst[offset..]).map_err(Self::Error::Id)?;
206 }
207
208 offset += Transformable::encode(&self.event_ltime, &mut dst[offset..])?;
209 let len = self.events.len() as u32;
210 NetworkEndian::write_u32(&mut dst[offset..offset + 4], len);
211 offset += 4;
212 for e in self.events.iter() {
213 match e {
214 Some(e) => {
215 dst[offset] = 1;
216 offset += 1;
217 offset += Transformable::encode(e, &mut dst[offset..])?;
218 }
219 None => {
220 dst[offset] = 0;
221 offset += 1;
222 }
223 }
224 }
225
226 offset += Transformable::encode(&self.query_ltime, &mut dst[offset..])?;
227
228 debug_assert_eq!(
229 offset, encoded_len,
230 "expect write {} bytes, but actual write {} bytes",
231 encoded_len, offset
232 );
233
234 Ok(offset)
235 }
236}
237
238#[derive(thiserror::Error)]
240pub enum PushPullMessageTransformError<I>
241where
242 I: Transformable,
243{
244 #[error("not enough bytes to decode PushPullMessage")]
246 NotEnoughBytes,
247 #[error("encode buffer too small")]
249 BufferTooSmall,
250 #[error(transparent)]
252 Id(I::Error),
253 #[error("expect {expect} nodes, but actual decode {got} nodes")]
255 MissingLeftMember {
256 expect: usize,
258 got: usize,
260 },
261 #[error("expect {expect} status time, but actual decode {got} status time")]
263 MissingNodeStatusTime {
264 expect: usize,
266 got: usize,
268 },
269 #[error(transparent)]
271 LamportTime(#[from] LamportTimeTransformError),
272 #[error(transparent)]
274 UserEvents(#[from] UserEventsTransformError),
275 #[error("expect {expect} events, but actual decode {got} events")]
277 MissingEvents {
278 expect: usize,
280 got: usize,
282 },
283}
284
285impl<I> core::fmt::Debug for PushPullMessageTransformError<I>
286where
287 I: Transformable,
288{
289 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
290 write!(f, "{}", self)
291 }
292}
293
294impl<I> Transformable for PushPullMessage<I>
295where
296 I: Transformable + core::hash::Hash + Eq,
297{
298 type Error = PushPullMessageTransformError<I>;
299
300 fn encode(&self, dst: &mut [u8]) -> Result<usize, Self::Error> {
301 super::Encodable::encode(&PushPullMessageRef::from(self), dst)
302 }
303
304 fn encoded_len(&self) -> usize {
305 super::Encodable::encoded_len(&PushPullMessageRef::from(self))
306 }
307
308 fn decode(src: &[u8]) -> Result<(usize, Self), Self::Error>
309 where
310 Self: Sized,
311 {
312 let src_len = src.len();
313 if src_len < 4 {
314 return Err(PushPullMessageTransformError::NotEnoughBytes);
315 }
316
317 let encoded_len = NetworkEndian::read_u32(&src[..4]) as usize;
318 if src_len < encoded_len {
319 return Err(PushPullMessageTransformError::NotEnoughBytes);
320 }
321
322 let mut offset = 4;
323 let (n, ltime) = LamportTime::decode(&src[offset..])?;
324 offset += n;
325
326 let len = NetworkEndian::read_u32(&src[offset..offset + 4]) as usize;
327 offset += 4;
328
329 let mut status_ltimes = IndexMap::with_capacity(len);
330 for _ in 0..len {
331 let (n, node) = I::decode(&src[offset..]).map_err(Self::Error::Id)?;
332 offset += n;
333 let (n, ltime) = LamportTime::decode(&src[offset..])?;
334 offset += n;
335 status_ltimes.insert(node, ltime);
336 }
337
338 let len = NetworkEndian::read_u32(&src[offset..offset + 4]) as usize;
339 offset += 4;
340
341 let mut left_members = IndexSet::with_capacity(len);
342 for _ in 0..len {
343 let (n, node) = I::decode(&src[offset..]).map_err(Self::Error::Id)?;
344 offset += n;
345 left_members.insert(node);
346 }
347
348 let (n, event_ltime) = LamportTime::decode(&src[offset..])?;
349 offset += n;
350
351 let len = NetworkEndian::read_u32(&src[offset..offset + 4]) as usize;
352 offset += 4;
353
354 let mut events = TinyVec::with_capacity(len);
355 for _ in 0..len {
356 let has_event = src[offset];
357 offset += 1;
358 if has_event == 1 {
359 let (n, event) = UserEvents::decode(&src[offset..])?;
360 offset += n;
361 events.push(Some(event));
362 } else {
363 events.push(None);
364 }
365 }
366
367 let (n, query_ltime) = LamportTime::decode(&src[offset..])?;
368 offset += n;
369
370 debug_assert_eq!(
371 offset, encoded_len,
372 "expect read {} bytes, but actual read {} bytes",
373 encoded_len, offset
374 );
375
376 Ok((
377 encoded_len,
378 PushPullMessage {
379 ltime,
380 status_ltimes,
381 left_members,
382 event_ltime,
383 events,
384 query_ltime,
385 },
386 ))
387 }
388}
389
390#[cfg(test)]
391mod tests {
392 use rand::{distributions::Alphanumeric, thread_rng, Rng};
393 use smol_str::SmolStr;
394
395 use super::*;
396
397 impl PushPullMessage<SmolStr> {
398 fn random(size: usize) -> Self {
399 let mut status_ltimes = IndexMap::new();
400 for _ in 0..size {
401 let id = thread_rng()
402 .sample_iter(Alphanumeric)
403 .take(size)
404 .collect::<Vec<u8>>();
405 let id = String::from_utf8(id).unwrap().into();
406
407 status_ltimes.insert(id, LamportTime::random());
408 }
409
410 let mut left_members = IndexSet::new();
411 for _ in 0..size {
412 let id = thread_rng()
413 .sample_iter(Alphanumeric)
414 .take(size)
415 .collect::<Vec<u8>>();
416 let id = String::from_utf8(id).unwrap().into();
417 left_members.insert(id);
418 }
419
420 let mut events = TinyVec::new();
421 for i in 0..size {
422 if i % 2 == 0 {
423 events.push(None);
424 } else {
425 events.push(Some(UserEvents::random(size, size % 10)));
426 }
427 }
428
429 Self {
430 ltime: LamportTime::random(),
431 status_ltimes,
432 left_members,
433 event_ltime: LamportTime::random(),
434 events,
435 query_ltime: LamportTime::random(),
436 }
437 }
438 }
439
440 #[test]
441 fn test_push_pull_message_transform() {
442 futures::executor::block_on(async {
443 for i in 0..100 {
444 let msg = PushPullMessage::random(i);
445 let mut buf = vec![0; msg.encoded_len()];
446 let encoded_len = msg.encode(&mut buf).unwrap();
447 assert_eq!(encoded_len, msg.encoded_len());
448
449 let (decoded_len, decoded) = PushPullMessage::<SmolStr>::decode(&buf).unwrap();
450 assert_eq!(decoded_len, encoded_len);
451 assert_eq!(decoded, msg);
452
453 let (decoded_len, decoded) =
454 PushPullMessage::<SmolStr>::decode_from_reader(&mut std::io::Cursor::new(&buf)).unwrap();
455 assert_eq!(decoded_len, encoded_len);
456 assert_eq!(decoded, msg);
457
458 let (decoded_len, decoded) =
459 PushPullMessage::<SmolStr>::decode_from_async_reader(&mut futures::io::Cursor::new(&buf))
460 .await
461 .unwrap();
462 assert_eq!(decoded_len, encoded_len);
463 assert_eq!(decoded, msg);
464 }
465 });
466 }
467}