Skip to main content

sfo_cmd_server/client/
classified_client.rs

1use crate::client::{
2    ClassifiedCmdClient, ClassifiedSendGuard, CmdClient, CmdSend, RespWaiter, RespWaiterRef,
3    gen_resp_id, gen_seq,
4};
5use crate::cmd::{CmdBodyRead, CmdHandler, CmdHandlerMap, CmdHeader};
6use crate::errors::{CmdErrorCode, CmdResult, cmd_err, into_cmd_err};
7use crate::peer_id::PeerId;
8use crate::{CmdBody, CmdTunnelMeta, CmdTunnelRead, CmdTunnelWrite, TunnelId, TunnelIdGenerator};
9use async_named_locker::ObjectHolder;
10use bucky_raw_codec::{RawConvertTo, RawDecode, RawEncode, RawFixedBytes, RawFrom};
11use num::{FromPrimitive, ToPrimitive};
12use sfo_pool::{
13    ClassifiedWorker, ClassifiedWorkerFactory, ClassifiedWorkerGuard, ClassifiedWorkerPool,
14    ClassifiedWorkerPoolRef, PoolErrorCode, PoolResult, WorkerClassification, into_pool_err,
15    pool_err,
16};
17use sfo_split::{RHalf, Splittable, WHalf};
18use std::fmt::Debug;
19use std::hash::Hash;
20use std::ops::DerefMut;
21use std::sync::{Arc, Mutex};
22use std::time::Duration;
23use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
24use tokio::spawn;
25use tokio::task::JoinHandle;
26
27pub trait ClassifiedCmdTunnelRead<C: WorkerClassification, M: CmdTunnelMeta>:
28    CmdTunnelRead<M> + 'static + Send
29{
30    fn get_classification(&self) -> C;
31}
32
33pub trait ClassifiedCmdTunnelWrite<C: WorkerClassification, M: CmdTunnelMeta>:
34    CmdTunnelWrite<M> + 'static + Send
35{
36    fn get_classification(&self) -> C;
37}
38
39pub type ClassifiedCmdTunnel<R, W> = Splittable<R, W>;
40pub type ClassifiedCmdTunnelRHalf<R, W> = RHalf<R, W>;
41pub type ClassifiedCmdTunnelWHalf<R, W> = WHalf<R, W>;
42
43#[derive(Debug, Clone, Copy, Eq, Hash)]
44pub struct CmdClientTunnelClassification<C: WorkerClassification> {
45    tunnel_id: Option<TunnelId>,
46    classification: Option<C>,
47}
48
49impl<C: WorkerClassification> PartialEq for CmdClientTunnelClassification<C> {
50    fn eq(&self, other: &Self) -> bool {
51        self.tunnel_id == other.tunnel_id && self.classification == other.classification
52    }
53}
54
55#[async_trait::async_trait]
56pub trait ClassifiedCmdTunnelFactory<
57    C: WorkerClassification,
58    M: CmdTunnelMeta,
59    R: ClassifiedCmdTunnelRead<C, M>,
60    W: ClassifiedCmdTunnelWrite<C, M>,
61>: Send + Sync + 'static
62{
63    async fn create_tunnel(&self, classification: Option<C>) -> CmdResult<Splittable<R, W>>;
64}
65
66pub struct ClassifiedCmdSend<C, M, R, W, LEN, CMD>
67where
68    C: WorkerClassification,
69    M: CmdTunnelMeta,
70    R: ClassifiedCmdTunnelRead<C, M>,
71    W: ClassifiedCmdTunnelWrite<C, M>,
72    LEN: RawEncode
73        + for<'a> RawDecode<'a>
74        + Copy
75        + Send
76        + Sync
77        + 'static
78        + FromPrimitive
79        + ToPrimitive,
80    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
81{
82    pub(crate) recv_handle: JoinHandle<CmdResult<()>>,
83    pub(crate) write: ObjectHolder<ClassifiedCmdTunnelWHalf<R, W>>,
84    pub(crate) is_work: bool,
85    pub(crate) classification: C,
86    pub(crate) tunnel_id: TunnelId,
87    pub(crate) resp_waiter: RespWaiterRef,
88    pub(crate) remote_id: PeerId,
89    pub(crate) tunnel_meta: Option<Arc<M>>,
90    _p: std::marker::PhantomData<(LEN, CMD)>,
91}
92
93// impl<C, R, W, LEN, CMD> Deref for ClassifiedCmdSend<C, R, W, LEN, CMD>
94// where C: WorkerClassification,
95//       R: ClassifiedCmdTunnelRead<C>,
96//       W: ClassifiedCmdTunnelWrite<C>,
97//       LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
98//       CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes {
99//     type Target = W;
100//
101//     fn deref(&self) -> &Self::Target {
102//         self.write.deref()
103//     }
104// }
105
106impl<C, M, R, W, LEN, CMD> ClassifiedCmdSend<C, M, R, W, LEN, CMD>
107where
108    C: WorkerClassification,
109    M: CmdTunnelMeta,
110    R: ClassifiedCmdTunnelRead<C, M>,
111    W: ClassifiedCmdTunnelWrite<C, M>,
112    LEN: RawEncode
113        + for<'a> RawDecode<'a>
114        + Copy
115        + Send
116        + Sync
117        + 'static
118        + FromPrimitive
119        + ToPrimitive,
120    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
121{
122    pub(crate) fn new(
123        tunnel_id: TunnelId,
124        classification: C,
125        recv_handle: JoinHandle<CmdResult<()>>,
126        write: ObjectHolder<ClassifiedCmdTunnelWHalf<R, W>>,
127        resp_waiter: RespWaiterRef,
128        remote_id: PeerId,
129        tunnel_meta: Option<Arc<M>>,
130    ) -> Self {
131        Self {
132            recv_handle,
133            write,
134            is_work: true,
135            classification,
136            tunnel_id,
137            resp_waiter,
138            remote_id,
139            tunnel_meta,
140            _p: Default::default(),
141        }
142    }
143
144    pub fn get_tunnel_id(&self) -> TunnelId {
145        self.tunnel_id
146    }
147
148    pub fn set_disable(&mut self) {
149        self.is_work = false;
150        self.recv_handle.abort();
151    }
152
153    pub async fn send(&mut self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
154        log::trace!(
155            "client {:?} send cmd: {:?}, len: {}, data: {}",
156            self.tunnel_id,
157            cmd,
158            body.len(),
159            hex::encode(body)
160        );
161        let header = CmdHeader::<LEN, CMD>::new(
162            version,
163            false,
164            None,
165            cmd,
166            LEN::from_u64(body.len() as u64).unwrap(),
167        );
168        let buf = header
169            .to_vec()
170            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
171        let ret = self.send_inner(buf.as_slice(), body).await;
172        if let Err(e) = ret {
173            self.set_disable();
174            return Err(e);
175        }
176        Ok(())
177    }
178
179    pub async fn send_with_resp(
180        &mut self,
181        cmd: CMD,
182        version: u8,
183        body: &[u8],
184        timeout: Duration,
185    ) -> CmdResult<CmdBody> {
186        if let Some(id) = tokio::task::try_id() {
187            if id == self.recv_handle.id() {
188                return Err(cmd_err!(
189                    CmdErrorCode::Failed,
190                    "can't send with resp in recv task"
191                ));
192            }
193        }
194        log::trace!(
195            "client {:?} send cmd: {:?}, len: {}, data: {}",
196            self.tunnel_id,
197            cmd,
198            body.len(),
199            hex::encode(body)
200        );
201        let seq = gen_seq();
202        let header = CmdHeader::<LEN, CMD>::new(
203            version,
204            false,
205            Some(seq),
206            cmd,
207            LEN::from_u64(body.len() as u64).unwrap(),
208        );
209        let buf = header
210            .to_vec()
211            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
212        let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
213        let waiter = self.resp_waiter.clone();
214        let resp_waiter = waiter
215            .create_timeout_result_future(resp_id, timeout)
216            .map_err(into_cmd_err!(
217                CmdErrorCode::Failed,
218                "create timeout result future error"
219            ))?;
220        let ret = self.send_inner(buf.as_slice(), body).await;
221        if let Err(e) = ret {
222            self.set_disable();
223            return Err(e);
224        }
225        let resp = resp_waiter
226            .await
227            .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
228        Ok(resp)
229    }
230
231    pub async fn send2(&mut self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
232        let mut len = 0;
233        for b in body.iter() {
234            len += b.len();
235            log::trace!(
236                "client {:?} send2 cmd {:?} body: {}",
237                self.tunnel_id,
238                cmd,
239                hex::encode(b)
240            );
241        }
242        log::trace!(
243            "client {:?} send2 cmd: {:?}, len {}",
244            self.tunnel_id,
245            cmd,
246            len
247        );
248        let header = CmdHeader::<LEN, CMD>::new(
249            version,
250            false,
251            None,
252            cmd,
253            LEN::from_u64(len as u64).unwrap(),
254        );
255        let buf = header
256            .to_vec()
257            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
258        let ret = self.send_inner2(buf.as_slice(), body).await;
259        if let Err(e) = ret {
260            self.set_disable();
261            return Err(e);
262        }
263        Ok(())
264    }
265
266    pub async fn send2_with_resp(
267        &mut self,
268        cmd: CMD,
269        version: u8,
270        body: &[&[u8]],
271        timeout: Duration,
272    ) -> CmdResult<CmdBody> {
273        if let Some(id) = tokio::task::try_id() {
274            if id == self.recv_handle.id() {
275                return Err(cmd_err!(
276                    CmdErrorCode::Failed,
277                    "can't send with resp in recv task"
278                ));
279            }
280        }
281        let mut len = 0;
282        for b in body.iter() {
283            len += b.len();
284            log::trace!(
285                "client {:?} send2 cmd {:?} body: {}",
286                self.tunnel_id,
287                cmd,
288                hex::encode(b)
289            );
290        }
291        log::trace!(
292            "client {:?} send2 cmd: {:?}, len {}",
293            self.tunnel_id,
294            cmd,
295            len
296        );
297        let seq = gen_seq();
298        let header = CmdHeader::<LEN, CMD>::new(
299            version,
300            false,
301            Some(seq),
302            cmd,
303            LEN::from_u64(len as u64).unwrap(),
304        );
305        let buf = header
306            .to_vec()
307            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
308        let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
309        let waiter = self.resp_waiter.clone();
310        let resp_waiter = waiter
311            .create_timeout_result_future(resp_id, timeout)
312            .map_err(into_cmd_err!(
313                CmdErrorCode::Failed,
314                "create timeout result future error"
315            ))?;
316        let ret = self.send_inner2(buf.as_slice(), body).await;
317        if let Err(e) = ret {
318            self.set_disable();
319            return Err(e);
320        }
321        let resp = resp_waiter
322            .await
323            .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
324        Ok(resp)
325    }
326
327    pub async fn send_cmd(&mut self, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
328        log::trace!(
329            "client {:?} send cmd: {:?}, len: {}",
330            self.tunnel_id,
331            cmd,
332            body.len()
333        );
334        let header = CmdHeader::<LEN, CMD>::new(
335            version,
336            false,
337            None,
338            cmd,
339            LEN::from_u64(body.len()).unwrap(),
340        );
341        let buf = header
342            .to_vec()
343            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
344        let ret = self.send_inner_cmd(buf.as_slice(), body).await;
345        if let Err(e) = ret {
346            self.set_disable();
347            return Err(e);
348        }
349        Ok(())
350    }
351
352    pub async fn send_cmd_with_resp(
353        &mut self,
354        cmd: CMD,
355        version: u8,
356        body: CmdBody,
357        timeout: Duration,
358    ) -> CmdResult<CmdBody> {
359        if let Some(id) = tokio::task::try_id() {
360            if id == self.recv_handle.id() {
361                return Err(cmd_err!(
362                    CmdErrorCode::Failed,
363                    "can't send with resp in recv task"
364                ));
365            }
366        }
367        log::trace!(
368            "client {:?} send cmd: {:?}, len: {}",
369            self.tunnel_id,
370            cmd,
371            body.len()
372        );
373        let seq = gen_seq();
374        let header = CmdHeader::<LEN, CMD>::new(
375            version,
376            false,
377            Some(seq),
378            cmd,
379            LEN::from_u64(body.len()).unwrap(),
380        );
381        let buf = header
382            .to_vec()
383            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
384        let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
385        let waiter = self.resp_waiter.clone();
386        let resp_waiter = waiter
387            .create_timeout_result_future(resp_id, timeout)
388            .map_err(into_cmd_err!(
389                CmdErrorCode::Failed,
390                "create timeout result future error"
391            ))?;
392        let ret = self.send_inner_cmd(buf.as_slice(), body).await;
393        if let Err(e) = ret {
394            self.set_disable();
395            return Err(e);
396        }
397        let resp = resp_waiter
398            .await
399            .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
400        Ok(resp)
401    }
402
403    async fn send_inner(&mut self, header: &[u8], body: &[u8]) -> CmdResult<()> {
404        let mut write = self.write.get().await;
405        if header.len() > 255 {
406            return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
407        }
408        write
409            .write_u8(header.len() as u8)
410            .await
411            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
412        write
413            .write_all(header)
414            .await
415            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
416        write
417            .write_all(body)
418            .await
419            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
420        write
421            .flush()
422            .await
423            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
424        Ok(())
425    }
426
427    async fn send_inner2(&mut self, header: &[u8], body: &[&[u8]]) -> CmdResult<()> {
428        let mut write = self.write.get().await;
429        if header.len() > 255 {
430            return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
431        }
432        write
433            .write_u8(header.len() as u8)
434            .await
435            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
436        write
437            .write_all(header)
438            .await
439            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
440        for b in body.iter() {
441            write
442                .write_all(b)
443                .await
444                .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
445        }
446        write
447            .flush()
448            .await
449            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
450        Ok(())
451    }
452
453    async fn send_inner_cmd(&mut self, header: &[u8], mut body: CmdBody) -> CmdResult<()> {
454        let mut write = self.write.get().await;
455        if header.len() > 255 {
456            return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
457        }
458        write
459            .write_u8(header.len() as u8)
460            .await
461            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
462        write
463            .write_all(header)
464            .await
465            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
466        tokio::io::copy(&mut body, write.deref_mut().deref_mut())
467            .await
468            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
469        write
470            .flush()
471            .await
472            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
473        Ok(())
474    }
475}
476
477impl<C, M, R, W, LEN, CMD> Drop for ClassifiedCmdSend<C, M, R, W, LEN, CMD>
478where
479    C: WorkerClassification,
480    M: CmdTunnelMeta,
481    R: ClassifiedCmdTunnelRead<C, M>,
482    W: ClassifiedCmdTunnelWrite<C, M>,
483    LEN: RawEncode
484        + for<'a> RawDecode<'a>
485        + Copy
486        + Send
487        + Sync
488        + 'static
489        + FromPrimitive
490        + ToPrimitive,
491    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
492{
493    fn drop(&mut self) {
494        self.set_disable();
495    }
496}
497
498impl<C, M, R, W, LEN, CMD> CmdSend<M> for ClassifiedCmdSend<C, M, R, W, LEN, CMD>
499where
500    C: WorkerClassification,
501    M: CmdTunnelMeta,
502    R: ClassifiedCmdTunnelRead<C, M>,
503    W: ClassifiedCmdTunnelWrite<C, M>,
504    LEN: RawEncode
505        + for<'a> RawDecode<'a>
506        + Copy
507        + Send
508        + Sync
509        + 'static
510        + FromPrimitive
511        + ToPrimitive,
512    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
513{
514    fn get_tunnel_meta(&self) -> Option<Arc<M>> {
515        self.tunnel_meta.clone()
516    }
517
518    fn get_remote_peer_id(&self) -> PeerId {
519        self.remote_id.clone()
520    }
521}
522
523impl<C, M, R, W, LEN, CMD> ClassifiedWorker<CmdClientTunnelClassification<C>>
524    for ClassifiedCmdSend<C, M, R, W, LEN, CMD>
525where
526    C: WorkerClassification,
527    M: CmdTunnelMeta,
528    R: ClassifiedCmdTunnelRead<C, M>,
529    W: ClassifiedCmdTunnelWrite<C, M>,
530    LEN: RawEncode
531        + for<'a> RawDecode<'a>
532        + Copy
533        + Send
534        + Sync
535        + 'static
536        + FromPrimitive
537        + ToPrimitive,
538    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
539{
540    fn is_work(&self) -> bool {
541        self.is_work && !self.recv_handle.is_finished()
542    }
543
544    fn is_valid(&self, c: CmdClientTunnelClassification<C>) -> bool {
545        if c.tunnel_id.is_some() {
546            self.tunnel_id == c.tunnel_id.unwrap()
547        } else {
548            if c.classification.is_some() {
549                self.classification == c.classification.unwrap()
550            } else {
551                true
552            }
553        }
554    }
555
556    fn classification(&self) -> CmdClientTunnelClassification<C> {
557        CmdClientTunnelClassification {
558            tunnel_id: Some(self.tunnel_id),
559            classification: Some(self.classification.clone()),
560        }
561    }
562}
563
564pub struct ClassifiedCmdWriteFactory<
565    C: WorkerClassification,
566    M: CmdTunnelMeta,
567    R: ClassifiedCmdTunnelRead<C, M>,
568    W: ClassifiedCmdTunnelWrite<C, M>,
569    F: ClassifiedCmdTunnelFactory<C, M, R, W>,
570    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
571    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
572> {
573    tunnel_factory: F,
574    cmd_handler: Arc<dyn CmdHandler<LEN, CMD>>,
575    resp_waiter: RespWaiterRef,
576    tunnel_id_generator: TunnelIdGenerator,
577    _p: std::marker::PhantomData<Mutex<(C, M, R, W)>>,
578}
579
580impl<
581    C: WorkerClassification,
582    M: CmdTunnelMeta,
583    R: ClassifiedCmdTunnelRead<C, M>,
584    W: ClassifiedCmdTunnelWrite<C, M>,
585    F: ClassifiedCmdTunnelFactory<C, M, R, W>,
586    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
587    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
588> ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>
589{
590    pub(crate) fn new(
591        tunnel_factory: F,
592        cmd_handler: impl CmdHandler<LEN, CMD>,
593        resp_waiter: RespWaiterRef,
594    ) -> Self {
595        Self {
596            tunnel_factory,
597            cmd_handler: Arc::new(cmd_handler),
598            resp_waiter,
599            tunnel_id_generator: TunnelIdGenerator::new(),
600            _p: Default::default(),
601        }
602    }
603}
604
605#[async_trait::async_trait]
606impl<
607    C: WorkerClassification,
608    M: CmdTunnelMeta,
609    R: ClassifiedCmdTunnelRead<C, M>,
610    W: ClassifiedCmdTunnelWrite<C, M>,
611    F: ClassifiedCmdTunnelFactory<C, M, R, W>,
612    LEN: RawEncode
613        + for<'a> RawDecode<'a>
614        + Copy
615        + Send
616        + Sync
617        + 'static
618        + FromPrimitive
619        + ToPrimitive
620        + RawFixedBytes,
621    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Debug,
622> ClassifiedWorkerFactory<CmdClientTunnelClassification<C>, ClassifiedCmdSend<C, M, R, W, LEN, CMD>>
623    for ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>
624{
625    async fn create(
626        &self,
627        classification: Option<CmdClientTunnelClassification<C>>,
628    ) -> PoolResult<ClassifiedCmdSend<C, M, R, W, LEN, CMD>> {
629        if classification.is_some() && classification.as_ref().unwrap().tunnel_id.is_some() {
630            return Err(pool_err!(
631                PoolErrorCode::Failed,
632                "tunnel {:?} not found",
633                classification.as_ref().unwrap().tunnel_id.unwrap()
634            ));
635        }
636
637        let classification = if classification.is_some()
638            && classification.as_ref().unwrap().classification.is_some()
639        {
640            classification.unwrap().classification
641        } else {
642            None
643        };
644        let tunnel = self
645            .tunnel_factory
646            .create_tunnel(classification)
647            .await
648            .map_err(into_pool_err!(PoolErrorCode::Failed))?;
649        let classification = tunnel.get_classification();
650        let peer_id = tunnel.get_remote_peer_id();
651        let tunnel_id = self.tunnel_id_generator.generate();
652        let (mut recv, write) = tunnel.split();
653        let remote_id = peer_id.clone();
654        let tunnel_meta = recv.get_tunnel_meta();
655        let write = ObjectHolder::new(write);
656        let resp_write = write.clone();
657        let cmd_handler = self.cmd_handler.clone();
658        let handle = spawn(async move {
659            let ret: CmdResult<()> = async move {
660                loop {
661                    let header_len = recv
662                        .read_u8()
663                        .await
664                        .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
665                    let mut header = vec![0u8; header_len as usize];
666                    let n = recv
667                        .read_exact(header.as_mut())
668                        .await
669                        .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
670                    if n == 0 {
671                        break;
672                    }
673                    let header = CmdHeader::<LEN, CMD>::clone_from_slice(header.as_slice())
674                        .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
675                    log::trace!(
676                        "recv cmd {:?} from {} len {} tunnel {:?}",
677                        header.cmd_code(),
678                        peer_id,
679                        header.pkg_len().to_u64().unwrap(),
680                        tunnel_id
681                    );
682                    let body_len = header.pkg_len().to_u64().unwrap();
683                    let cmd_read =
684                        CmdBodyRead::new(recv, header.pkg_len().to_u64().unwrap() as usize);
685                    let waiter = cmd_read.get_waiter();
686                    let future = waiter
687                        .create_result_future()
688                        .map_err(into_cmd_err!(CmdErrorCode::Failed))?;
689                    let version = header.version();
690                    let seq = header.seq();
691                    let cmd_code = header.cmd_code();
692                    match cmd_handler
693                        .handle(
694                            peer_id.clone(),
695                            tunnel_id,
696                            header,
697                            CmdBody::from_reader(BufReader::new(cmd_read), body_len),
698                        )
699                        .await
700                    {
701                        Ok(Some(mut body)) => {
702                            let mut write = resp_write.get().await;
703                            let header = CmdHeader::<LEN, CMD>::new(
704                                version,
705                                true,
706                                seq,
707                                cmd_code,
708                                LEN::from_u64(body.len()).unwrap(),
709                            );
710                            let buf = header
711                                .to_vec()
712                                .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
713                            if buf.len() > 255 {
714                                return Err(cmd_err!(
715                                    CmdErrorCode::InvalidParam,
716                                    "header len too large"
717                                ));
718                            }
719                            write
720                                .write_u8(buf.len() as u8)
721                                .await
722                                .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
723                            write
724                                .write_all(buf.as_slice())
725                                .await
726                                .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
727                            tokio::io::copy(&mut body, write.deref_mut().deref_mut())
728                                .await
729                                .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
730                            write
731                                .flush()
732                                .await
733                                .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
734                        }
735                        Err(e) => {
736                            log::error!("handle cmd error: {:?}", e);
737                        }
738                        _ => {}
739                    }
740                    recv = future
741                        .await
742                        .map_err(into_cmd_err!(CmdErrorCode::Failed))??;
743                    log::debug!(
744                        "handle cmd {:?} from {} len {} tunnel {:?} complete",
745                        cmd_code,
746                        peer_id,
747                        body_len,
748                        tunnel_id
749                    );
750                }
751                Ok(())
752            }
753            .await;
754            if ret.is_err() {
755                log::error!("recv cmd error: {:?}", ret.as_ref().err().unwrap());
756            }
757            ret
758        });
759        Ok(ClassifiedCmdSend::new(
760            tunnel_id,
761            classification,
762            handle,
763            write,
764            self.resp_waiter.clone(),
765            remote_id,
766            tunnel_meta,
767        ))
768    }
769}
770
771pub struct DefaultClassifiedCmdClient<
772    C: WorkerClassification,
773    M: CmdTunnelMeta,
774    R: ClassifiedCmdTunnelRead<C, M>,
775    W: ClassifiedCmdTunnelWrite<C, M>,
776    F: ClassifiedCmdTunnelFactory<C, M, R, W>,
777    LEN: RawEncode
778        + for<'a> RawDecode<'a>
779        + Copy
780        + Send
781        + Sync
782        + 'static
783        + FromPrimitive
784        + ToPrimitive
785        + RawFixedBytes,
786    CMD: RawEncode
787        + for<'a> RawDecode<'a>
788        + Copy
789        + Send
790        + Sync
791        + 'static
792        + RawFixedBytes
793        + Eq
794        + Hash
795        + Debug,
796> {
797    tunnel_pool: ClassifiedWorkerPoolRef<
798        CmdClientTunnelClassification<C>,
799        ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
800        ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>,
801    >,
802    cmd_handler_map: Arc<CmdHandlerMap<LEN, CMD>>,
803}
804
805impl<
806    C: WorkerClassification,
807    M: CmdTunnelMeta,
808    R: ClassifiedCmdTunnelRead<C, M>,
809    W: ClassifiedCmdTunnelWrite<C, M>,
810    F: ClassifiedCmdTunnelFactory<C, M, R, W>,
811    LEN: RawEncode
812        + for<'a> RawDecode<'a>
813        + Copy
814        + Send
815        + Sync
816        + 'static
817        + FromPrimitive
818        + ToPrimitive
819        + RawFixedBytes,
820    CMD: RawEncode
821        + for<'a> RawDecode<'a>
822        + Copy
823        + Send
824        + Sync
825        + 'static
826        + RawFixedBytes
827        + Eq
828        + Hash
829        + Debug,
830> DefaultClassifiedCmdClient<C, M, R, W, F, LEN, CMD>
831{
832    pub fn new(factory: F, tunnel_count: u16) -> Arc<Self> {
833        let cmd_handler_map = Arc::new(CmdHandlerMap::new());
834        let resp_waiter = Arc::new(RespWaiter::new());
835        let handler_map = cmd_handler_map.clone();
836        let waiter = resp_waiter.clone();
837        Arc::new(Self {
838            tunnel_pool: ClassifiedWorkerPool::new(
839                tunnel_count,
840                ClassifiedCmdWriteFactory::<C, M, R, W, _, LEN, CMD>::new(
841                    factory,
842                    move |peer_id: PeerId,
843                          tunnel_id: TunnelId,
844                          header: CmdHeader<LEN, CMD>,
845                          body_read: CmdBody| {
846                        let handler_map = handler_map.clone();
847                        let waiter = waiter.clone();
848                        async move {
849                            if header.is_resp() && header.seq().is_some() {
850                                let resp_id = gen_resp_id(
851                                    tunnel_id,
852                                    header.cmd_code(),
853                                    header.seq().unwrap(),
854                                );
855                                let _ = waiter.set_result(resp_id, body_read);
856                                Ok(None)
857                            } else {
858                                if let Some(handler) = handler_map.get(header.cmd_code()) {
859                                    handler.handle(peer_id, tunnel_id, header, body_read).await
860                                } else {
861                                    Ok(None)
862                                }
863                            }
864                        }
865                    },
866                    resp_waiter.clone(),
867                ),
868            ),
869            cmd_handler_map,
870        })
871    }
872
873    async fn get_send(
874        &self,
875    ) -> CmdResult<
876        ClassifiedWorkerGuard<
877            CmdClientTunnelClassification<C>,
878            ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
879            ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>,
880        >,
881    > {
882        self.tunnel_pool
883            .get_worker()
884            .await
885            .map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
886    }
887
888    async fn get_send_of_tunnel_id(
889        &self,
890        tunnel_id: TunnelId,
891    ) -> CmdResult<
892        ClassifiedWorkerGuard<
893            CmdClientTunnelClassification<C>,
894            ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
895            ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>,
896        >,
897    > {
898        self.tunnel_pool
899            .get_classified_worker(CmdClientTunnelClassification {
900                tunnel_id: Some(tunnel_id),
901                classification: None,
902            })
903            .await
904            .map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
905    }
906
907    async fn get_classified_send(
908        &self,
909        classification: C,
910    ) -> CmdResult<
911        ClassifiedWorkerGuard<
912            CmdClientTunnelClassification<C>,
913            ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
914            ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>,
915        >,
916    > {
917        self.tunnel_pool
918            .get_classified_worker(CmdClientTunnelClassification {
919                tunnel_id: None,
920                classification: Some(classification),
921            })
922            .await
923            .map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
924    }
925}
926
927pub type ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD> = ClassifiedSendGuard<
928    CmdClientTunnelClassification<C>,
929    M,
930    ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
931    ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>,
932>;
933#[async_trait::async_trait]
934impl<
935    C: WorkerClassification,
936    M: CmdTunnelMeta,
937    R: ClassifiedCmdTunnelRead<C, M>,
938    W: ClassifiedCmdTunnelWrite<C, M>,
939    F: ClassifiedCmdTunnelFactory<C, M, R, W>,
940    LEN: RawEncode
941        + for<'a> RawDecode<'a>
942        + Copy
943        + Send
944        + Sync
945        + 'static
946        + FromPrimitive
947        + ToPrimitive
948        + RawFixedBytes,
949    CMD: RawEncode
950        + for<'a> RawDecode<'a>
951        + Copy
952        + Send
953        + Sync
954        + 'static
955        + RawFixedBytes
956        + Eq
957        + Hash
958        + Debug,
959>
960    CmdClient<
961        LEN,
962        CMD,
963        M,
964        ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
965        ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD>,
966    > for DefaultClassifiedCmdClient<C, M, R, W, F, LEN, CMD>
967{
968    fn register_cmd_handler(&self, cmd: CMD, handler: impl CmdHandler<LEN, CMD>) {
969        self.cmd_handler_map.insert(cmd, handler);
970    }
971
972    async fn send(&self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
973        let mut send = self.get_send().await?;
974        send.send(cmd, version, body).await
975    }
976
977    async fn send_with_resp(
978        &self,
979        cmd: CMD,
980        version: u8,
981        body: &[u8],
982        timeout: Duration,
983    ) -> CmdResult<CmdBody> {
984        let mut send = self.get_send().await?;
985        send.send_with_resp(cmd, version, body, timeout).await
986    }
987
988    async fn send2(&self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
989        let mut send = self.get_send().await?;
990        send.send2(cmd, version, body).await
991    }
992
993    async fn send2_with_resp(
994        &self,
995        cmd: CMD,
996        version: u8,
997        body: &[&[u8]],
998        timeout: Duration,
999    ) -> CmdResult<CmdBody> {
1000        let mut send = self.get_send().await?;
1001        send.send2_with_resp(cmd, version, body, timeout).await
1002    }
1003
1004    async fn send_cmd(&self, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
1005        let mut send = self.get_send().await?;
1006        send.send_cmd(cmd, version, body).await
1007    }
1008
1009    async fn send_cmd_with_resp(
1010        &self,
1011        cmd: CMD,
1012        version: u8,
1013        body: CmdBody,
1014        timeout: Duration,
1015    ) -> CmdResult<CmdBody> {
1016        let mut send = self.get_send().await?;
1017        send.send_cmd_with_resp(cmd, version, body, timeout).await
1018    }
1019
1020    async fn send_by_specify_tunnel(
1021        &self,
1022        tunnel_id: TunnelId,
1023        cmd: CMD,
1024        version: u8,
1025        body: &[u8],
1026    ) -> CmdResult<()> {
1027        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1028        send.send(cmd, version, body).await
1029    }
1030
1031    async fn send_by_specify_tunnel_with_resp(
1032        &self,
1033        tunnel_id: TunnelId,
1034        cmd: CMD,
1035        version: u8,
1036        body: &[u8],
1037        timeout: Duration,
1038    ) -> CmdResult<CmdBody> {
1039        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1040        send.send_with_resp(cmd, version, body, timeout).await
1041    }
1042
1043    async fn send2_by_specify_tunnel(
1044        &self,
1045        tunnel_id: TunnelId,
1046        cmd: CMD,
1047        version: u8,
1048        body: &[&[u8]],
1049    ) -> CmdResult<()> {
1050        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1051        send.send2(cmd, version, body).await
1052    }
1053
1054    async fn send2_by_specify_tunnel_with_resp(
1055        &self,
1056        tunnel_id: TunnelId,
1057        cmd: CMD,
1058        version: u8,
1059        body: &[&[u8]],
1060        timeout: Duration,
1061    ) -> CmdResult<CmdBody> {
1062        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1063        send.send2_with_resp(cmd, version, body, timeout).await
1064    }
1065
1066    async fn send_cmd_by_specify_tunnel(
1067        &self,
1068        tunnel_id: TunnelId,
1069        cmd: CMD,
1070        version: u8,
1071        body: CmdBody,
1072    ) -> CmdResult<()> {
1073        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1074        send.send_cmd(cmd, version, body).await
1075    }
1076
1077    async fn send_cmd_by_specify_tunnel_with_resp(
1078        &self,
1079        tunnel_id: TunnelId,
1080        cmd: CMD,
1081        version: u8,
1082        body: CmdBody,
1083        timeout: Duration,
1084    ) -> CmdResult<CmdBody> {
1085        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1086        send.send_cmd_with_resp(cmd, version, body, timeout).await
1087    }
1088
1089    async fn clear_all_tunnel(&self) {
1090        self.tunnel_pool.clear_all_worker().await;
1091    }
1092
1093    async fn get_send(
1094        &self,
1095        tunnel_id: TunnelId,
1096    ) -> CmdResult<ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD>> {
1097        Ok(ClassifiedSendGuard {
1098            worker_guard: self.get_send_of_tunnel_id(tunnel_id).await?,
1099            _p: std::marker::PhantomData,
1100        })
1101    }
1102}
1103
1104#[async_trait::async_trait]
1105impl<
1106    C: WorkerClassification,
1107    M: CmdTunnelMeta,
1108    R: ClassifiedCmdTunnelRead<C, M>,
1109    W: ClassifiedCmdTunnelWrite<C, M>,
1110    F: ClassifiedCmdTunnelFactory<C, M, R, W>,
1111    LEN: RawEncode
1112        + for<'a> RawDecode<'a>
1113        + Copy
1114        + Send
1115        + Sync
1116        + 'static
1117        + FromPrimitive
1118        + ToPrimitive
1119        + RawFixedBytes,
1120    CMD: RawEncode
1121        + for<'a> RawDecode<'a>
1122        + Copy
1123        + Send
1124        + Sync
1125        + 'static
1126        + RawFixedBytes
1127        + Eq
1128        + Hash
1129        + Debug,
1130>
1131    ClassifiedCmdClient<
1132        LEN,
1133        CMD,
1134        C,
1135        M,
1136        ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
1137        ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD>,
1138    > for DefaultClassifiedCmdClient<C, M, R, W, F, LEN, CMD>
1139{
1140    async fn send_by_classified_tunnel(
1141        &self,
1142        classification: C,
1143        cmd: CMD,
1144        version: u8,
1145        body: &[u8],
1146    ) -> CmdResult<()> {
1147        let mut send = self.get_classified_send(classification).await?;
1148        send.send(cmd, version, body).await
1149    }
1150
1151    async fn send_by_classified_tunnel_with_resp(
1152        &self,
1153        classification: C,
1154        cmd: CMD,
1155        version: u8,
1156        body: &[u8],
1157        timeout: Duration,
1158    ) -> CmdResult<CmdBody> {
1159        let mut send = self.get_classified_send(classification).await?;
1160        send.send_with_resp(cmd, version, body, timeout).await
1161    }
1162
1163    async fn send2_by_classified_tunnel(
1164        &self,
1165        classification: C,
1166        cmd: CMD,
1167        version: u8,
1168        body: &[&[u8]],
1169    ) -> CmdResult<()> {
1170        let mut send = self.get_classified_send(classification).await?;
1171        send.send2(cmd, version, body).await
1172    }
1173
1174    async fn send2_by_classified_tunnel_with_resp(
1175        &self,
1176        classification: C,
1177        cmd: CMD,
1178        version: u8,
1179        body: &[&[u8]],
1180        timeout: Duration,
1181    ) -> CmdResult<CmdBody> {
1182        let mut send = self.get_classified_send(classification).await?;
1183        send.send2_with_resp(cmd, version, body, timeout).await
1184    }
1185
1186    async fn send_cmd_by_classified_tunnel(
1187        &self,
1188        classification: C,
1189        cmd: CMD,
1190        version: u8,
1191        body: CmdBody,
1192    ) -> CmdResult<()> {
1193        let mut send = self.get_classified_send(classification).await?;
1194        send.send_cmd(cmd, version, body).await
1195    }
1196
1197    async fn send_cmd_by_classified_tunnel_with_resp(
1198        &self,
1199        classification: C,
1200        cmd: CMD,
1201        version: u8,
1202        body: CmdBody,
1203        timeout: Duration,
1204    ) -> CmdResult<CmdBody> {
1205        let mut send = self.get_classified_send(classification).await?;
1206        send.send_cmd_with_resp(cmd, version, body, timeout).await
1207    }
1208
1209    async fn find_tunnel_id_by_classified(&self, classification: C) -> CmdResult<TunnelId> {
1210        let send = self.get_classified_send(classification).await?;
1211        Ok(send.get_tunnel_id())
1212    }
1213
1214    async fn get_send_by_classified(
1215        &self,
1216        classification: C,
1217    ) -> CmdResult<
1218        ClassifiedSendGuard<
1219            CmdClientTunnelClassification<C>,
1220            M,
1221            ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
1222            ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>,
1223        >,
1224    > {
1225        Ok(ClassifiedSendGuard {
1226            worker_guard: self.get_classified_send(classification).await?,
1227            _p: std::marker::PhantomData,
1228        })
1229    }
1230}