1use std::{
2 collections::HashSet,
3 sync::Arc,
4 time::{Duration, Instant},
5};
6
7use async_channel::{Receiver, Sender};
8use async_lock::RwLock;
9use futures::{stream::FuturesUnordered, FutureExt, StreamExt};
10use memberlist_core::{
11 bytes::{BufMut, Bytes, BytesMut},
12 tracing,
13 transport::{AddressResolver, Id, Node, Transport},
14 types::{OneOrMore, SmallVec, TinyVec},
15 CheapClone,
16};
17
18use crate::{
19 delegate::{Delegate, TransformDelegate},
20 error::Error,
21 types::{
22 Filter, LamportTime, Member, MemberStatus, MessageType, QueryMessage, QueryResponseMessage,
23 },
24};
25
26use super::Serf;
27
28#[viewit::viewit(
31 vis_all = "pub(crate)",
32 getters(vis_all = "pub", style = "ref"),
33 setters(vis_all = "pub", prefix = "with")
34)]
35#[derive(Debug, Clone)]
36#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
37pub struct QueryParam<I> {
38 #[viewit(
40 getter(const, attrs(doc = "Returns the filters of the query")),
41 setter(attrs(doc = "Sets the filters of the query"))
42 )]
43 filters: OneOrMore<Filter<I>>,
44
45 #[viewit(
50 getter(
51 const,
52 style = "move",
53 attrs(
54 doc = "Returns if we are requesting an delivery acknowledgement from every node that meets the filter requirement. This means nodes the receive the message but do not pass the filters, will not send an ack."
55 )
56 ),
57 setter(attrs(
58 doc = "Sets if we are requesting an delivery acknowledgement from every node that meets the filter requirement. This means nodes the receive the message but do not pass the filters, will not send an ack."
59 ))
60 )]
61 request_ack: bool,
62
63 #[viewit(
66 getter(
67 const,
68 style = "move",
69 attrs(
70 doc = "Returns the number of duplicate responses to relay back to the sender through other nodes for redundancy."
71 )
72 ),
73 setter(attrs(
74 doc = "Sets the number of duplicate responses to relay back to the sender through other nodes for redundancy."
75 ))
76 )]
77 relay_factor: u8,
78
79 #[viewit(
82 getter(
83 const,
84 style = "move",
85 attrs(
86 doc = "Returns timeout limits how long the query is left open. If not provided, then a default timeout is used based on the configuration of [`Serf`]"
87 )
88 ),
89 setter(attrs(doc = "Sets timeout limits how long the query is left open."))
90 )]
91 #[cfg_attr(feature = "serde", serde(with = "humantime_serde"))]
92 timeout: Duration,
93}
94
95impl<I> QueryParam<I>
96where
97 I: Id,
98{
99 pub(crate) fn encode_filters<W: TransformDelegate<Id = I>>(
101 &self,
102 ) -> Result<TinyVec<Bytes>, W::Error> {
103 let mut filters = TinyVec::with_capacity(self.filters.len());
104 for filter in self.filters.iter() {
105 filters.push(W::encode_filter(filter)?);
106 }
107
108 Ok(filters)
109 }
110}
111
112struct QueryResponseChannel<I, A> {
113 ack_ch: Option<(Sender<Node<I, A>>, Receiver<Node<I, A>>)>,
115 resp_ch: (Sender<NodeResponse<I, A>>, Receiver<NodeResponse<I, A>>),
117}
118
119pub(crate) struct QueryResponseCore<I, A> {
120 closed: bool,
121 acks: HashSet<Node<I, A>>,
122 responses: HashSet<Node<I, A>>,
123}
124
125pub(crate) struct QueryResponseInner<I, A> {
126 core: RwLock<QueryResponseCore<I, A>>,
127 channel: QueryResponseChannel<I, A>,
128}
129
130#[viewit::viewit(vis_all = "pub(crate)")]
133#[derive(Clone)]
134pub struct QueryResponse<I, A> {
135 #[viewit(
137 getter(
138 style = "move",
139 const,
140 attrs(doc = "Returns the ending deadline of the query")
141 ),
142 setter(skip)
143 )]
144 deadline: Instant,
145
146 #[viewit(
148 getter(style = "move", const, attrs(doc = "Returns the id of the query")),
149 setter(skip)
150 )]
151 id: u32,
152
153 #[viewit(
155 getter(
156 style = "move",
157 const,
158 attrs(doc = "Returns the Lamport Time of the query")
159 ),
160 setter(skip)
161 )]
162 ltime: LamportTime,
163
164 #[viewit(getter(vis = "pub(crate)", const, style = "ref"), setter(skip))]
165 inner: Arc<QueryResponseInner<I, A>>,
166}
167
168impl<I, A> QueryResponse<I, A> {
169 pub(crate) fn from_query(q: &QueryMessage<I, A>, num_nodes: usize) -> Self {
170 QueryResponse::new(
171 q.id(),
172 q.ltime(),
173 num_nodes,
174 Instant::now() + q.timeout(),
175 q.ack(),
176 )
177 }
178}
179
180impl<I, A> QueryResponse<I, A> {
181 #[inline]
182 pub(crate) fn new(
183 id: u32,
184 ltime: LamportTime,
185 num_nodes: usize,
186 deadline: Instant,
187 ack: bool,
188 ) -> Self {
189 let (ack_ch, acks) = if ack {
190 (
191 Some(async_channel::bounded(num_nodes)),
192 HashSet::with_capacity(num_nodes),
193 )
194 } else {
195 (None, HashSet::new())
196 };
197
198 Self {
199 deadline,
200 id,
201 ltime,
202 inner: Arc::new(QueryResponseInner {
203 core: RwLock::new(QueryResponseCore {
204 closed: false,
205 acks,
206 responses: HashSet::with_capacity(num_nodes),
207 }),
208 channel: QueryResponseChannel {
209 ack_ch,
210 resp_ch: async_channel::bounded(num_nodes),
211 },
212 }),
213 }
214 }
215
216 #[inline]
220 pub fn ack_rx(&self) -> Option<async_channel::Receiver<Node<I, A>>> {
221 self.inner.channel.ack_ch.as_ref().map(|(_, r)| r.clone())
222 }
223
224 #[inline]
227 pub fn response_rx(&self) -> async_channel::Receiver<NodeResponse<I, A>> {
228 self.inner.channel.resp_ch.1.clone()
229 }
230
231 #[inline]
233 pub async fn finished(&self) -> bool {
234 let c = self.inner.core.read().await;
235 c.closed || (Instant::now() > self.deadline)
236 }
237
238 #[inline]
241 pub async fn close(&self) {
242 let mut c = self.inner.core.write().await;
243 if c.closed {
244 return;
245 }
246
247 c.closed = true;
248
249 if let Some((tx, _)) = &self.inner.channel.ack_ch {
250 tx.close();
251 }
252
253 self.inner.channel.resp_ch.0.close();
254 }
255
256 #[inline]
257 pub(crate) async fn handle_query_response<T, D>(
258 &self,
259 resp: QueryResponseMessage<I, A>,
260 _local: &T::Id,
261 #[cfg(feature = "metrics")] metrics_labels: &memberlist_core::types::MetricLabels,
262 ) where
263 I: Eq + std::hash::Hash + CheapClone + core::fmt::Debug,
264 A: Eq + std::hash::Hash + CheapClone + core::fmt::Debug,
265 D: Delegate<Id = T::Id, Address = <T::Resolver as AddressResolver>::ResolvedAddress>,
266 T: Transport,
267 {
268 let c = self.inner.core.read().await;
270 if c.closed || (Instant::now() > self.deadline) {
271 return;
272 }
273
274 if resp.ack() {
276 if c.acks.contains(&resp.from) {
278 #[cfg(feature = "metrics")]
279 {
280 metrics::counter!("serf.query.duplicate_acks", metrics_labels.iter()).increment(1);
281 }
282 return;
283 }
284
285 #[cfg(feature = "metrics")]
286 {
287 metrics::counter!("serf.query.acks", metrics_labels.iter()).increment(1);
288 }
289
290 drop(c);
291 if let Err(e) = self.send_ack::<T, D>(&resp).await {
292 tracing::warn!("serf: {}", e);
293 }
294 } else {
295 if c.responses.contains(&resp.from) {
297 #[cfg(feature = "metrics")]
298 {
299 metrics::counter!("serf.query.duplicate_responses", metrics_labels.iter()).increment(1);
300 }
301 return;
302 }
303
304 #[cfg(feature = "metrics")]
305 {
306 metrics::counter!("serf.query.responses", metrics_labels.iter()).increment(1);
307 }
308 drop(c);
309
310 if let Err(e) = self
311 .send_response::<T, D>(NodeResponse {
312 from: resp.from,
313 payload: resp.payload,
314 })
315 .await
316 {
317 tracing::warn!("serf: {}", e);
318 }
319 }
320 }
321
322 #[inline]
324 pub(crate) async fn send_response<T, D>(&self, nr: NodeResponse<I, A>) -> Result<(), Error<T, D>>
325 where
326 I: Eq + std::hash::Hash + CheapClone + core::fmt::Debug,
327 A: Eq + std::hash::Hash + CheapClone + core::fmt::Debug,
328 D: Delegate<Id = T::Id, Address = <T::Resolver as AddressResolver>::ResolvedAddress>,
329 T: Transport,
330 {
331 let mut c = self.inner.core.write().await;
332 if c.responses.contains(&nr.from) {
334 return Ok(());
335 }
336
337 if c.closed {
338 Ok(())
339 } else {
340 let id = nr.from.cheap_clone();
341 futures::select! {
342 _ = self.inner.channel.resp_ch.0.send(nr).fuse() => {
343 c.responses.insert(id);
344 Ok(())
345 },
346 default => {
347 Err(Error::query_response_delivery_failed())
348 }
349 }
350 }
351 }
352
353 #[inline]
355 pub(crate) async fn send_ack<T, D>(
356 &self,
357 nr: &QueryResponseMessage<I, A>,
358 ) -> Result<(), Error<T, D>>
359 where
360 I: Eq + std::hash::Hash + CheapClone,
361 A: Eq + std::hash::Hash + CheapClone,
362 D: Delegate<Id = T::Id, Address = <T::Resolver as AddressResolver>::ResolvedAddress>,
363 T: Transport,
364 {
365 let mut c = self.inner.core.write().await;
366 if c.acks.contains(&nr.from) {
368 return Ok(());
369 }
370
371 if c.closed {
372 Ok(())
373 } else if let Some((tx, _)) = &self.inner.channel.ack_ch {
374 futures::select! {
375 _ = tx.send(nr.from.cheap_clone()).fuse() => {
376 c.acks.insert(nr.from.clone());
377 Ok(())
378 },
379 default => {
380 Err(Error::query_response_delivery_failed())
381 }
382 }
383 } else {
384 Ok(())
385 }
386 }
387}
388
389#[viewit::viewit(
391 vis_all = "pub(crate)",
392 setters(skip),
393 getters(vis_all = "pub", style = "ref")
394)]
395#[derive(Debug, Clone, PartialEq, Eq, Hash)]
396#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
397pub struct NodeResponse<I, A> {
398 #[viewit(getter(attrs(doc = "Returns the node that sent the response")))]
399 from: Node<I, A>,
400 #[viewit(getter(attrs(doc = "Returns the payload of the response")))]
401 payload: Bytes,
402}
403
404#[inline]
405fn random_members<I, A>(k: usize, mut members: SmallVec<Member<I, A>>) -> SmallVec<Member<I, A>> {
406 let n = members.len();
407 if n == 0 {
408 return SmallVec::new();
409 }
410
411 let rounds = 3 * n;
413 let mut i = 0;
414
415 while i < rounds && i < n {
416 let j = rand::random::<usize>() % (n - i) + i;
417 members.swap(i, j);
418 i += 1;
419 if i >= k && i >= rounds {
420 break;
421 }
422 }
423
424 members.truncate(k);
425 members
426}
427
428impl<T, D> Serf<T, D>
429where
430 D: Delegate<Id = T::Id, Address = <T::Resolver as AddressResolver>::ResolvedAddress>,
431 T: Transport,
432{
433 pub async fn default_query_timeout(&self) -> Duration {
439 let n = self.inner.memberlist.num_online_members().await;
440 let mut timeout = self.inner.opts.memberlist_options.gossip_interval();
441 timeout *= self.inner.opts.query_timeout_mult as u32;
442 timeout *= ((n + 1) as f64).log10().ceil() as u32; timeout
444 }
445
446 pub async fn default_query_param(&self) -> QueryParam<T::Id> {
448 QueryParam {
449 filters: OneOrMore::new(),
450 request_ack: false,
451 relay_factor: 0,
452 timeout: self.default_query_timeout().await,
453 }
454 }
455
456 pub(crate) fn should_process_query(&self, filters: &[Bytes]) -> bool {
457 for filter in filters.iter() {
458 if filter.is_empty() {
459 tracing::warn!("serf: empty filter");
460 return false;
461 }
462
463 let filter = match <D as TransformDelegate>::decode_filter(filter) {
465 Ok((read, filter)) => {
466 tracing::trace!(read=%read, filter=?filter, "serf: decoded filter successully");
467 filter
468 }
469 Err(err) => {
470 tracing::warn!(
471 err = %err,
472 "serf: failed to decode filter"
473 );
474 return false;
475 }
476 };
477
478 match filter {
479 Filter::Id(nodes) => {
480 let found = nodes.iter().any(|n| n.eq(self.inner.memberlist.local_id()));
482 if !found {
483 return false;
484 }
485 }
486 Filter::Tag { tag, expr: fexpr } => {
487 let tags = self.inner.opts.tags.load();
489 if !tags.is_empty() {
490 if let Some(expr) = tags.get(&tag) {
491 match regex::Regex::new(&fexpr) {
492 Ok(re) => {
493 if !re.is_match(expr) {
494 return false;
495 }
496 }
497 Err(err) => {
498 tracing::warn!(err=%err, "serf: failed to compile filter regex ({})", fexpr);
499 return false;
500 }
501 }
502 } else {
503 return false;
504 }
505 } else {
506 return false;
507 }
508 }
509 }
510 }
511 true
512 }
513
514 pub(crate) async fn relay_response(
515 &self,
516 relay_factor: u8,
517 node: Node<T::Id, <T::Resolver as AddressResolver>::ResolvedAddress>,
518 resp: QueryResponseMessage<T::Id, <T::Resolver as AddressResolver>::ResolvedAddress>,
519 ) -> Result<(), Error<T, D>> {
520 if relay_factor == 0 {
521 return Ok(());
522 }
523
524 let members = {
528 let members = self.inner.members.read().await;
529 if members.states.len() < relay_factor as usize + 1 {
530 return Ok(());
531 }
532 members
533 .states
534 .iter()
535 .filter_map(|(id, m)| {
536 if m.member.status == MemberStatus::Alive && id != self.inner.memberlist.local_id() {
537 Some(m.member.clone())
538 } else {
539 None
540 }
541 })
542 .collect::<SmallVec<_>>()
543 };
544
545 if members.is_empty() {
546 return Ok(());
547 }
548
549 let expected_encoded_len = 1
552 + <D as TransformDelegate>::node_encoded_len(&node)
553 + 1
554 + <D as TransformDelegate>::message_encoded_len(&resp); if expected_encoded_len > self.inner.opts.query_response_size_limit {
556 return Err(Error::relayed_response_too_large(
557 self.inner.opts.query_response_size_limit,
558 ));
559 }
560
561 let mut raw = BytesMut::with_capacity(expected_encoded_len + 1 + 1); raw.put_u8(MessageType::Relay as u8);
563 raw.resize(expected_encoded_len + 1 + 1, 0);
564 let mut encoded = 1;
565 encoded += <D as TransformDelegate>::encode_node(&node, &mut raw[encoded..])
566 .map_err(Error::transform_delegate)?;
567 raw[encoded] = MessageType::QueryResponse as u8;
568 encoded += 1;
569 encoded += <D as TransformDelegate>::encode_message(&resp, &mut raw[encoded..])
570 .map_err(Error::transform_delegate)?;
571
572 debug_assert_eq!(
573 encoded, expected_encoded_len,
574 "expected encoded len {} mismatch the actual encoded len {}",
575 expected_encoded_len, encoded
576 );
577
578 let raw = raw.freeze();
579 let relay_members = random_members(relay_factor as usize, members);
581
582 let futs: FuturesUnordered<_> = relay_members
583 .into_iter()
584 .map(|m| {
585 let raw = raw.clone();
586 async move {
587 self
588 .inner
589 .memberlist
590 .send(m.node.address(), raw)
591 .await
592 .map_err(|e| (m, e))
593 }
594 })
595 .collect();
596
597 let mut errs = TinyVec::new();
598 let stream = StreamExt::filter_map(futs, |res| async move {
599 if let Err((m, e)) = res {
600 Some((m, e))
601 } else {
602 None
603 }
604 });
605 futures::pin_mut!(stream);
606
607 while let Some(err) = stream.next().await {
608 errs.push(err);
609 }
610
611 Ok(())
612 }
613}