1use std::{
2 collections::HashSet,
3 sync::Arc,
4 time::{Duration, Instant},
5};
6
7use crate::types::FilterRef;
8use async_channel::{Receiver, Sender};
9use async_lock::RwLock;
10use either::Either;
11use futures::{FutureExt, StreamExt, stream::FuturesUnordered};
12use memberlist_core::{
13 CheapClone,
14 bytes::Bytes,
15 proto::{Data, RepeatedDecoder, SmallVec, TinyVec},
16 tracing,
17 transport::{Node, Transport},
18};
19
20use crate::{
21 delegate::Delegate,
22 error::Error,
23 types::{Filter, LamportTime, Member, MemberStatus, QueryMessage, QueryResponseMessage},
24};
25
26use super::Serf;
27
28#[viewit::viewit(
31 vis_all = "pub(crate)",
32 getters(vis_all = "pub", style = "ref"),
33 setters(vis_all = "pub", prefix = "with")
34)]
35#[derive(Debug, Clone)]
36#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
37pub struct QueryParam<I> {
38 #[viewit(
40 getter(const, attrs(doc = "Returns the filters of the query")),
41 setter(attrs(doc = "Sets the filters of the query"))
42 )]
43 filters: TinyVec<Filter<I>>,
44
45 #[viewit(
50 getter(
51 const,
52 style = "move",
53 attrs(
54 doc = "Returns if we are requesting an delivery acknowledgement from every node that meets the filter requirement. This means nodes the receive the message but do not pass the filters, will not send an ack."
55 )
56 ),
57 setter(attrs(
58 doc = "Sets if we are requesting an delivery acknowledgement from every node that meets the filter requirement. This means nodes the receive the message but do not pass the filters, will not send an ack."
59 ))
60 )]
61 request_ack: bool,
62
63 #[viewit(
66 getter(
67 const,
68 style = "move",
69 attrs(
70 doc = "Returns the number of duplicate responses to relay back to the sender through other nodes for redundancy."
71 )
72 ),
73 setter(attrs(
74 doc = "Sets the number of duplicate responses to relay back to the sender through other nodes for redundancy."
75 ))
76 )]
77 relay_factor: u8,
78
79 #[viewit(
82 getter(
83 const,
84 style = "move",
85 attrs(
86 doc = "Returns timeout limits how long the query is left open. If not provided, then a default timeout is used based on the configuration of [`Serf`]"
87 )
88 ),
89 setter(attrs(doc = "Sets timeout limits how long the query is left open."))
90 )]
91 #[cfg_attr(feature = "serde", serde(with = "humantime_serde"))]
92 timeout: Duration,
93}
94
95struct QueryResponseChannel<I, A> {
96 ack_ch: Option<(Sender<Node<I, A>>, Receiver<Node<I, A>>)>,
98 resp_ch: (Sender<NodeResponse<I, A>>, Receiver<NodeResponse<I, A>>),
100}
101
102pub(crate) struct QueryResponseCore<I, A> {
103 closed: bool,
104 acks: HashSet<Node<I, A>>,
105 responses: HashSet<Node<I, A>>,
106}
107
108pub(crate) struct QueryResponseInner<I, A> {
109 core: RwLock<QueryResponseCore<I, A>>,
110 channel: QueryResponseChannel<I, A>,
111}
112
113#[viewit::viewit(vis_all = "pub(crate)")]
116#[derive(Clone)]
117pub struct QueryResponse<I, A> {
118 #[viewit(
120 getter(
121 style = "move",
122 const,
123 attrs(doc = "Returns the ending deadline of the query")
124 ),
125 setter(skip)
126 )]
127 deadline: Instant,
128
129 #[viewit(
131 getter(style = "move", const, attrs(doc = "Returns the id of the query")),
132 setter(skip)
133 )]
134 id: u32,
135
136 #[viewit(
138 getter(
139 style = "move",
140 const,
141 attrs(doc = "Returns the Lamport Time of the query")
142 ),
143 setter(skip)
144 )]
145 ltime: LamportTime,
146
147 #[viewit(getter(vis = "pub(crate)", const, style = "ref"), setter(skip))]
148 inner: Arc<QueryResponseInner<I, A>>,
149}
150
151impl<I, A> QueryResponse<I, A> {
152 pub(crate) fn from_query(q: &QueryMessage<I, A>, num_nodes: usize) -> Self {
153 QueryResponse::new(
154 q.id(),
155 q.ltime(),
156 num_nodes,
157 Instant::now() + q.timeout(),
158 q.ack(),
159 )
160 }
161}
162
163impl<I, A> QueryResponse<I, A> {
164 #[inline]
165 pub(crate) fn new(
166 id: u32,
167 ltime: LamportTime,
168 num_nodes: usize,
169 deadline: Instant,
170 ack: bool,
171 ) -> Self {
172 let (ack_ch, acks) = if ack {
173 (
174 Some(async_channel::bounded(num_nodes)),
175 HashSet::with_capacity(num_nodes),
176 )
177 } else {
178 (None, HashSet::new())
179 };
180
181 Self {
182 deadline,
183 id,
184 ltime,
185 inner: Arc::new(QueryResponseInner {
186 core: RwLock::new(QueryResponseCore {
187 closed: false,
188 acks,
189 responses: HashSet::with_capacity(num_nodes),
190 }),
191 channel: QueryResponseChannel {
192 ack_ch,
193 resp_ch: async_channel::bounded(num_nodes),
194 },
195 }),
196 }
197 }
198
199 #[inline]
203 pub fn ack_rx(&self) -> Option<async_channel::Receiver<Node<I, A>>> {
204 self.inner.channel.ack_ch.as_ref().map(|(_, r)| r.clone())
205 }
206
207 #[inline]
210 pub fn response_rx(&self) -> async_channel::Receiver<NodeResponse<I, A>> {
211 self.inner.channel.resp_ch.1.clone()
212 }
213
214 #[inline]
216 pub async fn finished(&self) -> bool {
217 let c = self.inner.core.read().await;
218 c.closed || (Instant::now() > self.deadline)
219 }
220
221 #[inline]
224 pub async fn close(&self) {
225 let mut c = self.inner.core.write().await;
226 if c.closed {
227 return;
228 }
229
230 c.closed = true;
231
232 if let Some((tx, _)) = &self.inner.channel.ack_ch {
233 tx.close();
234 }
235
236 self.inner.channel.resp_ch.0.close();
237 }
238
239 #[inline]
240 pub(crate) async fn handle_query_response<T, D>(
241 &self,
242 resp: QueryResponseMessage<I, A>,
243 _local: &T::Id,
244 #[cfg(feature = "metrics")] metrics_labels: &memberlist_core::proto::MetricLabels,
245 ) where
246 I: Eq + std::hash::Hash + CheapClone + core::fmt::Debug,
247 A: Eq + std::hash::Hash + CheapClone + core::fmt::Debug,
248 D: Delegate<Id = T::Id, Address = T::ResolvedAddress>,
249 T: Transport,
250 {
251 let c = self.inner.core.read().await;
253 if c.closed || (Instant::now() > self.deadline) {
254 return;
255 }
256
257 if resp.ack() {
259 if c.acks.contains(&resp.from) {
261 #[cfg(feature = "metrics")]
262 {
263 metrics::counter!("serf.query.duplicate_acks", metrics_labels.iter()).increment(1);
264 }
265 return;
266 }
267
268 #[cfg(feature = "metrics")]
269 {
270 metrics::counter!("serf.query.acks", metrics_labels.iter()).increment(1);
271 }
272
273 drop(c);
274 if let Err(e) = self.send_ack::<T, D>(&resp).await {
275 tracing::warn!("serf: {}", e);
276 }
277 } else {
278 if c.responses.contains(&resp.from) {
280 #[cfg(feature = "metrics")]
281 {
282 metrics::counter!("serf.query.duplicate_responses", metrics_labels.iter()).increment(1);
283 }
284 return;
285 }
286
287 #[cfg(feature = "metrics")]
288 {
289 metrics::counter!("serf.query.responses", metrics_labels.iter()).increment(1);
290 }
291 drop(c);
292
293 if let Err(e) = self
294 .send_response::<T, D>(NodeResponse {
295 from: resp.from,
296 payload: resp.payload,
297 })
298 .await
299 {
300 tracing::warn!("serf: {}", e);
301 }
302 }
303 }
304
305 #[inline]
307 pub(crate) async fn send_response<T, D>(&self, nr: NodeResponse<I, A>) -> Result<(), Error<T, D>>
308 where
309 I: Eq + std::hash::Hash + CheapClone + core::fmt::Debug,
310 A: Eq + std::hash::Hash + CheapClone + core::fmt::Debug,
311 D: Delegate<Id = T::Id, Address = T::ResolvedAddress>,
312 T: Transport,
313 {
314 let mut c = self.inner.core.write().await;
315 if c.responses.contains(&nr.from) {
317 return Ok(());
318 }
319
320 if c.closed {
321 Ok(())
322 } else {
323 let id = nr.from.cheap_clone();
324 futures::select! {
325 _ = self.inner.channel.resp_ch.0.send(nr).fuse() => {
326 c.responses.insert(id);
327 Ok(())
328 },
329 default => {
330 Err(Error::query_response_delivery_failed())
331 }
332 }
333 }
334 }
335
336 #[inline]
338 pub(crate) async fn send_ack<T, D>(
339 &self,
340 nr: &QueryResponseMessage<I, A>,
341 ) -> Result<(), Error<T, D>>
342 where
343 I: Eq + std::hash::Hash + CheapClone,
344 A: Eq + std::hash::Hash + CheapClone,
345 D: Delegate<Id = T::Id, Address = T::ResolvedAddress>,
346 T: Transport,
347 {
348 let mut c = self.inner.core.write().await;
349 if c.acks.contains(&nr.from) {
351 return Ok(());
352 }
353
354 if c.closed {
355 Ok(())
356 } else if let Some((tx, _)) = &self.inner.channel.ack_ch {
357 futures::select! {
358 _ = tx.send(nr.from.cheap_clone()).fuse() => {
359 c.acks.insert(nr.from.clone());
360 Ok(())
361 },
362 default => {
363 Err(Error::query_response_delivery_failed())
364 }
365 }
366 } else {
367 Ok(())
368 }
369 }
370}
371
372#[viewit::viewit(
374 vis_all = "pub(crate)",
375 setters(skip),
376 getters(vis_all = "pub", style = "ref")
377)]
378#[derive(Debug, Clone, PartialEq, Eq, Hash)]
379#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
380pub struct NodeResponse<I, A> {
381 #[viewit(getter(attrs(doc = "Returns the node that sent the response")))]
382 from: Node<I, A>,
383 #[viewit(getter(attrs(doc = "Returns the payload of the response")))]
384 payload: Bytes,
385}
386
387#[inline]
388fn random_members<I, A>(k: usize, mut members: SmallVec<Member<I, A>>) -> SmallVec<Member<I, A>> {
389 let n = members.len();
390 if n == 0 {
391 return SmallVec::new();
392 }
393
394 let rounds = 3 * n;
396 let mut i = 0;
397
398 while i < rounds && i < n {
399 let j = (rand::random::<u32>() as usize) % (n - i) + i;
400 members.swap(i, j);
401 i += 1;
402 if i >= k && i >= rounds {
403 break;
404 }
405 }
406
407 members.truncate(k);
408 members
409}
410
411impl<T, D> Serf<T, D>
412where
413 D: Delegate<Id = T::Id, Address = T::ResolvedAddress>,
414 T: Transport,
415{
416 pub async fn default_query_timeout(&self) -> Duration {
422 let n = self.inner.memberlist.num_online_members().await;
423 let mut timeout = self.inner.opts.memberlist_options.gossip_interval();
424 timeout *= self.inner.opts.query_timeout_mult as u32;
425 timeout *= ((n + 1) as f64).log10().ceil() as u32; timeout
427 }
428
429 pub async fn default_query_param(&self) -> QueryParam<T::Id> {
431 QueryParam {
432 filters: TinyVec::new(),
433 request_ack: false,
434 relay_factor: 0,
435 timeout: self.default_query_timeout().await,
436 }
437 }
438
439 pub(crate) fn should_process_query(
440 &self,
441 filters: Either<RepeatedDecoder<'_>, &[Filter<T::Id>]>,
442 ) -> Result<bool, memberlist_core::proto::DecodeError> {
443 match filters {
444 Either::Left(filters) => {
445 for filter in filters.iter::<Filter<T::Id>>() {
446 let filter = filter?;
447 match filter {
448 FilterRef::Id(ids) => {
449 let mut found = false;
451 for id in ids.iter::<T::Id>() {
452 let id = id?;
453 if <T::Id as Data>::from_ref(id)?.eq(self.inner.memberlist.local_id()) {
454 found = true;
455 break;
456 }
457 }
458 if !found {
459 return Ok(false);
460 }
461 }
462 FilterRef::Tag(tag) => {
463 let tags = self.inner.opts.tags.load();
465 if !tags.is_empty() {
466 if let Some(expr) = tags.get(tag.tag()) {
467 if let Some(re) = tag.expr() {
468 if !regex::Regex::new(re)
469 .map_err(|_| memberlist_core::proto::DecodeError::custom("invalid regex"))?
470 .is_match(expr)
471 {
472 return Ok(false);
473 }
474 }
475 } else {
476 return Ok(false);
477 }
478 } else {
479 return Ok(false);
480 }
481 }
482 }
483 }
484
485 Ok(true)
486 }
487 Either::Right(filters) => {
488 for filter in filters.iter() {
489 match &filter {
490 Filter::Id(nodes) => {
491 let found = nodes
493 .iter()
494 .any(|n: &T::Id| n.eq(self.inner.memberlist.local_id()));
495 if !found {
496 return Ok(false);
497 }
498 }
499 Filter::Tag(tag) => {
500 let tags = self.inner.opts.tags.load();
502 if !tags.is_empty() {
503 if let Some(expr) = tags.get(tag.tag()) {
504 if let Some(re) = tag.expr() {
505 if !re.is_match(expr) {
506 return Ok(false);
507 }
508 }
509 } else {
510 return Ok(false);
511 }
512 } else {
513 return Ok(false);
514 }
515 }
516 }
517 }
518 Ok(true)
519 }
520 }
521 }
522
523 pub(crate) async fn relay_response(
524 &self,
525 relay_factor: u8,
526 node: Node<T::Id, T::ResolvedAddress>,
527 resp: QueryResponseMessage<T::Id, T::ResolvedAddress>,
528 ) -> Result<(), Error<T, D>> {
529 if relay_factor == 0 {
530 return Ok(());
531 }
532
533 let members = {
537 let members = self.inner.members.read().await;
538 if members.states.len() < relay_factor as usize + 1 {
539 return Ok(());
540 }
541 members
542 .states
543 .iter()
544 .filter_map(|(id, m)| {
545 if m.member.status == MemberStatus::Alive && id != self.inner.memberlist.local_id() {
546 Some(m.member.clone())
547 } else {
548 None
549 }
550 })
551 .collect::<SmallVec<_>>()
552 };
553
554 if members.is_empty() {
555 return Ok(());
556 }
557
558 let encoded_len = crate::types::encoded_relay_message_len(&resp, &node);
560 if encoded_len > self.inner.opts.query_response_size_limit {
561 return Err(Error::relayed_response_too_large(
562 self.inner.opts.query_response_size_limit,
563 ));
564 }
565
566 let raw = crate::types::encode_relay_message_to_bytes(&resp, &node)?;
567
568 let relay_members = random_members(relay_factor as usize, members);
570
571 let futs: FuturesUnordered<_> = relay_members
572 .into_iter()
573 .map(|m| {
574 let raw = raw.clone();
575 async move {
576 self
577 .inner
578 .memberlist
579 .send(m.node.address(), raw)
580 .await
581 .map_err(|e| (m, e))
582 }
583 })
584 .collect();
585
586 let mut errs = TinyVec::new();
587 let stream = StreamExt::filter_map(futs, |res| async move {
588 if let Err((m, e)) = res {
589 Some((m, e))
590 } else {
591 None
592 }
593 });
594 futures::pin_mut!(stream);
595
596 while let Some(err) = stream.next().await {
597 errs.push(err);
598 }
599
600 Ok(())
601 }
602}