1use std::time::Duration;
2
3use async_channel::Receiver;
4use memberlist_core::{
5 agnostic_lite::RuntimeLite,
6 bytes::Bytes,
7 delegate::NodeDelegate,
8 transport::MaybeResolvedAddress,
9 types::{OneOrMore, TinyVec},
10};
11use serf_types::{
12 MessageType, Node, PushPullMessage, QueryFlag, QueryMessage, SerfMessage, UserEvent,
13 UserEventMessage,
14};
15use smol_str::SmolStr;
16
17use crate::{
18 delegate::TransformDelegate,
19 event::{CrateEvent, CrateEventType, MemberEvent, MemberEventType},
20 types::Epoch,
21};
22
23use super::*;
24
25pub(crate) mod serf;
26
27fn test_config() -> Options {
28 let mut opts = Options::new();
29 opts.memberlist_options = opts
30 .memberlist_options
31 .with_gossip_interval(Duration::from_millis(5))
32 .with_probe_interval(Duration::from_millis(50))
33 .with_probe_timeout(Duration::from_millis(25))
34 .with_timeout(Duration::from_millis(100))
35 .with_suspicion_mult(1);
36 opts
37 .with_reap_interval(Duration::from_secs(1))
38 .with_reconnect_interval(Duration::from_millis(100))
39 .with_reconnect_timeout(Duration::from_micros(1))
40 .with_tombstone_timeout(Duration::from_micros(1))
41}
42
43async fn wait_until_num_nodes<T, D>(desired_nodes: usize, serfs: &[Serf<T, D>])
44where
45 D: Delegate<Id = T::Id, Address = <T::Resolver as AddressResolver>::ResolvedAddress>,
46 T: Transport,
47{
48 let start = Epoch::now();
49 loop {
50 <T::Runtime as RuntimeLite>::sleep(Duration::from_millis(25)).await;
51 let mut conds = Vec::with_capacity(serfs.len());
52 for (idx, s) in serfs.iter().enumerate() {
53 let n = s.num_members().await;
54 if n == desired_nodes {
55 conds.push(true);
56 continue;
57 }
58
59 if start.elapsed() > Duration::from_secs(7) {
60 panic!("s{} got {} expected {}", idx + 1, n, desired_nodes);
61 }
62 }
63 if conds.len() == serfs.len() {
64 break;
65 }
66 }
67}
68
69async fn wait_until_intent_queue_len<T, D>(desired_len: usize, serfs: &[Serf<T, D>])
70where
71 D: Delegate<Id = T::Id, Address = <T::Resolver as AddressResolver>::ResolvedAddress>,
72 T: Transport,
73{
74 let start = Epoch::now();
75 loop {
76 <T::Runtime as RuntimeLite>::sleep(Duration::from_millis(25)).await;
77 let mut conds = Vec::with_capacity(serfs.len());
78 for (idx, s) in serfs.iter().enumerate() {
79 let stats = s.stats().await;
80 if stats.get_intent_queue() == desired_len {
81 conds.push(true);
82 continue;
83 }
84
85 if start.elapsed() > Duration::from_secs(7) {
86 panic!(
87 "s{} got {} expected {}",
88 idx + 1,
89 stats.get_intent_queue(),
90 desired_len
91 );
92 }
93 }
94 if conds.len() == serfs.len() {
95 break;
96 }
97 }
98}
99
100async fn test_events<T, D>(
103 rx: Receiver<CrateEvent<T, D>>,
104 node: T::Id,
105 expected: Vec<CrateEventType>,
106) where
107 D: Delegate<Id = T::Id, Address = <T::Resolver as AddressResolver>::ResolvedAddress>,
108 T: Transport,
109{
110 let mut actual = Vec::with_capacity(expected.len());
111
112 loop {
113 futures::select! {
114 event = rx.recv().fuse() => {
115 let event = event.unwrap();
116 match event {
117 CrateEvent::Member(MemberEvent { ty, members }) => {
118 let mut found = false;
119
120 for m in members.iter() {
121 if node.eq(m.node.id()) {
122 found = true;
123 break;
124 }
125 }
126
127 if found {
128 actual.push(CrateEventType::Member(ty));
129 }
130 }
131 _ => continue,
132 }
133 }
134 _ = <T::Runtime as RuntimeLite>::sleep(Duration::from_millis(10)).fuse() => {
135 break;
136 }
137 }
138 }
139
140 assert_eq!(actual, expected, "bad events for node {:?}", node);
141}
142
143async fn test_user_events<T, D>(
146 rx: Receiver<CrateEvent<T, D>>,
147 expected_name: Vec<SmolStr>,
148 expected_payload: Vec<Bytes>,
149) where
150 D: Delegate<Id = T::Id, Address = <T::Resolver as AddressResolver>::ResolvedAddress>,
151 T: Transport,
152{
153 let mut actual_name = Vec::with_capacity(expected_name.len());
154 let mut actual_payload = Vec::with_capacity(expected_payload.len());
155
156 loop {
157 futures::select! {
158 event = rx.recv().fuse() => {
159 let Ok(event) = event else { break };
160 match event {
161 CrateEvent::User(e) => {
162 actual_name.push(e.name.clone());
163 actual_payload.push(e.payload.clone());
164 }
165 _ => continue,
166 }
167 }
168 _ = <T::Runtime as RuntimeLite>::sleep(Duration::from_millis(10)).fuse() => {
169 break;
170 }
171 }
172 }
173
174 assert_eq!(actual_name, expected_name);
175 assert_eq!(actual_payload, expected_payload);
176}
177
178async fn test_query_events<T, D>(
181 rx: Receiver<CrateEvent<T, D>>,
182 expected_name: Vec<SmolStr>,
183 expected_payload: Vec<Bytes>,
184) where
185 D: Delegate<Id = T::Id, Address = <T::Resolver as AddressResolver>::ResolvedAddress>,
186 T: Transport,
187{
188 let mut actual_name = Vec::with_capacity(expected_name.len());
189 let mut actual_payload = Vec::with_capacity(expected_payload.len());
190
191 loop {
192 futures::select! {
193 event = rx.recv().fuse() => {
194 let Ok(event) = event else { break };
195 match event {
196 CrateEvent::Query(e) => {
197 actual_name.push(e.name.clone());
198 actual_payload.push(e.payload.clone());
199 }
200 CrateEvent::InternalQuery { query, .. } => {
201 actual_name.push(query.name.clone());
202 actual_payload.push(query.payload.clone());
203 }
204 _ => continue,
205 }
206 }
207 _ = <T::Runtime as RuntimeLite>::sleep(Duration::from_millis(10)).fuse() => {
208 break;
209 }
210 }
211 }
212
213 assert_eq!(actual_name, expected_name);
214 assert_eq!(actual_payload, expected_payload);
215}
216
217pub async fn queries_pass_through<T>(s: Serf<T>)
219where
220 T: Transport,
221{
222 let (tx, rx) = async_channel::bounded(4);
223 let (_shutdown_tx, shutdown_rx) = async_channel::bounded(1);
224 let (event_tx, _handle) = SerfQueries::<T, DefaultDelegate<T>>::new(Some(tx), shutdown_rx);
225
226 let event = CrateEvent::from(
228 UserEventMessage::default()
229 .with_name("foo".into())
230 .with_ltime(42.into()),
231 );
232 event_tx.send(event.clone()).await.unwrap();
233
234 let query = s.query_event(QueryMessage {
236 ltime: 42.into(),
237 id: 1,
238 from: s.memberlist().advertise_node(),
239 filters: TinyVec::new(),
240 flags: QueryFlag::empty(),
241 relay_factor: 0,
242 timeout: Default::default(),
243 name: "foo".into(),
244 payload: Bytes::new(),
245 });
246 event_tx.send(CrateEvent::from(query)).await.unwrap();
247
248 let event = CrateEvent::from(MemberEvent {
250 ty: MemberEventType::Join,
251 members: TinyVec::new().into(),
252 });
253 event_tx.send(event).await.unwrap();
254
255 for _ in 0..3 {
257 let sleep = <T::Runtime as RuntimeLite>::sleep(Duration::from_millis(100));
258 futures::select! {
259 _ = rx.recv().fuse() => {},
260 _ = sleep.fuse() => panic!("timeout"),
261 }
262 }
263}
264
265pub async fn queries_ping<T>(s: Serf<T>)
267where
268 T: Transport,
269{
270 let (tx, rx) = async_channel::bounded(4);
271 let (_shutdown_tx, shutdown_rx) = async_channel::bounded(1);
272 let (event_tx, _handle) = SerfQueries::<T, DefaultDelegate<T>>::new(Some(tx), shutdown_rx);
273
274 let query = s.query_event(QueryMessage {
276 ltime: 42.into(),
277 id: 1,
278 from: s.memberlist().advertise_node(),
279 filters: TinyVec::new(),
280 flags: QueryFlag::empty(),
281 relay_factor: 0,
282 timeout: Default::default(),
283 name: "ping".into(),
284 payload: Bytes::new(),
285 });
286 event_tx
287 .send(CrateEvent::from((InternalQueryEvent::Ping, query)))
288 .await
289 .unwrap();
290
291 let sleep = <T::Runtime as RuntimeLite>::sleep(Duration::from_millis(50));
292 futures::select! {
293 _ = rx.recv().fuse() => panic!("should not passthrough query!"),
294 _ = sleep.fuse() => {},
295 }
296}
297
298pub async fn queries_conflict_same_name<T>(s: Serf<T>)
300where
301 T: Transport,
302{
303 let (tx, rx) = async_channel::bounded(4);
304 let (_shutdown_tx, shutdown_rx) = async_channel::bounded(1);
305 let (event_tx, _handle) = SerfQueries::<T, DefaultDelegate<T>>::new(Some(tx), shutdown_rx);
306
307 let query = s.query_event(QueryMessage {
309 ltime: 42.into(),
310 id: 1,
311 from: s.memberlist().advertise_node(),
312 filters: TinyVec::new(),
313 flags: QueryFlag::empty(),
314 relay_factor: 0,
315 timeout: Default::default(),
316 name: "conflict".into(),
317 payload: Bytes::new(),
318 });
319 let id = s.memberlist().local_id().clone();
320 event_tx
321 .send(CrateEvent::from((InternalQueryEvent::Conflict(id), query)))
322 .await
323 .unwrap();
324
325 let sleep = <T::Runtime as RuntimeLite>::sleep(Duration::from_millis(50));
326 futures::select! {
327 _ = rx.recv().fuse() => panic!("should not passthrough query!"),
328 _ = sleep.fuse() => {},
329 }
330}
331
332#[cfg(feature = "encryption")]
336pub async fn estimate_max_keys_in_list_key_response_factor<T>(
337 transport_opts: T::Options,
338 opts: Options,
339) where
340 T: Transport,
341{
342 use memberlist_core::types::SecretKey;
343 use serf_types::KeyResponseMessage;
344
345 let size_limit = opts.query_response_size_limit() * 10;
346 let opts = opts.with_query_response_size_limit(size_limit);
347 let s = Serf::<T>::new(transport_opts, opts).await.unwrap();
348 let query = s.query_event(QueryMessage {
349 ltime: 0.into(),
350 id: 0,
351 from: s.memberlist().advertise_node(),
352 filters: TinyVec::new(),
353 flags: QueryFlag::empty(),
354 relay_factor: 0,
355 timeout: Default::default(),
356 name: Default::default(),
357 payload: Default::default(),
358 });
359
360 let mut resp = KeyResponseMessage::default();
361 for _ in 0..=(size_limit / 25) {
362 resp.keys.push(SecretKey::from([1; 16]));
363 }
364
365 let mut found = 0;
366 for i in (0..=resp.keys.len()).rev() {
367 let encoded_len = <DefaultDelegate<T> as TransformDelegate>::message_encoded_len(&resp);
368 let mut dst = vec![0; encoded_len];
369 <DefaultDelegate<T> as TransformDelegate>::encode_message(&resp, &mut dst).unwrap();
370
371 let qresp = query.create_response(dst.into());
372 let encoded_len = <DefaultDelegate<T> as TransformDelegate>::message_encoded_len(&qresp);
373 let mut dst = vec![0; encoded_len];
374 <DefaultDelegate<T> as TransformDelegate>::encode_message(&qresp, &mut dst).unwrap();
375
376 if query.check_response_size(&dst).is_err() {
377 resp.keys.truncate(i);
378 continue;
379 }
380 found = i;
381 break;
382 }
383
384 assert_ne!(found, 0, "Do not find anything!");
385
386 println!(
387 "max keys in response with {} bytes: {}",
388 size_limit,
389 resp.keys.len()
390 );
391 println!("factor: {}", size_limit / resp.keys.len());
392}
393
394#[cfg(feature = "encryption")]
398pub async fn key_list_key_response_with_correct_size<T>(transport_opts: T::Options, opts: Options)
399where
400 T: Transport,
401{
402 use memberlist_core::types::SecretKey;
403 use serf_types::{Encodable, KeyResponseMessage};
404
405 let opts = opts.with_query_response_size_limit(1024);
406 let s = Serf::<T>::new(transport_opts, opts).await.unwrap();
407 let query = s.query_event(QueryMessage {
408 ltime: 0.into(),
409 id: 0,
410 from: s.memberlist().advertise_node(),
411 filters: TinyVec::new(),
412 flags: QueryFlag::empty(),
413 relay_factor: 0,
414 timeout: Default::default(),
415 name: Default::default(),
416 payload: Default::default(),
417 });
418
419 let k = [0; 16];
420 let encoded_len = SecretKey::from(k).encoded_len();
421 let cases = [
422 (0, false, KeyResponseMessage::default()),
423 (1, false, {
424 let mut msg = KeyResponseMessage::default();
425 msg.add_key(SecretKey::from(k));
426 msg
427 }),
428 (50, true, {
430 let mut msg = KeyResponseMessage::default();
431 for _ in 0..50 {
432 msg.add_key(SecretKey::from(k));
433 }
434 msg
435 }),
436 (encoded_len, true, {
438 let mut msg = KeyResponseMessage::default();
439 for _ in 0..encoded_len - 2 {
440 msg.add_key(SecretKey::from(k));
441 }
442 msg
443 }),
444 (encoded_len, true, {
446 let mut msg = KeyResponseMessage::default();
447 for _ in 0..encoded_len {
448 msg.add_key(SecretKey::from(k));
449 }
450 msg
451 }),
452 (18, true, {
454 let mut msg = KeyResponseMessage::default();
455 for _ in 0..18 {
456 msg.add_key(SecretKey::from(k));
457 }
458 msg
459 }),
460 ];
461
462 for (expected, has_msg, mut resp) in cases {
463 if let Err(e) = SerfQueries::key_list_response_with_correct_size(&query, &mut resp) {
464 println!("error: {:?}", e);
465 continue;
466 }
467
468 if resp.keys.len() != expected {
469 println!("expected: {}, got: {}", expected, resp.keys.len());
470 }
471
472 if has_msg && !resp.message.contains("truncated") {
473 println!("truncation message should be set");
474 }
475 }
476}