1use byteorder::{ByteOrder, NetworkEndian};
2use smol_str::SmolStr;
3use transformable::{
4 BytesTransformError, DurationTransformError, StringTransformError, Transformable,
5};
6
7use std::time::Duration;
8
9use memberlist_types::{bytes::Bytes, Node, NodeTransformError, TinyVec};
10
11use super::{LamportTime, LamportTimeTransformError};
12
13bitflags::bitflags! {
14 #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
16 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
17 #[cfg_attr(feature = "serde", serde(transparent))]
18 pub struct QueryFlag: u32 {
19 const ACK = 1 << 0;
21 const NO_BROADCAST = 1 << 1;
24 }
25}
26
27#[viewit::viewit(getters(style = "ref"), setters(prefix = "with"))]
29#[derive(Debug, Clone, Eq, PartialEq)]
30#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
31pub struct QueryMessage<I, A> {
32 #[viewit(
34 getter(const, style = "move", attrs(doc = "Returns the event lamport time")),
35 setter(const, attrs(doc = "Sets the event lamport time (Builder pattern)"))
36 )]
37 ltime: LamportTime,
38 #[viewit(
40 getter(const, style = "move", attrs(doc = "Returns the query id")),
41 setter(attrs(doc = "Sets the query id (Builder pattern)"))
42 )]
43 id: u32,
44 #[viewit(
46 getter(const, attrs(doc = "Returns the from node")),
47 setter(attrs(doc = "Sets the from node (Builder pattern)"))
48 )]
49 from: Node<I, A>,
50 #[viewit(
52 getter(const, attrs(doc = "Returns the potential query filters")),
53 setter(attrs(doc = "Sets the potential query filters (Builder pattern)"))
54 )]
55 filters: TinyVec<Bytes>,
56 #[viewit(
58 getter(const, style = "move", attrs(doc = "Returns the flags")),
59 setter(attrs(doc = "Sets the flags (Builder pattern)"))
60 )]
61 flags: QueryFlag,
62 #[viewit(
64 getter(
65 const,
66 style = "move",
67 attrs(doc = "Returns the number of duplicate relayed responses")
68 ),
69 setter(attrs(doc = "Sets the number of duplicate relayed responses (Builder pattern)"))
70 )]
71 relay_factor: u8,
72 #[viewit(
74 getter(
75 const,
76 style = "move",
77 attrs(doc = "Returns the maximum time between delivery and response")
78 ),
79 setter(attrs(doc = "Sets the maximum time between delivery and response (Builder pattern)"))
80 )]
81 timeout: Duration,
82 #[viewit(
84 getter(const, style = "ref", attrs(doc = "Returns the name of the query")),
85 setter(attrs(doc = "Sets the name of the query (Builder pattern)"))
86 )]
87 name: SmolStr,
88 #[viewit(
90 getter(const, style = "ref", attrs(doc = "Returns the payload")),
91 setter(attrs(doc = "Sets the payload (Builder pattern)"))
92 )]
93 payload: Bytes,
94}
95
96impl<I, A> QueryMessage<I, A> {
97 #[inline]
99 pub fn ack(&self) -> bool {
100 self.flags.contains(QueryFlag::ACK)
101 }
102
103 #[inline]
105 pub fn no_broadcast(&self) -> bool {
106 self.flags.contains(QueryFlag::NO_BROADCAST)
107 }
108}
109
110#[derive(thiserror::Error)]
112pub enum QueryMessageTransformError<I, A>
113where
114 I: Transformable,
115 A: Transformable,
116{
117 #[error("not enough bytes to decode QueryMessage")]
119 NotEnoughBytes,
120 #[error("encode buffer too small")]
122 BufferTooSmall,
123 #[error(transparent)]
125 From(#[from] NodeTransformError<I, A>),
126 #[error(transparent)]
128 LamportTime(#[from] LamportTimeTransformError),
129 #[error(transparent)]
131 Payload(BytesTransformError),
132
133 #[error(transparent)]
135 Filters(BytesTransformError),
136
137 #[error(transparent)]
139 Name(#[from] StringTransformError),
140
141 #[error(transparent)]
143 Timeout(#[from] DurationTransformError),
144}
145
146impl<I, A> core::fmt::Debug for QueryMessageTransformError<I, A>
147where
148 I: Transformable,
149 A: Transformable,
150{
151 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
152 write!(f, "{}", self)
153 }
154}
155
156impl<I, A> Transformable for QueryMessage<I, A>
157where
158 I: Transformable,
159 A: Transformable,
160{
161 type Error = QueryMessageTransformError<I, A>;
162
163 fn encode(&self, dst: &mut [u8]) -> Result<usize, Self::Error> {
164 let encoded_len = self.encoded_len();
165 if dst.len() < encoded_len {
166 return Err(Self::Error::BufferTooSmall);
167 }
168
169 let mut offset = 0;
170 NetworkEndian::write_u32(&mut dst[offset..], encoded_len as u32);
171 offset += 4;
172 offset += self.ltime.encode(&mut dst[offset..])?;
173 NetworkEndian::write_u32(&mut dst[offset..], self.id);
174 offset += 4;
175 offset += self.from.encode(&mut dst[offset..])?;
176 NetworkEndian::write_u32(&mut dst[offset..], self.filters.len() as u32);
177 offset += 4;
178 for filter in self.filters.iter() {
179 offset += filter
180 .encode(&mut dst[offset..])
181 .map_err(Self::Error::Filters)?;
182 }
183 NetworkEndian::write_u32(&mut dst[offset..], self.flags.bits());
184 offset += 4;
185 dst[offset] = self.relay_factor;
186 offset += 1;
187 offset += self.timeout.encode(&mut dst[offset..])?;
188 offset += self.name.encode(&mut dst[offset..])?;
189 offset += self
190 .payload
191 .encode(&mut dst[offset..])
192 .map_err(Self::Error::Payload)?;
193
194 debug_assert_eq!(
195 offset, encoded_len,
196 "expect write {} bytes, but actual write {} bytes",
197 encoded_len, offset
198 );
199
200 Ok(offset)
201 }
202
203 fn encoded_len(&self) -> usize {
204 4 + self.ltime.encoded_len()
205 + 4 + self.from.encoded_len()
207 + 4 + self.filters.iter().map(|f| f.encoded_len()).sum::<usize>()
209 + 4 + 1 + self.timeout.encoded_len()
212 + self.name.encoded_len()
213 + self.payload.encoded_len()
214 }
215
216 fn decode(src: &[u8]) -> Result<(usize, Self), Self::Error>
217 where
218 Self: Sized,
219 {
220 let src_len = src.len();
221 if src.len() < 4 {
222 return Err(Self::Error::NotEnoughBytes);
223 }
224
225 let mut offset = 0;
226 let len = NetworkEndian::read_u32(&src[offset..]) as usize;
227 if src.len() < len {
228 return Err(Self::Error::NotEnoughBytes);
229 }
230 offset += 4;
231
232 let (n, ltime) = LamportTime::decode(&src[offset..])?;
233 offset += n;
234
235 if offset + 4 > src_len {
236 return Err(Self::Error::NotEnoughBytes);
237 }
238
239 let id = NetworkEndian::read_u32(&src[offset..]);
240 offset += 4;
241
242 let (n, from) = Node::decode(&src[offset..])?;
243 offset += n;
244
245 if offset + 4 > src_len {
246 return Err(Self::Error::NotEnoughBytes);
247 }
248
249 let num_filters = NetworkEndian::read_u32(&src[offset..]) as usize;
250 offset += 4;
251
252 let mut filters = TinyVec::with_capacity(num_filters);
253 for _ in 0..num_filters {
254 let (n, filter) = Bytes::decode(&src[offset..]).map_err(Self::Error::Filters)?;
255 filters.push(filter);
256 offset += n;
257 }
258
259 if offset + 4 > src_len {
260 return Err(Self::Error::NotEnoughBytes);
261 }
262
263 let flags = QueryFlag::from_bits_retain(NetworkEndian::read_u32(&src[offset..]));
264 offset += 4;
265
266 if offset + 1 > src_len {
267 return Err(Self::Error::NotEnoughBytes);
268 }
269
270 let relay_factor = src[offset];
271 offset += 1;
272
273 let (n, timeout) = Duration::decode(&src[offset..])?;
274 offset += n;
275
276 let (n, name) = SmolStr::decode(&src[offset..])?;
277 offset += n;
278
279 let (n, payload) = Bytes::decode(&src[offset..]).map_err(Self::Error::Payload)?;
280 offset += n;
281
282 debug_assert_eq!(
283 offset, len,
284 "expect read {} bytes, but actual read {} bytes",
285 len, offset
286 );
287
288 Ok((
289 offset,
290 Self {
291 ltime,
292 id,
293 from,
294 filters,
295 flags,
296 relay_factor,
297 timeout,
298 name,
299 payload,
300 },
301 ))
302 }
303}
304
305#[viewit::viewit(getters(style = "ref"), setters(prefix = "with"))]
307#[derive(Debug, Clone, Eq, PartialEq)]
308#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
309pub struct QueryResponseMessage<I, A> {
310 #[viewit(
312 getter(const, attrs(doc = "Returns the lamport time for this message")),
313 setter(
314 const,
315 attrs(doc = "Sets the lamport time for this message (Builder pattern)")
316 )
317 )]
318 ltime: LamportTime,
319 #[viewit(
321 getter(const, attrs(doc = "Returns the query id")),
322 setter(attrs(doc = "Sets the query id (Builder pattern)"))
323 )]
324 id: u32,
325 #[viewit(
327 getter(const, attrs(doc = "Returns the from node")),
328 setter(attrs(doc = "Sets the from node (Builder pattern)"))
329 )]
330 from: Node<I, A>,
331 #[viewit(
333 getter(const, style = "ref", attrs(doc = "Returns the flags")),
334 setter(attrs(doc = "Sets the flags (Builder pattern)"))
335 )]
336 flags: QueryFlag,
337 #[viewit(
339 getter(const, style = "ref", attrs(doc = "Returns the payload")),
340 setter(attrs(doc = "Sets the payload (Builder pattern)"))
341 )]
342 payload: Bytes,
343}
344
345impl<I, A> QueryResponseMessage<I, A> {
346 #[inline]
348 pub fn ack(&self) -> bool {
349 self.flags.contains(QueryFlag::ACK)
350 }
351
352 #[inline]
354 pub fn no_broadcast(&self) -> bool {
355 self.flags.contains(QueryFlag::NO_BROADCAST)
356 }
357}
358
359#[derive(thiserror::Error)]
361pub enum QueryResponseMessageTransformError<I, A>
362where
363 I: Transformable,
364 A: Transformable,
365{
366 #[error("not enough bytes to decode QueryResponseMessage")]
368 NotEnoughBytes,
369 #[error("encode buffer too small")]
371 BufferTooSmall,
372 #[error(transparent)]
374 Node(#[from] NodeTransformError<I, A>),
375 #[error(transparent)]
377 LamportTime(#[from] LamportTimeTransformError),
378 #[error(transparent)]
380 Payload(#[from] BytesTransformError),
381}
382
383impl<I, A> core::fmt::Debug for QueryResponseMessageTransformError<I, A>
384where
385 I: Transformable,
386 A: Transformable,
387{
388 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
389 write!(f, "{}", self)
390 }
391}
392
393impl<I, A> Transformable for QueryResponseMessage<I, A>
394where
395 I: Transformable,
396 A: Transformable,
397{
398 type Error = QueryResponseMessageTransformError<I, A>;
399
400 fn encode(&self, dst: &mut [u8]) -> Result<usize, Self::Error> {
401 let encoded_len = self.encoded_len();
402 if dst.len() < encoded_len {
403 return Err(Self::Error::BufferTooSmall);
404 }
405
406 let mut offset = 0;
407 NetworkEndian::write_u32(&mut dst[offset..], encoded_len as u32);
408 offset += 4;
409 offset += self.ltime.encode(&mut dst[offset..])?;
410 NetworkEndian::write_u32(&mut dst[offset..], self.id);
411 offset += 4;
412 offset += self.from.encode(&mut dst[offset..])?;
413 NetworkEndian::write_u32(&mut dst[offset..], self.flags.bits());
414 offset += 4;
415 offset += self.payload.encode(&mut dst[offset..])?;
416
417 debug_assert_eq!(
418 offset, encoded_len,
419 "expect write {} bytes, but actual write {} bytes",
420 encoded_len, offset
421 );
422
423 Ok(offset)
424 }
425
426 fn encoded_len(&self) -> usize {
427 4 + self.ltime.encoded_len() + 4 + self.from.encoded_len() + 4 + self.payload.encoded_len()
428 }
429
430 fn decode(src: &[u8]) -> Result<(usize, Self), Self::Error>
431 where
432 Self: Sized,
433 {
434 let src_len = src.len();
435 if src.len() < 4 {
436 return Err(Self::Error::NotEnoughBytes);
437 }
438
439 let mut offset = 0;
440 let len = NetworkEndian::read_u32(&src[offset..]) as usize;
441 if src.len() < len {
442 return Err(Self::Error::NotEnoughBytes);
443 }
444
445 offset += 4;
446 let (n, ltime) = LamportTime::decode(&src[offset..])?;
447 offset += n;
448
449 if offset + 4 > src_len {
450 return Err(Self::Error::NotEnoughBytes);
451 }
452 let id = NetworkEndian::read_u32(&src[offset..]);
453 offset += 4;
454
455 let (n, from) = Node::decode(&src[offset..])?;
456 offset += n;
457
458 if offset + 4 > src_len {
459 return Err(Self::Error::NotEnoughBytes);
460 }
461
462 let flags = QueryFlag::from_bits_retain(NetworkEndian::read_u32(&src[offset..]));
463 offset += 4;
464
465 let (n, payload) = Bytes::decode(&src[offset..])?;
466 offset += n;
467
468 debug_assert_eq!(
469 offset, len,
470 "expect read {} bytes, but actual read {} bytes",
471 len, offset
472 );
473
474 Ok((
475 offset,
476 Self {
477 ltime,
478 id,
479 from,
480 flags,
481 payload,
482 },
483 ))
484 }
485}
486
487#[cfg(test)]
488mod tests {
489 use std::net::SocketAddr;
490
491 use rand::{distributions::Alphanumeric, random, thread_rng, Rng};
492
493 use super::*;
494
495 impl QueryMessage<SmolStr, SocketAddr> {
496 fn random(size: usize, num_filters: usize) -> Self {
497 let ltime = LamportTime::random();
498 let id = random();
499 let from_id = thread_rng()
500 .sample_iter(Alphanumeric)
501 .take(size)
502 .collect::<Vec<u8>>();
503 let from_id = String::from_utf8(from_id).unwrap().into();
504 let addr = SocketAddr::from(([127, 0, 0, 1], random::<u16>()));
505 let from = Node::new(from_id, addr);
506 let filters = (0..num_filters)
507 .map(|_| {
508 let payload = thread_rng()
509 .sample_iter(Alphanumeric)
510 .take(size)
511 .collect::<Vec<u8>>();
512 payload.into()
513 })
514 .collect();
515 let flags = QueryFlag::empty();
516 let relay_factor = random();
517 let timeout = Duration::from_secs(random::<u64>());
518 let name = thread_rng()
519 .sample_iter(Alphanumeric)
520 .take(size)
521 .collect::<Vec<u8>>();
522 let name = SmolStr::from(String::from_utf8(name).unwrap());
523 let payload = thread_rng()
524 .sample_iter(Alphanumeric)
525 .take(size)
526 .collect::<Vec<u8>>();
527 let payload = Bytes::from(payload);
528 Self {
529 ltime,
530 id,
531 from,
532 filters,
533 flags,
534 relay_factor,
535 timeout,
536 name,
537 payload,
538 }
539 }
540 }
541
542 impl QueryResponseMessage<SmolStr, SocketAddr> {
543 fn random(size: usize) -> Self {
544 let id = rand::random();
545
546 let from_id = thread_rng()
547 .sample_iter(Alphanumeric)
548 .take(size)
549 .collect::<Vec<u8>>();
550 let from_id = String::from_utf8(from_id).unwrap().into();
551 let addr = SocketAddr::from(([127, 0, 0, 1], random::<u16>()));
552 let from = Node::new(from_id, addr);
553 let flags = QueryFlag::empty();
554 let payload = thread_rng()
555 .sample_iter(Alphanumeric)
556 .take(size)
557 .collect::<Vec<u8>>();
558 Self {
559 ltime: LamportTime::random(),
560 id,
561 from,
562 flags,
563 payload: payload.into(),
564 }
565 }
566 }
567
568 #[test]
569 fn test_query_response_transform() {
570 futures::executor::block_on(async {
571 for i in 0..100 {
572 let filter = QueryResponseMessage::random(i);
573 let mut buf = vec![0; filter.encoded_len()];
574 let encoded_len = filter.encode(&mut buf).unwrap();
575 assert_eq!(encoded_len, filter.encoded_len());
576
577 let (decoded_len, decoded) =
578 QueryResponseMessage::<SmolStr, SocketAddr>::decode(&buf).unwrap();
579 assert_eq!(decoded_len, encoded_len);
580 assert_eq!(decoded, filter);
581
582 let (decoded_len, decoded) =
583 QueryResponseMessage::<SmolStr, SocketAddr>::decode_from_reader(
584 &mut std::io::Cursor::new(&buf),
585 )
586 .unwrap();
587 assert_eq!(decoded_len, encoded_len);
588 assert_eq!(decoded, filter);
589
590 let (decoded_len, decoded) =
591 QueryResponseMessage::<SmolStr, SocketAddr>::decode_from_async_reader(
592 &mut futures::io::Cursor::new(&buf),
593 )
594 .await
595 .unwrap();
596 assert_eq!(decoded_len, encoded_len);
597 assert_eq!(decoded, filter);
598 }
599 });
600 }
601
602 #[test]
603 fn test_query_message_transform() {
604 futures::executor::block_on(async {
605 for i in 0..100 {
606 let filter = QueryMessage::random(i, i % 10);
607 let mut buf = vec![0; filter.encoded_len()];
608 let encoded_len = filter.encode(&mut buf).unwrap();
609 assert_eq!(encoded_len, filter.encoded_len());
610
611 let (decoded_len, decoded) = QueryMessage::<SmolStr, SocketAddr>::decode(&buf).unwrap();
612 assert_eq!(decoded_len, encoded_len);
613 assert_eq!(decoded, filter);
614
615 let (decoded_len, decoded) =
616 QueryMessage::<SmolStr, SocketAddr>::decode_from_reader(&mut std::io::Cursor::new(&buf))
617 .unwrap();
618 assert_eq!(decoded_len, encoded_len);
619 assert_eq!(decoded, filter);
620
621 let (decoded_len, decoded) = QueryMessage::<SmolStr, SocketAddr>::decode_from_async_reader(
622 &mut futures::io::Cursor::new(&buf),
623 )
624 .await
625 .unwrap();
626 assert_eq!(decoded_len, encoded_len);
627 assert_eq!(decoded, filter);
628 }
629 });
630 }
631}