serf_types/
key.rs

1use byteorder::{ByteOrder, NetworkEndian};
2use indexmap::IndexMap;
3use memberlist_types::{SecretKey, SecretKeyTransformError, SecretKeys, SecretKeysTransformError};
4use smol_str::SmolStr;
5use transformable::{StringTransformError, Transformable};
6
7/// KeyRequest is used to contain input parameters which get broadcasted to all
8/// nodes as part of a key query operation.
9#[viewit::viewit(setters(prefix = "with"))]
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11#[repr(transparent)]
12#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
13#[cfg_attr(feature = "serde", serde(transparent))]
14pub struct KeyRequestMessage {
15  /// The secret key
16  #[viewit(
17    getter(const, attrs(doc = "Returns the secret key")),
18    setter(const, attrs(doc = "Sets the secret key (Builder pattern)"))
19  )]
20  key: Option<SecretKey>,
21}
22
23/// The error that can occur when transforming a [`KeyRequestMessage`]
24#[derive(Debug, thiserror::Error)]
25pub enum OptionSecretKeyTransformError {
26  /// Not enough bytes to decode [`Option<SecretKey>`]
27  #[error("not enough bytes to decode")]
28  NotEnoughBytes,
29  /// Encode buffer too small
30  #[error("encode buffer too small")]
31  BufferTooSmall,
32  /// Error transforming a secret key
33  #[error(transparent)]
34  SecretKey(#[from] SecretKeyTransformError),
35}
36
37impl Transformable for KeyRequestMessage {
38  type Error = OptionSecretKeyTransformError;
39
40  fn encode(&self, dst: &mut [u8]) -> Result<usize, Self::Error> {
41    let encoded_len = self.encoded_len();
42    if dst.len() < encoded_len {
43      return Err(Self::Error::BufferTooSmall);
44    }
45
46    match &self.key {
47      None => {
48        dst[0] = 0;
49        Ok(1)
50      }
51      Some(key) => key.encode(dst).map_err(Self::Error::SecretKey),
52    }
53  }
54
55  fn encoded_len(&self) -> usize {
56    match &self.key {
57      Some(key) => key.encoded_len(),
58      None => 1,
59    }
60  }
61
62  fn encode_to_writer<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<usize> {
63    match &self.key {
64      None => {
65        writer.write_all(&[0])?;
66        Ok(1)
67      }
68      Some(key) => key.encode_to_writer(writer),
69    }
70  }
71
72  async fn encode_to_async_writer<W: futures::AsyncWrite + Send + Unpin>(
73    &self,
74    writer: &mut W,
75  ) -> std::io::Result<usize> {
76    use futures::AsyncWriteExt;
77
78    match &self.key {
79      None => {
80        writer.write_all(&[0]).await?;
81        Ok(1)
82      }
83      Some(key) => key.encode_to_async_writer(writer).await,
84    }
85  }
86
87  fn decode(src: &[u8]) -> Result<(usize, Self), Self::Error>
88  where
89    Self: Sized,
90  {
91    if src.is_empty() {
92      return Err(Self::Error::NotEnoughBytes);
93    }
94
95    match src[0] {
96      0 => Ok((1, Self { key: None })),
97      _ => {
98        let (n, key) = SecretKey::decode(src).map_err(Self::Error::SecretKey)?;
99        Ok((n, Self { key: Some(key) }))
100      }
101    }
102  }
103
104  fn decode_from_reader<R: std::io::Read>(reader: &mut R) -> std::io::Result<(usize, Self)>
105  where
106    Self: Sized,
107  {
108    let mut buf = [0u8; 1];
109    reader.read_exact(&mut buf)?;
110
111    match buf[0] {
112      0 => Ok((1, Self { key: None })),
113      16 => {
114        let mut buf = [0u8; 16];
115        reader.read_exact(&mut buf)?;
116        Ok((
117          17,
118          Self {
119            key: Some(SecretKey::from(buf)),
120          },
121        ))
122      }
123      24 => {
124        let mut buf = [0u8; 24];
125        reader.read_exact(&mut buf)?;
126        Ok((
127          25,
128          Self {
129            key: Some(SecretKey::from(buf)),
130          },
131        ))
132      }
133      32 => {
134        let mut buf = [0u8; 32];
135        reader.read_exact(&mut buf)?;
136        Ok((
137          33,
138          Self {
139            key: Some(SecretKey::from(buf)),
140          },
141        ))
142      }
143      _ => Err(std::io::Error::new(
144        std::io::ErrorKind::InvalidData,
145        "unknown secret key kind",
146      )),
147    }
148  }
149
150  async fn decode_from_async_reader<R: futures::AsyncRead + Send + Unpin>(
151    reader: &mut R,
152  ) -> std::io::Result<(usize, Self)>
153  where
154    Self: Sized,
155  {
156    use futures::AsyncReadExt;
157
158    let mut buf = [0u8; 1];
159    reader.read_exact(&mut buf).await?;
160
161    match buf[0] {
162      0 => Ok((1, Self { key: None })),
163      16 => {
164        let mut buf = [0u8; 16];
165        reader.read_exact(&mut buf).await?;
166        Ok((
167          17,
168          Self {
169            key: Some(SecretKey::from(buf)),
170          },
171        ))
172      }
173      24 => {
174        let mut buf = [0u8; 24];
175        reader.read_exact(&mut buf).await?;
176        Ok((
177          25,
178          Self {
179            key: Some(SecretKey::from(buf)),
180          },
181        ))
182      }
183      32 => {
184        let mut buf = [0u8; 32];
185        reader.read_exact(&mut buf).await?;
186        Ok((
187          33,
188          Self {
189            key: Some(SecretKey::from(buf)),
190          },
191        ))
192      }
193      _ => Err(std::io::Error::new(
194        std::io::ErrorKind::InvalidData,
195        "unknown secret key kind",
196      )),
197    }
198  }
199}
200
201/// Key response message
202#[viewit::viewit(setters(prefix = "with"))]
203#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
204#[cfg(feature = "encryption")]
205#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
206pub struct KeyResponseMessage {
207  /// Indicates true/false if there were errors or not
208  #[viewit(
209    getter(const, attrs(doc = "Returns true/false if there were errors or not")),
210    setter(
211      const,
212      attrs(doc = "Sets true/false if there were errors or not (Builder pattern)")
213    )
214  )]
215  result: bool,
216  /// Contains error messages or other information
217  #[viewit(
218    getter(
219      const,
220      style = "ref",
221      attrs(doc = "Returns the error messages or other information")
222    ),
223    setter(attrs(doc = "Sets the error messages or other information (Builder pattern)"))
224  )]
225  message: SmolStr,
226  /// Used in listing queries to relay a list of installed keys
227  #[viewit(
228    getter(const, style = "ref", attrs(doc = "Returns a list of installed keys")),
229    setter(attrs(doc = "Sets the the list of installed keys (Builder pattern)"))
230  )]
231  keys: SecretKeys,
232  /// Used in listing queries to relay the primary key
233  #[viewit(
234    getter(const, attrs(doc = "Returns the primary key")),
235    setter(attrs(doc = "Sets the primary key (Builder pattern)"))
236  )]
237  primary_key: Option<SecretKey>,
238}
239
240impl KeyResponseMessage {
241  /// Adds a key to the list of keys
242  #[inline]
243  pub fn add_key(&mut self, key: SecretKey) -> &mut Self {
244    self.keys.push(key);
245    self
246  }
247}
248
249/// Error that can occur when transforming a [`KeyResponseMessage`].
250#[derive(Debug, thiserror::Error)]
251pub enum KeyResponseMessageTransformError {
252  /// Not enough bytes to decode KeyResponseMessage
253  #[error("not enough bytes to decode `KeyResponseMessage`")]
254  NotEnoughBytes,
255  /// Encode buffer too small
256  #[error("encode buffer too small")]
257  BufferTooSmall,
258  /// Error transforming a message field
259  #[error(transparent)]
260  Message(#[from] StringTransformError),
261  /// Error transforming a `primary_key` field
262  #[error(transparent)]
263  PrimaryKey(#[from] OptionSecretKeyTransformError),
264  /// Error transforming a `keys` field
265  #[error(transparent)]
266  Keys(#[from] SecretKeysTransformError),
267}
268
269impl Transformable for KeyResponseMessage {
270  type Error = KeyResponseMessageTransformError;
271
272  fn encode(&self, dst: &mut [u8]) -> Result<usize, Self::Error> {
273    let encoded_len = self.encoded_len();
274    if dst.len() < encoded_len {
275      return Err(Self::Error::BufferTooSmall);
276    }
277
278    let mut offset = 0;
279    NetworkEndian::write_u32(&mut dst[offset..offset + 4], encoded_len as u32);
280    offset += 4;
281    dst[offset] = self.result as u8;
282    offset += 1;
283    offset += self.message.encode(&mut dst[offset..])?;
284    offset += self.keys.encode(&mut dst[offset..])?;
285    offset += KeyRequestMessage {
286      key: self.primary_key,
287    }
288    .encode(&mut dst[offset..])?;
289
290    debug_assert_eq!(
291      offset, encoded_len,
292      "expect write {} bytes, but actual write {} bytes",
293      encoded_len, offset
294    );
295
296    Ok(offset)
297  }
298
299  fn encoded_len(&self) -> usize {
300    4 + 1
301      + self.message.encoded_len()
302      + self.keys.encoded_len()
303      + KeyRequestMessage {
304        key: self.primary_key,
305      }
306      .encoded_len()
307  }
308
309  fn decode(src: &[u8]) -> Result<(usize, Self), Self::Error>
310  where
311    Self: Sized,
312  {
313    let src_len = src.len();
314    if src_len < 5 {
315      return Err(Self::Error::NotEnoughBytes);
316    }
317
318    let mut offset = 0;
319    let encoded_len = NetworkEndian::read_u32(&src[offset..offset + 4]) as usize;
320    if src_len < encoded_len {
321      return Err(Self::Error::NotEnoughBytes);
322    }
323    offset += 4;
324
325    let result = src[offset] != 0;
326    offset += 1;
327    let (n, message) = SmolStr::decode(&src[offset..])?;
328    offset += n;
329    let (n, keys) = SecretKeys::decode(&src[offset..])?;
330    offset += n;
331    let (n, primary_key) = KeyRequestMessage::decode(&src[offset..])?;
332    offset += n;
333
334    debug_assert_eq!(
335      offset, encoded_len,
336      "expect read {} bytes, but actual read {} bytes",
337      encoded_len, offset
338    );
339
340    Ok((
341      offset,
342      Self {
343        result,
344        message,
345        keys,
346        primary_key: primary_key.key,
347      },
348    ))
349  }
350}
351
352/// KeyResponse is used to relay a query for a list of all keys in use.
353#[viewit::viewit(setters(prefix = "with"))]
354#[derive(Default)]
355pub struct KeyResponse<I> {
356  /// Map of node id to response message
357  #[viewit(
358    getter(
359      const,
360      style = "ref",
361      attrs(doc = "Returns the map of node id to response message")
362    ),
363    setter(attrs(doc = "Sets the map of node id to response message (Builder pattern)"))
364  )]
365  messages: IndexMap<I, SmolStr>,
366  /// Total nodes memberlist knows of
367  #[viewit(
368    getter(const, attrs(doc = "Returns the total nodes memberlist knows of")),
369    setter(
370      const,
371      attrs(doc = "Sets total nodes memberlist knows of (Builder pattern)")
372    )
373  )]
374  num_nodes: usize,
375  /// Total responses received
376  #[viewit(
377    getter(const, attrs(doc = "Returns the total responses received")),
378    setter(
379      const,
380      attrs(doc = "Sets the total responses received (Builder pattern)")
381    )
382  )]
383  num_resp: usize,
384  /// Total errors from request
385  #[viewit(
386    getter(const, attrs(doc = "Returns the total errors from request")),
387    setter(
388      const,
389      attrs(doc = "Sets the total errors from request (Builder pattern)")
390    )
391  )]
392  num_err: usize,
393
394  /// A mapping of the value of the key bytes to the
395  /// number of nodes that have the key installed.
396  #[viewit(
397    getter(
398      const,
399      style = "ref",
400      attrs(
401        doc = "Returns a mapping of the value of the key bytes to the number of nodes that have the key installed."
402      )
403    ),
404    setter(attrs(
405      doc = "Sets a mapping of the value of the key bytes to the number of nodes that have the key installed (Builder pattern)"
406    ))
407  )]
408  keys: IndexMap<SecretKey, usize>,
409
410  /// A mapping of the value of the primary
411  /// key bytes to the number of nodes that have the key installed.
412  #[viewit(
413    getter(
414      const,
415      style = "ref",
416      attrs(
417        doc = "Returns a mapping of the value of the primary key bytes to the number of nodes that have the key installed."
418      )
419    ),
420    setter(attrs(
421      doc = "Sets a mapping of the value of the primary key bytes to the number of nodes that have the key installed. (Builder pattern)"
422    ))
423  )]
424  primary_keys: IndexMap<SecretKey, usize>,
425}
426
427/// KeyRequestOptions is used to contain optional parameters for a keyring operation
428#[derive(Debug, Clone, PartialEq, Eq, Hash)]
429pub struct KeyRequestOptions {
430  /// The number of duplicate query responses to send by relaying through
431  /// other nodes, for redundancy
432  pub relay_factor: u8,
433}
434
435#[cfg(test)]
436mod tests {
437  use rand::{distributions::Alphanumeric, thread_rng, Rng};
438
439  use super::*;
440
441  impl KeyRequestMessage {
442    pub(crate) fn random(kind: u8) -> Self {
443      let key = if rand::random() {
444        match kind {
445          16 => {
446            let mut buf = [0u8; 16];
447            rand::thread_rng().fill(&mut buf);
448            Some(SecretKey::from(buf))
449          }
450          24 => {
451            let mut buf = [0u8; 24];
452            rand::thread_rng().fill(&mut buf);
453            Some(SecretKey::from(buf))
454          }
455          32 => {
456            let mut buf = [0u8; 32];
457            rand::thread_rng().fill(&mut buf);
458            Some(SecretKey::from(buf))
459          }
460          _ => None,
461        }
462      } else {
463        None
464      };
465
466      Self { key }
467    }
468  }
469
470  impl KeyResponseMessage {
471    pub(crate) fn random(num_keys: usize, size: usize) -> Self {
472      let mut keys = SecretKeys::new();
473      for i in 0..num_keys {
474        let kind = match i % 3 {
475          0 => 16,
476          1 => 24,
477          2 => 32,
478          _ => unreachable!(),
479        };
480        let key = match kind {
481          16 => {
482            let mut buf = [0u8; 16];
483            rand::thread_rng().fill(&mut buf);
484            SecretKey::from(buf)
485          }
486          24 => {
487            let mut buf = [0u8; 24];
488            rand::thread_rng().fill(&mut buf);
489            SecretKey::from(buf)
490          }
491          32 => {
492            let mut buf = [0u8; 32];
493            rand::thread_rng().fill(&mut buf);
494            SecretKey::from(buf)
495          }
496          _ => unreachable!(),
497        };
498        keys.push(key);
499      }
500
501      let primary_key = if rand::random() {
502        let mut buf = [0u8; 32];
503        rand::thread_rng().fill(&mut buf);
504        Some(SecretKey::from(buf))
505      } else {
506        None
507      };
508
509      let message = thread_rng()
510        .sample_iter(Alphanumeric)
511        .take(size)
512        .collect::<Vec<u8>>();
513      let message = String::from_utf8(message).unwrap().into();
514
515      Self {
516        result: rand::random(),
517        message,
518        keys,
519        primary_key,
520      }
521    }
522  }
523
524  #[test]
525  fn test_key_request_message_transform() {
526    futures::executor::block_on(async {
527      for i in 0..100 {
528        let kind = match i % 4 {
529          0 => 0,
530          1 => 16,
531          2 => 24,
532          _ => 32,
533        };
534        let key = KeyRequestMessage::random(kind);
535        let mut buf = vec![0; key.encoded_len()];
536        let encoded_len = key.encode(&mut buf).unwrap();
537        assert_eq!(encoded_len, key.encoded_len());
538        let mut buf1 = vec![];
539        let encoded_len1 = key.encode_to_writer(&mut buf1).unwrap();
540        assert_eq!(encoded_len1, key.encoded_len());
541        let mut buf2 = vec![];
542        let encoded_len2 = key.encode_to_async_writer(&mut buf2).await.unwrap();
543        assert_eq!(encoded_len2, key.encoded_len());
544
545        let (decoded_len, decoded) = KeyRequestMessage::decode(&buf).unwrap();
546        assert_eq!(decoded_len, encoded_len);
547        assert_eq!(decoded, key);
548        let (decoded_len, decoded) = KeyRequestMessage::decode(&buf1).unwrap();
549        assert_eq!(decoded_len, encoded_len);
550        assert_eq!(decoded, key);
551        let (decoded_len, decoded) = KeyRequestMessage::decode(&buf2).unwrap();
552        assert_eq!(decoded_len, encoded_len);
553        assert_eq!(decoded, key);
554
555        let (decoded_len, decoded) =
556          KeyRequestMessage::decode_from_reader(&mut std::io::Cursor::new(&buf)).unwrap();
557        assert_eq!(decoded_len, encoded_len);
558        assert_eq!(decoded, key);
559        let (decoded_len, decoded) =
560          KeyRequestMessage::decode_from_reader(&mut std::io::Cursor::new(&buf1)).unwrap();
561        assert_eq!(decoded_len, encoded_len);
562        assert_eq!(decoded, key);
563        let (decoded_len, decoded) =
564          KeyRequestMessage::decode_from_reader(&mut std::io::Cursor::new(&buf2)).unwrap();
565        assert_eq!(decoded_len, encoded_len);
566        assert_eq!(decoded, key);
567
568        let (decoded_len, decoded) =
569          KeyRequestMessage::decode_from_async_reader(&mut futures::io::Cursor::new(&buf))
570            .await
571            .unwrap();
572        assert_eq!(decoded_len, encoded_len);
573        assert_eq!(decoded, key);
574        let (decoded_len, decoded) =
575          KeyRequestMessage::decode_from_async_reader(&mut futures::io::Cursor::new(&buf1))
576            .await
577            .unwrap();
578        assert_eq!(decoded_len, encoded_len);
579        assert_eq!(decoded, key);
580        let (decoded_len, decoded) =
581          KeyRequestMessage::decode_from_async_reader(&mut futures::io::Cursor::new(&buf2))
582            .await
583            .unwrap();
584        assert_eq!(decoded_len, encoded_len);
585        assert_eq!(decoded, key);
586      }
587    });
588  }
589
590  #[test]
591  fn test_key_response_message_transform() {
592    futures::executor::block_on(async {
593      for i in 0..100 {
594        let message = KeyResponseMessage::random(i % 10, i);
595        let mut buf = vec![0; message.encoded_len()];
596        let encoded_len = message.encode(&mut buf).unwrap();
597        assert_eq!(encoded_len, message.encoded_len());
598
599        let (decoded_len, decoded) = KeyResponseMessage::decode(&buf).unwrap();
600        assert_eq!(decoded_len, encoded_len);
601        assert_eq!(decoded, message);
602
603        let (decoded_len, decoded) =
604          KeyResponseMessage::decode_from_reader(&mut std::io::Cursor::new(&buf)).unwrap();
605        assert_eq!(decoded_len, encoded_len);
606        assert_eq!(decoded, message);
607
608        let (decoded_len, decoded) =
609          KeyResponseMessage::decode_from_async_reader(&mut futures::io::Cursor::new(&buf))
610            .await
611            .unwrap();
612        assert_eq!(decoded_len, encoded_len);
613        assert_eq!(decoded, message);
614      }
615    });
616  }
617}