1use std::{collections::HashMap, sync::OnceLock};
2
3use crate::types::MessageRef;
4use async_channel::Receiver;
5use async_lock::RwLock;
6use futures::StreamExt;
7use memberlist_core::{CheapClone, proto::SecretKey, tracing, transport::Transport};
8use smol_str::{SmolStr, format_smolstr};
9
10use crate::event::{
11 INTERNAL_INSTALL_KEY, INTERNAL_LIST_KEYS, INTERNAL_REMOVE_KEY, INTERNAL_USE_KEY,
12 InternalQueryEvent,
13};
14
15use super::{
16 Serf,
17 delegate::Delegate,
18 error::Error,
19 serf::{NodeResponse, QueryResponse},
20 types::KeyRequestMessage,
21};
22
23#[viewit::viewit(
25 vis_all = "pub(crate)",
26 getters(style = "move", vis_all = "pub"),
27 setters(skip)
28)]
29#[derive(Default, Debug)]
30pub struct KeyResponse<I> {
31 #[viewit(getter(
33 const,
34 style = "ref",
35 attrs(doc = "Returns a map of node id to response message.")
36 ))]
37 messages: HashMap<I, SmolStr>,
38 #[viewit(getter(const, attrs(doc = "Returns the total nodes memberlist knows of.")))]
40 num_nodes: usize,
41 #[viewit(getter(const, attrs(doc = "Returns the total responses received.")))]
43 num_resp: usize,
44 #[viewit(getter(const, attrs(doc = "Returns the total errors from request.")))]
46 num_err: usize,
47
48 #[viewit(getter(
51 const,
52 style = "ref",
53 attrs(
54 doc = "Returns a mapping of the value of the key bytes to the number of nodes that have the key installed.."
55 )
56 ))]
57 keys: HashMap<SecretKey, usize>,
58
59 #[viewit(getter(
62 const,
63 style = "ref",
64 attrs(
65 doc = "Returns a mapping of the value of the primary key bytes to the number of nodes that have the key installed."
66 )
67 ))]
68 primary_keys: HashMap<SecretKey, usize>,
69}
70
71pub struct KeyRequestOptions {
73 pub relay_factor: u8,
76}
77
78pub struct KeyManager<T, D>
81where
82 D: Delegate<Id = T::Id, Address = T::ResolvedAddress>,
83 T: Transport,
84{
85 serf: OnceLock<Serf<T, D>>,
86 l: RwLock<()>,
88}
89
90impl<T, D> KeyManager<T, D>
91where
92 D: Delegate<Id = T::Id, Address = T::ResolvedAddress>,
93 T: Transport,
94{
95 pub(crate) fn new() -> Self {
96 Self {
97 serf: OnceLock::new(),
98 l: RwLock::new(()),
99 }
100 }
101
102 pub(crate) fn store(&self, serf: Serf<T, D>) {
103 let _ = self.serf.set(serf);
105 }
106
107 pub async fn install_key(
111 &self,
112 key: SecretKey,
113 opts: Option<KeyRequestOptions>,
114 ) -> Result<KeyResponse<T::Id>, Error<T, D>> {
115 let _mu = self.l.write().await;
116 self
117 .handle_key_request(
118 Some(key),
119 INTERNAL_INSTALL_KEY,
120 opts,
121 InternalQueryEvent::InstallKey,
122 )
123 .await
124 }
125
126 pub async fn use_key(
130 &self,
131 key: SecretKey,
132 opts: Option<KeyRequestOptions>,
133 ) -> Result<KeyResponse<T::Id>, Error<T, D>> {
134 let _mu = self.l.write().await;
135 self
136 .handle_key_request(
137 Some(key),
138 INTERNAL_USE_KEY,
139 opts,
140 InternalQueryEvent::UseKey,
141 )
142 .await
143 }
144
145 pub async fn remove_key(
149 &self,
150 key: SecretKey,
151 opts: Option<KeyRequestOptions>,
152 ) -> Result<KeyResponse<T::Id>, Error<T, D>> {
153 let _mu = self.l.write().await;
154 self
155 .handle_key_request(
156 Some(key),
157 INTERNAL_REMOVE_KEY,
158 opts,
159 InternalQueryEvent::RemoveKey,
160 )
161 .await
162 }
163
164 pub async fn list_keys(&self) -> Result<KeyResponse<T::Id>, Error<T, D>> {
170 let _mu = self.l.read().await;
171 self
172 .handle_key_request(None, INTERNAL_LIST_KEYS, None, InternalQueryEvent::ListKey)
173 .await
174 }
175
176 pub(crate) async fn handle_key_request(
177 &self,
178 key: Option<SecretKey>,
179 ty: &str,
180 opts: Option<KeyRequestOptions>,
181 event: InternalQueryEvent<T::Id>,
182 ) -> Result<KeyResponse<T::Id>, Error<T, D>> {
183 let kr = KeyRequestMessage { key };
184 let buf = crate::types::encode_message_to_bytes(&kr)?;
185
186 let serf = self.serf.get().unwrap();
187 let mut q_param = serf.default_query_param().await;
188 if let Some(opts) = opts {
189 q_param.relay_factor = opts.relay_factor;
190 }
191 let qresp: QueryResponse<T::Id, T::ResolvedAddress> = serf
192 .internal_query(SmolStr::new(ty), buf, Some(q_param), event)
193 .await?;
194
195 let resp = self.stream_key_response(qresp.response_rx()).await;
197
198 if resp.num_err > 0 {
200 tracing::error!(
201 "serf: {}/{} nodes reported failure",
202 resp.num_err,
203 resp.num_nodes
204 );
205 }
206
207 if resp.num_resp != resp.num_nodes {
208 tracing::error!(
209 "serf: {}/{} nodes responded success",
210 resp.num_resp,
211 resp.num_nodes
212 );
213 }
214
215 Ok(resp)
216 }
217
218 async fn stream_key_response(
219 &self,
220 ch: Receiver<NodeResponse<T::Id, T::ResolvedAddress>>,
221 ) -> KeyResponse<T::Id> {
222 let mut resp = KeyResponse {
223 num_nodes: self.serf.get().unwrap().num_members().await,
224 messages: HashMap::new(),
225 num_resp: 0,
226 num_err: 0,
227 keys: HashMap::new(),
228 primary_keys: HashMap::new(),
229 };
230 futures::pin_mut!(ch);
231 while let Some(r) = ch.next().await {
232 resp.num_resp += 1;
233
234 if r.payload.is_empty() {
236 resp
237 .messages
238 .insert(r.from.id().cheap_clone(), SmolStr::new("empty payload"));
239 resp.num_err += 1;
240
241 if resp.num_resp == resp.num_nodes {
242 return resp;
243 }
244 continue;
245 }
246
247 let node_response =
248 match crate::types::decode_message::<T::Id, T::ResolvedAddress>(&r.payload) {
249 Ok(msg) => match msg {
250 MessageRef::KeyResponse(kr) => kr,
251 msg => {
252 tracing::error!(type=%msg.ty(), "serf: invalid key query response type");
253
254 resp.messages.insert(
255 r.from.id().cheap_clone(),
256 format_smolstr!("invalid key query response: {}", msg.ty()),
257 );
258 resp.num_err += 1;
259
260 if resp.num_resp == resp.num_nodes {
261 return resp;
262 }
263 continue;
264 }
265 },
266 Err(e) => {
267 tracing::error!(err=%e, "serf: failed to decode key query response");
268 resp
269 .messages
270 .insert(r.from.id().cheap_clone(), format_smolstr!("{e}"));
271 resp.num_err += 1;
272
273 if resp.num_resp == resp.num_nodes {
274 return resp;
275 }
276 continue;
277 }
278 };
279
280 if !node_response.result() {
281 resp.messages.insert(
282 r.from.id().cheap_clone(),
283 SmolStr::new(node_response.message()),
284 );
285 resp.num_err += 1;
286 }
287
288 if node_response.result() && !node_response.message().is_empty() {
289 tracing::warn!("serf: {}", node_response.message());
290 resp.messages.insert(
291 r.from.id().cheap_clone(),
292 SmolStr::new(node_response.message()),
293 );
294 }
295
296 let res = node_response
299 .keys()
300 .iter::<SecretKey>()
301 .try_for_each(|res| {
302 res.map(|k| {
303 let count = resp.keys.entry(k).or_insert(0);
304 *count += 1;
305 })
306 });
307
308 if let Err(e) = res {
309 resp.messages.insert(
310 r.from.id().cheap_clone(),
311 SmolStr::new(format!("Failed to decode key query response: {:?}", e)),
312 );
313 resp.num_err += 1;
314
315 if resp.num_resp == resp.num_nodes {
316 return resp;
317 }
318 continue;
319 }
320
321 if let Some(pk) = node_response.primary_key() {
322 let ctr = resp.primary_keys.entry(*pk).or_insert(0);
323 *ctr += 1;
324 }
325
326 if resp.num_resp == resp.num_nodes {
329 return resp;
330 }
331 }
332 resp
333 }
334}