serf_core/
key_manager.rs

1use std::{collections::HashMap, sync::OnceLock};
2
3use crate::types::MessageRef;
4use async_channel::Receiver;
5use async_lock::RwLock;
6use futures::StreamExt;
7use memberlist_core::{CheapClone, proto::SecretKey, tracing, transport::Transport};
8use smol_str::{SmolStr, format_smolstr};
9
10use crate::event::{
11  INTERNAL_INSTALL_KEY, INTERNAL_LIST_KEYS, INTERNAL_REMOVE_KEY, INTERNAL_USE_KEY,
12  InternalQueryEvent,
13};
14
15use super::{
16  Serf, SerfWeakRef,
17  delegate::Delegate,
18  error::Error,
19  serf::{NodeResponse, QueryResponse},
20  types::KeyRequestMessage,
21};
22
23/// KeyResponse is used to relay a query for a list of all keys in use.
24#[viewit::viewit(
25  vis_all = "pub(crate)",
26  getters(style = "move", vis_all = "pub"),
27  setters(skip)
28)]
29#[derive(Default, Debug)]
30pub struct KeyResponse<I> {
31  /// Map of node id to response message
32  #[viewit(getter(
33    const,
34    style = "ref",
35    attrs(doc = "Returns a map of node id to response message.")
36  ))]
37  messages: HashMap<I, SmolStr>,
38  /// Total nodes memberlist knows of
39  #[viewit(getter(const, attrs(doc = "Returns the total nodes memberlist knows of.")))]
40  num_nodes: usize,
41  /// Total responses received
42  #[viewit(getter(const, attrs(doc = "Returns the total responses received.")))]
43  num_resp: usize,
44  /// Total errors from request
45  #[viewit(getter(const, attrs(doc = "Returns the total errors from request.")))]
46  num_err: usize,
47
48  /// A mapping of the value of the key bytes to the
49  /// number of nodes that have the key installed.
50  #[viewit(getter(
51    const,
52    style = "ref",
53    attrs(
54      doc = "Returns a mapping of the value of the key bytes to the number of nodes that have the key installed.."
55    )
56  ))]
57  keys: HashMap<SecretKey, usize>,
58
59  /// A mapping of the value of the primary
60  /// key bytes to the number of nodes that have the key installed.
61  #[viewit(getter(
62    const,
63    style = "ref",
64    attrs(
65      doc = "Returns a mapping of the value of the primary key bytes to the number of nodes that have the key installed."
66    )
67  ))]
68  primary_keys: HashMap<SecretKey, usize>,
69}
70
71/// KeyRequestOptions is used to contain optional parameters for a keyring operation
72pub struct KeyRequestOptions {
73  /// The number of duplicate query responses to send by relaying through
74  /// other nodes, for redundancy
75  pub relay_factor: u8,
76}
77
78/// `KeyManager` encapsulates all functionality within Serf for handling
79/// encryption keyring changes across a cluster.
80pub struct KeyManager<T, D>
81where
82  D: Delegate<Id = T::Id, Address = T::ResolvedAddress>,
83  T: Transport,
84{
85  serf: OnceLock<SerfWeakRef<T, D>>,
86  /// The lock is used to serialize keys related handlers
87  l: RwLock<()>,
88}
89
90impl<T, D> KeyManager<T, D>
91where
92  D: Delegate<Id = T::Id, Address = T::ResolvedAddress>,
93  T: Transport,
94{
95  pub(crate) fn new() -> Self {
96    Self {
97      serf: OnceLock::new(),
98      l: RwLock::new(()),
99    }
100  }
101
102  pub(crate) fn store(&self, serf: SerfWeakRef<T, D>) {
103    // No error handling here, because we never call this in parallel
104    let _ = self.serf.set(serf);
105  }
106
107  fn this(&self) -> Option<Serf<T, D>> {
108    self.serf.get().and_then(|weak_ref| weak_ref.upgrade())
109  }
110
111  /// Handles broadcasting a query to all members and gathering
112  /// responses from each of them, returning a list of messages from each node
113  /// and any applicable error conditions.
114  pub async fn install_key(
115    &self,
116    key: SecretKey,
117    opts: Option<KeyRequestOptions>,
118  ) -> Result<KeyResponse<T::Id>, Error<T, D>> {
119    let _mu = self.l.write().await;
120    self
121      .handle_key_request(
122        Some(key),
123        INTERNAL_INSTALL_KEY,
124        opts,
125        InternalQueryEvent::InstallKey,
126      )
127      .await
128  }
129
130  /// Handles broadcasting a primary key change to all members in the
131  /// cluster, and gathering any response messages. If successful, there should
132  /// be an empty KeyResponse returned.
133  pub async fn use_key(
134    &self,
135    key: SecretKey,
136    opts: Option<KeyRequestOptions>,
137  ) -> Result<KeyResponse<T::Id>, Error<T, D>> {
138    let _mu = self.l.write().await;
139    self
140      .handle_key_request(
141        Some(key),
142        INTERNAL_USE_KEY,
143        opts,
144        InternalQueryEvent::UseKey,
145      )
146      .await
147  }
148
149  /// Handles broadcasting a key to the cluster for removal. Each member
150  /// will receive this event, and if they have the key in their keyring, remove
151  /// it. If any errors are encountered, RemoveKey will collect and relay them.
152  pub async fn remove_key(
153    &self,
154    key: SecretKey,
155    opts: Option<KeyRequestOptions>,
156  ) -> Result<KeyResponse<T::Id>, Error<T, D>> {
157    let _mu = self.l.write().await;
158    self
159      .handle_key_request(
160        Some(key),
161        INTERNAL_REMOVE_KEY,
162        opts,
163        InternalQueryEvent::RemoveKey,
164      )
165      .await
166  }
167
168  /// Used to collect installed keys from members in a Serf cluster
169  /// and return an aggregated list of all installed keys. This is useful to
170  /// operators to ensure that there are no lingering keys installed on any agents.
171  /// Since having multiple keys installed can cause performance penalties in some
172  /// cases, it's important to verify this information and remove unneeded keys.
173  pub async fn list_keys(&self) -> Result<KeyResponse<T::Id>, Error<T, D>> {
174    let _mu = self.l.read().await;
175    self
176      .handle_key_request(None, INTERNAL_LIST_KEYS, None, InternalQueryEvent::ListKey)
177      .await
178  }
179
180  pub(crate) async fn handle_key_request(
181    &self,
182    key: Option<SecretKey>,
183    ty: &str,
184    opts: Option<KeyRequestOptions>,
185    event: InternalQueryEvent<T::Id>,
186  ) -> Result<KeyResponse<T::Id>, Error<T, D>> {
187    let kr = KeyRequestMessage { key };
188    let buf = crate::types::encode_message_to_bytes(&kr)?;
189
190    let Some(this) = self.this() else {
191      return Ok(KeyResponse {
192        num_nodes: 0,
193        messages: HashMap::new(),
194        num_resp: 0,
195        num_err: 0,
196        keys: HashMap::new(),
197        primary_keys: HashMap::new(),
198      });
199    };
200
201    let mut q_param = this.default_query_param().await;
202    if let Some(opts) = opts {
203      q_param.relay_factor = opts.relay_factor;
204    }
205    let qresp: QueryResponse<T::Id, T::ResolvedAddress> = this
206      .internal_query(SmolStr::new(ty), buf, Some(q_param), event)
207      .await?;
208
209    // Handle the response stream and populate the KeyResponse
210    let resp = Self::stream_key_response(&this, qresp.response_rx()).await;
211
212    // Check the response for any reported failure conditions
213    if resp.num_err > 0 {
214      tracing::error!(
215        "serf: {}/{} nodes reported failure",
216        resp.num_err,
217        resp.num_nodes
218      );
219    }
220
221    if resp.num_resp != resp.num_nodes {
222      tracing::error!(
223        "serf: {}/{} nodes responded success",
224        resp.num_resp,
225        resp.num_nodes
226      );
227    }
228
229    Ok(resp)
230  }
231
232  async fn stream_key_response(
233    this: &Serf<T, D>,
234    ch: Receiver<NodeResponse<T::Id, T::ResolvedAddress>>,
235  ) -> KeyResponse<T::Id> {
236    let mut resp = KeyResponse {
237      num_nodes: this.num_members().await,
238      messages: HashMap::new(),
239      num_resp: 0,
240      num_err: 0,
241      keys: HashMap::new(),
242      primary_keys: HashMap::new(),
243    };
244    futures::pin_mut!(ch);
245    while let Some(r) = ch.next().await {
246      resp.num_resp += 1;
247
248      // Decode the response
249      if r.payload.is_empty() {
250        resp
251          .messages
252          .insert(r.from.id().cheap_clone(), SmolStr::new("empty payload"));
253        resp.num_err += 1;
254
255        if resp.num_resp == resp.num_nodes {
256          return resp;
257        }
258        continue;
259      }
260
261      let node_response =
262        match crate::types::decode_message::<T::Id, T::ResolvedAddress>(&r.payload) {
263          Ok(msg) => match msg {
264            MessageRef::KeyResponse(kr) => kr,
265            msg => {
266              tracing::error!(type=%msg.ty(), "serf: invalid key query response type");
267
268              resp.messages.insert(
269                r.from.id().cheap_clone(),
270                format_smolstr!("invalid key query response: {}", msg.ty()),
271              );
272              resp.num_err += 1;
273
274              if resp.num_resp == resp.num_nodes {
275                return resp;
276              }
277              continue;
278            }
279          },
280          Err(e) => {
281            tracing::error!(err=%e, "serf: failed to decode key query response");
282            resp
283              .messages
284              .insert(r.from.id().cheap_clone(), format_smolstr!("{e}"));
285            resp.num_err += 1;
286
287            if resp.num_resp == resp.num_nodes {
288              return resp;
289            }
290            continue;
291          }
292        };
293
294      if !node_response.result() {
295        resp.messages.insert(
296          r.from.id().cheap_clone(),
297          SmolStr::new(node_response.message()),
298        );
299        resp.num_err += 1;
300      }
301
302      if node_response.result() && !node_response.message().is_empty() {
303        tracing::warn!("serf: {}", node_response.message());
304        resp.messages.insert(
305          r.from.id().cheap_clone(),
306          SmolStr::new(node_response.message()),
307        );
308      }
309
310      // Currently only used for key list queries, this adds keys to a counter
311      // and increments them for each node response which contains them.
312      let res = node_response
313        .keys()
314        .iter::<SecretKey>()
315        .try_for_each(|res| {
316          res.map(|k| {
317            let count = resp.keys.entry(k).or_insert(0);
318            *count += 1;
319          })
320        });
321
322      if let Err(e) = res {
323        resp.messages.insert(
324          r.from.id().cheap_clone(),
325          SmolStr::new(format!("Failed to decode key query response: {:?}", e)),
326        );
327        resp.num_err += 1;
328
329        if resp.num_resp == resp.num_nodes {
330          return resp;
331        }
332        continue;
333      }
334
335      if let Some(pk) = node_response.primary_key() {
336        let ctr = resp.primary_keys.entry(*pk).or_insert(0);
337        *ctr += 1;
338      }
339
340      // Return early if all nodes have responded. This allows us to avoid
341      // waiting for the full timeout when there is nothing left to do.
342      if resp.num_resp == resp.num_nodes {
343        return resp;
344      }
345    }
346    resp
347  }
348}