1mod client;
2
3use bucky_raw_codec::{RawDecode, RawEncode, RawFixedBytes};
4use callback_result::CallbackWaiter;
5pub use client::*;
6use num::{FromPrimitive, ToPrimitive};
7use sfo_pool::{
8 ClassifiedWorker, ClassifiedWorkerFactory, ClassifiedWorkerGuard, WorkerClassification,
9};
10use std::collections::{HashMap, VecDeque};
11use std::hash::Hash;
12use std::marker::PhantomData;
13use std::ops::{Deref, DerefMut};
14use std::sync::{Arc, Mutex};
15use std::time::Duration;
16use tokio::sync::Notify;
17
18mod classified_client;
19pub use classified_client::*;
20
21use crate::errors::CmdResult;
22use crate::{CmdBody, CmdHandler, CmdTunnelMeta, PeerId, TunnelId};
23
24pub trait CmdSend<M: CmdTunnelMeta>: Send + 'static {
25 fn get_tunnel_meta(&self) -> Option<Arc<M>>;
26 fn get_remote_peer_id(&self) -> PeerId;
27}
28
29pub trait SendGuard<M: CmdTunnelMeta, S: CmdSend<M>>: Send + 'static + Deref<Target = S> {}
30
31#[derive(Clone, Copy, Debug, Eq, PartialEq)]
32enum TunnelBorrowState {
33 Idle,
34 Borrowed,
35}
36
37struct TunnelRuntimeEntry {
38 state: TunnelBorrowState,
39 waiters: VecDeque<Arc<Notify>>,
40}
41
42pub(crate) enum TunnelReserveResult {
43 Acquired,
44 Wait(Arc<Notify>),
45 Missing,
46}
47
48#[derive(Default)]
49pub(crate) struct TunnelRuntimeRegistry {
50 state: Mutex<HashMap<TunnelId, TunnelRuntimeEntry>>,
51}
52
53impl TunnelRuntimeRegistry {
54 pub(crate) fn new() -> Arc<Self> {
55 Arc::new(Self::default())
56 }
57
58 pub(crate) fn reserve_existing(&self, tunnel_id: TunnelId) -> TunnelReserveResult {
59 let mut state = self.state.lock().unwrap();
60 match state.get_mut(&tunnel_id) {
61 Some(entry) => match entry.state {
62 TunnelBorrowState::Idle => {
63 entry.state = TunnelBorrowState::Borrowed;
64 TunnelReserveResult::Acquired
65 }
66 TunnelBorrowState::Borrowed => {
67 let notify = Arc::new(Notify::new());
68 entry.waiters.push_back(notify.clone());
69 TunnelReserveResult::Wait(notify)
70 }
71 },
72 None => TunnelReserveResult::Missing,
73 }
74 }
75
76 pub(crate) fn mark_borrowed(&self, tunnel_id: TunnelId) -> bool {
77 let mut state = self.state.lock().unwrap();
78 match state.get_mut(&tunnel_id) {
79 Some(entry) => match entry.state {
80 TunnelBorrowState::Idle => {
81 entry.state = TunnelBorrowState::Borrowed;
82 true
83 }
84 TunnelBorrowState::Borrowed => false,
85 },
86 None => {
87 state.insert(
88 tunnel_id,
89 TunnelRuntimeEntry {
90 state: TunnelBorrowState::Borrowed,
91 waiters: VecDeque::new(),
92 },
93 );
94 true
95 }
96 }
97 }
98
99 pub(crate) fn remove(&self, tunnel_id: TunnelId) {
100 let waiters = {
101 let mut state = self.state.lock().unwrap();
102 state
103 .remove(&tunnel_id)
104 .map(|entry| entry.waiters.into_iter().collect::<Vec<_>>())
105 .unwrap_or_default()
106 };
107 for waiter in waiters {
108 waiter.notify_one();
109 }
110 }
111
112 pub(crate) fn release(&self, tunnel_id: TunnelId, alive: bool) {
113 let (waiter, waiters) = {
114 let mut state = self.state.lock().unwrap();
115 match state.get_mut(&tunnel_id) {
116 Some(entry) if alive => {
117 entry.state = TunnelBorrowState::Idle;
118 (entry.waiters.pop_front(), Vec::new())
119 }
120 Some(_) => {
121 let entry = state.remove(&tunnel_id).unwrap();
122 (None, entry.waiters.into_iter().collect::<Vec<_>>())
123 }
124 None => (None, Vec::new()),
125 }
126 };
127 if let Some(waiter) = waiter {
128 waiter.notify_one();
129 }
130 for waiter in waiters {
131 waiter.notify_one();
132 }
133 }
134
135 pub(crate) fn clear(&self) {
136 let waiters = {
137 let mut state = self.state.lock().unwrap();
138 state
139 .drain()
140 .flat_map(|(_, entry)| entry.waiters.into_iter())
141 .collect::<Vec<_>>()
142 };
143 for waiter in waiters {
144 waiter.notify_one();
145 }
146 }
147}
148
149pub struct TrackedSendGuard<
150 C: WorkerClassification,
151 M: CmdTunnelMeta,
152 S: ClassifiedWorker<C> + CmdSend<M>,
153 F: ClassifiedWorkerFactory<C, S>,
154> {
155 worker_guard: Option<ClassifiedWorkerGuard<C, S, F>>,
156 runtime_tunnels: Arc<TunnelRuntimeRegistry>,
157 tunnel_id: TunnelId,
158 _p: PhantomData<M>,
159}
160
161impl<
162 C: WorkerClassification,
163 M: CmdTunnelMeta,
164 S: ClassifiedWorker<C> + CmdSend<M>,
165 F: ClassifiedWorkerFactory<C, S>,
166> TrackedSendGuard<C, M, S, F>
167{
168 pub(crate) fn new(
169 worker_guard: ClassifiedWorkerGuard<C, S, F>,
170 runtime_tunnels: Arc<TunnelRuntimeRegistry>,
171 tunnel_id: TunnelId,
172 ) -> Self {
173 Self {
174 worker_guard: Some(worker_guard),
175 runtime_tunnels,
176 tunnel_id,
177 _p: PhantomData,
178 }
179 }
180}
181
182impl<
183 C: WorkerClassification,
184 M: CmdTunnelMeta,
185 S: ClassifiedWorker<C> + CmdSend<M>,
186 F: ClassifiedWorkerFactory<C, S>,
187> Deref for TrackedSendGuard<C, M, S, F>
188{
189 type Target = S;
190
191 fn deref(&self) -> &Self::Target {
192 self.worker_guard.as_ref().unwrap().deref()
193 }
194}
195
196impl<
197 C: WorkerClassification,
198 M: CmdTunnelMeta,
199 S: ClassifiedWorker<C> + CmdSend<M>,
200 F: ClassifiedWorkerFactory<C, S>,
201> DerefMut for TrackedSendGuard<C, M, S, F>
202{
203 fn deref_mut(&mut self) -> &mut Self::Target {
204 self.worker_guard.as_mut().unwrap().deref_mut()
205 }
206}
207
208impl<
209 C: WorkerClassification,
210 M: CmdTunnelMeta,
211 S: ClassifiedWorker<C> + CmdSend<M>,
212 F: ClassifiedWorkerFactory<C, S>,
213> Drop for TrackedSendGuard<C, M, S, F>
214{
215 fn drop(&mut self) {
216 if let Some(worker_guard) = self.worker_guard.take() {
217 let alive = worker_guard.is_work();
218 drop(worker_guard);
219 self.runtime_tunnels.release(self.tunnel_id, alive);
220 }
221 }
222}
223
224impl<
225 C: WorkerClassification,
226 M: CmdTunnelMeta,
227 S: ClassifiedWorker<C> + CmdSend<M>,
228 F: ClassifiedWorkerFactory<C, S>,
229> SendGuard<M, S> for TrackedSendGuard<C, M, S, F>
230{
231}
232
233#[async_trait::async_trait]
234pub trait CmdClient<
235 LEN: RawEncode
236 + for<'a> RawDecode<'a>
237 + Copy
238 + RawFixedBytes
239 + Sync
240 + Send
241 + 'static
242 + FromPrimitive
243 + ToPrimitive,
244 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + RawFixedBytes + Sync + Send + 'static + Eq + Hash,
245 M: CmdTunnelMeta,
246 S: CmdSend<M>,
247 G: SendGuard<M, S>,
248>: Send + Sync + 'static
249{
250 fn register_cmd_handler(&self, cmd: CMD, handler: impl CmdHandler<LEN, CMD>);
251 async fn send(&self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()>;
252 async fn send_with_resp(
253 &self,
254 cmd: CMD,
255 version: u8,
256 body: &[u8],
257 timeout: Duration,
258 ) -> CmdResult<CmdBody>;
259 async fn send_parts(&self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()>;
260 async fn send_parts_with_resp(
261 &self,
262 cmd: CMD,
263 version: u8,
264 body: &[&[u8]],
265 timeout: Duration,
266 ) -> CmdResult<CmdBody>;
267 #[deprecated(note = "use send_parts instead")]
268 async fn send2(&self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
269 self.send_parts(cmd, version, body).await
270 }
271 #[deprecated(note = "use send_parts_with_resp instead")]
272 async fn send2_with_resp(
273 &self,
274 cmd: CMD,
275 version: u8,
276 body: &[&[u8]],
277 timeout: Duration,
278 ) -> CmdResult<CmdBody> {
279 self.send_parts_with_resp(cmd, version, body, timeout).await
280 }
281 async fn send_cmd(&self, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()>;
282 async fn send_cmd_with_resp(
283 &self,
284 cmd: CMD,
285 version: u8,
286 body: CmdBody,
287 timeout: Duration,
288 ) -> CmdResult<CmdBody>;
289 async fn send_by_specify_tunnel(
290 &self,
291 tunnel_id: TunnelId,
292 cmd: CMD,
293 version: u8,
294 body: &[u8],
295 ) -> CmdResult<()>;
296 async fn send_by_specify_tunnel_with_resp(
297 &self,
298 tunnel_id: TunnelId,
299 cmd: CMD,
300 version: u8,
301 body: &[u8],
302 timeout: Duration,
303 ) -> CmdResult<CmdBody>;
304 async fn send_parts_by_specify_tunnel(
305 &self,
306 tunnel_id: TunnelId,
307 cmd: CMD,
308 version: u8,
309 body: &[&[u8]],
310 ) -> CmdResult<()>;
311 async fn send_parts_by_specify_tunnel_with_resp(
312 &self,
313 tunnel_id: TunnelId,
314 cmd: CMD,
315 version: u8,
316 body: &[&[u8]],
317 timeout: Duration,
318 ) -> CmdResult<CmdBody>;
319 #[deprecated(note = "use send_parts_by_specify_tunnel instead")]
320 async fn send2_by_specify_tunnel(
321 &self,
322 tunnel_id: TunnelId,
323 cmd: CMD,
324 version: u8,
325 body: &[&[u8]],
326 ) -> CmdResult<()> {
327 self.send_parts_by_specify_tunnel(tunnel_id, cmd, version, body)
328 .await
329 }
330 #[deprecated(note = "use send_parts_by_specify_tunnel_with_resp instead")]
331 async fn send2_by_specify_tunnel_with_resp(
332 &self,
333 tunnel_id: TunnelId,
334 cmd: CMD,
335 version: u8,
336 body: &[&[u8]],
337 timeout: Duration,
338 ) -> CmdResult<CmdBody> {
339 self.send_parts_by_specify_tunnel_with_resp(tunnel_id, cmd, version, body, timeout)
340 .await
341 }
342 async fn send_cmd_by_specify_tunnel(
343 &self,
344 tunnel_id: TunnelId,
345 cmd: CMD,
346 version: u8,
347 body: CmdBody,
348 ) -> CmdResult<()>;
349 async fn send_cmd_by_specify_tunnel_with_resp(
350 &self,
351 tunnel_id: TunnelId,
352 cmd: CMD,
353 version: u8,
354 body: CmdBody,
355 timeout: Duration,
356 ) -> CmdResult<CmdBody>;
357 async fn clear_all_tunnel(&self);
358 async fn get_send(&self, tunnel_id: TunnelId) -> CmdResult<G>;
359}
360
361#[async_trait::async_trait]
362pub trait ClassifiedCmdClient<
363 LEN: RawEncode
364 + for<'a> RawDecode<'a>
365 + Copy
366 + RawFixedBytes
367 + Sync
368 + Send
369 + 'static
370 + FromPrimitive
371 + ToPrimitive,
372 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + RawFixedBytes + Sync + Send + 'static + Eq + Hash,
373 C: WorkerClassification,
374 M: CmdTunnelMeta,
375 S: CmdSend<M>,
376 G: SendGuard<M, S>,
377>: CmdClient<LEN, CMD, M, S, G>
378{
379 async fn send_by_classified_tunnel(
380 &self,
381 classification: C,
382 cmd: CMD,
383 version: u8,
384 body: &[u8],
385 ) -> CmdResult<()>;
386 async fn send_by_classified_tunnel_with_resp(
387 &self,
388 classification: C,
389 cmd: CMD,
390 version: u8,
391 body: &[u8],
392 timeout: Duration,
393 ) -> CmdResult<CmdBody>;
394 async fn send_parts_by_classified_tunnel(
395 &self,
396 classification: C,
397 cmd: CMD,
398 version: u8,
399 body: &[&[u8]],
400 ) -> CmdResult<()>;
401 async fn send_parts_by_classified_tunnel_with_resp(
402 &self,
403 classification: C,
404 cmd: CMD,
405 version: u8,
406 body: &[&[u8]],
407 timeout: Duration,
408 ) -> CmdResult<CmdBody>;
409 #[deprecated(note = "use send_parts_by_classified_tunnel instead")]
410 async fn send2_by_classified_tunnel(
411 &self,
412 classification: C,
413 cmd: CMD,
414 version: u8,
415 body: &[&[u8]],
416 ) -> CmdResult<()> {
417 self.send_parts_by_classified_tunnel(classification, cmd, version, body)
418 .await
419 }
420 #[deprecated(note = "use send_parts_by_classified_tunnel_with_resp instead")]
421 async fn send2_by_classified_tunnel_with_resp(
422 &self,
423 classification: C,
424 cmd: CMD,
425 version: u8,
426 body: &[&[u8]],
427 timeout: Duration,
428 ) -> CmdResult<CmdBody> {
429 self.send_parts_by_classified_tunnel_with_resp(classification, cmd, version, body, timeout)
430 .await
431 }
432 async fn send_cmd_by_classified_tunnel(
433 &self,
434 classification: C,
435 cmd: CMD,
436 version: u8,
437 body: CmdBody,
438 ) -> CmdResult<()>;
439 async fn send_cmd_by_classified_tunnel_with_resp(
440 &self,
441 classification: C,
442 cmd: CMD,
443 version: u8,
444 body: CmdBody,
445 timeout: Duration,
446 ) -> CmdResult<CmdBody>;
447 async fn find_tunnel_id_by_classified(&self, classification: C) -> CmdResult<TunnelId>;
448 async fn get_send_by_classified(&self, classification: C) -> CmdResult<G>;
449}
450
451pub(crate) type RespWaiter = CallbackWaiter<u128, CmdBody>;
452pub(crate) type RespWaiterRef = Arc<RespWaiter>;
453
454pub(crate) fn gen_resp_id<
455 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + RawFixedBytes + Sync + Send + 'static,
456>(
457 tunnel_id: TunnelId,
458 cmd: CMD,
459 seq: u32,
460) -> u128 {
461 let cmd_buf = cmd.raw_encode_to_buffer().unwrap();
462 let mut cmd = cmd_buf.len() as u64;
463 for chunk in cmd_buf.chunks(8) {
464 let mut buf = [0u8; 8];
465 buf[..chunk.len()].copy_from_slice(chunk);
466 cmd = cmd.rotate_left(13) ^ u64::from_be_bytes(buf);
467 }
468 ((tunnel_id.value() as u128) << 96) | ((seq as u128) << 64) | (cmd as u128)
469}
470
471pub(crate) fn gen_seq() -> u32 {
472 rand::random::<u32>()
473}
474
475#[cfg(test)]
476mod tests {
477 use super::gen_resp_id;
478 use crate::TunnelId;
479
480 #[test]
481 fn resp_id_changes_with_seq() {
482 let id1 = gen_resp_id(TunnelId::from(7), 0x11u8, 1);
483 let id2 = gen_resp_id(TunnelId::from(7), 0x11u8, 2);
484 assert_ne!(id1, id2);
485 }
486
487 #[test]
488 fn resp_id_changes_with_cmd() {
489 let id1 = gen_resp_id(TunnelId::from(7), 0x11u8, 5);
490 let id2 = gen_resp_id(TunnelId::from(7), 0x12u8, 5);
491 assert_ne!(id1, id2);
492 }
493
494 #[test]
495 fn resp_id_changes_with_tunnel() {
496 let id1 = gen_resp_id(TunnelId::from(7), 0x11u8, 5);
497 let id2 = gen_resp_id(TunnelId::from(8), 0x11u8, 5);
498 assert_ne!(id1, id2);
499 }
500
501 #[test]
502 fn resp_id_changes_with_long_cmd_suffix() {
503 let id1 = gen_resp_id(
504 TunnelId::from(7),
505 0x1122_3344_5566_7788_0000_0000_0000_0001u128,
506 5,
507 );
508 let id2 = gen_resp_id(
509 TunnelId::from(7),
510 0x1122_3344_5566_7788_0000_0000_0000_0002u128,
511 5,
512 );
513 assert_ne!(id1, id2);
514 }
515}