1use std::{collections::HashMap, sync::OnceLock};
2
3use async_channel::Receiver;
4use async_lock::RwLock;
5use futures::StreamExt;
6use memberlist_core::{
7 bytes::{BufMut, BytesMut},
8 tracing,
9 transport::{AddressResolver, Transport},
10 types::SecretKey,
11 CheapClone,
12};
13use smol_str::SmolStr;
14
15use crate::event::{
16 InternalQueryEvent, INTERNAL_INSTALL_KEY, INTERNAL_LIST_KEYS, INTERNAL_REMOVE_KEY,
17 INTERNAL_USE_KEY,
18};
19
20use super::{
21 delegate::{Delegate, TransformDelegate},
22 error::Error,
23 serf::{NodeResponse, QueryResponse},
24 types::{KeyRequestMessage, MessageType, SerfMessage},
25 Serf,
26};
27
28#[viewit::viewit(
30 vis_all = "pub(crate)",
31 getters(style = "move", vis_all = "pub"),
32 setters(skip)
33)]
34#[derive(Default, Debug)]
35pub struct KeyResponse<I> {
36 #[viewit(getter(
38 const,
39 style = "ref",
40 attrs(doc = "Returns a map of node id to response message.")
41 ))]
42 messages: HashMap<I, SmolStr>,
43 #[viewit(getter(const, attrs(doc = "Returns the total nodes memberlist knows of.")))]
45 num_nodes: usize,
46 #[viewit(getter(const, attrs(doc = "Returns the total responses received.")))]
48 num_resp: usize,
49 #[viewit(getter(const, attrs(doc = "Returns the total errors from request.")))]
51 num_err: usize,
52
53 #[viewit(getter(
56 const,
57 style = "ref",
58 attrs(
59 doc = "Returns a mapping of the value of the key bytes to the number of nodes that have the key installed.."
60 )
61 ))]
62 keys: HashMap<SecretKey, usize>,
63
64 #[viewit(getter(
67 const,
68 style = "ref",
69 attrs(
70 doc = "Returns a mapping of the value of the primary key bytes to the number of nodes that have the key installed."
71 )
72 ))]
73 primary_keys: HashMap<SecretKey, usize>,
74}
75
76pub struct KeyRequestOptions {
78 pub relay_factor: u8,
81}
82
83pub struct KeyManager<T, D>
86where
87 D: Delegate<Id = T::Id, Address = <T::Resolver as AddressResolver>::ResolvedAddress>,
88 T: Transport,
89{
90 serf: OnceLock<Serf<T, D>>,
91 l: RwLock<()>,
93}
94
95impl<T, D> KeyManager<T, D>
96where
97 D: Delegate<Id = T::Id, Address = <T::Resolver as AddressResolver>::ResolvedAddress>,
98 T: Transport,
99{
100 pub(crate) fn new() -> Self {
101 Self {
102 serf: OnceLock::new(),
103 l: RwLock::new(()),
104 }
105 }
106
107 pub(crate) fn store(&self, serf: Serf<T, D>) {
108 let _ = self.serf.set(serf);
110 }
111
112 pub async fn install_key(
116 &self,
117 key: SecretKey,
118 opts: Option<KeyRequestOptions>,
119 ) -> Result<KeyResponse<T::Id>, Error<T, D>> {
120 let _mu = self.l.write().await;
121 self
122 .handle_key_request(
123 Some(key),
124 INTERNAL_INSTALL_KEY,
125 opts,
126 InternalQueryEvent::InstallKey,
127 )
128 .await
129 }
130
131 pub async fn use_key(
135 &self,
136 key: SecretKey,
137 opts: Option<KeyRequestOptions>,
138 ) -> Result<KeyResponse<T::Id>, Error<T, D>> {
139 let _mu = self.l.write().await;
140 self
141 .handle_key_request(
142 Some(key),
143 INTERNAL_USE_KEY,
144 opts,
145 InternalQueryEvent::UseKey,
146 )
147 .await
148 }
149
150 pub async fn remove_key(
154 &self,
155 key: SecretKey,
156 opts: Option<KeyRequestOptions>,
157 ) -> Result<KeyResponse<T::Id>, Error<T, D>> {
158 let _mu = self.l.write().await;
159 self
160 .handle_key_request(
161 Some(key),
162 INTERNAL_REMOVE_KEY,
163 opts,
164 InternalQueryEvent::RemoveKey,
165 )
166 .await
167 }
168
169 pub async fn list_keys(&self) -> Result<KeyResponse<T::Id>, Error<T, D>> {
175 let _mu = self.l.read().await;
176 self
177 .handle_key_request(None, INTERNAL_LIST_KEYS, None, InternalQueryEvent::ListKey)
178 .await
179 }
180
181 pub(crate) async fn handle_key_request(
182 &self,
183 key: Option<SecretKey>,
184 ty: &str,
185 opts: Option<KeyRequestOptions>,
186 event: InternalQueryEvent<T::Id>,
187 ) -> Result<KeyResponse<T::Id>, Error<T, D>> {
188 let kr = KeyRequestMessage { key };
189 let expected_encoded_len = <D as TransformDelegate>::message_encoded_len(&kr);
190 let mut buf = BytesMut::with_capacity(expected_encoded_len + 1); buf.put_u8(MessageType::KeyRequest as u8);
192 buf.resize(expected_encoded_len + 1, 0);
193 let len = <D as TransformDelegate>::encode_message(&kr, &mut buf[1..])
195 .map_err(Error::transform_delegate)?;
196
197 debug_assert_eq!(
198 len, expected_encoded_len,
199 "expected encoded len {} mismatch the actual encoded len {}",
200 expected_encoded_len, len
201 );
202
203 let serf = self.serf.get().unwrap();
204 let mut q_param = serf.default_query_param().await;
205 if let Some(opts) = opts {
206 q_param.relay_factor = opts.relay_factor;
207 }
208 let qresp: QueryResponse<T::Id, <T::Resolver as AddressResolver>::ResolvedAddress> = serf
209 .internal_query(SmolStr::new(ty), buf.freeze(), Some(q_param), event)
210 .await?;
211
212 let resp = self.stream_key_response(qresp.response_rx()).await;
214
215 if resp.num_err > 0 {
217 tracing::error!(
218 "serf: {}/{} nodes reported failure",
219 resp.num_err,
220 resp.num_nodes
221 );
222 }
223
224 if resp.num_resp != resp.num_nodes {
225 tracing::error!(
226 "serf: {}/{} nodes responded success",
227 resp.num_resp,
228 resp.num_nodes
229 );
230 }
231
232 Ok(resp)
233 }
234
235 async fn stream_key_response(
236 &self,
237 ch: Receiver<NodeResponse<T::Id, <T::Resolver as AddressResolver>::ResolvedAddress>>,
238 ) -> KeyResponse<T::Id> {
239 let mut resp = KeyResponse {
240 num_nodes: self.serf.get().unwrap().num_members().await,
241 messages: HashMap::new(),
242 num_resp: 0,
243 num_err: 0,
244 keys: HashMap::new(),
245 primary_keys: HashMap::new(),
246 };
247 futures::pin_mut!(ch);
248 while let Some(r) = ch.next().await {
249 resp.num_resp += 1;
250
251 if r.payload.is_empty() || r.payload[0] != MessageType::KeyResponse as u8 {
253 resp.messages.insert(
254 r.from.id().cheap_clone(),
255 SmolStr::new(format!(
256 "Invalid key query response type: {:?}",
257 r.payload.as_ref()
258 )),
259 );
260 resp.num_err += 1;
261
262 if resp.num_resp == resp.num_nodes {
263 return resp;
264 }
265 continue;
266 }
267
268 let node_response =
269 match <D as TransformDelegate>::decode_message(MessageType::KeyResponse, &r.payload[1..]) {
270 Ok((_, nr)) => match nr {
271 SerfMessage::KeyResponse(kr) => kr,
272 msg => {
273 resp.messages.insert(
274 r.from.id().cheap_clone(),
275 SmolStr::new(format!(
276 "Invalid key query response type: {:?}",
277 msg.ty().as_str()
278 )),
279 );
280 resp.num_err += 1;
281
282 if resp.num_resp == resp.num_nodes {
283 return resp;
284 }
285 continue;
286 }
287 },
288 Err(e) => {
289 resp.messages.insert(
290 r.from.id().cheap_clone(),
291 SmolStr::new(format!("Failed to decode key query response: {:?}", e)),
292 );
293 resp.num_err += 1;
294
295 if resp.num_resp == resp.num_nodes {
296 return resp;
297 }
298 continue;
299 }
300 };
301
302 if !node_response.result {
303 resp
304 .messages
305 .insert(r.from.id().cheap_clone(), node_response.message);
306 resp.num_err += 1;
307 } else if node_response.result && node_response.message.is_empty() {
308 tracing::warn!("serf: {}", node_response.message);
309 resp
310 .messages
311 .insert(r.from.id().cheap_clone(), node_response.message);
312 }
313
314 for k in node_response.keys {
317 let count = resp.keys.entry(k).or_insert(0);
318 *count += 1;
319 }
320
321 if let Some(pk) = node_response.primary_key {
322 let ctr = resp.primary_keys.entry(pk).or_insert(0);
323 *ctr += 1;
324 }
325
326 if resp.num_resp == resp.num_nodes {
329 return resp;
330 }
331 }
332 resp
333 }
334}