serf_core/serf/
query.rs

1use std::{
2  collections::HashSet,
3  sync::Arc,
4  time::{Duration, Instant},
5};
6
7use async_channel::{Receiver, Sender};
8use async_lock::RwLock;
9use futures::{stream::FuturesUnordered, FutureExt, StreamExt};
10use memberlist_core::{
11  bytes::{BufMut, Bytes, BytesMut},
12  tracing,
13  transport::{AddressResolver, Id, Node, Transport},
14  types::{OneOrMore, SmallVec, TinyVec},
15  CheapClone,
16};
17
18use crate::{
19  delegate::{Delegate, TransformDelegate},
20  error::Error,
21  types::{
22    Filter, LamportTime, Member, MemberStatus, MessageType, QueryMessage, QueryResponseMessage,
23  },
24};
25
26use super::Serf;
27
28/// Provided to [`Serf::query`] to configure the parameters of the
29/// query. If not provided, sane defaults will be used.
30#[viewit::viewit(
31  vis_all = "pub(crate)",
32  getters(vis_all = "pub", style = "ref"),
33  setters(vis_all = "pub", prefix = "with")
34)]
35#[derive(Debug, Clone)]
36#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
37pub struct QueryParam<I> {
38  /// The filters to apply to the query.
39  #[viewit(
40    getter(const, attrs(doc = "Returns the filters of the query")),
41    setter(attrs(doc = "Sets the filters of the query"))
42  )]
43  filters: OneOrMore<Filter<I>>,
44
45  /// If true, we are requesting an delivery acknowledgement from
46  /// every node that meets the filter requirement. This means nodes
47  /// the receive the message but do not pass the filters, will not
48  /// send an ack.
49  #[viewit(
50    getter(
51      const,
52      style = "move",
53      attrs(
54        doc = "Returns if we are requesting an delivery acknowledgement from every node that meets the filter requirement. This means nodes the receive the message but do not pass the filters, will not send an ack."
55      )
56    ),
57    setter(attrs(
58      doc = "Sets if we are requesting an delivery acknowledgement from every node that meets the filter requirement. This means nodes the receive the message but do not pass the filters, will not send an ack."
59    ))
60  )]
61  request_ack: bool,
62
63  /// Controls the number of duplicate responses to relay
64  /// back to the sender through other nodes for redundancy.
65  #[viewit(
66    getter(
67      const,
68      style = "move",
69      attrs(
70        doc = "Returns the number of duplicate responses to relay back to the sender through other nodes for redundancy."
71      )
72    ),
73    setter(attrs(
74      doc = "Sets the number of duplicate responses to relay back to the sender through other nodes for redundancy."
75    ))
76  )]
77  relay_factor: u8,
78
79  /// The timeout limits how long the query is left open. If not provided,
80  /// then a default timeout is used based on the configuration of Serf
81  #[viewit(
82    getter(
83      const,
84      style = "move",
85      attrs(
86        doc = "Returns timeout limits how long the query is left open. If not provided, then a default timeout is used based on the configuration of [`Serf`]"
87      )
88    ),
89    setter(attrs(doc = "Sets timeout limits how long the query is left open."))
90  )]
91  #[cfg_attr(feature = "serde", serde(with = "humantime_serde"))]
92  timeout: Duration,
93}
94
95impl<I> QueryParam<I>
96where
97  I: Id,
98{
99  /// Used to convert the filters into the wire format
100  pub(crate) fn encode_filters<W: TransformDelegate<Id = I>>(
101    &self,
102  ) -> Result<TinyVec<Bytes>, W::Error> {
103    let mut filters = TinyVec::with_capacity(self.filters.len());
104    for filter in self.filters.iter() {
105      filters.push(W::encode_filter(filter)?);
106    }
107
108    Ok(filters)
109  }
110}
111
112struct QueryResponseChannel<I, A> {
113  /// Used to send the name of a node for which we've received an ack
114  ack_ch: Option<(Sender<Node<I, A>>, Receiver<Node<I, A>>)>,
115  /// Used to send a response from a node
116  resp_ch: (Sender<NodeResponse<I, A>>, Receiver<NodeResponse<I, A>>),
117}
118
119pub(crate) struct QueryResponseCore<I, A> {
120  closed: bool,
121  acks: HashSet<Node<I, A>>,
122  responses: HashSet<Node<I, A>>,
123}
124
125pub(crate) struct QueryResponseInner<I, A> {
126  core: RwLock<QueryResponseCore<I, A>>,
127  channel: QueryResponseChannel<I, A>,
128}
129
130/// Returned for each new Query. It is used to collect
131/// Ack's as well as responses and to provide those back to a client.
132#[viewit::viewit(vis_all = "pub(crate)")]
133#[derive(Clone)]
134pub struct QueryResponse<I, A> {
135  /// The duration of the query
136  #[viewit(
137    getter(
138      style = "move",
139      const,
140      attrs(doc = "Returns the ending deadline of the query")
141    ),
142    setter(skip)
143  )]
144  deadline: Instant,
145
146  /// The query id
147  #[viewit(
148    getter(style = "move", const, attrs(doc = "Returns the id of the query")),
149    setter(skip)
150  )]
151  id: u32,
152
153  /// Stores the LTime of the query
154  #[viewit(
155    getter(
156      style = "move",
157      const,
158      attrs(doc = "Returns the Lamport Time of the query")
159    ),
160    setter(skip)
161  )]
162  ltime: LamportTime,
163
164  #[viewit(getter(vis = "pub(crate)", const, style = "ref"), setter(skip))]
165  inner: Arc<QueryResponseInner<I, A>>,
166}
167
168impl<I, A> QueryResponse<I, A> {
169  pub(crate) fn from_query(q: &QueryMessage<I, A>, num_nodes: usize) -> Self {
170    QueryResponse::new(
171      q.id(),
172      q.ltime(),
173      num_nodes,
174      Instant::now() + q.timeout(),
175      q.ack(),
176    )
177  }
178}
179
180impl<I, A> QueryResponse<I, A> {
181  #[inline]
182  pub(crate) fn new(
183    id: u32,
184    ltime: LamportTime,
185    num_nodes: usize,
186    deadline: Instant,
187    ack: bool,
188  ) -> Self {
189    let (ack_ch, acks) = if ack {
190      (
191        Some(async_channel::bounded(num_nodes)),
192        HashSet::with_capacity(num_nodes),
193      )
194    } else {
195      (None, HashSet::new())
196    };
197
198    Self {
199      deadline,
200      id,
201      ltime,
202      inner: Arc::new(QueryResponseInner {
203        core: RwLock::new(QueryResponseCore {
204          closed: false,
205          acks,
206          responses: HashSet::with_capacity(num_nodes),
207        }),
208        channel: QueryResponseChannel {
209          ack_ch,
210          resp_ch: async_channel::bounded(num_nodes),
211        },
212      }),
213    }
214  }
215
216  /// Returns a receiver that can be used to listen for acks.
217  /// Channel will be closed when the query is finished. This is `None`,
218  /// if the query did not specify `request_ack`.
219  #[inline]
220  pub fn ack_rx(&self) -> Option<async_channel::Receiver<Node<I, A>>> {
221    self.inner.channel.ack_ch.as_ref().map(|(_, r)| r.clone())
222  }
223
224  /// Returns a receiver that can be used to listen for responses.
225  /// Channel will be closed when the query is finished.
226  #[inline]
227  pub fn response_rx(&self) -> async_channel::Receiver<NodeResponse<I, A>> {
228    self.inner.channel.resp_ch.1.clone()
229  }
230
231  /// Returns if the query is finished running
232  #[inline]
233  pub async fn finished(&self) -> bool {
234    let c = self.inner.core.read().await;
235    c.closed || (Instant::now() > self.deadline)
236  }
237
238  /// Used to close the query, which will close the underlying
239  /// channels and prevent further deliveries
240  #[inline]
241  pub async fn close(&self) {
242    let mut c = self.inner.core.write().await;
243    if c.closed {
244      return;
245    }
246
247    c.closed = true;
248
249    if let Some((tx, _)) = &self.inner.channel.ack_ch {
250      tx.close();
251    }
252
253    self.inner.channel.resp_ch.0.close();
254  }
255
256  #[inline]
257  pub(crate) async fn handle_query_response<T, D>(
258    &self,
259    resp: QueryResponseMessage<I, A>,
260    _local: &T::Id,
261    #[cfg(feature = "metrics")] metrics_labels: &memberlist_core::types::MetricLabels,
262  ) where
263    I: Eq + std::hash::Hash + CheapClone + core::fmt::Debug,
264    A: Eq + std::hash::Hash + CheapClone + core::fmt::Debug,
265    D: Delegate<Id = T::Id, Address = <T::Resolver as AddressResolver>::ResolvedAddress>,
266    T: Transport,
267  {
268    // Check if the query is closed
269    let c = self.inner.core.read().await;
270    if c.closed || (Instant::now() > self.deadline) {
271      return;
272    }
273
274    // Process each type of response
275    if resp.ack() {
276      // Exit early if this is a duplicate ack
277      if c.acks.contains(&resp.from) {
278        #[cfg(feature = "metrics")]
279        {
280          metrics::counter!("serf.query.duplicate_acks", metrics_labels.iter()).increment(1);
281        }
282        return;
283      }
284
285      #[cfg(feature = "metrics")]
286      {
287        metrics::counter!("serf.query.acks", metrics_labels.iter()).increment(1);
288      }
289
290      drop(c);
291      if let Err(e) = self.send_ack::<T, D>(&resp).await {
292        tracing::warn!("serf: {}", e);
293      }
294    } else {
295      // Exit early if this is a duplicate response
296      if c.responses.contains(&resp.from) {
297        #[cfg(feature = "metrics")]
298        {
299          metrics::counter!("serf.query.duplicate_responses", metrics_labels.iter()).increment(1);
300        }
301        return;
302      }
303
304      #[cfg(feature = "metrics")]
305      {
306        metrics::counter!("serf.query.responses", metrics_labels.iter()).increment(1);
307      }
308      drop(c);
309
310      if let Err(e) = self
311        .send_response::<T, D>(NodeResponse {
312          from: resp.from,
313          payload: resp.payload,
314        })
315        .await
316      {
317        tracing::warn!("serf: {}", e);
318      }
319    }
320  }
321
322  /// Sends a response on the response channel ensuring the channel is not closed.
323  #[inline]
324  pub(crate) async fn send_response<T, D>(&self, nr: NodeResponse<I, A>) -> Result<(), Error<T, D>>
325  where
326    I: Eq + std::hash::Hash + CheapClone + core::fmt::Debug,
327    A: Eq + std::hash::Hash + CheapClone + core::fmt::Debug,
328    D: Delegate<Id = T::Id, Address = <T::Resolver as AddressResolver>::ResolvedAddress>,
329    T: Transport,
330  {
331    let mut c = self.inner.core.write().await;
332    // Exit early if this is a duplicate ack
333    if c.responses.contains(&nr.from) {
334      return Ok(());
335    }
336
337    if c.closed {
338      Ok(())
339    } else {
340      let id = nr.from.cheap_clone();
341      futures::select! {
342        _ = self.inner.channel.resp_ch.0.send(nr).fuse() => {
343          c.responses.insert(id);
344          Ok(())
345        },
346        default => {
347          Err(Error::query_response_delivery_failed())
348        }
349      }
350    }
351  }
352
353  /// Sends a response on the ack channel ensuring the channel is not closed.
354  #[inline]
355  pub(crate) async fn send_ack<T, D>(
356    &self,
357    nr: &QueryResponseMessage<I, A>,
358  ) -> Result<(), Error<T, D>>
359  where
360    I: Eq + std::hash::Hash + CheapClone,
361    A: Eq + std::hash::Hash + CheapClone,
362    D: Delegate<Id = T::Id, Address = <T::Resolver as AddressResolver>::ResolvedAddress>,
363    T: Transport,
364  {
365    let mut c = self.inner.core.write().await;
366    // Exit early if this is a duplicate ack
367    if c.acks.contains(&nr.from) {
368      return Ok(());
369    }
370
371    if c.closed {
372      Ok(())
373    } else if let Some((tx, _)) = &self.inner.channel.ack_ch {
374      futures::select! {
375        _ = tx.send(nr.from.cheap_clone()).fuse() => {
376          c.acks.insert(nr.from.clone());
377          Ok(())
378        },
379        default => {
380          Err(Error::query_response_delivery_failed())
381        }
382      }
383    } else {
384      Ok(())
385    }
386  }
387}
388
389/// Used to represent a single response from a node
390#[viewit::viewit(
391  vis_all = "pub(crate)",
392  setters(skip),
393  getters(vis_all = "pub", style = "ref")
394)]
395#[derive(Debug, Clone, PartialEq, Eq, Hash)]
396#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
397pub struct NodeResponse<I, A> {
398  #[viewit(getter(attrs(doc = "Returns the node that sent the response")))]
399  from: Node<I, A>,
400  #[viewit(getter(attrs(doc = "Returns the payload of the response")))]
401  payload: Bytes,
402}
403
404#[inline]
405fn random_members<I, A>(k: usize, mut members: SmallVec<Member<I, A>>) -> SmallVec<Member<I, A>> {
406  let n = members.len();
407  if n == 0 {
408    return SmallVec::new();
409  }
410
411  // The modified Fisher-Yates algorithm, but up to 3*n times to ensure exhaustive search for small n.
412  let rounds = 3 * n;
413  let mut i = 0;
414
415  while i < rounds && i < n {
416    let j = rand::random::<usize>() % (n - i) + i;
417    members.swap(i, j);
418    i += 1;
419    if i >= k && i >= rounds {
420      break;
421    }
422  }
423
424  members.truncate(k);
425  members
426}
427
428impl<T, D> Serf<T, D>
429where
430  D: Delegate<Id = T::Id, Address = <T::Resolver as AddressResolver>::ResolvedAddress>,
431  T: Transport,
432{
433  /// Returns the default timeout value for a query
434  /// Computed as
435  /// ```text
436  /// gossip_interval * query_timeout_mult * log(N+1)
437  /// ```
438  pub async fn default_query_timeout(&self) -> Duration {
439    let n = self.inner.memberlist.num_online_members().await;
440    let mut timeout = self.inner.opts.memberlist_options.gossip_interval();
441    timeout *= self.inner.opts.query_timeout_mult as u32;
442    timeout *= ((n + 1) as f64).log10().ceil() as u32; // Using ceil approximation
443    timeout
444  }
445
446  /// Used to return the default query parameters
447  pub async fn default_query_param(&self) -> QueryParam<T::Id> {
448    QueryParam {
449      filters: OneOrMore::new(),
450      request_ack: false,
451      relay_factor: 0,
452      timeout: self.default_query_timeout().await,
453    }
454  }
455
456  pub(crate) fn should_process_query(&self, filters: &[Bytes]) -> bool {
457    for filter in filters.iter() {
458      if filter.is_empty() {
459        tracing::warn!("serf: empty filter");
460        return false;
461      }
462
463      // Decode the filter
464      let filter = match <D as TransformDelegate>::decode_filter(filter) {
465        Ok((read, filter)) => {
466          tracing::trace!(read=%read, filter=?filter, "serf: decoded filter successully");
467          filter
468        }
469        Err(err) => {
470          tracing::warn!(
471            err = %err,
472            "serf: failed to decode filter"
473          );
474          return false;
475        }
476      };
477
478      match filter {
479        Filter::Id(nodes) => {
480          // Check if we are being targeted
481          let found = nodes.iter().any(|n| n.eq(self.inner.memberlist.local_id()));
482          if !found {
483            return false;
484          }
485        }
486        Filter::Tag { tag, expr: fexpr } => {
487          // Check if we match this regex
488          let tags = self.inner.opts.tags.load();
489          if !tags.is_empty() {
490            if let Some(expr) = tags.get(&tag) {
491              match regex::Regex::new(&fexpr) {
492                Ok(re) => {
493                  if !re.is_match(expr) {
494                    return false;
495                  }
496                }
497                Err(err) => {
498                  tracing::warn!(err=%err, "serf: failed to compile filter regex ({})", fexpr);
499                  return false;
500                }
501              }
502            } else {
503              return false;
504            }
505          } else {
506            return false;
507          }
508        }
509      }
510    }
511    true
512  }
513
514  pub(crate) async fn relay_response(
515    &self,
516    relay_factor: u8,
517    node: Node<T::Id, <T::Resolver as AddressResolver>::ResolvedAddress>,
518    resp: QueryResponseMessage<T::Id, <T::Resolver as AddressResolver>::ResolvedAddress>,
519  ) -> Result<(), Error<T, D>> {
520    if relay_factor == 0 {
521      return Ok(());
522    }
523
524    // Needs to be worth it; we need to have at least relayFactor *other*
525    // nodes. If you have a tiny cluster then the relayFactor shouldn't
526    // be needed.
527    let members = {
528      let members = self.inner.members.read().await;
529      if members.states.len() < relay_factor as usize + 1 {
530        return Ok(());
531      }
532      members
533        .states
534        .iter()
535        .filter_map(|(id, m)| {
536          if m.member.status == MemberStatus::Alive && id != self.inner.memberlist.local_id() {
537            Some(m.member.clone())
538          } else {
539            None
540          }
541        })
542        .collect::<SmallVec<_>>()
543    };
544
545    if members.is_empty() {
546      return Ok(());
547    }
548
549    // Prep the relay message, which is a wrapped version of the original.
550    // let relay_msg = SerfRelayMessage::new(node, SerfMessage::QueryResponse(resp));
551    let expected_encoded_len = 1
552      + <D as TransformDelegate>::node_encoded_len(&node)
553      + 1
554      + <D as TransformDelegate>::message_encoded_len(&resp); // +1 for relay message type byte, +1 for the message type
555    if expected_encoded_len > self.inner.opts.query_response_size_limit {
556      return Err(Error::relayed_response_too_large(
557        self.inner.opts.query_response_size_limit,
558      ));
559    }
560
561    let mut raw = BytesMut::with_capacity(expected_encoded_len + 1 + 1); // +1 for relay message type byte, +1 for the message type byte
562    raw.put_u8(MessageType::Relay as u8);
563    raw.resize(expected_encoded_len + 1 + 1, 0);
564    let mut encoded = 1;
565    encoded += <D as TransformDelegate>::encode_node(&node, &mut raw[encoded..])
566      .map_err(Error::transform_delegate)?;
567    raw[encoded] = MessageType::QueryResponse as u8;
568    encoded += 1;
569    encoded += <D as TransformDelegate>::encode_message(&resp, &mut raw[encoded..])
570      .map_err(Error::transform_delegate)?;
571
572    debug_assert_eq!(
573      encoded, expected_encoded_len,
574      "expected encoded len {} mismatch the actual encoded len {}",
575      expected_encoded_len, encoded
576    );
577
578    let raw = raw.freeze();
579    // Relay to a random set of peers.
580    let relay_members = random_members(relay_factor as usize, members);
581
582    let futs: FuturesUnordered<_> = relay_members
583      .into_iter()
584      .map(|m| {
585        let raw = raw.clone();
586        async move {
587          self
588            .inner
589            .memberlist
590            .send(m.node.address(), raw)
591            .await
592            .map_err(|e| (m, e))
593        }
594      })
595      .collect();
596
597    let mut errs = TinyVec::new();
598    let stream = StreamExt::filter_map(futs, |res| async move {
599      if let Err((m, e)) = res {
600        Some((m, e))
601      } else {
602        None
603      }
604    });
605    futures::pin_mut!(stream);
606
607    while let Some(err) = stream.next().await {
608      errs.push(err);
609    }
610
611    Ok(())
612  }
613}