Skip to main content

sfo_cmd_server/client/
mod.rs

1mod client;
2
3use bucky_raw_codec::{RawDecode, RawEncode, RawFixedBytes};
4use callback_result::CallbackWaiter;
5pub use client::*;
6use num::{FromPrimitive, ToPrimitive};
7use sfo_pool::{
8    ClassifiedWorker, ClassifiedWorkerFactory, ClassifiedWorkerGuard, WorkerClassification,
9};
10use std::collections::{HashMap, VecDeque};
11use std::hash::Hash;
12use std::marker::PhantomData;
13use std::ops::{Deref, DerefMut};
14use std::sync::{Arc, Mutex};
15use std::time::Duration;
16use tokio::sync::Notify;
17
18mod classified_client;
19pub use classified_client::*;
20
21use crate::errors::CmdResult;
22use crate::{CmdBody, CmdHandler, CmdTunnelMeta, PeerId, TunnelId};
23
24pub trait CmdSend<M: CmdTunnelMeta>: Send + 'static {
25    fn get_tunnel_meta(&self) -> Option<Arc<M>>;
26    fn get_remote_peer_id(&self) -> PeerId;
27}
28
29pub trait SendGuard<M: CmdTunnelMeta, S: CmdSend<M>>: Send + 'static + Deref<Target = S> {}
30
31#[derive(Clone, Copy, Debug, Eq, PartialEq)]
32enum TunnelBorrowState {
33    Idle,
34    Borrowed,
35}
36
37struct TunnelRuntimeEntry {
38    state: TunnelBorrowState,
39    waiters: VecDeque<Arc<Notify>>,
40}
41
42pub(crate) enum TunnelReserveResult {
43    Acquired,
44    Wait(Arc<Notify>),
45    Missing,
46}
47
48#[derive(Default)]
49pub(crate) struct TunnelRuntimeRegistry {
50    state: Mutex<HashMap<TunnelId, TunnelRuntimeEntry>>,
51}
52
53impl TunnelRuntimeRegistry {
54    pub(crate) fn new() -> Arc<Self> {
55        Arc::new(Self::default())
56    }
57
58    pub(crate) fn reserve_existing(&self, tunnel_id: TunnelId) -> TunnelReserveResult {
59        let mut state = self.state.lock().unwrap();
60        match state.get_mut(&tunnel_id) {
61            Some(entry) => match entry.state {
62                TunnelBorrowState::Idle => {
63                    entry.state = TunnelBorrowState::Borrowed;
64                    TunnelReserveResult::Acquired
65                }
66                TunnelBorrowState::Borrowed => {
67                    let notify = Arc::new(Notify::new());
68                    entry.waiters.push_back(notify.clone());
69                    TunnelReserveResult::Wait(notify)
70                }
71            },
72            None => TunnelReserveResult::Missing,
73        }
74    }
75
76    pub(crate) fn mark_borrowed(&self, tunnel_id: TunnelId) -> bool {
77        let mut state = self.state.lock().unwrap();
78        match state.get_mut(&tunnel_id) {
79            Some(entry) => match entry.state {
80                TunnelBorrowState::Idle => {
81                    entry.state = TunnelBorrowState::Borrowed;
82                    true
83                }
84                TunnelBorrowState::Borrowed => false,
85            },
86            None => {
87                state.insert(
88                    tunnel_id,
89                    TunnelRuntimeEntry {
90                        state: TunnelBorrowState::Borrowed,
91                        waiters: VecDeque::new(),
92                    },
93                );
94                true
95            }
96        }
97    }
98
99    pub(crate) fn remove(&self, tunnel_id: TunnelId) {
100        let waiters = {
101            let mut state = self.state.lock().unwrap();
102            state
103                .remove(&tunnel_id)
104                .map(|entry| entry.waiters.into_iter().collect::<Vec<_>>())
105                .unwrap_or_default()
106        };
107        for waiter in waiters {
108            waiter.notify_one();
109        }
110    }
111
112    pub(crate) fn release(&self, tunnel_id: TunnelId, alive: bool) {
113        let (waiter, waiters) = {
114            let mut state = self.state.lock().unwrap();
115            match state.get_mut(&tunnel_id) {
116                Some(entry) if alive => {
117                    entry.state = TunnelBorrowState::Idle;
118                    (entry.waiters.pop_front(), Vec::new())
119                }
120                Some(_) => {
121                    let entry = state.remove(&tunnel_id).unwrap();
122                    (None, entry.waiters.into_iter().collect::<Vec<_>>())
123                }
124                None => (None, Vec::new()),
125            }
126        };
127        if let Some(waiter) = waiter {
128            waiter.notify_one();
129        }
130        for waiter in waiters {
131            waiter.notify_one();
132        }
133    }
134
135    pub(crate) fn clear(&self) {
136        let waiters = {
137            let mut state = self.state.lock().unwrap();
138            state
139                .drain()
140                .flat_map(|(_, entry)| entry.waiters.into_iter())
141                .collect::<Vec<_>>()
142        };
143        for waiter in waiters {
144            waiter.notify_one();
145        }
146    }
147}
148
149pub struct TrackedSendGuard<
150    C: WorkerClassification,
151    M: CmdTunnelMeta,
152    S: ClassifiedWorker<C> + CmdSend<M>,
153    F: ClassifiedWorkerFactory<C, S>,
154> {
155    worker_guard: Option<ClassifiedWorkerGuard<C, S, F>>,
156    runtime_tunnels: Arc<TunnelRuntimeRegistry>,
157    tunnel_id: TunnelId,
158    _p: PhantomData<M>,
159}
160
161impl<
162    C: WorkerClassification,
163    M: CmdTunnelMeta,
164    S: ClassifiedWorker<C> + CmdSend<M>,
165    F: ClassifiedWorkerFactory<C, S>,
166> TrackedSendGuard<C, M, S, F>
167{
168    pub(crate) fn new(
169        worker_guard: ClassifiedWorkerGuard<C, S, F>,
170        runtime_tunnels: Arc<TunnelRuntimeRegistry>,
171        tunnel_id: TunnelId,
172    ) -> Self {
173        Self {
174            worker_guard: Some(worker_guard),
175            runtime_tunnels,
176            tunnel_id,
177            _p: PhantomData,
178        }
179    }
180}
181
182impl<
183    C: WorkerClassification,
184    M: CmdTunnelMeta,
185    S: ClassifiedWorker<C> + CmdSend<M>,
186    F: ClassifiedWorkerFactory<C, S>,
187> Deref for TrackedSendGuard<C, M, S, F>
188{
189    type Target = S;
190
191    fn deref(&self) -> &Self::Target {
192        self.worker_guard.as_ref().unwrap().deref()
193    }
194}
195
196impl<
197    C: WorkerClassification,
198    M: CmdTunnelMeta,
199    S: ClassifiedWorker<C> + CmdSend<M>,
200    F: ClassifiedWorkerFactory<C, S>,
201> DerefMut for TrackedSendGuard<C, M, S, F>
202{
203    fn deref_mut(&mut self) -> &mut Self::Target {
204        self.worker_guard.as_mut().unwrap().deref_mut()
205    }
206}
207
208impl<
209    C: WorkerClassification,
210    M: CmdTunnelMeta,
211    S: ClassifiedWorker<C> + CmdSend<M>,
212    F: ClassifiedWorkerFactory<C, S>,
213> Drop for TrackedSendGuard<C, M, S, F>
214{
215    fn drop(&mut self) {
216        if let Some(worker_guard) = self.worker_guard.take() {
217            let alive = worker_guard.is_work();
218            drop(worker_guard);
219            self.runtime_tunnels.release(self.tunnel_id, alive);
220        }
221    }
222}
223
224impl<
225    C: WorkerClassification,
226    M: CmdTunnelMeta,
227    S: ClassifiedWorker<C> + CmdSend<M>,
228    F: ClassifiedWorkerFactory<C, S>,
229> SendGuard<M, S> for TrackedSendGuard<C, M, S, F>
230{
231}
232
233#[async_trait::async_trait]
234pub trait CmdClient<
235    LEN: RawEncode
236        + for<'a> RawDecode<'a>
237        + Copy
238        + RawFixedBytes
239        + Sync
240        + Send
241        + 'static
242        + FromPrimitive
243        + ToPrimitive,
244    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + RawFixedBytes + Sync + Send + 'static + Eq + Hash,
245    M: CmdTunnelMeta,
246    S: CmdSend<M>,
247    G: SendGuard<M, S>,
248>: Send + Sync + 'static
249{
250    fn register_cmd_handler(&self, cmd: CMD, handler: impl CmdHandler<LEN, CMD>);
251    async fn send(&self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()>;
252    async fn send_with_resp(
253        &self,
254        cmd: CMD,
255        version: u8,
256        body: &[u8],
257        timeout: Duration,
258    ) -> CmdResult<CmdBody>;
259    async fn send_parts(&self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()>;
260    async fn send_parts_with_resp(
261        &self,
262        cmd: CMD,
263        version: u8,
264        body: &[&[u8]],
265        timeout: Duration,
266    ) -> CmdResult<CmdBody>;
267    #[deprecated(note = "use send_parts instead")]
268    async fn send2(&self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
269        self.send_parts(cmd, version, body).await
270    }
271    #[deprecated(note = "use send_parts_with_resp instead")]
272    async fn send2_with_resp(
273        &self,
274        cmd: CMD,
275        version: u8,
276        body: &[&[u8]],
277        timeout: Duration,
278    ) -> CmdResult<CmdBody> {
279        self.send_parts_with_resp(cmd, version, body, timeout).await
280    }
281    async fn send_cmd(&self, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()>;
282    async fn send_cmd_with_resp(
283        &self,
284        cmd: CMD,
285        version: u8,
286        body: CmdBody,
287        timeout: Duration,
288    ) -> CmdResult<CmdBody>;
289    async fn send_by_specify_tunnel(
290        &self,
291        tunnel_id: TunnelId,
292        cmd: CMD,
293        version: u8,
294        body: &[u8],
295    ) -> CmdResult<()>;
296    async fn send_by_specify_tunnel_with_resp(
297        &self,
298        tunnel_id: TunnelId,
299        cmd: CMD,
300        version: u8,
301        body: &[u8],
302        timeout: Duration,
303    ) -> CmdResult<CmdBody>;
304    async fn send_parts_by_specify_tunnel(
305        &self,
306        tunnel_id: TunnelId,
307        cmd: CMD,
308        version: u8,
309        body: &[&[u8]],
310    ) -> CmdResult<()>;
311    async fn send_parts_by_specify_tunnel_with_resp(
312        &self,
313        tunnel_id: TunnelId,
314        cmd: CMD,
315        version: u8,
316        body: &[&[u8]],
317        timeout: Duration,
318    ) -> CmdResult<CmdBody>;
319    #[deprecated(note = "use send_parts_by_specify_tunnel instead")]
320    async fn send2_by_specify_tunnel(
321        &self,
322        tunnel_id: TunnelId,
323        cmd: CMD,
324        version: u8,
325        body: &[&[u8]],
326    ) -> CmdResult<()> {
327        self.send_parts_by_specify_tunnel(tunnel_id, cmd, version, body)
328            .await
329    }
330    #[deprecated(note = "use send_parts_by_specify_tunnel_with_resp instead")]
331    async fn send2_by_specify_tunnel_with_resp(
332        &self,
333        tunnel_id: TunnelId,
334        cmd: CMD,
335        version: u8,
336        body: &[&[u8]],
337        timeout: Duration,
338    ) -> CmdResult<CmdBody> {
339        self.send_parts_by_specify_tunnel_with_resp(tunnel_id, cmd, version, body, timeout)
340            .await
341    }
342    async fn send_cmd_by_specify_tunnel(
343        &self,
344        tunnel_id: TunnelId,
345        cmd: CMD,
346        version: u8,
347        body: CmdBody,
348    ) -> CmdResult<()>;
349    async fn send_cmd_by_specify_tunnel_with_resp(
350        &self,
351        tunnel_id: TunnelId,
352        cmd: CMD,
353        version: u8,
354        body: CmdBody,
355        timeout: Duration,
356    ) -> CmdResult<CmdBody>;
357    async fn clear_all_tunnel(&self);
358    async fn get_send(&self, tunnel_id: TunnelId) -> CmdResult<G>;
359}
360
361#[async_trait::async_trait]
362pub trait ClassifiedCmdClient<
363    LEN: RawEncode
364        + for<'a> RawDecode<'a>
365        + Copy
366        + RawFixedBytes
367        + Sync
368        + Send
369        + 'static
370        + FromPrimitive
371        + ToPrimitive,
372    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + RawFixedBytes + Sync + Send + 'static + Eq + Hash,
373    C: WorkerClassification,
374    M: CmdTunnelMeta,
375    S: CmdSend<M>,
376    G: SendGuard<M, S>,
377>: CmdClient<LEN, CMD, M, S, G>
378{
379    async fn send_by_classified_tunnel(
380        &self,
381        classification: C,
382        cmd: CMD,
383        version: u8,
384        body: &[u8],
385    ) -> CmdResult<()>;
386    async fn send_by_classified_tunnel_with_resp(
387        &self,
388        classification: C,
389        cmd: CMD,
390        version: u8,
391        body: &[u8],
392        timeout: Duration,
393    ) -> CmdResult<CmdBody>;
394    async fn send_parts_by_classified_tunnel(
395        &self,
396        classification: C,
397        cmd: CMD,
398        version: u8,
399        body: &[&[u8]],
400    ) -> CmdResult<()>;
401    async fn send_parts_by_classified_tunnel_with_resp(
402        &self,
403        classification: C,
404        cmd: CMD,
405        version: u8,
406        body: &[&[u8]],
407        timeout: Duration,
408    ) -> CmdResult<CmdBody>;
409    #[deprecated(note = "use send_parts_by_classified_tunnel instead")]
410    async fn send2_by_classified_tunnel(
411        &self,
412        classification: C,
413        cmd: CMD,
414        version: u8,
415        body: &[&[u8]],
416    ) -> CmdResult<()> {
417        self.send_parts_by_classified_tunnel(classification, cmd, version, body)
418            .await
419    }
420    #[deprecated(note = "use send_parts_by_classified_tunnel_with_resp instead")]
421    async fn send2_by_classified_tunnel_with_resp(
422        &self,
423        classification: C,
424        cmd: CMD,
425        version: u8,
426        body: &[&[u8]],
427        timeout: Duration,
428    ) -> CmdResult<CmdBody> {
429        self.send_parts_by_classified_tunnel_with_resp(classification, cmd, version, body, timeout)
430            .await
431    }
432    async fn send_cmd_by_classified_tunnel(
433        &self,
434        classification: C,
435        cmd: CMD,
436        version: u8,
437        body: CmdBody,
438    ) -> CmdResult<()>;
439    async fn send_cmd_by_classified_tunnel_with_resp(
440        &self,
441        classification: C,
442        cmd: CMD,
443        version: u8,
444        body: CmdBody,
445        timeout: Duration,
446    ) -> CmdResult<CmdBody>;
447    async fn find_tunnel_id_by_classified(&self, classification: C) -> CmdResult<TunnelId>;
448    async fn get_send_by_classified(&self, classification: C) -> CmdResult<G>;
449}
450
451pub(crate) type RespWaiter = CallbackWaiter<u128, CmdBody>;
452pub(crate) type RespWaiterRef = Arc<RespWaiter>;
453
454pub(crate) fn gen_resp_id<
455    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + RawFixedBytes + Sync + Send + 'static,
456>(
457    tunnel_id: TunnelId,
458    cmd: CMD,
459    seq: u32,
460) -> u128 {
461    let cmd_buf = cmd.raw_encode_to_buffer().unwrap();
462    let mut cmd = cmd_buf.len() as u64;
463    for chunk in cmd_buf.chunks(8) {
464        let mut buf = [0u8; 8];
465        buf[..chunk.len()].copy_from_slice(chunk);
466        cmd = cmd.rotate_left(13) ^ u64::from_be_bytes(buf);
467    }
468    ((tunnel_id.value() as u128) << 96) | ((seq as u128) << 64) | (cmd as u128)
469}
470
471pub(crate) fn gen_seq() -> u32 {
472    rand::random::<u32>()
473}
474
475#[cfg(test)]
476mod tests {
477    use super::gen_resp_id;
478    use crate::TunnelId;
479
480    #[test]
481    fn resp_id_changes_with_seq() {
482        let id1 = gen_resp_id(TunnelId::from(7), 0x11u8, 1);
483        let id2 = gen_resp_id(TunnelId::from(7), 0x11u8, 2);
484        assert_ne!(id1, id2);
485    }
486
487    #[test]
488    fn resp_id_changes_with_cmd() {
489        let id1 = gen_resp_id(TunnelId::from(7), 0x11u8, 5);
490        let id2 = gen_resp_id(TunnelId::from(7), 0x12u8, 5);
491        assert_ne!(id1, id2);
492    }
493
494    #[test]
495    fn resp_id_changes_with_tunnel() {
496        let id1 = gen_resp_id(TunnelId::from(7), 0x11u8, 5);
497        let id2 = gen_resp_id(TunnelId::from(8), 0x11u8, 5);
498        assert_ne!(id1, id2);
499    }
500
501    #[test]
502    fn resp_id_changes_with_long_cmd_suffix() {
503        let id1 = gen_resp_id(
504            TunnelId::from(7),
505            0x1122_3344_5566_7788_0000_0000_0000_0001u128,
506            5,
507        );
508        let id2 = gen_resp_id(
509            TunnelId::from(7),
510            0x1122_3344_5566_7788_0000_0000_0000_0002u128,
511            5,
512        );
513        assert_ne!(id1, id2);
514    }
515}