zerodds_discovery/type_lookup/
client.rs1use alloc::boxed::Box;
29use alloc::collections::{BTreeMap, VecDeque};
30use alloc::vec::Vec;
31
32use zerodds_cdr::{BufferWriter, EncodeError, Endianness};
33use zerodds_types::type_lookup::{
34 ContinuationPoint, GetTypeDependenciesReply, GetTypeDependenciesRequest, GetTypesReply,
35 GetTypesRequest,
36};
37use zerodds_types::{EquivalenceHash, TypeIdentifier};
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
45pub struct RequestId(pub u64);
46
47impl RequestId {
48 #[must_use]
50 pub fn from_u64(v: u64) -> Self {
51 Self(v)
52 }
53}
54
55#[derive(Debug, Clone)]
58pub enum TypeLookupReply {
59 Types(GetTypesReply),
61 Dependencies(GetTypeDependenciesReply),
63}
64
65pub type ClientCallback = Box<dyn FnMut(TypeLookupReply) + Send>;
67
68struct Pending {
70 callback: ClientCallback,
71}
72
73impl core::fmt::Debug for Pending {
74 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
75 f.debug_struct("Pending").finish()
76 }
77}
78
79#[derive(Debug)]
85pub struct TypeLookupClient {
86 pending: BTreeMap<RequestId, Pending>,
87 pending_order: VecDeque<RequestId>,
89 next_seq: u64,
90 max_pending: usize,
91}
92
93impl TypeLookupClient {
94 pub const DEFAULT_MAX_PENDING: usize = 256;
96
97 #[must_use]
99 pub fn new() -> Self {
100 Self::with_capacity(Self::DEFAULT_MAX_PENDING)
101 }
102
103 #[must_use]
105 pub fn with_capacity(max_pending: usize) -> Self {
106 Self {
107 pending: BTreeMap::new(),
108 pending_order: VecDeque::new(),
109 next_seq: 1,
110 max_pending: max_pending.max(1),
111 }
112 }
113
114 #[must_use]
116 pub fn pending_count(&self) -> usize {
117 self.pending.len()
118 }
119
120 pub fn request_types(
124 &mut self,
125 _ids: Vec<TypeIdentifier>,
126 callback: ClientCallback,
127 ) -> RequestId {
128 self.alloc_pending(callback)
129 }
130
131 pub fn request_type_dependencies(
133 &mut self,
134 _ids: Vec<TypeIdentifier>,
135 _continuation_point: ContinuationPoint,
136 callback: ClientCallback,
137 ) -> RequestId {
138 self.alloc_pending(callback)
139 }
140
141 fn alloc_pending(&mut self, callback: ClientCallback) -> RequestId {
142 let id = RequestId(self.next_seq);
143 self.next_seq = self.next_seq.saturating_add(1);
144
145 while self.pending.len() >= self.max_pending {
147 if let Some(old) = self.pending_order.pop_front() {
148 self.pending.remove(&old);
149 } else {
150 break;
151 }
152 }
153
154 self.pending.insert(id, Pending { callback });
155 self.pending_order.push_back(id);
156 id
157 }
158
159 pub fn handle_reply(&mut self, request_id: RequestId, reply: TypeLookupReply) -> bool {
166 let Some(mut entry) = self.pending.remove(&request_id) else {
167 return false;
168 };
169 if let Some(pos) = self.pending_order.iter().position(|x| *x == request_id) {
171 self.pending_order.remove(pos);
172 }
173 (entry.callback)(reply);
174 true
175 }
176
177 pub fn clear(&mut self) {
179 self.pending.clear();
180 self.pending_order.clear();
181 }
182}
183
184impl Default for TypeLookupClient {
185 fn default() -> Self {
186 Self::new()
187 }
188}
189
190pub fn request_types_payload(ids: &[TypeIdentifier]) -> Result<Vec<u8>, EncodeError> {
195 let req = GetTypesRequest {
196 type_ids: ids.to_vec(),
197 };
198 let mut w = BufferWriter::new(Endianness::Little);
199 req.encode_into(&mut w)?;
200 Ok(w.into_bytes())
201}
202
203pub fn request_dependencies_payload(
208 ids: &[TypeIdentifier],
209 continuation_point: ContinuationPoint,
210) -> Result<Vec<u8>, EncodeError> {
211 let req = GetTypeDependenciesRequest {
212 type_ids: ids.to_vec(),
213 continuation_point,
214 };
215 let mut w = BufferWriter::new(Endianness::Little);
216 req.encode_into(&mut w)?;
217 Ok(w.into_bytes())
218}
219
220#[must_use]
222pub fn hashes_to_minimal_ids(hashes: &[EquivalenceHash]) -> Vec<TypeIdentifier> {
223 hashes
224 .iter()
225 .map(|h| TypeIdentifier::EquivalenceHashMinimal(*h))
226 .collect()
227}
228
229#[cfg(test)]
230#[allow(clippy::unwrap_used)]
231mod tests {
232 use super::*;
233 use core::cell::RefCell;
234 extern crate std;
235 use std::sync::Arc;
236 use std::sync::Mutex;
237
238 #[test]
239 fn request_id_unique_and_monotone() {
240 let mut c = TypeLookupClient::new();
241 let id1 = c.request_types(alloc::vec![], Box::new(|_| {}));
242 let id2 = c.request_types(alloc::vec![], Box::new(|_| {}));
243 let id3 = c.request_types(alloc::vec![], Box::new(|_| {}));
244 assert!(id1 < id2);
245 assert!(id2 < id3);
246 }
247
248 #[test]
249 fn handle_reply_unknown_id_is_ignored() {
250 let mut c = TypeLookupClient::new();
251 let consumed = c.handle_reply(
252 RequestId(99),
253 TypeLookupReply::Types(GetTypesReply::default()),
254 );
255 assert!(!consumed);
256 }
257
258 #[test]
259 fn handle_reply_invokes_callback() {
260 let calls = Arc::new(Mutex::new(0u32));
261 let calls_clone = Arc::clone(&calls);
262 let mut c = TypeLookupClient::new();
263 let id = c.request_types(
264 alloc::vec![],
265 Box::new(move |_| {
266 *calls_clone.lock().unwrap() += 1;
267 }),
268 );
269 assert_eq!(*calls.lock().unwrap(), 0);
270
271 let consumed = c.handle_reply(id, TypeLookupReply::Types(GetTypesReply::default()));
272 assert!(consumed);
273 assert_eq!(*calls.lock().unwrap(), 1);
274 assert_eq!(c.pending_count(), 0);
275 }
276
277 #[test]
278 fn double_reply_runs_callback_only_once() {
279 let calls = Arc::new(Mutex::new(0u32));
280 let calls_clone = Arc::clone(&calls);
281 let mut c = TypeLookupClient::new();
282 let id = c.request_types(
283 alloc::vec![],
284 Box::new(move |_| {
285 *calls_clone.lock().unwrap() += 1;
286 }),
287 );
288 c.handle_reply(id, TypeLookupReply::Types(GetTypesReply::default()));
289 c.handle_reply(id, TypeLookupReply::Types(GetTypesReply::default()));
290 assert_eq!(*calls.lock().unwrap(), 1);
291 }
292
293 #[test]
294 fn pending_cap_evicts_oldest() {
295 let mut c = TypeLookupClient::with_capacity(2);
296 let _id1 = c.request_types(alloc::vec![], Box::new(|_| {}));
297 let id2 = c.request_types(alloc::vec![], Box::new(|_| {}));
298 let id3 = c.request_types(alloc::vec![], Box::new(|_| {}));
299 assert_eq!(c.pending_count(), 2);
301 assert!(c.pending.contains_key(&id2));
302 assert!(c.pending.contains_key(&id3));
303 }
304
305 #[test]
306 fn clear_drops_all_pending() {
307 let mut c = TypeLookupClient::new();
308 c.request_types(alloc::vec![], Box::new(|_| {}));
309 c.request_types(alloc::vec![], Box::new(|_| {}));
310 assert_eq!(c.pending_count(), 2);
311 c.clear();
312 assert_eq!(c.pending_count(), 0);
313 }
314
315 #[test]
316 fn request_types_payload_roundtrips() {
317 let ids = alloc::vec![
318 TypeIdentifier::EquivalenceHashMinimal(EquivalenceHash([0x55; 14])),
319 TypeIdentifier::Primitive(zerodds_types::PrimitiveKind::Int32),
320 ];
321 let bytes = request_types_payload(&ids).unwrap();
322 assert!(bytes.len() >= 4);
324 }
325
326 #[test]
327 fn dependencies_payload_carries_continuation() {
328 let ids = alloc::vec![TypeIdentifier::EquivalenceHashMinimal(EquivalenceHash(
329 [0x77; 14]
330 ))];
331 let cp = ContinuationPoint(alloc::vec![1, 2, 3]);
332 let bytes = request_dependencies_payload(&ids, cp).unwrap();
333 assert!(!bytes.is_empty());
334 }
335
336 #[test]
337 fn hashes_to_minimal_ids_maps_each() {
338 let hashes = alloc::vec![EquivalenceHash([1; 14]), EquivalenceHash([2; 14])];
339 let ids = hashes_to_minimal_ids(&hashes);
340 assert_eq!(ids.len(), 2);
341 assert!(matches!(ids[0], TypeIdentifier::EquivalenceHashMinimal(_)));
342 }
343
344 #[test]
348 fn callback_can_mutate_via_arc_mutex() {
349 let _: RefCell<i32> = RefCell::new(0); }
351}