serf_core/
key_manager.rs

1use std::{collections::HashMap, sync::OnceLock};
2
3use crate::types::MessageRef;
4use async_channel::Receiver;
5use async_lock::RwLock;
6use futures::StreamExt;
7use memberlist_core::{CheapClone, proto::SecretKey, tracing, transport::Transport};
8use smol_str::{SmolStr, format_smolstr};
9
10use crate::event::{
11  INTERNAL_INSTALL_KEY, INTERNAL_LIST_KEYS, INTERNAL_REMOVE_KEY, INTERNAL_USE_KEY,
12  InternalQueryEvent,
13};
14
15use super::{
16  Serf,
17  delegate::Delegate,
18  error::Error,
19  serf::{NodeResponse, QueryResponse},
20  types::KeyRequestMessage,
21};
22
23/// KeyResponse is used to relay a query for a list of all keys in use.
24#[viewit::viewit(
25  vis_all = "pub(crate)",
26  getters(style = "move", vis_all = "pub"),
27  setters(skip)
28)]
29#[derive(Default, Debug)]
30pub struct KeyResponse<I> {
31  /// Map of node id to response message
32  #[viewit(getter(
33    const,
34    style = "ref",
35    attrs(doc = "Returns a map of node id to response message.")
36  ))]
37  messages: HashMap<I, SmolStr>,
38  /// Total nodes memberlist knows of
39  #[viewit(getter(const, attrs(doc = "Returns the total nodes memberlist knows of.")))]
40  num_nodes: usize,
41  /// Total responses received
42  #[viewit(getter(const, attrs(doc = "Returns the total responses received.")))]
43  num_resp: usize,
44  /// Total errors from request
45  #[viewit(getter(const, attrs(doc = "Returns the total errors from request.")))]
46  num_err: usize,
47
48  /// A mapping of the value of the key bytes to the
49  /// number of nodes that have the key installed.
50  #[viewit(getter(
51    const,
52    style = "ref",
53    attrs(
54      doc = "Returns a mapping of the value of the key bytes to the number of nodes that have the key installed.."
55    )
56  ))]
57  keys: HashMap<SecretKey, usize>,
58
59  /// A mapping of the value of the primary
60  /// key bytes to the number of nodes that have the key installed.
61  #[viewit(getter(
62    const,
63    style = "ref",
64    attrs(
65      doc = "Returns a mapping of the value of the primary key bytes to the number of nodes that have the key installed."
66    )
67  ))]
68  primary_keys: HashMap<SecretKey, usize>,
69}
70
71/// KeyRequestOptions is used to contain optional parameters for a keyring operation
72pub struct KeyRequestOptions {
73  /// The number of duplicate query responses to send by relaying through
74  /// other nodes, for redundancy
75  pub relay_factor: u8,
76}
77
78/// `KeyManager` encapsulates all functionality within Serf for handling
79/// encryption keyring changes across a cluster.
80pub struct KeyManager<T, D>
81where
82  D: Delegate<Id = T::Id, Address = T::ResolvedAddress>,
83  T: Transport,
84{
85  serf: OnceLock<Serf<T, D>>,
86  /// The lock is used to serialize keys related handlers
87  l: RwLock<()>,
88}
89
90impl<T, D> KeyManager<T, D>
91where
92  D: Delegate<Id = T::Id, Address = T::ResolvedAddress>,
93  T: Transport,
94{
95  pub(crate) fn new() -> Self {
96    Self {
97      serf: OnceLock::new(),
98      l: RwLock::new(()),
99    }
100  }
101
102  pub(crate) fn store(&self, serf: Serf<T, D>) {
103    // No error handling here, because we never call this in parallel
104    let _ = self.serf.set(serf);
105  }
106
107  /// Handles broadcasting a query to all members and gathering
108  /// responses from each of them, returning a list of messages from each node
109  /// and any applicable error conditions.
110  pub async fn install_key(
111    &self,
112    key: SecretKey,
113    opts: Option<KeyRequestOptions>,
114  ) -> Result<KeyResponse<T::Id>, Error<T, D>> {
115    let _mu = self.l.write().await;
116    self
117      .handle_key_request(
118        Some(key),
119        INTERNAL_INSTALL_KEY,
120        opts,
121        InternalQueryEvent::InstallKey,
122      )
123      .await
124  }
125
126  /// Handles broadcasting a primary key change to all members in the
127  /// cluster, and gathering any response messages. If successful, there should
128  /// be an empty KeyResponse returned.
129  pub async fn use_key(
130    &self,
131    key: SecretKey,
132    opts: Option<KeyRequestOptions>,
133  ) -> Result<KeyResponse<T::Id>, Error<T, D>> {
134    let _mu = self.l.write().await;
135    self
136      .handle_key_request(
137        Some(key),
138        INTERNAL_USE_KEY,
139        opts,
140        InternalQueryEvent::UseKey,
141      )
142      .await
143  }
144
145  /// Handles broadcasting a key to the cluster for removal. Each member
146  /// will receive this event, and if they have the key in their keyring, remove
147  /// it. If any errors are encountered, RemoveKey will collect and relay them.
148  pub async fn remove_key(
149    &self,
150    key: SecretKey,
151    opts: Option<KeyRequestOptions>,
152  ) -> Result<KeyResponse<T::Id>, Error<T, D>> {
153    let _mu = self.l.write().await;
154    self
155      .handle_key_request(
156        Some(key),
157        INTERNAL_REMOVE_KEY,
158        opts,
159        InternalQueryEvent::RemoveKey,
160      )
161      .await
162  }
163
164  /// Used to collect installed keys from members in a Serf cluster
165  /// and return an aggregated list of all installed keys. This is useful to
166  /// operators to ensure that there are no lingering keys installed on any agents.
167  /// Since having multiple keys installed can cause performance penalties in some
168  /// cases, it's important to verify this information and remove unneeded keys.
169  pub async fn list_keys(&self) -> Result<KeyResponse<T::Id>, Error<T, D>> {
170    let _mu = self.l.read().await;
171    self
172      .handle_key_request(None, INTERNAL_LIST_KEYS, None, InternalQueryEvent::ListKey)
173      .await
174  }
175
176  pub(crate) async fn handle_key_request(
177    &self,
178    key: Option<SecretKey>,
179    ty: &str,
180    opts: Option<KeyRequestOptions>,
181    event: InternalQueryEvent<T::Id>,
182  ) -> Result<KeyResponse<T::Id>, Error<T, D>> {
183    let kr = KeyRequestMessage { key };
184    let buf = crate::types::encode_message_to_bytes(&kr)?;
185
186    let serf = self.serf.get().unwrap();
187    let mut q_param = serf.default_query_param().await;
188    if let Some(opts) = opts {
189      q_param.relay_factor = opts.relay_factor;
190    }
191    let qresp: QueryResponse<T::Id, T::ResolvedAddress> = serf
192      .internal_query(SmolStr::new(ty), buf, Some(q_param), event)
193      .await?;
194
195    // Handle the response stream and populate the KeyResponse
196    let resp = self.stream_key_response(qresp.response_rx()).await;
197
198    // Check the response for any reported failure conditions
199    if resp.num_err > 0 {
200      tracing::error!(
201        "serf: {}/{} nodes reported failure",
202        resp.num_err,
203        resp.num_nodes
204      );
205    }
206
207    if resp.num_resp != resp.num_nodes {
208      tracing::error!(
209        "serf: {}/{} nodes responded success",
210        resp.num_resp,
211        resp.num_nodes
212      );
213    }
214
215    Ok(resp)
216  }
217
218  async fn stream_key_response(
219    &self,
220    ch: Receiver<NodeResponse<T::Id, T::ResolvedAddress>>,
221  ) -> KeyResponse<T::Id> {
222    let mut resp = KeyResponse {
223      num_nodes: self.serf.get().unwrap().num_members().await,
224      messages: HashMap::new(),
225      num_resp: 0,
226      num_err: 0,
227      keys: HashMap::new(),
228      primary_keys: HashMap::new(),
229    };
230    futures::pin_mut!(ch);
231    while let Some(r) = ch.next().await {
232      resp.num_resp += 1;
233
234      // Decode the response
235      if r.payload.is_empty() {
236        resp
237          .messages
238          .insert(r.from.id().cheap_clone(), SmolStr::new("empty payload"));
239        resp.num_err += 1;
240
241        if resp.num_resp == resp.num_nodes {
242          return resp;
243        }
244        continue;
245      }
246
247      let node_response =
248        match crate::types::decode_message::<T::Id, T::ResolvedAddress>(&r.payload) {
249          Ok(msg) => match msg {
250            MessageRef::KeyResponse(kr) => kr,
251            msg => {
252              tracing::error!(type=%msg.ty(), "serf: invalid key query response type");
253
254              resp.messages.insert(
255                r.from.id().cheap_clone(),
256                format_smolstr!("invalid key query response: {}", msg.ty()),
257              );
258              resp.num_err += 1;
259
260              if resp.num_resp == resp.num_nodes {
261                return resp;
262              }
263              continue;
264            }
265          },
266          Err(e) => {
267            tracing::error!(err=%e, "serf: failed to decode key query response");
268            resp
269              .messages
270              .insert(r.from.id().cheap_clone(), format_smolstr!("{e}"));
271            resp.num_err += 1;
272
273            if resp.num_resp == resp.num_nodes {
274              return resp;
275            }
276            continue;
277          }
278        };
279
280      if !node_response.result() {
281        resp.messages.insert(
282          r.from.id().cheap_clone(),
283          SmolStr::new(node_response.message()),
284        );
285        resp.num_err += 1;
286      }
287
288      if node_response.result() && !node_response.message().is_empty() {
289        tracing::warn!("serf: {}", node_response.message());
290        resp.messages.insert(
291          r.from.id().cheap_clone(),
292          SmolStr::new(node_response.message()),
293        );
294      }
295
296      // Currently only used for key list queries, this adds keys to a counter
297      // and increments them for each node response which contains them.
298      let res = node_response
299        .keys()
300        .iter::<SecretKey>()
301        .try_for_each(|res| {
302          res.map(|k| {
303            let count = resp.keys.entry(k).or_insert(0);
304            *count += 1;
305          })
306        });
307
308      if let Err(e) = res {
309        resp.messages.insert(
310          r.from.id().cheap_clone(),
311          SmolStr::new(format!("Failed to decode key query response: {:?}", e)),
312        );
313        resp.num_err += 1;
314
315        if resp.num_resp == resp.num_nodes {
316          return resp;
317        }
318        continue;
319      }
320
321      if let Some(pk) = node_response.primary_key() {
322        let ctr = resp.primary_keys.entry(*pk).or_insert(0);
323        *ctr += 1;
324      }
325
326      // Return early if all nodes have responded. This allows us to avoid
327      // waiting for the full timeout when there is nothing left to do.
328      if resp.num_resp == resp.num_nodes {
329        return resp;
330      }
331    }
332    resp
333  }
334}