1use byteorder::{ByteOrder, NetworkEndian};
2use indexmap::IndexMap;
3use memberlist_types::{SecretKey, SecretKeyTransformError, SecretKeys, SecretKeysTransformError};
4use smol_str::SmolStr;
5use transformable::{StringTransformError, Transformable};
6
7#[viewit::viewit(setters(prefix = "with"))]
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11#[repr(transparent)]
12#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
13#[cfg_attr(feature = "serde", serde(transparent))]
14pub struct KeyRequestMessage {
15 #[viewit(
17 getter(const, attrs(doc = "Returns the secret key")),
18 setter(const, attrs(doc = "Sets the secret key (Builder pattern)"))
19 )]
20 key: Option<SecretKey>,
21}
22
23#[derive(Debug, thiserror::Error)]
25pub enum OptionSecretKeyTransformError {
26 #[error("not enough bytes to decode")]
28 NotEnoughBytes,
29 #[error("encode buffer too small")]
31 BufferTooSmall,
32 #[error(transparent)]
34 SecretKey(#[from] SecretKeyTransformError),
35}
36
37impl Transformable for KeyRequestMessage {
38 type Error = OptionSecretKeyTransformError;
39
40 fn encode(&self, dst: &mut [u8]) -> Result<usize, Self::Error> {
41 let encoded_len = self.encoded_len();
42 if dst.len() < encoded_len {
43 return Err(Self::Error::BufferTooSmall);
44 }
45
46 match &self.key {
47 None => {
48 dst[0] = 0;
49 Ok(1)
50 }
51 Some(key) => key.encode(dst).map_err(Self::Error::SecretKey),
52 }
53 }
54
55 fn encoded_len(&self) -> usize {
56 match &self.key {
57 Some(key) => key.encoded_len(),
58 None => 1,
59 }
60 }
61
62 fn encode_to_writer<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<usize> {
63 match &self.key {
64 None => {
65 writer.write_all(&[0])?;
66 Ok(1)
67 }
68 Some(key) => key.encode_to_writer(writer),
69 }
70 }
71
72 async fn encode_to_async_writer<W: futures::AsyncWrite + Send + Unpin>(
73 &self,
74 writer: &mut W,
75 ) -> std::io::Result<usize> {
76 use futures::AsyncWriteExt;
77
78 match &self.key {
79 None => {
80 writer.write_all(&[0]).await?;
81 Ok(1)
82 }
83 Some(key) => key.encode_to_async_writer(writer).await,
84 }
85 }
86
87 fn decode(src: &[u8]) -> Result<(usize, Self), Self::Error>
88 where
89 Self: Sized,
90 {
91 if src.is_empty() {
92 return Err(Self::Error::NotEnoughBytes);
93 }
94
95 match src[0] {
96 0 => Ok((1, Self { key: None })),
97 _ => {
98 let (n, key) = SecretKey::decode(src).map_err(Self::Error::SecretKey)?;
99 Ok((n, Self { key: Some(key) }))
100 }
101 }
102 }
103
104 fn decode_from_reader<R: std::io::Read>(reader: &mut R) -> std::io::Result<(usize, Self)>
105 where
106 Self: Sized,
107 {
108 let mut buf = [0u8; 1];
109 reader.read_exact(&mut buf)?;
110
111 match buf[0] {
112 0 => Ok((1, Self { key: None })),
113 16 => {
114 let mut buf = [0u8; 16];
115 reader.read_exact(&mut buf)?;
116 Ok((
117 17,
118 Self {
119 key: Some(SecretKey::from(buf)),
120 },
121 ))
122 }
123 24 => {
124 let mut buf = [0u8; 24];
125 reader.read_exact(&mut buf)?;
126 Ok((
127 25,
128 Self {
129 key: Some(SecretKey::from(buf)),
130 },
131 ))
132 }
133 32 => {
134 let mut buf = [0u8; 32];
135 reader.read_exact(&mut buf)?;
136 Ok((
137 33,
138 Self {
139 key: Some(SecretKey::from(buf)),
140 },
141 ))
142 }
143 _ => Err(std::io::Error::new(
144 std::io::ErrorKind::InvalidData,
145 "unknown secret key kind",
146 )),
147 }
148 }
149
150 async fn decode_from_async_reader<R: futures::AsyncRead + Send + Unpin>(
151 reader: &mut R,
152 ) -> std::io::Result<(usize, Self)>
153 where
154 Self: Sized,
155 {
156 use futures::AsyncReadExt;
157
158 let mut buf = [0u8; 1];
159 reader.read_exact(&mut buf).await?;
160
161 match buf[0] {
162 0 => Ok((1, Self { key: None })),
163 16 => {
164 let mut buf = [0u8; 16];
165 reader.read_exact(&mut buf).await?;
166 Ok((
167 17,
168 Self {
169 key: Some(SecretKey::from(buf)),
170 },
171 ))
172 }
173 24 => {
174 let mut buf = [0u8; 24];
175 reader.read_exact(&mut buf).await?;
176 Ok((
177 25,
178 Self {
179 key: Some(SecretKey::from(buf)),
180 },
181 ))
182 }
183 32 => {
184 let mut buf = [0u8; 32];
185 reader.read_exact(&mut buf).await?;
186 Ok((
187 33,
188 Self {
189 key: Some(SecretKey::from(buf)),
190 },
191 ))
192 }
193 _ => Err(std::io::Error::new(
194 std::io::ErrorKind::InvalidData,
195 "unknown secret key kind",
196 )),
197 }
198 }
199}
200
201#[viewit::viewit(setters(prefix = "with"))]
203#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
204#[cfg(feature = "encryption")]
205#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
206pub struct KeyResponseMessage {
207 #[viewit(
209 getter(const, attrs(doc = "Returns true/false if there were errors or not")),
210 setter(
211 const,
212 attrs(doc = "Sets true/false if there were errors or not (Builder pattern)")
213 )
214 )]
215 result: bool,
216 #[viewit(
218 getter(
219 const,
220 style = "ref",
221 attrs(doc = "Returns the error messages or other information")
222 ),
223 setter(attrs(doc = "Sets the error messages or other information (Builder pattern)"))
224 )]
225 message: SmolStr,
226 #[viewit(
228 getter(const, style = "ref", attrs(doc = "Returns a list of installed keys")),
229 setter(attrs(doc = "Sets the the list of installed keys (Builder pattern)"))
230 )]
231 keys: SecretKeys,
232 #[viewit(
234 getter(const, attrs(doc = "Returns the primary key")),
235 setter(attrs(doc = "Sets the primary key (Builder pattern)"))
236 )]
237 primary_key: Option<SecretKey>,
238}
239
240impl KeyResponseMessage {
241 #[inline]
243 pub fn add_key(&mut self, key: SecretKey) -> &mut Self {
244 self.keys.push(key);
245 self
246 }
247}
248
249#[derive(Debug, thiserror::Error)]
251pub enum KeyResponseMessageTransformError {
252 #[error("not enough bytes to decode `KeyResponseMessage`")]
254 NotEnoughBytes,
255 #[error("encode buffer too small")]
257 BufferTooSmall,
258 #[error(transparent)]
260 Message(#[from] StringTransformError),
261 #[error(transparent)]
263 PrimaryKey(#[from] OptionSecretKeyTransformError),
264 #[error(transparent)]
266 Keys(#[from] SecretKeysTransformError),
267}
268
269impl Transformable for KeyResponseMessage {
270 type Error = KeyResponseMessageTransformError;
271
272 fn encode(&self, dst: &mut [u8]) -> Result<usize, Self::Error> {
273 let encoded_len = self.encoded_len();
274 if dst.len() < encoded_len {
275 return Err(Self::Error::BufferTooSmall);
276 }
277
278 let mut offset = 0;
279 NetworkEndian::write_u32(&mut dst[offset..offset + 4], encoded_len as u32);
280 offset += 4;
281 dst[offset] = self.result as u8;
282 offset += 1;
283 offset += self.message.encode(&mut dst[offset..])?;
284 offset += self.keys.encode(&mut dst[offset..])?;
285 offset += KeyRequestMessage {
286 key: self.primary_key,
287 }
288 .encode(&mut dst[offset..])?;
289
290 debug_assert_eq!(
291 offset, encoded_len,
292 "expect write {} bytes, but actual write {} bytes",
293 encoded_len, offset
294 );
295
296 Ok(offset)
297 }
298
299 fn encoded_len(&self) -> usize {
300 4 + 1
301 + self.message.encoded_len()
302 + self.keys.encoded_len()
303 + KeyRequestMessage {
304 key: self.primary_key,
305 }
306 .encoded_len()
307 }
308
309 fn decode(src: &[u8]) -> Result<(usize, Self), Self::Error>
310 where
311 Self: Sized,
312 {
313 let src_len = src.len();
314 if src_len < 5 {
315 return Err(Self::Error::NotEnoughBytes);
316 }
317
318 let mut offset = 0;
319 let encoded_len = NetworkEndian::read_u32(&src[offset..offset + 4]) as usize;
320 if src_len < encoded_len {
321 return Err(Self::Error::NotEnoughBytes);
322 }
323 offset += 4;
324
325 let result = src[offset] != 0;
326 offset += 1;
327 let (n, message) = SmolStr::decode(&src[offset..])?;
328 offset += n;
329 let (n, keys) = SecretKeys::decode(&src[offset..])?;
330 offset += n;
331 let (n, primary_key) = KeyRequestMessage::decode(&src[offset..])?;
332 offset += n;
333
334 debug_assert_eq!(
335 offset, encoded_len,
336 "expect read {} bytes, but actual read {} bytes",
337 encoded_len, offset
338 );
339
340 Ok((
341 offset,
342 Self {
343 result,
344 message,
345 keys,
346 primary_key: primary_key.key,
347 },
348 ))
349 }
350}
351
352#[viewit::viewit(setters(prefix = "with"))]
354#[derive(Default)]
355pub struct KeyResponse<I> {
356 #[viewit(
358 getter(
359 const,
360 style = "ref",
361 attrs(doc = "Returns the map of node id to response message")
362 ),
363 setter(attrs(doc = "Sets the map of node id to response message (Builder pattern)"))
364 )]
365 messages: IndexMap<I, SmolStr>,
366 #[viewit(
368 getter(const, attrs(doc = "Returns the total nodes memberlist knows of")),
369 setter(
370 const,
371 attrs(doc = "Sets total nodes memberlist knows of (Builder pattern)")
372 )
373 )]
374 num_nodes: usize,
375 #[viewit(
377 getter(const, attrs(doc = "Returns the total responses received")),
378 setter(
379 const,
380 attrs(doc = "Sets the total responses received (Builder pattern)")
381 )
382 )]
383 num_resp: usize,
384 #[viewit(
386 getter(const, attrs(doc = "Returns the total errors from request")),
387 setter(
388 const,
389 attrs(doc = "Sets the total errors from request (Builder pattern)")
390 )
391 )]
392 num_err: usize,
393
394 #[viewit(
397 getter(
398 const,
399 style = "ref",
400 attrs(
401 doc = "Returns a mapping of the value of the key bytes to the number of nodes that have the key installed."
402 )
403 ),
404 setter(attrs(
405 doc = "Sets a mapping of the value of the key bytes to the number of nodes that have the key installed (Builder pattern)"
406 ))
407 )]
408 keys: IndexMap<SecretKey, usize>,
409
410 #[viewit(
413 getter(
414 const,
415 style = "ref",
416 attrs(
417 doc = "Returns a mapping of the value of the primary key bytes to the number of nodes that have the key installed."
418 )
419 ),
420 setter(attrs(
421 doc = "Sets a mapping of the value of the primary key bytes to the number of nodes that have the key installed. (Builder pattern)"
422 ))
423 )]
424 primary_keys: IndexMap<SecretKey, usize>,
425}
426
427#[derive(Debug, Clone, PartialEq, Eq, Hash)]
429pub struct KeyRequestOptions {
430 pub relay_factor: u8,
433}
434
435#[cfg(test)]
436mod tests {
437 use rand::{distributions::Alphanumeric, thread_rng, Rng};
438
439 use super::*;
440
441 impl KeyRequestMessage {
442 pub(crate) fn random(kind: u8) -> Self {
443 let key = if rand::random() {
444 match kind {
445 16 => {
446 let mut buf = [0u8; 16];
447 rand::thread_rng().fill(&mut buf);
448 Some(SecretKey::from(buf))
449 }
450 24 => {
451 let mut buf = [0u8; 24];
452 rand::thread_rng().fill(&mut buf);
453 Some(SecretKey::from(buf))
454 }
455 32 => {
456 let mut buf = [0u8; 32];
457 rand::thread_rng().fill(&mut buf);
458 Some(SecretKey::from(buf))
459 }
460 _ => None,
461 }
462 } else {
463 None
464 };
465
466 Self { key }
467 }
468 }
469
470 impl KeyResponseMessage {
471 pub(crate) fn random(num_keys: usize, size: usize) -> Self {
472 let mut keys = SecretKeys::new();
473 for i in 0..num_keys {
474 let kind = match i % 3 {
475 0 => 16,
476 1 => 24,
477 2 => 32,
478 _ => unreachable!(),
479 };
480 let key = match kind {
481 16 => {
482 let mut buf = [0u8; 16];
483 rand::thread_rng().fill(&mut buf);
484 SecretKey::from(buf)
485 }
486 24 => {
487 let mut buf = [0u8; 24];
488 rand::thread_rng().fill(&mut buf);
489 SecretKey::from(buf)
490 }
491 32 => {
492 let mut buf = [0u8; 32];
493 rand::thread_rng().fill(&mut buf);
494 SecretKey::from(buf)
495 }
496 _ => unreachable!(),
497 };
498 keys.push(key);
499 }
500
501 let primary_key = if rand::random() {
502 let mut buf = [0u8; 32];
503 rand::thread_rng().fill(&mut buf);
504 Some(SecretKey::from(buf))
505 } else {
506 None
507 };
508
509 let message = thread_rng()
510 .sample_iter(Alphanumeric)
511 .take(size)
512 .collect::<Vec<u8>>();
513 let message = String::from_utf8(message).unwrap().into();
514
515 Self {
516 result: rand::random(),
517 message,
518 keys,
519 primary_key,
520 }
521 }
522 }
523
524 #[test]
525 fn test_key_request_message_transform() {
526 futures::executor::block_on(async {
527 for i in 0..100 {
528 let kind = match i % 4 {
529 0 => 0,
530 1 => 16,
531 2 => 24,
532 _ => 32,
533 };
534 let key = KeyRequestMessage::random(kind);
535 let mut buf = vec![0; key.encoded_len()];
536 let encoded_len = key.encode(&mut buf).unwrap();
537 assert_eq!(encoded_len, key.encoded_len());
538 let mut buf1 = vec![];
539 let encoded_len1 = key.encode_to_writer(&mut buf1).unwrap();
540 assert_eq!(encoded_len1, key.encoded_len());
541 let mut buf2 = vec![];
542 let encoded_len2 = key.encode_to_async_writer(&mut buf2).await.unwrap();
543 assert_eq!(encoded_len2, key.encoded_len());
544
545 let (decoded_len, decoded) = KeyRequestMessage::decode(&buf).unwrap();
546 assert_eq!(decoded_len, encoded_len);
547 assert_eq!(decoded, key);
548 let (decoded_len, decoded) = KeyRequestMessage::decode(&buf1).unwrap();
549 assert_eq!(decoded_len, encoded_len);
550 assert_eq!(decoded, key);
551 let (decoded_len, decoded) = KeyRequestMessage::decode(&buf2).unwrap();
552 assert_eq!(decoded_len, encoded_len);
553 assert_eq!(decoded, key);
554
555 let (decoded_len, decoded) =
556 KeyRequestMessage::decode_from_reader(&mut std::io::Cursor::new(&buf)).unwrap();
557 assert_eq!(decoded_len, encoded_len);
558 assert_eq!(decoded, key);
559 let (decoded_len, decoded) =
560 KeyRequestMessage::decode_from_reader(&mut std::io::Cursor::new(&buf1)).unwrap();
561 assert_eq!(decoded_len, encoded_len);
562 assert_eq!(decoded, key);
563 let (decoded_len, decoded) =
564 KeyRequestMessage::decode_from_reader(&mut std::io::Cursor::new(&buf2)).unwrap();
565 assert_eq!(decoded_len, encoded_len);
566 assert_eq!(decoded, key);
567
568 let (decoded_len, decoded) =
569 KeyRequestMessage::decode_from_async_reader(&mut futures::io::Cursor::new(&buf))
570 .await
571 .unwrap();
572 assert_eq!(decoded_len, encoded_len);
573 assert_eq!(decoded, key);
574 let (decoded_len, decoded) =
575 KeyRequestMessage::decode_from_async_reader(&mut futures::io::Cursor::new(&buf1))
576 .await
577 .unwrap();
578 assert_eq!(decoded_len, encoded_len);
579 assert_eq!(decoded, key);
580 let (decoded_len, decoded) =
581 KeyRequestMessage::decode_from_async_reader(&mut futures::io::Cursor::new(&buf2))
582 .await
583 .unwrap();
584 assert_eq!(decoded_len, encoded_len);
585 assert_eq!(decoded, key);
586 }
587 });
588 }
589
590 #[test]
591 fn test_key_response_message_transform() {
592 futures::executor::block_on(async {
593 for i in 0..100 {
594 let message = KeyResponseMessage::random(i % 10, i);
595 let mut buf = vec![0; message.encoded_len()];
596 let encoded_len = message.encode(&mut buf).unwrap();
597 assert_eq!(encoded_len, message.encoded_len());
598
599 let (decoded_len, decoded) = KeyResponseMessage::decode(&buf).unwrap();
600 assert_eq!(decoded_len, encoded_len);
601 assert_eq!(decoded, message);
602
603 let (decoded_len, decoded) =
604 KeyResponseMessage::decode_from_reader(&mut std::io::Cursor::new(&buf)).unwrap();
605 assert_eq!(decoded_len, encoded_len);
606 assert_eq!(decoded, message);
607
608 let (decoded_len, decoded) =
609 KeyResponseMessage::decode_from_async_reader(&mut futures::io::Cursor::new(&buf))
610 .await
611 .unwrap();
612 assert_eq!(decoded_len, encoded_len);
613 assert_eq!(decoded, message);
614 }
615 });
616 }
617}