serf_core/
key_manager.rs

1use std::{collections::HashMap, sync::OnceLock};
2
3use async_channel::Receiver;
4use async_lock::RwLock;
5use futures::StreamExt;
6use memberlist_core::{
7  bytes::{BufMut, BytesMut},
8  tracing,
9  transport::{AddressResolver, Transport},
10  types::SecretKey,
11  CheapClone,
12};
13use smol_str::SmolStr;
14
15use crate::event::{
16  InternalQueryEvent, INTERNAL_INSTALL_KEY, INTERNAL_LIST_KEYS, INTERNAL_REMOVE_KEY,
17  INTERNAL_USE_KEY,
18};
19
20use super::{
21  delegate::{Delegate, TransformDelegate},
22  error::Error,
23  serf::{NodeResponse, QueryResponse},
24  types::{KeyRequestMessage, MessageType, SerfMessage},
25  Serf,
26};
27
28/// KeyResponse is used to relay a query for a list of all keys in use.
29#[viewit::viewit(
30  vis_all = "pub(crate)",
31  getters(style = "move", vis_all = "pub"),
32  setters(skip)
33)]
34#[derive(Default, Debug)]
35pub struct KeyResponse<I> {
36  /// Map of node id to response message
37  #[viewit(getter(
38    const,
39    style = "ref",
40    attrs(doc = "Returns a map of node id to response message.")
41  ))]
42  messages: HashMap<I, SmolStr>,
43  /// Total nodes memberlist knows of
44  #[viewit(getter(const, attrs(doc = "Returns the total nodes memberlist knows of.")))]
45  num_nodes: usize,
46  /// Total responses received
47  #[viewit(getter(const, attrs(doc = "Returns the total responses received.")))]
48  num_resp: usize,
49  /// Total errors from request
50  #[viewit(getter(const, attrs(doc = "Returns the total errors from request.")))]
51  num_err: usize,
52
53  /// A mapping of the value of the key bytes to the
54  /// number of nodes that have the key installed.
55  #[viewit(getter(
56    const,
57    style = "ref",
58    attrs(
59      doc = "Returns a mapping of the value of the key bytes to the number of nodes that have the key installed.."
60    )
61  ))]
62  keys: HashMap<SecretKey, usize>,
63
64  /// A mapping of the value of the primary
65  /// key bytes to the number of nodes that have the key installed.
66  #[viewit(getter(
67    const,
68    style = "ref",
69    attrs(
70      doc = "Returns a mapping of the value of the primary key bytes to the number of nodes that have the key installed."
71    )
72  ))]
73  primary_keys: HashMap<SecretKey, usize>,
74}
75
76/// KeyRequestOptions is used to contain optional parameters for a keyring operation
77pub struct KeyRequestOptions {
78  /// The number of duplicate query responses to send by relaying through
79  /// other nodes, for redundancy
80  pub relay_factor: u8,
81}
82
83/// `KeyManager` encapsulates all functionality within Serf for handling
84/// encryption keyring changes across a cluster.
85pub struct KeyManager<T, D>
86where
87  D: Delegate<Id = T::Id, Address = <T::Resolver as AddressResolver>::ResolvedAddress>,
88  T: Transport,
89{
90  serf: OnceLock<Serf<T, D>>,
91  /// The lock is used to serialize keys related handlers
92  l: RwLock<()>,
93}
94
95impl<T, D> KeyManager<T, D>
96where
97  D: Delegate<Id = T::Id, Address = <T::Resolver as AddressResolver>::ResolvedAddress>,
98  T: Transport,
99{
100  pub(crate) fn new() -> Self {
101    Self {
102      serf: OnceLock::new(),
103      l: RwLock::new(()),
104    }
105  }
106
107  pub(crate) fn store(&self, serf: Serf<T, D>) {
108    // No error handling here, because we never call this in parallel
109    let _ = self.serf.set(serf);
110  }
111
112  /// Handles broadcasting a query to all members and gathering
113  /// responses from each of them, returning a list of messages from each node
114  /// and any applicable error conditions.
115  pub async fn install_key(
116    &self,
117    key: SecretKey,
118    opts: Option<KeyRequestOptions>,
119  ) -> Result<KeyResponse<T::Id>, Error<T, D>> {
120    let _mu = self.l.write().await;
121    self
122      .handle_key_request(
123        Some(key),
124        INTERNAL_INSTALL_KEY,
125        opts,
126        InternalQueryEvent::InstallKey,
127      )
128      .await
129  }
130
131  /// Handles broadcasting a primary key change to all members in the
132  /// cluster, and gathering any response messages. If successful, there should
133  /// be an empty KeyResponse returned.
134  pub async fn use_key(
135    &self,
136    key: SecretKey,
137    opts: Option<KeyRequestOptions>,
138  ) -> Result<KeyResponse<T::Id>, Error<T, D>> {
139    let _mu = self.l.write().await;
140    self
141      .handle_key_request(
142        Some(key),
143        INTERNAL_USE_KEY,
144        opts,
145        InternalQueryEvent::UseKey,
146      )
147      .await
148  }
149
150  /// Handles broadcasting a key to the cluster for removal. Each member
151  /// will receive this event, and if they have the key in their keyring, remove
152  /// it. If any errors are encountered, RemoveKey will collect and relay them.
153  pub async fn remove_key(
154    &self,
155    key: SecretKey,
156    opts: Option<KeyRequestOptions>,
157  ) -> Result<KeyResponse<T::Id>, Error<T, D>> {
158    let _mu = self.l.write().await;
159    self
160      .handle_key_request(
161        Some(key),
162        INTERNAL_REMOVE_KEY,
163        opts,
164        InternalQueryEvent::RemoveKey,
165      )
166      .await
167  }
168
169  /// Used to collect installed keys from members in a Serf cluster
170  /// and return an aggregated list of all installed keys. This is useful to
171  /// operators to ensure that there are no lingering keys installed on any agents.
172  /// Since having multiple keys installed can cause performance penalties in some
173  /// cases, it's important to verify this information and remove unneeded keys.
174  pub async fn list_keys(&self) -> Result<KeyResponse<T::Id>, Error<T, D>> {
175    let _mu = self.l.read().await;
176    self
177      .handle_key_request(None, INTERNAL_LIST_KEYS, None, InternalQueryEvent::ListKey)
178      .await
179  }
180
181  pub(crate) async fn handle_key_request(
182    &self,
183    key: Option<SecretKey>,
184    ty: &str,
185    opts: Option<KeyRequestOptions>,
186    event: InternalQueryEvent<T::Id>,
187  ) -> Result<KeyResponse<T::Id>, Error<T, D>> {
188    let kr = KeyRequestMessage { key };
189    let expected_encoded_len = <D as TransformDelegate>::message_encoded_len(&kr);
190    let mut buf = BytesMut::with_capacity(expected_encoded_len + 1); // +1 for the message type
191    buf.put_u8(MessageType::KeyRequest as u8);
192    buf.resize(expected_encoded_len + 1, 0);
193    // Encode the query request
194    let len = <D as TransformDelegate>::encode_message(&kr, &mut buf[1..])
195      .map_err(Error::transform_delegate)?;
196
197    debug_assert_eq!(
198      len, expected_encoded_len,
199      "expected encoded len {} mismatch the actual encoded len {}",
200      expected_encoded_len, len
201    );
202
203    let serf = self.serf.get().unwrap();
204    let mut q_param = serf.default_query_param().await;
205    if let Some(opts) = opts {
206      q_param.relay_factor = opts.relay_factor;
207    }
208    let qresp: QueryResponse<T::Id, <T::Resolver as AddressResolver>::ResolvedAddress> = serf
209      .internal_query(SmolStr::new(ty), buf.freeze(), Some(q_param), event)
210      .await?;
211
212    // Handle the response stream and populate the KeyResponse
213    let resp = self.stream_key_response(qresp.response_rx()).await;
214
215    // Check the response for any reported failure conditions
216    if resp.num_err > 0 {
217      tracing::error!(
218        "serf: {}/{} nodes reported failure",
219        resp.num_err,
220        resp.num_nodes
221      );
222    }
223
224    if resp.num_resp != resp.num_nodes {
225      tracing::error!(
226        "serf: {}/{} nodes responded success",
227        resp.num_resp,
228        resp.num_nodes
229      );
230    }
231
232    Ok(resp)
233  }
234
235  async fn stream_key_response(
236    &self,
237    ch: Receiver<NodeResponse<T::Id, <T::Resolver as AddressResolver>::ResolvedAddress>>,
238  ) -> KeyResponse<T::Id> {
239    let mut resp = KeyResponse {
240      num_nodes: self.serf.get().unwrap().num_members().await,
241      messages: HashMap::new(),
242      num_resp: 0,
243      num_err: 0,
244      keys: HashMap::new(),
245      primary_keys: HashMap::new(),
246    };
247    futures::pin_mut!(ch);
248    while let Some(r) = ch.next().await {
249      resp.num_resp += 1;
250
251      // Decode the response
252      if r.payload.is_empty() || r.payload[0] != MessageType::KeyResponse as u8 {
253        resp.messages.insert(
254          r.from.id().cheap_clone(),
255          SmolStr::new(format!(
256            "Invalid key query response type: {:?}",
257            r.payload.as_ref()
258          )),
259        );
260        resp.num_err += 1;
261
262        if resp.num_resp == resp.num_nodes {
263          return resp;
264        }
265        continue;
266      }
267
268      let node_response =
269        match <D as TransformDelegate>::decode_message(MessageType::KeyResponse, &r.payload[1..]) {
270          Ok((_, nr)) => match nr {
271            SerfMessage::KeyResponse(kr) => kr,
272            msg => {
273              resp.messages.insert(
274                r.from.id().cheap_clone(),
275                SmolStr::new(format!(
276                  "Invalid key query response type: {:?}",
277                  msg.ty().as_str()
278                )),
279              );
280              resp.num_err += 1;
281
282              if resp.num_resp == resp.num_nodes {
283                return resp;
284              }
285              continue;
286            }
287          },
288          Err(e) => {
289            resp.messages.insert(
290              r.from.id().cheap_clone(),
291              SmolStr::new(format!("Failed to decode key query response: {:?}", e)),
292            );
293            resp.num_err += 1;
294
295            if resp.num_resp == resp.num_nodes {
296              return resp;
297            }
298            continue;
299          }
300        };
301
302      if !node_response.result {
303        resp
304          .messages
305          .insert(r.from.id().cheap_clone(), node_response.message);
306        resp.num_err += 1;
307      } else if node_response.result && node_response.message.is_empty() {
308        tracing::warn!("serf: {}", node_response.message);
309        resp
310          .messages
311          .insert(r.from.id().cheap_clone(), node_response.message);
312      }
313
314      // Currently only used for key list queries, this adds keys to a counter
315      // and increments them for each node response which contains them.
316      for k in node_response.keys {
317        let count = resp.keys.entry(k).or_insert(0);
318        *count += 1;
319      }
320
321      if let Some(pk) = node_response.primary_key {
322        let ctr = resp.primary_keys.entry(pk).or_insert(0);
323        *ctr += 1;
324      }
325
326      // Return early if all nodes have responded. This allows us to avoid
327      // waiting for the full timeout when there is nothing left to do.
328      if resp.num_resp == resp.num_nodes {
329        return resp;
330      }
331    }
332    resp
333  }
334}