1use std::{collections::HashMap, sync::OnceLock};
2
3use crate::types::MessageRef;
4use async_channel::Receiver;
5use async_lock::RwLock;
6use futures::StreamExt;
7use memberlist_core::{CheapClone, proto::SecretKey, tracing, transport::Transport};
8use smol_str::{SmolStr, format_smolstr};
9
10use crate::event::{
11 INTERNAL_INSTALL_KEY, INTERNAL_LIST_KEYS, INTERNAL_REMOVE_KEY, INTERNAL_USE_KEY,
12 InternalQueryEvent,
13};
14
15use super::{
16 Serf, SerfWeakRef,
17 delegate::Delegate,
18 error::Error,
19 serf::{NodeResponse, QueryResponse},
20 types::KeyRequestMessage,
21};
22
23#[viewit::viewit(
25 vis_all = "pub(crate)",
26 getters(style = "move", vis_all = "pub"),
27 setters(skip)
28)]
29#[derive(Default, Debug)]
30pub struct KeyResponse<I> {
31 #[viewit(getter(
33 const,
34 style = "ref",
35 attrs(doc = "Returns a map of node id to response message.")
36 ))]
37 messages: HashMap<I, SmolStr>,
38 #[viewit(getter(const, attrs(doc = "Returns the total nodes memberlist knows of.")))]
40 num_nodes: usize,
41 #[viewit(getter(const, attrs(doc = "Returns the total responses received.")))]
43 num_resp: usize,
44 #[viewit(getter(const, attrs(doc = "Returns the total errors from request.")))]
46 num_err: usize,
47
48 #[viewit(getter(
51 const,
52 style = "ref",
53 attrs(
54 doc = "Returns a mapping of the value of the key bytes to the number of nodes that have the key installed.."
55 )
56 ))]
57 keys: HashMap<SecretKey, usize>,
58
59 #[viewit(getter(
62 const,
63 style = "ref",
64 attrs(
65 doc = "Returns a mapping of the value of the primary key bytes to the number of nodes that have the key installed."
66 )
67 ))]
68 primary_keys: HashMap<SecretKey, usize>,
69}
70
71pub struct KeyRequestOptions {
73 pub relay_factor: u8,
76}
77
78pub struct KeyManager<T, D>
81where
82 D: Delegate<Id = T::Id, Address = T::ResolvedAddress>,
83 T: Transport,
84{
85 serf: OnceLock<SerfWeakRef<T, D>>,
86 l: RwLock<()>,
88}
89
90impl<T, D> KeyManager<T, D>
91where
92 D: Delegate<Id = T::Id, Address = T::ResolvedAddress>,
93 T: Transport,
94{
95 pub(crate) fn new() -> Self {
96 Self {
97 serf: OnceLock::new(),
98 l: RwLock::new(()),
99 }
100 }
101
102 pub(crate) fn store(&self, serf: SerfWeakRef<T, D>) {
103 let _ = self.serf.set(serf);
105 }
106
107 fn this(&self) -> Option<Serf<T, D>> {
108 self.serf.get().and_then(|weak_ref| weak_ref.upgrade())
109 }
110
111 pub async fn install_key(
115 &self,
116 key: SecretKey,
117 opts: Option<KeyRequestOptions>,
118 ) -> Result<KeyResponse<T::Id>, Error<T, D>> {
119 let _mu = self.l.write().await;
120 self
121 .handle_key_request(
122 Some(key),
123 INTERNAL_INSTALL_KEY,
124 opts,
125 InternalQueryEvent::InstallKey,
126 )
127 .await
128 }
129
130 pub async fn use_key(
134 &self,
135 key: SecretKey,
136 opts: Option<KeyRequestOptions>,
137 ) -> Result<KeyResponse<T::Id>, Error<T, D>> {
138 let _mu = self.l.write().await;
139 self
140 .handle_key_request(
141 Some(key),
142 INTERNAL_USE_KEY,
143 opts,
144 InternalQueryEvent::UseKey,
145 )
146 .await
147 }
148
149 pub async fn remove_key(
153 &self,
154 key: SecretKey,
155 opts: Option<KeyRequestOptions>,
156 ) -> Result<KeyResponse<T::Id>, Error<T, D>> {
157 let _mu = self.l.write().await;
158 self
159 .handle_key_request(
160 Some(key),
161 INTERNAL_REMOVE_KEY,
162 opts,
163 InternalQueryEvent::RemoveKey,
164 )
165 .await
166 }
167
168 pub async fn list_keys(&self) -> Result<KeyResponse<T::Id>, Error<T, D>> {
174 let _mu = self.l.read().await;
175 self
176 .handle_key_request(None, INTERNAL_LIST_KEYS, None, InternalQueryEvent::ListKey)
177 .await
178 }
179
180 pub(crate) async fn handle_key_request(
181 &self,
182 key: Option<SecretKey>,
183 ty: &str,
184 opts: Option<KeyRequestOptions>,
185 event: InternalQueryEvent<T::Id>,
186 ) -> Result<KeyResponse<T::Id>, Error<T, D>> {
187 let kr = KeyRequestMessage { key };
188 let buf = crate::types::encode_message_to_bytes(&kr)?;
189
190 let Some(this) = self.this() else {
191 return Ok(KeyResponse {
192 num_nodes: 0,
193 messages: HashMap::new(),
194 num_resp: 0,
195 num_err: 0,
196 keys: HashMap::new(),
197 primary_keys: HashMap::new(),
198 });
199 };
200
201 let mut q_param = this.default_query_param().await;
202 if let Some(opts) = opts {
203 q_param.relay_factor = opts.relay_factor;
204 }
205 let qresp: QueryResponse<T::Id, T::ResolvedAddress> = this
206 .internal_query(SmolStr::new(ty), buf, Some(q_param), event)
207 .await?;
208
209 let resp = Self::stream_key_response(&this, qresp.response_rx()).await;
211
212 if resp.num_err > 0 {
214 tracing::error!(
215 "serf: {}/{} nodes reported failure",
216 resp.num_err,
217 resp.num_nodes
218 );
219 }
220
221 if resp.num_resp != resp.num_nodes {
222 tracing::error!(
223 "serf: {}/{} nodes responded success",
224 resp.num_resp,
225 resp.num_nodes
226 );
227 }
228
229 Ok(resp)
230 }
231
232 async fn stream_key_response(
233 this: &Serf<T, D>,
234 ch: Receiver<NodeResponse<T::Id, T::ResolvedAddress>>,
235 ) -> KeyResponse<T::Id> {
236 let mut resp = KeyResponse {
237 num_nodes: this.num_members().await,
238 messages: HashMap::new(),
239 num_resp: 0,
240 num_err: 0,
241 keys: HashMap::new(),
242 primary_keys: HashMap::new(),
243 };
244 futures::pin_mut!(ch);
245 while let Some(r) = ch.next().await {
246 resp.num_resp += 1;
247
248 if r.payload.is_empty() {
250 resp
251 .messages
252 .insert(r.from.id().cheap_clone(), SmolStr::new("empty payload"));
253 resp.num_err += 1;
254
255 if resp.num_resp == resp.num_nodes {
256 return resp;
257 }
258 continue;
259 }
260
261 let node_response =
262 match crate::types::decode_message::<T::Id, T::ResolvedAddress>(&r.payload) {
263 Ok(msg) => match msg {
264 MessageRef::KeyResponse(kr) => kr,
265 msg => {
266 tracing::error!(type=%msg.ty(), "serf: invalid key query response type");
267
268 resp.messages.insert(
269 r.from.id().cheap_clone(),
270 format_smolstr!("invalid key query response: {}", msg.ty()),
271 );
272 resp.num_err += 1;
273
274 if resp.num_resp == resp.num_nodes {
275 return resp;
276 }
277 continue;
278 }
279 },
280 Err(e) => {
281 tracing::error!(err=%e, "serf: failed to decode key query response");
282 resp
283 .messages
284 .insert(r.from.id().cheap_clone(), format_smolstr!("{e}"));
285 resp.num_err += 1;
286
287 if resp.num_resp == resp.num_nodes {
288 return resp;
289 }
290 continue;
291 }
292 };
293
294 if !node_response.result() {
295 resp.messages.insert(
296 r.from.id().cheap_clone(),
297 SmolStr::new(node_response.message()),
298 );
299 resp.num_err += 1;
300 }
301
302 if node_response.result() && !node_response.message().is_empty() {
303 tracing::warn!("serf: {}", node_response.message());
304 resp.messages.insert(
305 r.from.id().cheap_clone(),
306 SmolStr::new(node_response.message()),
307 );
308 }
309
310 let res = node_response
313 .keys()
314 .iter::<SecretKey>()
315 .try_for_each(|res| {
316 res.map(|k| {
317 let count = resp.keys.entry(k).or_insert(0);
318 *count += 1;
319 })
320 });
321
322 if let Err(e) = res {
323 resp.messages.insert(
324 r.from.id().cheap_clone(),
325 SmolStr::new(format!("Failed to decode key query response: {:?}", e)),
326 );
327 resp.num_err += 1;
328
329 if resp.num_resp == resp.num_nodes {
330 return resp;
331 }
332 continue;
333 }
334
335 if let Some(pk) = node_response.primary_key() {
336 let ctr = resp.primary_keys.entry(*pk).or_insert(0);
337 *ctr += 1;
338 }
339
340 if resp.num_resp == resp.num_nodes {
343 return resp;
344 }
345 }
346 resp
347 }
348}