Skip to main content

sfo_cmd_server/client/
classified_client.rs

1use crate::client::{
2    ClassifiedCmdClient, ClassifiedSendGuard, CmdClient, CmdSend, RespWaiter, RespWaiterRef,
3    gen_resp_id, gen_seq,
4};
5use crate::cmd::{CmdBodyRead, CmdHandler, CmdHandlerMap, CmdHeader};
6use crate::errors::{CmdErrorCode, CmdResult, cmd_err, into_cmd_err};
7use crate::peer_id::PeerId;
8use crate::{CmdBody, CmdTunnelMeta, CmdTunnelRead, CmdTunnelWrite, TunnelId, TunnelIdGenerator};
9use async_named_locker::ObjectHolder;
10use bucky_raw_codec::{RawConvertTo, RawDecode, RawEncode, RawFixedBytes, RawFrom};
11use num::{FromPrimitive, ToPrimitive};
12use sfo_pool::{
13    ClassifiedWorker, ClassifiedWorkerFactory, ClassifiedWorkerGuard, ClassifiedWorkerPool,
14    ClassifiedWorkerPoolRef, PoolErrorCode, PoolResult, WorkerClassification, into_pool_err,
15    pool_err,
16};
17use sfo_split::{RHalf, Splittable, WHalf};
18use std::fmt::Debug;
19use std::hash::Hash;
20use std::ops::DerefMut;
21use std::sync::{Arc, Mutex};
22use std::time::Duration;
23use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
24use tokio::spawn;
25use tokio::task::JoinHandle;
26
27pub trait ClassifiedCmdTunnelRead<C: WorkerClassification, M: CmdTunnelMeta>:
28    CmdTunnelRead<M> + 'static + Send
29{
30    fn get_classification(&self) -> C;
31}
32
33pub trait ClassifiedCmdTunnelWrite<C: WorkerClassification, M: CmdTunnelMeta>:
34    CmdTunnelWrite<M> + 'static + Send
35{
36    fn get_classification(&self) -> C;
37}
38
39pub type ClassifiedCmdTunnel<R, W> = Splittable<R, W>;
40pub type ClassifiedCmdTunnelRHalf<R, W> = RHalf<R, W>;
41pub type ClassifiedCmdTunnelWHalf<R, W> = WHalf<R, W>;
42
43#[derive(Debug, Clone, Copy, Eq, Hash)]
44pub struct CmdClientTunnelClassification<C: WorkerClassification> {
45    tunnel_id: Option<TunnelId>,
46    classification: Option<C>,
47}
48
49impl<C: WorkerClassification> PartialEq for CmdClientTunnelClassification<C> {
50    fn eq(&self, other: &Self) -> bool {
51        self.tunnel_id == other.tunnel_id && self.classification == other.classification
52    }
53}
54
55#[async_trait::async_trait]
56pub trait ClassifiedCmdTunnelFactory<
57    C: WorkerClassification,
58    M: CmdTunnelMeta,
59    R: ClassifiedCmdTunnelRead<C, M>,
60    W: ClassifiedCmdTunnelWrite<C, M>,
61>: Send + Sync + 'static
62{
63    async fn create_tunnel(&self, classification: Option<C>) -> CmdResult<Splittable<R, W>>;
64}
65
66pub struct ClassifiedCmdSend<C, M, R, W, LEN, CMD>
67where
68    C: WorkerClassification,
69    M: CmdTunnelMeta,
70    R: ClassifiedCmdTunnelRead<C, M>,
71    W: ClassifiedCmdTunnelWrite<C, M>,
72    LEN: RawEncode
73        + for<'a> RawDecode<'a>
74        + Copy
75        + Send
76        + Sync
77        + 'static
78        + FromPrimitive
79        + ToPrimitive,
80    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
81{
82    pub(crate) recv_handle: JoinHandle<CmdResult<()>>,
83    pub(crate) write: ObjectHolder<ClassifiedCmdTunnelWHalf<R, W>>,
84    pub(crate) is_work: bool,
85    pub(crate) classification: C,
86    pub(crate) tunnel_id: TunnelId,
87    pub(crate) resp_waiter: RespWaiterRef,
88    pub(crate) remote_id: PeerId,
89    pub(crate) tunnel_meta: Option<Arc<M>>,
90    _p: std::marker::PhantomData<(LEN, CMD)>,
91}
92
93// impl<C, R, W, LEN, CMD> Deref for ClassifiedCmdSend<C, R, W, LEN, CMD>
94// where C: WorkerClassification,
95//       R: ClassifiedCmdTunnelRead<C>,
96//       W: ClassifiedCmdTunnelWrite<C>,
97//       LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
98//       CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes {
99//     type Target = W;
100//
101//     fn deref(&self) -> &Self::Target {
102//         self.write.deref()
103//     }
104// }
105
106impl<C, M, R, W, LEN, CMD> ClassifiedCmdSend<C, M, R, W, LEN, CMD>
107where
108    C: WorkerClassification,
109    M: CmdTunnelMeta,
110    R: ClassifiedCmdTunnelRead<C, M>,
111    W: ClassifiedCmdTunnelWrite<C, M>,
112    LEN: RawEncode
113        + for<'a> RawDecode<'a>
114        + Copy
115        + Send
116        + Sync
117        + 'static
118        + FromPrimitive
119        + ToPrimitive,
120    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
121{
122    pub(crate) fn new(
123        tunnel_id: TunnelId,
124        classification: C,
125        recv_handle: JoinHandle<CmdResult<()>>,
126        write: ObjectHolder<ClassifiedCmdTunnelWHalf<R, W>>,
127        resp_waiter: RespWaiterRef,
128        remote_id: PeerId,
129        tunnel_meta: Option<Arc<M>>,
130    ) -> Self {
131        Self {
132            recv_handle,
133            write,
134            is_work: true,
135            classification,
136            tunnel_id,
137            resp_waiter,
138            remote_id,
139            tunnel_meta,
140            _p: Default::default(),
141        }
142    }
143
144    pub fn get_tunnel_id(&self) -> TunnelId {
145        self.tunnel_id
146    }
147
148    pub fn set_disable(&mut self) {
149        self.is_work = false;
150        self.recv_handle.abort();
151    }
152
153    pub async fn send(&mut self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
154        log::trace!(
155            "client {:?} send cmd: {:?}, len: {}, data: {}",
156            self.tunnel_id,
157            cmd,
158            body.len(),
159            hex::encode(body)
160        );
161        let header = CmdHeader::<LEN, CMD>::new(
162            version,
163            false,
164            None,
165            cmd,
166            LEN::from_u64(body.len() as u64).unwrap(),
167        );
168        let buf = header
169            .to_vec()
170            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
171        let ret = self.send_inner(buf.as_slice(), body).await;
172        if let Err(e) = ret {
173            self.set_disable();
174            return Err(e);
175        }
176        Ok(())
177    }
178
179    pub async fn send_with_resp(
180        &mut self,
181        cmd: CMD,
182        version: u8,
183        body: &[u8],
184        timeout: Duration,
185    ) -> CmdResult<CmdBody> {
186        if let Some(id) = tokio::task::try_id() {
187            if id == self.recv_handle.id() {
188                return Err(cmd_err!(
189                    CmdErrorCode::Failed,
190                    "can't send with resp in recv task"
191                ));
192            }
193        }
194        log::trace!(
195            "client {:?} send cmd: {:?}, len: {}, data: {}",
196            self.tunnel_id,
197            cmd,
198            body.len(),
199            hex::encode(body)
200        );
201        let seq = gen_seq();
202        let header = CmdHeader::<LEN, CMD>::new(
203            version,
204            false,
205            Some(seq),
206            cmd,
207            LEN::from_u64(body.len() as u64).unwrap(),
208        );
209        let buf = header
210            .to_vec()
211            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
212        let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
213        let waiter = self.resp_waiter.clone();
214        let resp_waiter = waiter
215            .create_timeout_result_future(resp_id, timeout)
216            .map_err(into_cmd_err!(
217                CmdErrorCode::Failed,
218                "create timeout result future error"
219            ))?;
220        let ret = self.send_inner(buf.as_slice(), body).await;
221        if let Err(e) = ret {
222            self.set_disable();
223            return Err(e);
224        }
225        let resp = resp_waiter
226            .await
227            .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
228        Ok(resp)
229    }
230
231    pub async fn send2(&mut self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
232        let mut len = 0;
233        for b in body.iter() {
234            len += b.len();
235            log::trace!(
236                "client {:?} send2 cmd {:?} body: {}",
237                self.tunnel_id,
238                cmd,
239                hex::encode(b)
240            );
241        }
242        log::trace!(
243            "client {:?} send2 cmd: {:?}, len {}",
244            self.tunnel_id,
245            cmd,
246            len
247        );
248        let header = CmdHeader::<LEN, CMD>::new(
249            version,
250            false,
251            None,
252            cmd,
253            LEN::from_u64(len as u64).unwrap(),
254        );
255        let buf = header
256            .to_vec()
257            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
258        let ret = self.send_inner2(buf.as_slice(), body).await;
259        if let Err(e) = ret {
260            self.set_disable();
261            return Err(e);
262        }
263        Ok(())
264    }
265
266    pub async fn send2_with_resp(
267        &mut self,
268        cmd: CMD,
269        version: u8,
270        body: &[&[u8]],
271        timeout: Duration,
272    ) -> CmdResult<CmdBody> {
273        if let Some(id) = tokio::task::try_id() {
274            if id == self.recv_handle.id() {
275                return Err(cmd_err!(
276                    CmdErrorCode::Failed,
277                    "can't send with resp in recv task"
278                ));
279            }
280        }
281        let mut len = 0;
282        for b in body.iter() {
283            len += b.len();
284            log::trace!(
285                "client {:?} send2 cmd {:?} body: {}",
286                self.tunnel_id,
287                cmd,
288                hex::encode(b)
289            );
290        }
291        log::trace!(
292            "client {:?} send2 cmd: {:?}, len {}",
293            self.tunnel_id,
294            cmd,
295            len
296        );
297        let seq = gen_seq();
298        let header = CmdHeader::<LEN, CMD>::new(
299            version,
300            false,
301            Some(seq),
302            cmd,
303            LEN::from_u64(len as u64).unwrap(),
304        );
305        let buf = header
306            .to_vec()
307            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
308        let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
309        let waiter = self.resp_waiter.clone();
310        let resp_waiter = waiter
311            .create_timeout_result_future(resp_id, timeout)
312            .map_err(into_cmd_err!(
313                CmdErrorCode::Failed,
314                "create timeout result future error"
315            ))?;
316        let ret = self.send_inner2(buf.as_slice(), body).await;
317        if let Err(e) = ret {
318            self.set_disable();
319            return Err(e);
320        }
321        let resp = resp_waiter
322            .await
323            .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
324        Ok(resp)
325    }
326
327    pub async fn send_cmd(&mut self, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
328        log::trace!(
329            "client {:?} send cmd: {:?}, len: {}",
330            self.tunnel_id,
331            cmd,
332            body.len()
333        );
334        let header = CmdHeader::<LEN, CMD>::new(
335            version,
336            false,
337            None,
338            cmd,
339            LEN::from_u64(body.len()).unwrap(),
340        );
341        let buf = header
342            .to_vec()
343            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
344        let ret = self.send_inner_cmd(buf.as_slice(), body).await;
345        if let Err(e) = ret {
346            self.set_disable();
347            return Err(e);
348        }
349        Ok(())
350    }
351
352    pub async fn send_cmd_with_resp(
353        &mut self,
354        cmd: CMD,
355        version: u8,
356        body: CmdBody,
357        timeout: Duration,
358    ) -> CmdResult<CmdBody> {
359        if let Some(id) = tokio::task::try_id() {
360            if id == self.recv_handle.id() {
361                return Err(cmd_err!(
362                    CmdErrorCode::Failed,
363                    "can't send with resp in recv task"
364                ));
365            }
366        }
367        log::trace!(
368            "client {:?} send cmd: {:?}, len: {}",
369            self.tunnel_id,
370            cmd,
371            body.len()
372        );
373        let seq = gen_seq();
374        let header = CmdHeader::<LEN, CMD>::new(
375            version,
376            false,
377            Some(seq),
378            cmd,
379            LEN::from_u64(body.len()).unwrap(),
380        );
381        let buf = header
382            .to_vec()
383            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
384        let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
385        let waiter = self.resp_waiter.clone();
386        let resp_waiter = waiter
387            .create_timeout_result_future(resp_id, timeout)
388            .map_err(into_cmd_err!(
389                CmdErrorCode::Failed,
390                "create timeout result future error"
391            ))?;
392        let ret = self.send_inner_cmd(buf.as_slice(), body).await;
393        if let Err(e) = ret {
394            self.set_disable();
395            return Err(e);
396        }
397        let resp = resp_waiter
398            .await
399            .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
400        Ok(resp)
401    }
402
403    async fn send_inner(&mut self, header: &[u8], body: &[u8]) -> CmdResult<()> {
404        let mut write = self.write.get().await;
405        if header.len() > 255 {
406            return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
407        }
408        write
409            .write_u8(header.len() as u8)
410            .await
411            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
412        write
413            .write_all(header)
414            .await
415            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
416        write
417            .write_all(body)
418            .await
419            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
420        write
421            .flush()
422            .await
423            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
424        Ok(())
425    }
426
427    async fn send_inner2(&mut self, header: &[u8], body: &[&[u8]]) -> CmdResult<()> {
428        let mut write = self.write.get().await;
429        if header.len() > 255 {
430            return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
431        }
432        write
433            .write_u8(header.len() as u8)
434            .await
435            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
436        write
437            .write_all(header)
438            .await
439            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
440        for b in body.iter() {
441            write
442                .write_all(b)
443                .await
444                .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
445        }
446        write
447            .flush()
448            .await
449            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
450        Ok(())
451    }
452
453    async fn send_inner_cmd(&mut self, header: &[u8], mut body: CmdBody) -> CmdResult<()> {
454        let mut write = self.write.get().await;
455        if header.len() > 255 {
456            return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
457        }
458        write
459            .write_u8(header.len() as u8)
460            .await
461            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
462        write
463            .write_all(header)
464            .await
465            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
466        tokio::io::copy(&mut body, write.deref_mut().deref_mut())
467            .await
468            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
469        write
470            .flush()
471            .await
472            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
473        Ok(())
474    }
475}
476
477impl<C, M, R, W, LEN, CMD> Drop for ClassifiedCmdSend<C, M, R, W, LEN, CMD>
478where
479    C: WorkerClassification,
480    M: CmdTunnelMeta,
481    R: ClassifiedCmdTunnelRead<C, M>,
482    W: ClassifiedCmdTunnelWrite<C, M>,
483    LEN: RawEncode
484        + for<'a> RawDecode<'a>
485        + Copy
486        + Send
487        + Sync
488        + 'static
489        + FromPrimitive
490        + ToPrimitive,
491    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
492{
493    fn drop(&mut self) {
494        self.set_disable();
495    }
496}
497
498impl<C, M, R, W, LEN, CMD> CmdSend<M> for ClassifiedCmdSend<C, M, R, W, LEN, CMD>
499where
500    C: WorkerClassification,
501    M: CmdTunnelMeta,
502    R: ClassifiedCmdTunnelRead<C, M>,
503    W: ClassifiedCmdTunnelWrite<C, M>,
504    LEN: RawEncode
505        + for<'a> RawDecode<'a>
506        + Copy
507        + Send
508        + Sync
509        + 'static
510        + FromPrimitive
511        + ToPrimitive,
512    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
513{
514    fn get_tunnel_meta(&self) -> Option<Arc<M>> {
515        self.tunnel_meta.clone()
516    }
517
518    fn get_remote_peer_id(&self) -> PeerId {
519        self.remote_id.clone()
520    }
521}
522
523impl<C, M, R, W, LEN, CMD> ClassifiedWorker<CmdClientTunnelClassification<C>>
524    for ClassifiedCmdSend<C, M, R, W, LEN, CMD>
525where
526    C: WorkerClassification,
527    M: CmdTunnelMeta,
528    R: ClassifiedCmdTunnelRead<C, M>,
529    W: ClassifiedCmdTunnelWrite<C, M>,
530    LEN: RawEncode
531        + for<'a> RawDecode<'a>
532        + Copy
533        + Send
534        + Sync
535        + 'static
536        + FromPrimitive
537        + ToPrimitive,
538    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
539{
540    fn is_work(&self) -> bool {
541        self.is_work && !self.recv_handle.is_finished()
542    }
543
544    fn is_valid(&self, c: CmdClientTunnelClassification<C>) -> bool {
545        if c.tunnel_id.is_some() {
546            self.tunnel_id == c.tunnel_id.unwrap()
547        } else {
548            if c.classification.is_some() {
549                self.classification == c.classification.unwrap()
550            } else {
551                true
552            }
553        }
554    }
555
556    fn classification(&self) -> CmdClientTunnelClassification<C> {
557        CmdClientTunnelClassification {
558            tunnel_id: Some(self.tunnel_id),
559            classification: Some(self.classification.clone()),
560        }
561    }
562}
563
564pub struct ClassifiedCmdWriteFactory<
565    C: WorkerClassification,
566    M: CmdTunnelMeta,
567    R: ClassifiedCmdTunnelRead<C, M>,
568    W: ClassifiedCmdTunnelWrite<C, M>,
569    F: ClassifiedCmdTunnelFactory<C, M, R, W>,
570    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
571    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
572> {
573    tunnel_factory: F,
574    cmd_handler: Arc<dyn CmdHandler<LEN, CMD>>,
575    resp_waiter: RespWaiterRef,
576    tunnel_id_generator: TunnelIdGenerator,
577    _p: std::marker::PhantomData<Mutex<(C, M, R, W)>>,
578}
579
580impl<
581    C: WorkerClassification,
582    M: CmdTunnelMeta,
583    R: ClassifiedCmdTunnelRead<C, M>,
584    W: ClassifiedCmdTunnelWrite<C, M>,
585    F: ClassifiedCmdTunnelFactory<C, M, R, W>,
586    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
587    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
588> ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>
589{
590    pub(crate) fn new(
591        tunnel_factory: F,
592        cmd_handler: impl CmdHandler<LEN, CMD>,
593        resp_waiter: RespWaiterRef,
594    ) -> Self {
595        Self {
596            tunnel_factory,
597            cmd_handler: Arc::new(cmd_handler),
598            resp_waiter,
599            tunnel_id_generator: TunnelIdGenerator::new(),
600            _p: Default::default(),
601        }
602    }
603}
604
605#[async_trait::async_trait]
606impl<
607    C: WorkerClassification,
608    M: CmdTunnelMeta,
609    R: ClassifiedCmdTunnelRead<C, M>,
610    W: ClassifiedCmdTunnelWrite<C, M>,
611    F: ClassifiedCmdTunnelFactory<C, M, R, W>,
612    LEN: RawEncode
613        + for<'a> RawDecode<'a>
614        + Copy
615        + Send
616        + Sync
617        + 'static
618        + FromPrimitive
619        + ToPrimitive
620        + RawFixedBytes,
621    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Debug,
622> ClassifiedWorkerFactory<CmdClientTunnelClassification<C>, ClassifiedCmdSend<C, M, R, W, LEN, CMD>>
623    for ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>
624{
625    async fn create(
626        &self,
627        classification: Option<CmdClientTunnelClassification<C>>,
628    ) -> PoolResult<ClassifiedCmdSend<C, M, R, W, LEN, CMD>> {
629        if classification.is_some() && classification.as_ref().unwrap().tunnel_id.is_some() {
630            return Err(pool_err!(
631                PoolErrorCode::Failed,
632                "tunnel {:?} not found",
633                classification.as_ref().unwrap().tunnel_id.unwrap()
634            ));
635        }
636
637        let classification = if classification.is_some()
638            && classification.as_ref().unwrap().classification.is_some()
639        {
640            classification.unwrap().classification
641        } else {
642            None
643        };
644        let tunnel = self
645            .tunnel_factory
646            .create_tunnel(classification)
647            .await
648            .map_err(into_pool_err!(PoolErrorCode::Failed))?;
649        let classification = tunnel.get_classification();
650        let peer_id = tunnel.get_remote_peer_id();
651        let tunnel_id = self.tunnel_id_generator.generate();
652        let (mut recv, write) = tunnel.split();
653        let local_id = recv.get_local_peer_id();
654        let remote_id = peer_id.clone();
655        let tunnel_meta = recv.get_tunnel_meta();
656        let write = ObjectHolder::new(write);
657        let resp_write = write.clone();
658        let cmd_handler = self.cmd_handler.clone();
659        let handle = spawn(async move {
660            let ret: CmdResult<()> = async move {
661                loop {
662                    let header_len = recv
663                        .read_u8()
664                        .await
665                        .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
666                    let mut header = vec![0u8; header_len as usize];
667                    let n = recv
668                        .read_exact(header.as_mut())
669                        .await
670                        .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
671                    if n == 0 {
672                        break;
673                    }
674                    let header = CmdHeader::<LEN, CMD>::clone_from_slice(header.as_slice())
675                        .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
676                    log::trace!(
677                        "recv cmd {:?} from {} len {} tunnel {:?}",
678                        header.cmd_code(),
679                        peer_id,
680                        header.pkg_len().to_u64().unwrap(),
681                        tunnel_id
682                    );
683                    let body_len = header.pkg_len().to_u64().unwrap();
684                    let cmd_read =
685                        CmdBodyRead::new(recv, header.pkg_len().to_u64().unwrap() as usize);
686                    let waiter = cmd_read.get_waiter();
687                    let future = waiter
688                        .create_result_future()
689                        .map_err(into_cmd_err!(CmdErrorCode::Failed))?;
690                    let version = header.version();
691                    let seq = header.seq();
692                    let cmd_code = header.cmd_code();
693                    match cmd_handler
694                        .handle(
695                            local_id.clone(),
696                            peer_id.clone(),
697                            tunnel_id,
698                            header,
699                            CmdBody::from_reader(BufReader::new(cmd_read), body_len),
700                        )
701                        .await
702                    {
703                        Ok(Some(mut body)) => {
704                            let mut write = resp_write.get().await;
705                            let header = CmdHeader::<LEN, CMD>::new(
706                                version,
707                                true,
708                                seq,
709                                cmd_code,
710                                LEN::from_u64(body.len()).unwrap(),
711                            );
712                            let buf = header
713                                .to_vec()
714                                .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
715                            if buf.len() > 255 {
716                                return Err(cmd_err!(
717                                    CmdErrorCode::InvalidParam,
718                                    "header len too large"
719                                ));
720                            }
721                            write
722                                .write_u8(buf.len() as u8)
723                                .await
724                                .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
725                            write
726                                .write_all(buf.as_slice())
727                                .await
728                                .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
729                            tokio::io::copy(&mut body, write.deref_mut().deref_mut())
730                                .await
731                                .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
732                            write
733                                .flush()
734                                .await
735                                .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
736                        }
737                        Err(e) => {
738                            log::error!("handle cmd error: {:?}", e);
739                        }
740                        _ => {}
741                    }
742                    recv = future
743                        .await
744                        .map_err(into_cmd_err!(CmdErrorCode::Failed))??;
745                    log::debug!(
746                        "handle cmd {:?} from {} len {} tunnel {:?} complete",
747                        cmd_code,
748                        peer_id,
749                        body_len,
750                        tunnel_id
751                    );
752                }
753                Ok(())
754            }
755            .await;
756            if ret.is_err() {
757                log::error!("recv cmd error: {:?}", ret.as_ref().err().unwrap());
758            }
759            ret
760        });
761        Ok(ClassifiedCmdSend::new(
762            tunnel_id,
763            classification,
764            handle,
765            write,
766            self.resp_waiter.clone(),
767            remote_id,
768            tunnel_meta,
769        ))
770    }
771}
772
773pub struct DefaultClassifiedCmdClient<
774    C: WorkerClassification,
775    M: CmdTunnelMeta,
776    R: ClassifiedCmdTunnelRead<C, M>,
777    W: ClassifiedCmdTunnelWrite<C, M>,
778    F: ClassifiedCmdTunnelFactory<C, M, R, W>,
779    LEN: RawEncode
780        + for<'a> RawDecode<'a>
781        + Copy
782        + Send
783        + Sync
784        + 'static
785        + FromPrimitive
786        + ToPrimitive
787        + RawFixedBytes,
788    CMD: RawEncode
789        + for<'a> RawDecode<'a>
790        + Copy
791        + Send
792        + Sync
793        + 'static
794        + RawFixedBytes
795        + Eq
796        + Hash
797        + Debug,
798> {
799    tunnel_pool: ClassifiedWorkerPoolRef<
800        CmdClientTunnelClassification<C>,
801        ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
802        ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>,
803    >,
804    cmd_handler_map: Arc<CmdHandlerMap<LEN, CMD>>,
805}
806
807impl<
808    C: WorkerClassification,
809    M: CmdTunnelMeta,
810    R: ClassifiedCmdTunnelRead<C, M>,
811    W: ClassifiedCmdTunnelWrite<C, M>,
812    F: ClassifiedCmdTunnelFactory<C, M, R, W>,
813    LEN: RawEncode
814        + for<'a> RawDecode<'a>
815        + Copy
816        + Send
817        + Sync
818        + 'static
819        + FromPrimitive
820        + ToPrimitive
821        + RawFixedBytes,
822    CMD: RawEncode
823        + for<'a> RawDecode<'a>
824        + Copy
825        + Send
826        + Sync
827        + 'static
828        + RawFixedBytes
829        + Eq
830        + Hash
831        + Debug,
832> DefaultClassifiedCmdClient<C, M, R, W, F, LEN, CMD>
833{
834    pub fn new(factory: F, tunnel_count: u16) -> Arc<Self> {
835        let cmd_handler_map = Arc::new(CmdHandlerMap::new());
836        let resp_waiter = Arc::new(RespWaiter::new());
837        let handler_map = cmd_handler_map.clone();
838        let waiter = resp_waiter.clone();
839        Arc::new(Self {
840            tunnel_pool: ClassifiedWorkerPool::new(
841                tunnel_count,
842                ClassifiedCmdWriteFactory::<C, M, R, W, _, LEN, CMD>::new(
843                    factory,
844                    move |local_id: PeerId,
845                          peer_id: PeerId,
846                          tunnel_id: TunnelId,
847                          header: CmdHeader<LEN, CMD>,
848                          body_read: CmdBody| {
849                        let handler_map = handler_map.clone();
850                        let waiter = waiter.clone();
851                        async move {
852                            if header.is_resp() && header.seq().is_some() {
853                                let resp_id = gen_resp_id(
854                                    tunnel_id,
855                                    header.cmd_code(),
856                                    header.seq().unwrap(),
857                                );
858                                let _ = waiter.set_result(resp_id, body_read);
859                                Ok(None)
860                            } else {
861                                if let Some(handler) = handler_map.get(header.cmd_code()) {
862                                    handler
863                                        .handle(local_id, peer_id, tunnel_id, header, body_read)
864                                        .await
865                                } else {
866                                    Ok(None)
867                                }
868                            }
869                        }
870                    },
871                    resp_waiter.clone(),
872                ),
873            ),
874            cmd_handler_map,
875        })
876    }
877
878    async fn get_send(
879        &self,
880    ) -> CmdResult<
881        ClassifiedWorkerGuard<
882            CmdClientTunnelClassification<C>,
883            ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
884            ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>,
885        >,
886    > {
887        self.tunnel_pool
888            .get_worker()
889            .await
890            .map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
891    }
892
893    async fn get_send_of_tunnel_id(
894        &self,
895        tunnel_id: TunnelId,
896    ) -> CmdResult<
897        ClassifiedWorkerGuard<
898            CmdClientTunnelClassification<C>,
899            ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
900            ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>,
901        >,
902    > {
903        self.tunnel_pool
904            .get_classified_worker(CmdClientTunnelClassification {
905                tunnel_id: Some(tunnel_id),
906                classification: None,
907            })
908            .await
909            .map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
910    }
911
912    async fn get_classified_send(
913        &self,
914        classification: C,
915    ) -> CmdResult<
916        ClassifiedWorkerGuard<
917            CmdClientTunnelClassification<C>,
918            ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
919            ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>,
920        >,
921    > {
922        self.tunnel_pool
923            .get_classified_worker(CmdClientTunnelClassification {
924                tunnel_id: None,
925                classification: Some(classification),
926            })
927            .await
928            .map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
929    }
930}
931
932pub type ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD> = ClassifiedSendGuard<
933    CmdClientTunnelClassification<C>,
934    M,
935    ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
936    ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>,
937>;
938#[async_trait::async_trait]
939impl<
940    C: WorkerClassification,
941    M: CmdTunnelMeta,
942    R: ClassifiedCmdTunnelRead<C, M>,
943    W: ClassifiedCmdTunnelWrite<C, M>,
944    F: ClassifiedCmdTunnelFactory<C, M, R, W>,
945    LEN: RawEncode
946        + for<'a> RawDecode<'a>
947        + Copy
948        + Send
949        + Sync
950        + 'static
951        + FromPrimitive
952        + ToPrimitive
953        + RawFixedBytes,
954    CMD: RawEncode
955        + for<'a> RawDecode<'a>
956        + Copy
957        + Send
958        + Sync
959        + 'static
960        + RawFixedBytes
961        + Eq
962        + Hash
963        + Debug,
964>
965    CmdClient<
966        LEN,
967        CMD,
968        M,
969        ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
970        ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD>,
971    > for DefaultClassifiedCmdClient<C, M, R, W, F, LEN, CMD>
972{
973    fn register_cmd_handler(&self, cmd: CMD, handler: impl CmdHandler<LEN, CMD>) {
974        self.cmd_handler_map.insert(cmd, handler);
975    }
976
977    async fn send(&self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
978        let mut send = self.get_send().await?;
979        send.send(cmd, version, body).await
980    }
981
982    async fn send_with_resp(
983        &self,
984        cmd: CMD,
985        version: u8,
986        body: &[u8],
987        timeout: Duration,
988    ) -> CmdResult<CmdBody> {
989        let mut send = self.get_send().await?;
990        send.send_with_resp(cmd, version, body, timeout).await
991    }
992
993    async fn send2(&self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
994        let mut send = self.get_send().await?;
995        send.send2(cmd, version, body).await
996    }
997
998    async fn send2_with_resp(
999        &self,
1000        cmd: CMD,
1001        version: u8,
1002        body: &[&[u8]],
1003        timeout: Duration,
1004    ) -> CmdResult<CmdBody> {
1005        let mut send = self.get_send().await?;
1006        send.send2_with_resp(cmd, version, body, timeout).await
1007    }
1008
1009    async fn send_cmd(&self, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
1010        let mut send = self.get_send().await?;
1011        send.send_cmd(cmd, version, body).await
1012    }
1013
1014    async fn send_cmd_with_resp(
1015        &self,
1016        cmd: CMD,
1017        version: u8,
1018        body: CmdBody,
1019        timeout: Duration,
1020    ) -> CmdResult<CmdBody> {
1021        let mut send = self.get_send().await?;
1022        send.send_cmd_with_resp(cmd, version, body, timeout).await
1023    }
1024
1025    async fn send_by_specify_tunnel(
1026        &self,
1027        tunnel_id: TunnelId,
1028        cmd: CMD,
1029        version: u8,
1030        body: &[u8],
1031    ) -> CmdResult<()> {
1032        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1033        send.send(cmd, version, body).await
1034    }
1035
1036    async fn send_by_specify_tunnel_with_resp(
1037        &self,
1038        tunnel_id: TunnelId,
1039        cmd: CMD,
1040        version: u8,
1041        body: &[u8],
1042        timeout: Duration,
1043    ) -> CmdResult<CmdBody> {
1044        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1045        send.send_with_resp(cmd, version, body, timeout).await
1046    }
1047
1048    async fn send2_by_specify_tunnel(
1049        &self,
1050        tunnel_id: TunnelId,
1051        cmd: CMD,
1052        version: u8,
1053        body: &[&[u8]],
1054    ) -> CmdResult<()> {
1055        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1056        send.send2(cmd, version, body).await
1057    }
1058
1059    async fn send2_by_specify_tunnel_with_resp(
1060        &self,
1061        tunnel_id: TunnelId,
1062        cmd: CMD,
1063        version: u8,
1064        body: &[&[u8]],
1065        timeout: Duration,
1066    ) -> CmdResult<CmdBody> {
1067        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1068        send.send2_with_resp(cmd, version, body, timeout).await
1069    }
1070
1071    async fn send_cmd_by_specify_tunnel(
1072        &self,
1073        tunnel_id: TunnelId,
1074        cmd: CMD,
1075        version: u8,
1076        body: CmdBody,
1077    ) -> CmdResult<()> {
1078        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1079        send.send_cmd(cmd, version, body).await
1080    }
1081
1082    async fn send_cmd_by_specify_tunnel_with_resp(
1083        &self,
1084        tunnel_id: TunnelId,
1085        cmd: CMD,
1086        version: u8,
1087        body: CmdBody,
1088        timeout: Duration,
1089    ) -> CmdResult<CmdBody> {
1090        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1091        send.send_cmd_with_resp(cmd, version, body, timeout).await
1092    }
1093
1094    async fn clear_all_tunnel(&self) {
1095        self.tunnel_pool.clear_all_worker().await;
1096    }
1097
1098    async fn get_send(
1099        &self,
1100        tunnel_id: TunnelId,
1101    ) -> CmdResult<ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD>> {
1102        Ok(ClassifiedSendGuard {
1103            worker_guard: self.get_send_of_tunnel_id(tunnel_id).await?,
1104            _p: std::marker::PhantomData,
1105        })
1106    }
1107}
1108
1109#[async_trait::async_trait]
1110impl<
1111    C: WorkerClassification,
1112    M: CmdTunnelMeta,
1113    R: ClassifiedCmdTunnelRead<C, M>,
1114    W: ClassifiedCmdTunnelWrite<C, M>,
1115    F: ClassifiedCmdTunnelFactory<C, M, R, W>,
1116    LEN: RawEncode
1117        + for<'a> RawDecode<'a>
1118        + Copy
1119        + Send
1120        + Sync
1121        + 'static
1122        + FromPrimitive
1123        + ToPrimitive
1124        + RawFixedBytes,
1125    CMD: RawEncode
1126        + for<'a> RawDecode<'a>
1127        + Copy
1128        + Send
1129        + Sync
1130        + 'static
1131        + RawFixedBytes
1132        + Eq
1133        + Hash
1134        + Debug,
1135>
1136    ClassifiedCmdClient<
1137        LEN,
1138        CMD,
1139        C,
1140        M,
1141        ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
1142        ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD>,
1143    > for DefaultClassifiedCmdClient<C, M, R, W, F, LEN, CMD>
1144{
1145    async fn send_by_classified_tunnel(
1146        &self,
1147        classification: C,
1148        cmd: CMD,
1149        version: u8,
1150        body: &[u8],
1151    ) -> CmdResult<()> {
1152        let mut send = self.get_classified_send(classification).await?;
1153        send.send(cmd, version, body).await
1154    }
1155
1156    async fn send_by_classified_tunnel_with_resp(
1157        &self,
1158        classification: C,
1159        cmd: CMD,
1160        version: u8,
1161        body: &[u8],
1162        timeout: Duration,
1163    ) -> CmdResult<CmdBody> {
1164        let mut send = self.get_classified_send(classification).await?;
1165        send.send_with_resp(cmd, version, body, timeout).await
1166    }
1167
1168    async fn send2_by_classified_tunnel(
1169        &self,
1170        classification: C,
1171        cmd: CMD,
1172        version: u8,
1173        body: &[&[u8]],
1174    ) -> CmdResult<()> {
1175        let mut send = self.get_classified_send(classification).await?;
1176        send.send2(cmd, version, body).await
1177    }
1178
1179    async fn send2_by_classified_tunnel_with_resp(
1180        &self,
1181        classification: C,
1182        cmd: CMD,
1183        version: u8,
1184        body: &[&[u8]],
1185        timeout: Duration,
1186    ) -> CmdResult<CmdBody> {
1187        let mut send = self.get_classified_send(classification).await?;
1188        send.send2_with_resp(cmd, version, body, timeout).await
1189    }
1190
1191    async fn send_cmd_by_classified_tunnel(
1192        &self,
1193        classification: C,
1194        cmd: CMD,
1195        version: u8,
1196        body: CmdBody,
1197    ) -> CmdResult<()> {
1198        let mut send = self.get_classified_send(classification).await?;
1199        send.send_cmd(cmd, version, body).await
1200    }
1201
1202    async fn send_cmd_by_classified_tunnel_with_resp(
1203        &self,
1204        classification: C,
1205        cmd: CMD,
1206        version: u8,
1207        body: CmdBody,
1208        timeout: Duration,
1209    ) -> CmdResult<CmdBody> {
1210        let mut send = self.get_classified_send(classification).await?;
1211        send.send_cmd_with_resp(cmd, version, body, timeout).await
1212    }
1213
1214    async fn find_tunnel_id_by_classified(&self, classification: C) -> CmdResult<TunnelId> {
1215        let send = self.get_classified_send(classification).await?;
1216        Ok(send.get_tunnel_id())
1217    }
1218
1219    async fn get_send_by_classified(
1220        &self,
1221        classification: C,
1222    ) -> CmdResult<
1223        ClassifiedSendGuard<
1224            CmdClientTunnelClassification<C>,
1225            M,
1226            ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
1227            ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>,
1228        >,
1229    > {
1230        Ok(ClassifiedSendGuard {
1231            worker_guard: self.get_classified_send(classification).await?,
1232            _p: std::marker::PhantomData,
1233        })
1234    }
1235}