serf_types/
query.rs

1use byteorder::{ByteOrder, NetworkEndian};
2use smol_str::SmolStr;
3use transformable::{
4  BytesTransformError, DurationTransformError, StringTransformError, Transformable,
5};
6
7use std::time::Duration;
8
9use memberlist_types::{bytes::Bytes, Node, NodeTransformError, TinyVec};
10
11use super::{LamportTime, LamportTimeTransformError};
12
13bitflags::bitflags! {
14  /// Flags for query message
15  #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
16  #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
17  #[cfg_attr(feature = "serde", serde(transparent))]
18  pub struct QueryFlag: u32 {
19    /// Ack flag is used to force receiver to send an ack back
20    const ACK = 1 << 0;
21    /// NoBroadcast is used to prevent re-broadcast of a query.
22    /// this can be used to selectively send queries to individual members
23    const NO_BROADCAST = 1 << 1;
24  }
25}
26
27/// Query message
28#[viewit::viewit(getters(style = "ref"), setters(prefix = "with"))]
29#[derive(Debug, Clone, Eq, PartialEq)]
30#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
31pub struct QueryMessage<I, A> {
32  /// Event lamport time
33  #[viewit(
34    getter(const, style = "move", attrs(doc = "Returns the event lamport time")),
35    setter(const, attrs(doc = "Sets the event lamport time (Builder pattern)"))
36  )]
37  ltime: LamportTime,
38  /// query id, randomly generated
39  #[viewit(
40    getter(const, style = "move", attrs(doc = "Returns the query id")),
41    setter(attrs(doc = "Sets the query id (Builder pattern)"))
42  )]
43  id: u32,
44  /// source node
45  #[viewit(
46    getter(const, attrs(doc = "Returns the from node")),
47    setter(attrs(doc = "Sets the from node (Builder pattern)"))
48  )]
49  from: Node<I, A>,
50  /// Potential query filters
51  #[viewit(
52    getter(const, attrs(doc = "Returns the potential query filters")),
53    setter(attrs(doc = "Sets the potential query filters (Builder pattern)"))
54  )]
55  filters: TinyVec<Bytes>,
56  /// Used to provide various flags
57  #[viewit(
58    getter(const, style = "move", attrs(doc = "Returns the flags")),
59    setter(attrs(doc = "Sets the flags (Builder pattern)"))
60  )]
61  flags: QueryFlag,
62  /// Used to set the number of duplicate relayed responses
63  #[viewit(
64    getter(
65      const,
66      style = "move",
67      attrs(doc = "Returns the number of duplicate relayed responses")
68    ),
69    setter(attrs(doc = "Sets the number of duplicate relayed responses (Builder pattern)"))
70  )]
71  relay_factor: u8,
72  /// Maximum time between delivery and response
73  #[viewit(
74    getter(
75      const,
76      style = "move",
77      attrs(doc = "Returns the maximum time between delivery and response")
78    ),
79    setter(attrs(doc = "Sets the maximum time between delivery and response (Builder pattern)"))
80  )]
81  timeout: Duration,
82  /// Query nqme
83  #[viewit(
84    getter(const, style = "ref", attrs(doc = "Returns the name of the query")),
85    setter(attrs(doc = "Sets the name of the query (Builder pattern)"))
86  )]
87  name: SmolStr,
88  /// Query payload
89  #[viewit(
90    getter(const, style = "ref", attrs(doc = "Returns the payload")),
91    setter(attrs(doc = "Sets the payload (Builder pattern)"))
92  )]
93  payload: Bytes,
94}
95
96impl<I, A> QueryMessage<I, A> {
97  /// Checks if the ack flag is set
98  #[inline]
99  pub fn ack(&self) -> bool {
100    self.flags.contains(QueryFlag::ACK)
101  }
102
103  /// Checks if the no broadcast flag is set
104  #[inline]
105  pub fn no_broadcast(&self) -> bool {
106    self.flags.contains(QueryFlag::NO_BROADCAST)
107  }
108}
109
110/// Error that can occur when transforming a [`QueryMessage`].
111#[derive(thiserror::Error)]
112pub enum QueryMessageTransformError<I, A>
113where
114  I: Transformable,
115  A: Transformable,
116{
117  /// Not enough bytes to decode QueryMessage
118  #[error("not enough bytes to decode QueryMessage")]
119  NotEnoughBytes,
120  /// Encode buffer too small
121  #[error("encode buffer too small")]
122  BufferTooSmall,
123  /// Error transforming `from` field
124  #[error(transparent)]
125  From(#[from] NodeTransformError<I, A>),
126  /// Error transforming `ltime` field
127  #[error(transparent)]
128  LamportTime(#[from] LamportTimeTransformError),
129  /// Error transforming `payload` field
130  #[error(transparent)]
131  Payload(BytesTransformError),
132
133  /// Error transforming `filters` field
134  #[error(transparent)]
135  Filters(BytesTransformError),
136
137  /// Error transforming `name` field
138  #[error(transparent)]
139  Name(#[from] StringTransformError),
140
141  /// Error transforming `timeout` field
142  #[error(transparent)]
143  Timeout(#[from] DurationTransformError),
144}
145
146impl<I, A> core::fmt::Debug for QueryMessageTransformError<I, A>
147where
148  I: Transformable,
149  A: Transformable,
150{
151  fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
152    write!(f, "{}", self)
153  }
154}
155
156impl<I, A> Transformable for QueryMessage<I, A>
157where
158  I: Transformable,
159  A: Transformable,
160{
161  type Error = QueryMessageTransformError<I, A>;
162
163  fn encode(&self, dst: &mut [u8]) -> Result<usize, Self::Error> {
164    let encoded_len = self.encoded_len();
165    if dst.len() < encoded_len {
166      return Err(Self::Error::BufferTooSmall);
167    }
168
169    let mut offset = 0;
170    NetworkEndian::write_u32(&mut dst[offset..], encoded_len as u32);
171    offset += 4;
172    offset += self.ltime.encode(&mut dst[offset..])?;
173    NetworkEndian::write_u32(&mut dst[offset..], self.id);
174    offset += 4;
175    offset += self.from.encode(&mut dst[offset..])?;
176    NetworkEndian::write_u32(&mut dst[offset..], self.filters.len() as u32);
177    offset += 4;
178    for filter in self.filters.iter() {
179      offset += filter
180        .encode(&mut dst[offset..])
181        .map_err(Self::Error::Filters)?;
182    }
183    NetworkEndian::write_u32(&mut dst[offset..], self.flags.bits());
184    offset += 4;
185    dst[offset] = self.relay_factor;
186    offset += 1;
187    offset += self.timeout.encode(&mut dst[offset..])?;
188    offset += self.name.encode(&mut dst[offset..])?;
189    offset += self
190      .payload
191      .encode(&mut dst[offset..])
192      .map_err(Self::Error::Payload)?;
193
194    debug_assert_eq!(
195      offset, encoded_len,
196      "expect write {} bytes, but actual write {} bytes",
197      encoded_len, offset
198    );
199
200    Ok(offset)
201  }
202
203  fn encoded_len(&self) -> usize {
204    4 + self.ltime.encoded_len()
205      + 4 // id
206      + self.from.encoded_len()
207      + 4 // num filters
208      + self.filters.iter().map(|f| f.encoded_len()).sum::<usize>()
209      + 4 // flags
210      + 1 // relay_factor
211      + self.timeout.encoded_len()
212      + self.name.encoded_len()
213      + self.payload.encoded_len()
214  }
215
216  fn decode(src: &[u8]) -> Result<(usize, Self), Self::Error>
217  where
218    Self: Sized,
219  {
220    let src_len = src.len();
221    if src.len() < 4 {
222      return Err(Self::Error::NotEnoughBytes);
223    }
224
225    let mut offset = 0;
226    let len = NetworkEndian::read_u32(&src[offset..]) as usize;
227    if src.len() < len {
228      return Err(Self::Error::NotEnoughBytes);
229    }
230    offset += 4;
231
232    let (n, ltime) = LamportTime::decode(&src[offset..])?;
233    offset += n;
234
235    if offset + 4 > src_len {
236      return Err(Self::Error::NotEnoughBytes);
237    }
238
239    let id = NetworkEndian::read_u32(&src[offset..]);
240    offset += 4;
241
242    let (n, from) = Node::decode(&src[offset..])?;
243    offset += n;
244
245    if offset + 4 > src_len {
246      return Err(Self::Error::NotEnoughBytes);
247    }
248
249    let num_filters = NetworkEndian::read_u32(&src[offset..]) as usize;
250    offset += 4;
251
252    let mut filters = TinyVec::with_capacity(num_filters);
253    for _ in 0..num_filters {
254      let (n, filter) = Bytes::decode(&src[offset..]).map_err(Self::Error::Filters)?;
255      filters.push(filter);
256      offset += n;
257    }
258
259    if offset + 4 > src_len {
260      return Err(Self::Error::NotEnoughBytes);
261    }
262
263    let flags = QueryFlag::from_bits_retain(NetworkEndian::read_u32(&src[offset..]));
264    offset += 4;
265
266    if offset + 1 > src_len {
267      return Err(Self::Error::NotEnoughBytes);
268    }
269
270    let relay_factor = src[offset];
271    offset += 1;
272
273    let (n, timeout) = Duration::decode(&src[offset..])?;
274    offset += n;
275
276    let (n, name) = SmolStr::decode(&src[offset..])?;
277    offset += n;
278
279    let (n, payload) = Bytes::decode(&src[offset..]).map_err(Self::Error::Payload)?;
280    offset += n;
281
282    debug_assert_eq!(
283      offset, len,
284      "expect read {} bytes, but actual read {} bytes",
285      len, offset
286    );
287
288    Ok((
289      offset,
290      Self {
291        ltime,
292        id,
293        from,
294        filters,
295        flags,
296        relay_factor,
297        timeout,
298        name,
299        payload,
300      },
301    ))
302  }
303}
304
305/// Query response message
306#[viewit::viewit(getters(style = "ref"), setters(prefix = "with"))]
307#[derive(Debug, Clone, Eq, PartialEq)]
308#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
309pub struct QueryResponseMessage<I, A> {
310  /// Event lamport time
311  #[viewit(
312    getter(const, attrs(doc = "Returns the lamport time for this message")),
313    setter(
314      const,
315      attrs(doc = "Sets the lamport time for this message (Builder pattern)")
316    )
317  )]
318  ltime: LamportTime,
319  /// query id
320  #[viewit(
321    getter(const, attrs(doc = "Returns the query id")),
322    setter(attrs(doc = "Sets the query id (Builder pattern)"))
323  )]
324  id: u32,
325  /// node
326  #[viewit(
327    getter(const, attrs(doc = "Returns the from node")),
328    setter(attrs(doc = "Sets the from node (Builder pattern)"))
329  )]
330  from: Node<I, A>,
331  /// Used to provide various flags
332  #[viewit(
333    getter(const, style = "ref", attrs(doc = "Returns the flags")),
334    setter(attrs(doc = "Sets the flags (Builder pattern)"))
335  )]
336  flags: QueryFlag,
337  /// Optional response payload
338  #[viewit(
339    getter(const, style = "ref", attrs(doc = "Returns the payload")),
340    setter(attrs(doc = "Sets the payload (Builder pattern)"))
341  )]
342  payload: Bytes,
343}
344
345impl<I, A> QueryResponseMessage<I, A> {
346  /// Checks if the ack flag is set
347  #[inline]
348  pub fn ack(&self) -> bool {
349    self.flags.contains(QueryFlag::ACK)
350  }
351
352  /// Checks if the no broadcast flag is set
353  #[inline]
354  pub fn no_broadcast(&self) -> bool {
355    self.flags.contains(QueryFlag::NO_BROADCAST)
356  }
357}
358
359/// Error that can occur when transforming a [`QueryResponseMessage`].
360#[derive(thiserror::Error)]
361pub enum QueryResponseMessageTransformError<I, A>
362where
363  I: Transformable,
364  A: Transformable,
365{
366  /// Not enough bytes to decode QueryResponseMessage
367  #[error("not enough bytes to decode QueryResponseMessage")]
368  NotEnoughBytes,
369  /// Encode buffer too small
370  #[error("encode buffer too small")]
371  BufferTooSmall,
372  /// Error transforming Node
373  #[error(transparent)]
374  Node(#[from] NodeTransformError<I, A>),
375  /// Error transforming LamportTime
376  #[error(transparent)]
377  LamportTime(#[from] LamportTimeTransformError),
378  /// Error transforming payload
379  #[error(transparent)]
380  Payload(#[from] BytesTransformError),
381}
382
383impl<I, A> core::fmt::Debug for QueryResponseMessageTransformError<I, A>
384where
385  I: Transformable,
386  A: Transformable,
387{
388  fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
389    write!(f, "{}", self)
390  }
391}
392
393impl<I, A> Transformable for QueryResponseMessage<I, A>
394where
395  I: Transformable,
396  A: Transformable,
397{
398  type Error = QueryResponseMessageTransformError<I, A>;
399
400  fn encode(&self, dst: &mut [u8]) -> Result<usize, Self::Error> {
401    let encoded_len = self.encoded_len();
402    if dst.len() < encoded_len {
403      return Err(Self::Error::BufferTooSmall);
404    }
405
406    let mut offset = 0;
407    NetworkEndian::write_u32(&mut dst[offset..], encoded_len as u32);
408    offset += 4;
409    offset += self.ltime.encode(&mut dst[offset..])?;
410    NetworkEndian::write_u32(&mut dst[offset..], self.id);
411    offset += 4;
412    offset += self.from.encode(&mut dst[offset..])?;
413    NetworkEndian::write_u32(&mut dst[offset..], self.flags.bits());
414    offset += 4;
415    offset += self.payload.encode(&mut dst[offset..])?;
416
417    debug_assert_eq!(
418      offset, encoded_len,
419      "expect write {} bytes, but actual write {} bytes",
420      encoded_len, offset
421    );
422
423    Ok(offset)
424  }
425
426  fn encoded_len(&self) -> usize {
427    4 + self.ltime.encoded_len() + 4 + self.from.encoded_len() + 4 + self.payload.encoded_len()
428  }
429
430  fn decode(src: &[u8]) -> Result<(usize, Self), Self::Error>
431  where
432    Self: Sized,
433  {
434    let src_len = src.len();
435    if src.len() < 4 {
436      return Err(Self::Error::NotEnoughBytes);
437    }
438
439    let mut offset = 0;
440    let len = NetworkEndian::read_u32(&src[offset..]) as usize;
441    if src.len() < len {
442      return Err(Self::Error::NotEnoughBytes);
443    }
444
445    offset += 4;
446    let (n, ltime) = LamportTime::decode(&src[offset..])?;
447    offset += n;
448
449    if offset + 4 > src_len {
450      return Err(Self::Error::NotEnoughBytes);
451    }
452    let id = NetworkEndian::read_u32(&src[offset..]);
453    offset += 4;
454
455    let (n, from) = Node::decode(&src[offset..])?;
456    offset += n;
457
458    if offset + 4 > src_len {
459      return Err(Self::Error::NotEnoughBytes);
460    }
461
462    let flags = QueryFlag::from_bits_retain(NetworkEndian::read_u32(&src[offset..]));
463    offset += 4;
464
465    let (n, payload) = Bytes::decode(&src[offset..])?;
466    offset += n;
467
468    debug_assert_eq!(
469      offset, len,
470      "expect read {} bytes, but actual read {} bytes",
471      len, offset
472    );
473
474    Ok((
475      offset,
476      Self {
477        ltime,
478        id,
479        from,
480        flags,
481        payload,
482      },
483    ))
484  }
485}
486
487#[cfg(test)]
488mod tests {
489  use std::net::SocketAddr;
490
491  use rand::{distributions::Alphanumeric, random, thread_rng, Rng};
492
493  use super::*;
494
495  impl QueryMessage<SmolStr, SocketAddr> {
496    fn random(size: usize, num_filters: usize) -> Self {
497      let ltime = LamportTime::random();
498      let id = random();
499      let from_id = thread_rng()
500        .sample_iter(Alphanumeric)
501        .take(size)
502        .collect::<Vec<u8>>();
503      let from_id = String::from_utf8(from_id).unwrap().into();
504      let addr = SocketAddr::from(([127, 0, 0, 1], random::<u16>()));
505      let from = Node::new(from_id, addr);
506      let filters = (0..num_filters)
507        .map(|_| {
508          let payload = thread_rng()
509            .sample_iter(Alphanumeric)
510            .take(size)
511            .collect::<Vec<u8>>();
512          payload.into()
513        })
514        .collect();
515      let flags = QueryFlag::empty();
516      let relay_factor = random();
517      let timeout = Duration::from_secs(random::<u64>());
518      let name = thread_rng()
519        .sample_iter(Alphanumeric)
520        .take(size)
521        .collect::<Vec<u8>>();
522      let name = SmolStr::from(String::from_utf8(name).unwrap());
523      let payload = thread_rng()
524        .sample_iter(Alphanumeric)
525        .take(size)
526        .collect::<Vec<u8>>();
527      let payload = Bytes::from(payload);
528      Self {
529        ltime,
530        id,
531        from,
532        filters,
533        flags,
534        relay_factor,
535        timeout,
536        name,
537        payload,
538      }
539    }
540  }
541
542  impl QueryResponseMessage<SmolStr, SocketAddr> {
543    fn random(size: usize) -> Self {
544      let id = rand::random();
545
546      let from_id = thread_rng()
547        .sample_iter(Alphanumeric)
548        .take(size)
549        .collect::<Vec<u8>>();
550      let from_id = String::from_utf8(from_id).unwrap().into();
551      let addr = SocketAddr::from(([127, 0, 0, 1], random::<u16>()));
552      let from = Node::new(from_id, addr);
553      let flags = QueryFlag::empty();
554      let payload = thread_rng()
555        .sample_iter(Alphanumeric)
556        .take(size)
557        .collect::<Vec<u8>>();
558      Self {
559        ltime: LamportTime::random(),
560        id,
561        from,
562        flags,
563        payload: payload.into(),
564      }
565    }
566  }
567
568  #[test]
569  fn test_query_response_transform() {
570    futures::executor::block_on(async {
571      for i in 0..100 {
572        let filter = QueryResponseMessage::random(i);
573        let mut buf = vec![0; filter.encoded_len()];
574        let encoded_len = filter.encode(&mut buf).unwrap();
575        assert_eq!(encoded_len, filter.encoded_len());
576
577        let (decoded_len, decoded) =
578          QueryResponseMessage::<SmolStr, SocketAddr>::decode(&buf).unwrap();
579        assert_eq!(decoded_len, encoded_len);
580        assert_eq!(decoded, filter);
581
582        let (decoded_len, decoded) =
583          QueryResponseMessage::<SmolStr, SocketAddr>::decode_from_reader(
584            &mut std::io::Cursor::new(&buf),
585          )
586          .unwrap();
587        assert_eq!(decoded_len, encoded_len);
588        assert_eq!(decoded, filter);
589
590        let (decoded_len, decoded) =
591          QueryResponseMessage::<SmolStr, SocketAddr>::decode_from_async_reader(
592            &mut futures::io::Cursor::new(&buf),
593          )
594          .await
595          .unwrap();
596        assert_eq!(decoded_len, encoded_len);
597        assert_eq!(decoded, filter);
598      }
599    });
600  }
601
602  #[test]
603  fn test_query_message_transform() {
604    futures::executor::block_on(async {
605      for i in 0..100 {
606        let filter = QueryMessage::random(i, i % 10);
607        let mut buf = vec![0; filter.encoded_len()];
608        let encoded_len = filter.encode(&mut buf).unwrap();
609        assert_eq!(encoded_len, filter.encoded_len());
610
611        let (decoded_len, decoded) = QueryMessage::<SmolStr, SocketAddr>::decode(&buf).unwrap();
612        assert_eq!(decoded_len, encoded_len);
613        assert_eq!(decoded, filter);
614
615        let (decoded_len, decoded) =
616          QueryMessage::<SmolStr, SocketAddr>::decode_from_reader(&mut std::io::Cursor::new(&buf))
617            .unwrap();
618        assert_eq!(decoded_len, encoded_len);
619        assert_eq!(decoded, filter);
620
621        let (decoded_len, decoded) = QueryMessage::<SmolStr, SocketAddr>::decode_from_async_reader(
622          &mut futures::io::Cursor::new(&buf),
623        )
624        .await
625        .unwrap();
626        assert_eq!(decoded_len, encoded_len);
627        assert_eq!(decoded, filter);
628      }
629    });
630  }
631}