serf_core/serf/
query.rs

1use std::{
2  collections::HashSet,
3  sync::Arc,
4  time::{Duration, Instant},
5};
6
7use crate::types::FilterRef;
8use async_channel::{Receiver, Sender};
9use async_lock::RwLock;
10use either::Either;
11use futures::{FutureExt, StreamExt, stream::FuturesUnordered};
12use memberlist_core::{
13  CheapClone,
14  bytes::Bytes,
15  proto::{Data, RepeatedDecoder, SmallVec, TinyVec},
16  tracing,
17  transport::{Node, Transport},
18};
19
20use crate::{
21  delegate::Delegate,
22  error::Error,
23  types::{Filter, LamportTime, Member, MemberStatus, QueryMessage, QueryResponseMessage},
24};
25
26use super::Serf;
27
28/// Provided to [`Serf::query`] to configure the parameters of the
29/// query. If not provided, sane defaults will be used.
30#[viewit::viewit(
31  vis_all = "pub(crate)",
32  getters(vis_all = "pub", style = "ref"),
33  setters(vis_all = "pub", prefix = "with")
34)]
35#[derive(Debug, Clone)]
36#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
37pub struct QueryParam<I> {
38  /// The filters to apply to the query.
39  #[viewit(
40    getter(const, attrs(doc = "Returns the filters of the query")),
41    setter(attrs(doc = "Sets the filters of the query"))
42  )]
43  filters: TinyVec<Filter<I>>,
44
45  /// If true, we are requesting an delivery acknowledgement from
46  /// every node that meets the filter requirement. This means nodes
47  /// the receive the message but do not pass the filters, will not
48  /// send an ack.
49  #[viewit(
50    getter(
51      const,
52      style = "move",
53      attrs(
54        doc = "Returns if we are requesting an delivery acknowledgement from every node that meets the filter requirement. This means nodes the receive the message but do not pass the filters, will not send an ack."
55      )
56    ),
57    setter(attrs(
58      doc = "Sets if we are requesting an delivery acknowledgement from every node that meets the filter requirement. This means nodes the receive the message but do not pass the filters, will not send an ack."
59    ))
60  )]
61  request_ack: bool,
62
63  /// Controls the number of duplicate responses to relay
64  /// back to the sender through other nodes for redundancy.
65  #[viewit(
66    getter(
67      const,
68      style = "move",
69      attrs(
70        doc = "Returns the number of duplicate responses to relay back to the sender through other nodes for redundancy."
71      )
72    ),
73    setter(attrs(
74      doc = "Sets the number of duplicate responses to relay back to the sender through other nodes for redundancy."
75    ))
76  )]
77  relay_factor: u8,
78
79  /// The timeout limits how long the query is left open. If not provided,
80  /// then a default timeout is used based on the configuration of Serf
81  #[viewit(
82    getter(
83      const,
84      style = "move",
85      attrs(
86        doc = "Returns timeout limits how long the query is left open. If not provided, then a default timeout is used based on the configuration of [`Serf`]"
87      )
88    ),
89    setter(attrs(doc = "Sets timeout limits how long the query is left open."))
90  )]
91  #[cfg_attr(feature = "serde", serde(with = "humantime_serde"))]
92  timeout: Duration,
93}
94
95struct QueryResponseChannel<I, A> {
96  /// Used to send the name of a node for which we've received an ack
97  ack_ch: Option<(Sender<Node<I, A>>, Receiver<Node<I, A>>)>,
98  /// Used to send a response from a node
99  resp_ch: (Sender<NodeResponse<I, A>>, Receiver<NodeResponse<I, A>>),
100}
101
102pub(crate) struct QueryResponseCore<I, A> {
103  closed: bool,
104  acks: HashSet<Node<I, A>>,
105  responses: HashSet<Node<I, A>>,
106}
107
108pub(crate) struct QueryResponseInner<I, A> {
109  core: RwLock<QueryResponseCore<I, A>>,
110  channel: QueryResponseChannel<I, A>,
111}
112
113/// Returned for each new Query. It is used to collect
114/// Ack's as well as responses and to provide those back to a client.
115#[viewit::viewit(vis_all = "pub(crate)")]
116#[derive(Clone)]
117pub struct QueryResponse<I, A> {
118  /// The duration of the query
119  #[viewit(
120    getter(
121      style = "move",
122      const,
123      attrs(doc = "Returns the ending deadline of the query")
124    ),
125    setter(skip)
126  )]
127  deadline: Instant,
128
129  /// The query id
130  #[viewit(
131    getter(style = "move", const, attrs(doc = "Returns the id of the query")),
132    setter(skip)
133  )]
134  id: u32,
135
136  /// Stores the LTime of the query
137  #[viewit(
138    getter(
139      style = "move",
140      const,
141      attrs(doc = "Returns the Lamport Time of the query")
142    ),
143    setter(skip)
144  )]
145  ltime: LamportTime,
146
147  #[viewit(getter(vis = "pub(crate)", const, style = "ref"), setter(skip))]
148  inner: Arc<QueryResponseInner<I, A>>,
149}
150
151impl<I, A> QueryResponse<I, A> {
152  pub(crate) fn from_query(q: &QueryMessage<I, A>, num_nodes: usize) -> Self {
153    QueryResponse::new(
154      q.id(),
155      q.ltime(),
156      num_nodes,
157      Instant::now() + q.timeout(),
158      q.ack(),
159    )
160  }
161}
162
163impl<I, A> QueryResponse<I, A> {
164  #[inline]
165  pub(crate) fn new(
166    id: u32,
167    ltime: LamportTime,
168    num_nodes: usize,
169    deadline: Instant,
170    ack: bool,
171  ) -> Self {
172    let (ack_ch, acks) = if ack {
173      (
174        Some(async_channel::bounded(num_nodes)),
175        HashSet::with_capacity(num_nodes),
176      )
177    } else {
178      (None, HashSet::new())
179    };
180
181    Self {
182      deadline,
183      id,
184      ltime,
185      inner: Arc::new(QueryResponseInner {
186        core: RwLock::new(QueryResponseCore {
187          closed: false,
188          acks,
189          responses: HashSet::with_capacity(num_nodes),
190        }),
191        channel: QueryResponseChannel {
192          ack_ch,
193          resp_ch: async_channel::bounded(num_nodes),
194        },
195      }),
196    }
197  }
198
199  /// Returns a receiver that can be used to listen for acks.
200  /// Channel will be closed when the query is finished. This is `None`,
201  /// if the query did not specify `request_ack`.
202  #[inline]
203  pub fn ack_rx(&self) -> Option<async_channel::Receiver<Node<I, A>>> {
204    self.inner.channel.ack_ch.as_ref().map(|(_, r)| r.clone())
205  }
206
207  /// Returns a receiver that can be used to listen for responses.
208  /// Channel will be closed when the query is finished.
209  #[inline]
210  pub fn response_rx(&self) -> async_channel::Receiver<NodeResponse<I, A>> {
211    self.inner.channel.resp_ch.1.clone()
212  }
213
214  /// Returns if the query is finished running
215  #[inline]
216  pub async fn finished(&self) -> bool {
217    let c = self.inner.core.read().await;
218    c.closed || (Instant::now() > self.deadline)
219  }
220
221  /// Used to close the query, which will close the underlying
222  /// channels and prevent further deliveries
223  #[inline]
224  pub async fn close(&self) {
225    let mut c = self.inner.core.write().await;
226    if c.closed {
227      return;
228    }
229
230    c.closed = true;
231
232    if let Some((tx, _)) = &self.inner.channel.ack_ch {
233      tx.close();
234    }
235
236    self.inner.channel.resp_ch.0.close();
237  }
238
239  #[inline]
240  pub(crate) async fn handle_query_response<T, D>(
241    &self,
242    resp: QueryResponseMessage<I, A>,
243    _local: &T::Id,
244    #[cfg(feature = "metrics")] metrics_labels: &memberlist_core::proto::MetricLabels,
245  ) where
246    I: Eq + std::hash::Hash + CheapClone + core::fmt::Debug,
247    A: Eq + std::hash::Hash + CheapClone + core::fmt::Debug,
248    D: Delegate<Id = T::Id, Address = T::ResolvedAddress>,
249    T: Transport,
250  {
251    // Check if the query is closed
252    let c = self.inner.core.read().await;
253    if c.closed || (Instant::now() > self.deadline) {
254      return;
255    }
256
257    // Process each type of response
258    if resp.ack() {
259      // Exit early if this is a duplicate ack
260      if c.acks.contains(&resp.from) {
261        #[cfg(feature = "metrics")]
262        {
263          metrics::counter!("serf.query.duplicate_acks", metrics_labels.iter()).increment(1);
264        }
265        return;
266      }
267
268      #[cfg(feature = "metrics")]
269      {
270        metrics::counter!("serf.query.acks", metrics_labels.iter()).increment(1);
271      }
272
273      drop(c);
274      if let Err(e) = self.send_ack::<T, D>(&resp).await {
275        tracing::warn!("serf: {}", e);
276      }
277    } else {
278      // Exit early if this is a duplicate response
279      if c.responses.contains(&resp.from) {
280        #[cfg(feature = "metrics")]
281        {
282          metrics::counter!("serf.query.duplicate_responses", metrics_labels.iter()).increment(1);
283        }
284        return;
285      }
286
287      #[cfg(feature = "metrics")]
288      {
289        metrics::counter!("serf.query.responses", metrics_labels.iter()).increment(1);
290      }
291      drop(c);
292
293      if let Err(e) = self
294        .send_response::<T, D>(NodeResponse {
295          from: resp.from,
296          payload: resp.payload,
297        })
298        .await
299      {
300        tracing::warn!("serf: {}", e);
301      }
302    }
303  }
304
305  /// Sends a response on the response channel ensuring the channel is not closed.
306  #[inline]
307  pub(crate) async fn send_response<T, D>(&self, nr: NodeResponse<I, A>) -> Result<(), Error<T, D>>
308  where
309    I: Eq + std::hash::Hash + CheapClone + core::fmt::Debug,
310    A: Eq + std::hash::Hash + CheapClone + core::fmt::Debug,
311    D: Delegate<Id = T::Id, Address = T::ResolvedAddress>,
312    T: Transport,
313  {
314    let mut c = self.inner.core.write().await;
315    // Exit early if this is a duplicate ack
316    if c.responses.contains(&nr.from) {
317      return Ok(());
318    }
319
320    if c.closed {
321      Ok(())
322    } else {
323      let id = nr.from.cheap_clone();
324      futures::select! {
325        _ = self.inner.channel.resp_ch.0.send(nr).fuse() => {
326          c.responses.insert(id);
327          Ok(())
328        },
329        default => {
330          Err(Error::query_response_delivery_failed())
331        }
332      }
333    }
334  }
335
336  /// Sends a response on the ack channel ensuring the channel is not closed.
337  #[inline]
338  pub(crate) async fn send_ack<T, D>(
339    &self,
340    nr: &QueryResponseMessage<I, A>,
341  ) -> Result<(), Error<T, D>>
342  where
343    I: Eq + std::hash::Hash + CheapClone,
344    A: Eq + std::hash::Hash + CheapClone,
345    D: Delegate<Id = T::Id, Address = T::ResolvedAddress>,
346    T: Transport,
347  {
348    let mut c = self.inner.core.write().await;
349    // Exit early if this is a duplicate ack
350    if c.acks.contains(&nr.from) {
351      return Ok(());
352    }
353
354    if c.closed {
355      Ok(())
356    } else if let Some((tx, _)) = &self.inner.channel.ack_ch {
357      futures::select! {
358        _ = tx.send(nr.from.cheap_clone()).fuse() => {
359          c.acks.insert(nr.from.clone());
360          Ok(())
361        },
362        default => {
363          Err(Error::query_response_delivery_failed())
364        }
365      }
366    } else {
367      Ok(())
368    }
369  }
370}
371
372/// Used to represent a single response from a node
373#[viewit::viewit(
374  vis_all = "pub(crate)",
375  setters(skip),
376  getters(vis_all = "pub", style = "ref")
377)]
378#[derive(Debug, Clone, PartialEq, Eq, Hash)]
379#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
380pub struct NodeResponse<I, A> {
381  #[viewit(getter(attrs(doc = "Returns the node that sent the response")))]
382  from: Node<I, A>,
383  #[viewit(getter(attrs(doc = "Returns the payload of the response")))]
384  payload: Bytes,
385}
386
387#[inline]
388fn random_members<I, A>(k: usize, mut members: SmallVec<Member<I, A>>) -> SmallVec<Member<I, A>> {
389  let n = members.len();
390  if n == 0 {
391    return SmallVec::new();
392  }
393
394  // The modified Fisher-Yates algorithm, but up to 3*n times to ensure exhaustive search for small n.
395  let rounds = 3 * n;
396  let mut i = 0;
397
398  while i < rounds && i < n {
399    let j = (rand::random::<u32>() as usize) % (n - i) + i;
400    members.swap(i, j);
401    i += 1;
402    if i >= k && i >= rounds {
403      break;
404    }
405  }
406
407  members.truncate(k);
408  members
409}
410
411impl<T, D> Serf<T, D>
412where
413  D: Delegate<Id = T::Id, Address = T::ResolvedAddress>,
414  T: Transport,
415{
416  /// Returns the default timeout value for a query
417  /// Computed as
418  /// ```text
419  /// gossip_interval * query_timeout_mult * log(N+1)
420  /// ```
421  pub async fn default_query_timeout(&self) -> Duration {
422    let n = self.inner.memberlist.num_online_members().await;
423    let mut timeout = self.inner.opts.memberlist_options.gossip_interval();
424    timeout *= self.inner.opts.query_timeout_mult as u32;
425    timeout *= ((n + 1) as f64).log10().ceil() as u32; // Using ceil approximation
426    timeout
427  }
428
429  /// Used to return the default query parameters
430  pub async fn default_query_param(&self) -> QueryParam<T::Id> {
431    QueryParam {
432      filters: TinyVec::new(),
433      request_ack: false,
434      relay_factor: 0,
435      timeout: self.default_query_timeout().await,
436    }
437  }
438
439  pub(crate) fn should_process_query(
440    &self,
441    filters: Either<RepeatedDecoder<'_>, &[Filter<T::Id>]>,
442  ) -> Result<bool, memberlist_core::proto::DecodeError> {
443    match filters {
444      Either::Left(filters) => {
445        for filter in filters.iter::<Filter<T::Id>>() {
446          let filter = filter?;
447          match filter {
448            FilterRef::Id(ids) => {
449              // Check if we are being targeted
450              let mut found = false;
451              for id in ids.iter::<T::Id>() {
452                let id = id?;
453                if <T::Id as Data>::from_ref(id)?.eq(self.inner.memberlist.local_id()) {
454                  found = true;
455                  break;
456                }
457              }
458              if !found {
459                return Ok(false);
460              }
461            }
462            FilterRef::Tag(tag) => {
463              // Check if we match this regex
464              let tags = self.inner.opts.tags.load();
465              if !tags.is_empty() {
466                if let Some(expr) = tags.get(tag.tag()) {
467                  if let Some(re) = tag.expr() {
468                    if !regex::Regex::new(re)
469                      .map_err(|_| memberlist_core::proto::DecodeError::custom("invalid regex"))?
470                      .is_match(expr)
471                    {
472                      return Ok(false);
473                    }
474                  }
475                } else {
476                  return Ok(false);
477                }
478              } else {
479                return Ok(false);
480              }
481            }
482          }
483        }
484
485        Ok(true)
486      }
487      Either::Right(filters) => {
488        for filter in filters.iter() {
489          match &filter {
490            Filter::Id(nodes) => {
491              // Check if we are being targeted
492              let found = nodes
493                .iter()
494                .any(|n: &T::Id| n.eq(self.inner.memberlist.local_id()));
495              if !found {
496                return Ok(false);
497              }
498            }
499            Filter::Tag(tag) => {
500              // Check if we match this regex
501              let tags = self.inner.opts.tags.load();
502              if !tags.is_empty() {
503                if let Some(expr) = tags.get(tag.tag()) {
504                  if let Some(re) = tag.expr() {
505                    if !re.is_match(expr) {
506                      return Ok(false);
507                    }
508                  }
509                } else {
510                  return Ok(false);
511                }
512              } else {
513                return Ok(false);
514              }
515            }
516          }
517        }
518        Ok(true)
519      }
520    }
521  }
522
523  pub(crate) async fn relay_response(
524    &self,
525    relay_factor: u8,
526    node: Node<T::Id, T::ResolvedAddress>,
527    resp: QueryResponseMessage<T::Id, T::ResolvedAddress>,
528  ) -> Result<(), Error<T, D>> {
529    if relay_factor == 0 {
530      return Ok(());
531    }
532
533    // Needs to be worth it; we need to have at least relayFactor *other*
534    // nodes. If you have a tiny cluster then the relayFactor shouldn't
535    // be needed.
536    let members = {
537      let members = self.inner.members.read().await;
538      if members.states.len() < relay_factor as usize + 1 {
539        return Ok(());
540      }
541      members
542        .states
543        .iter()
544        .filter_map(|(id, m)| {
545          if m.member.status == MemberStatus::Alive && id != self.inner.memberlist.local_id() {
546            Some(m.member.clone())
547          } else {
548            None
549          }
550        })
551        .collect::<SmallVec<_>>()
552    };
553
554    if members.is_empty() {
555      return Ok(());
556    }
557
558    // Prep the relay message, which is a wrapped version of the original.
559    let encoded_len = crate::types::encoded_relay_message_len(&resp, &node);
560    if encoded_len > self.inner.opts.query_response_size_limit {
561      return Err(Error::relayed_response_too_large(
562        self.inner.opts.query_response_size_limit,
563      ));
564    }
565
566    let raw = crate::types::encode_relay_message_to_bytes(&resp, &node)?;
567
568    // Relay to a random set of peers.
569    let relay_members = random_members(relay_factor as usize, members);
570
571    let futs: FuturesUnordered<_> = relay_members
572      .into_iter()
573      .map(|m| {
574        let raw = raw.clone();
575        async move {
576          self
577            .inner
578            .memberlist
579            .send(m.node.address(), raw)
580            .await
581            .map_err(|e| (m, e))
582        }
583      })
584      .collect();
585
586    let mut errs = TinyVec::new();
587    let stream = StreamExt::filter_map(futs, |res| async move {
588      if let Err((m, e)) = res {
589        Some((m, e))
590      } else {
591        None
592      }
593    });
594    futures::pin_mut!(stream);
595
596    while let Some(err) = stream.next().await {
597      errs.push(err);
598    }
599
600    Ok(())
601  }
602}