Skip to main content

sfo_cmd_server/client/
classified_client.rs

1use crate::client::{
2    ClassifiedCmdClient, ClassifiedSendGuard, CmdClient, CmdSend, RespWaiter, RespWaiterRef,
3    gen_resp_id, gen_seq,
4};
5use crate::cmd::{CmdBodyRead, CmdHandler, CmdHandlerMap, CmdHeader};
6use crate::errors::{CmdErrorCode, CmdResult, cmd_err, into_cmd_err};
7use crate::peer_id::PeerId;
8use crate::{CmdBody, CmdTunnelMeta, CmdTunnelRead, CmdTunnelWrite, TunnelId, TunnelIdGenerator};
9use async_named_locker::ObjectHolder;
10use bucky_raw_codec::{RawConvertTo, RawDecode, RawEncode, RawFixedBytes, RawFrom};
11use num::{FromPrimitive, ToPrimitive};
12use sfo_pool::{
13    ClassifiedWorker, ClassifiedWorkerFactory, ClassifiedWorkerGuard, ClassifiedWorkerPool,
14    ClassifiedWorkerPoolRef, PoolErrorCode, PoolResult, WorkerClassification, into_pool_err,
15    pool_err,
16};
17use sfo_split::{RHalf, Splittable, WHalf};
18use std::fmt::Debug;
19use std::hash::Hash;
20use std::ops::DerefMut;
21use std::sync::{Arc, Mutex};
22use std::time::Duration;
23use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
24use tokio::spawn;
25use tokio::task::JoinHandle;
26
27pub trait ClassifiedCmdTunnelRead<C: WorkerClassification, M: CmdTunnelMeta>:
28    CmdTunnelRead<M> + 'static + Send
29{
30    fn get_classification(&self) -> C;
31}
32
33pub trait ClassifiedCmdTunnelWrite<C: WorkerClassification, M: CmdTunnelMeta>:
34    CmdTunnelWrite<M> + 'static + Send
35{
36    fn get_classification(&self) -> C;
37}
38
39pub type ClassifiedCmdTunnel<R, W> = Splittable<R, W>;
40pub type ClassifiedCmdTunnelRHalf<R, W> = RHalf<R, W>;
41pub type ClassifiedCmdTunnelWHalf<R, W> = WHalf<R, W>;
42
43#[derive(Debug, Clone, Copy, Eq, Hash)]
44pub struct CmdClientTunnelClassification<C: WorkerClassification> {
45    tunnel_id: Option<TunnelId>,
46    classification: Option<C>,
47}
48
49impl<C: WorkerClassification> PartialEq for CmdClientTunnelClassification<C> {
50    fn eq(&self, other: &Self) -> bool {
51        self.tunnel_id == other.tunnel_id && self.classification == other.classification
52    }
53}
54
55#[async_trait::async_trait]
56pub trait ClassifiedCmdTunnelFactory<
57    C: WorkerClassification,
58    M: CmdTunnelMeta,
59    R: ClassifiedCmdTunnelRead<C, M>,
60    W: ClassifiedCmdTunnelWrite<C, M>,
61>: Send + Sync + 'static
62{
63    async fn create_tunnel(&self, classification: Option<C>) -> CmdResult<Splittable<R, W>>;
64}
65
66pub struct ClassifiedCmdSend<C, M, R, W, LEN, CMD>
67where
68    C: WorkerClassification,
69    M: CmdTunnelMeta,
70    R: ClassifiedCmdTunnelRead<C, M>,
71    W: ClassifiedCmdTunnelWrite<C, M>,
72    LEN: RawEncode
73        + for<'a> RawDecode<'a>
74        + Copy
75        + Send
76        + Sync
77        + 'static
78        + FromPrimitive
79        + ToPrimitive,
80    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
81{
82    pub(crate) recv_handle: JoinHandle<CmdResult<()>>,
83    pub(crate) write: ObjectHolder<ClassifiedCmdTunnelWHalf<R, W>>,
84    pub(crate) is_work: bool,
85    pub(crate) classification: C,
86    pub(crate) pool_tunnel_id: Option<TunnelId>,
87    pub(crate) pool_classification: Option<C>,
88    pub(crate) tunnel_id: TunnelId,
89    pub(crate) resp_waiter: RespWaiterRef,
90    pub(crate) remote_id: PeerId,
91    pub(crate) pool_peer_id: Option<PeerId>,
92    pub(crate) tunnel_meta: Option<Arc<M>>,
93    _p: std::marker::PhantomData<(LEN, CMD)>,
94}
95
96// impl<C, R, W, LEN, CMD> Deref for ClassifiedCmdSend<C, R, W, LEN, CMD>
97// where C: WorkerClassification,
98//       R: ClassifiedCmdTunnelRead<C>,
99//       W: ClassifiedCmdTunnelWrite<C>,
100//       LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
101//       CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes {
102//     type Target = W;
103//
104//     fn deref(&self) -> &Self::Target {
105//         self.write.deref()
106//     }
107// }
108
109impl<C, M, R, W, LEN, CMD> ClassifiedCmdSend<C, M, R, W, LEN, CMD>
110where
111    C: WorkerClassification,
112    M: CmdTunnelMeta,
113    R: ClassifiedCmdTunnelRead<C, M>,
114    W: ClassifiedCmdTunnelWrite<C, M>,
115    LEN: RawEncode
116        + for<'a> RawDecode<'a>
117        + Copy
118        + Send
119        + Sync
120        + 'static
121        + FromPrimitive
122        + ToPrimitive,
123    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
124{
125    pub(crate) fn new(
126        tunnel_id: TunnelId,
127        classification: C,
128        pool_peer_id: Option<PeerId>,
129        pool_tunnel_id: Option<TunnelId>,
130        pool_classification: Option<C>,
131        recv_handle: JoinHandle<CmdResult<()>>,
132        write: ObjectHolder<ClassifiedCmdTunnelWHalf<R, W>>,
133        resp_waiter: RespWaiterRef,
134        remote_id: PeerId,
135        tunnel_meta: Option<Arc<M>>,
136    ) -> Self {
137        Self {
138            recv_handle,
139            write,
140            is_work: true,
141            classification,
142            pool_tunnel_id,
143            pool_classification,
144            tunnel_id,
145            resp_waiter,
146            remote_id,
147            pool_peer_id,
148            tunnel_meta,
149            _p: Default::default(),
150        }
151    }
152
153    pub fn get_tunnel_id(&self) -> TunnelId {
154        self.tunnel_id
155    }
156
157    pub fn set_disable(&mut self) {
158        self.is_work = false;
159        self.recv_handle.abort();
160    }
161
162    pub(crate) fn set_pool_key(
163        &mut self,
164        peer_id: Option<PeerId>,
165        tunnel_id: Option<TunnelId>,
166        classification: Option<C>,
167    ) {
168        self.pool_peer_id = peer_id;
169        self.pool_tunnel_id = tunnel_id;
170        self.pool_classification = classification;
171    }
172
173    pub async fn send(&mut self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
174        log::trace!(
175            "client {:?} send cmd: {:?}, len: {}, data: {}",
176            self.tunnel_id,
177            cmd,
178            body.len(),
179            hex::encode(body)
180        );
181        let header = CmdHeader::<LEN, CMD>::new(
182            version,
183            false,
184            None,
185            cmd,
186            LEN::from_u64(body.len() as u64).unwrap(),
187        );
188        let buf = header
189            .to_vec()
190            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
191        let ret = self.send_inner(buf.as_slice(), body).await;
192        if let Err(e) = ret {
193            self.set_disable();
194            return Err(e);
195        }
196        Ok(())
197    }
198
199    pub async fn send_with_resp(
200        &mut self,
201        cmd: CMD,
202        version: u8,
203        body: &[u8],
204        timeout: Duration,
205    ) -> CmdResult<CmdBody> {
206        if let Some(id) = tokio::task::try_id() {
207            if id == self.recv_handle.id() {
208                return Err(cmd_err!(
209                    CmdErrorCode::Failed,
210                    "can't send with resp in recv task"
211                ));
212            }
213        }
214        log::trace!(
215            "client {:?} send cmd: {:?}, len: {}, data: {}",
216            self.tunnel_id,
217            cmd,
218            body.len(),
219            hex::encode(body)
220        );
221        let seq = gen_seq();
222        let header = CmdHeader::<LEN, CMD>::new(
223            version,
224            false,
225            Some(seq),
226            cmd,
227            LEN::from_u64(body.len() as u64).unwrap(),
228        );
229        let buf = header
230            .to_vec()
231            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
232        let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
233        let waiter = self.resp_waiter.clone();
234        let resp_waiter = waiter
235            .create_timeout_result_future(resp_id, timeout)
236            .map_err(into_cmd_err!(
237                CmdErrorCode::Failed,
238                "create timeout result future error"
239            ))?;
240        let ret = self.send_inner(buf.as_slice(), body).await;
241        if let Err(e) = ret {
242            self.set_disable();
243            return Err(e);
244        }
245        let resp = resp_waiter
246            .await
247            .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
248        Ok(resp)
249    }
250
251    pub async fn send_parts(&mut self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
252        let mut len = 0;
253        for b in body.iter() {
254            len += b.len();
255            log::trace!(
256                "client {:?} send2 cmd {:?} body: {}",
257                self.tunnel_id,
258                cmd,
259                hex::encode(b)
260            );
261        }
262        log::trace!(
263            "client {:?} send2 cmd: {:?}, len {}",
264            self.tunnel_id,
265            cmd,
266            len
267        );
268        let header = CmdHeader::<LEN, CMD>::new(
269            version,
270            false,
271            None,
272            cmd,
273            LEN::from_u64(len as u64).unwrap(),
274        );
275        let buf = header
276            .to_vec()
277            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
278        let ret = self.send_inner2(buf.as_slice(), body).await;
279        if let Err(e) = ret {
280            self.set_disable();
281            return Err(e);
282        }
283        Ok(())
284    }
285
286    pub async fn send_parts_with_resp(
287        &mut self,
288        cmd: CMD,
289        version: u8,
290        body: &[&[u8]],
291        timeout: Duration,
292    ) -> CmdResult<CmdBody> {
293        if let Some(id) = tokio::task::try_id() {
294            if id == self.recv_handle.id() {
295                return Err(cmd_err!(
296                    CmdErrorCode::Failed,
297                    "can't send with resp in recv task"
298                ));
299            }
300        }
301        let mut len = 0;
302        for b in body.iter() {
303            len += b.len();
304            log::trace!(
305                "client {:?} send2 cmd {:?} body: {}",
306                self.tunnel_id,
307                cmd,
308                hex::encode(b)
309            );
310        }
311        log::trace!(
312            "client {:?} send2 cmd: {:?}, len {}",
313            self.tunnel_id,
314            cmd,
315            len
316        );
317        let seq = gen_seq();
318        let header = CmdHeader::<LEN, CMD>::new(
319            version,
320            false,
321            Some(seq),
322            cmd,
323            LEN::from_u64(len as u64).unwrap(),
324        );
325        let buf = header
326            .to_vec()
327            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
328        let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
329        let waiter = self.resp_waiter.clone();
330        let resp_waiter = waiter
331            .create_timeout_result_future(resp_id, timeout)
332            .map_err(into_cmd_err!(
333                CmdErrorCode::Failed,
334                "create timeout result future error"
335            ))?;
336        let ret = self.send_inner2(buf.as_slice(), body).await;
337        if let Err(e) = ret {
338            self.set_disable();
339            return Err(e);
340        }
341        let resp = resp_waiter
342            .await
343            .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
344        Ok(resp)
345    }
346
347    #[allow(deprecated)]
348    #[deprecated(note = "use send_parts instead")]
349    pub async fn send2(&mut self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
350        self.send_parts(cmd, version, body).await
351    }
352
353    #[allow(deprecated)]
354    #[deprecated(note = "use send_parts_with_resp instead")]
355    pub async fn send2_with_resp(
356        &mut self,
357        cmd: CMD,
358        version: u8,
359        body: &[&[u8]],
360        timeout: Duration,
361    ) -> CmdResult<CmdBody> {
362        self.send_parts_with_resp(cmd, version, body, timeout).await
363    }
364
365    pub async fn send_cmd(&mut self, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
366        log::trace!(
367            "client {:?} send cmd: {:?}, len: {}",
368            self.tunnel_id,
369            cmd,
370            body.len()
371        );
372        let header = CmdHeader::<LEN, CMD>::new(
373            version,
374            false,
375            None,
376            cmd,
377            LEN::from_u64(body.len()).unwrap(),
378        );
379        let buf = header
380            .to_vec()
381            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
382        let ret = self.send_inner_cmd(buf.as_slice(), body).await;
383        if let Err(e) = ret {
384            self.set_disable();
385            return Err(e);
386        }
387        Ok(())
388    }
389
390    pub async fn send_cmd_with_resp(
391        &mut self,
392        cmd: CMD,
393        version: u8,
394        body: CmdBody,
395        timeout: Duration,
396    ) -> CmdResult<CmdBody> {
397        if let Some(id) = tokio::task::try_id() {
398            if id == self.recv_handle.id() {
399                return Err(cmd_err!(
400                    CmdErrorCode::Failed,
401                    "can't send with resp in recv task"
402                ));
403            }
404        }
405        log::trace!(
406            "client {:?} send cmd: {:?}, len: {}",
407            self.tunnel_id,
408            cmd,
409            body.len()
410        );
411        let seq = gen_seq();
412        let header = CmdHeader::<LEN, CMD>::new(
413            version,
414            false,
415            Some(seq),
416            cmd,
417            LEN::from_u64(body.len()).unwrap(),
418        );
419        let buf = header
420            .to_vec()
421            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
422        let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
423        let waiter = self.resp_waiter.clone();
424        let resp_waiter = waiter
425            .create_timeout_result_future(resp_id, timeout)
426            .map_err(into_cmd_err!(
427                CmdErrorCode::Failed,
428                "create timeout result future error"
429            ))?;
430        let ret = self.send_inner_cmd(buf.as_slice(), body).await;
431        if let Err(e) = ret {
432            self.set_disable();
433            return Err(e);
434        }
435        let resp = resp_waiter
436            .await
437            .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
438        Ok(resp)
439    }
440
441    async fn send_inner(&mut self, header: &[u8], body: &[u8]) -> CmdResult<()> {
442        let mut write = self.write.get().await;
443        if header.len() > 255 {
444            return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
445        }
446        write
447            .write_u8(header.len() as u8)
448            .await
449            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
450        write
451            .write_all(header)
452            .await
453            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
454        write
455            .write_all(body)
456            .await
457            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
458        write
459            .flush()
460            .await
461            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
462        Ok(())
463    }
464
465    async fn send_inner2(&mut self, header: &[u8], body: &[&[u8]]) -> CmdResult<()> {
466        let mut write = self.write.get().await;
467        if header.len() > 255 {
468            return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
469        }
470        write
471            .write_u8(header.len() as u8)
472            .await
473            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
474        write
475            .write_all(header)
476            .await
477            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
478        for b in body.iter() {
479            write
480                .write_all(b)
481                .await
482                .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
483        }
484        write
485            .flush()
486            .await
487            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
488        Ok(())
489    }
490
491    async fn send_inner_cmd(&mut self, header: &[u8], mut body: CmdBody) -> CmdResult<()> {
492        let mut write = self.write.get().await;
493        if header.len() > 255 {
494            return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
495        }
496        write
497            .write_u8(header.len() as u8)
498            .await
499            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
500        write
501            .write_all(header)
502            .await
503            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
504        tokio::io::copy(&mut body, write.deref_mut().deref_mut())
505            .await
506            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
507        write
508            .flush()
509            .await
510            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
511        Ok(())
512    }
513}
514
515impl<C, M, R, W, LEN, CMD> Drop for ClassifiedCmdSend<C, M, R, W, LEN, CMD>
516where
517    C: WorkerClassification,
518    M: CmdTunnelMeta,
519    R: ClassifiedCmdTunnelRead<C, M>,
520    W: ClassifiedCmdTunnelWrite<C, M>,
521    LEN: RawEncode
522        + for<'a> RawDecode<'a>
523        + Copy
524        + Send
525        + Sync
526        + 'static
527        + FromPrimitive
528        + ToPrimitive,
529    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
530{
531    fn drop(&mut self) {
532        self.set_disable();
533    }
534}
535
536impl<C, M, R, W, LEN, CMD> CmdSend<M> for ClassifiedCmdSend<C, M, R, W, LEN, CMD>
537where
538    C: WorkerClassification,
539    M: CmdTunnelMeta,
540    R: ClassifiedCmdTunnelRead<C, M>,
541    W: ClassifiedCmdTunnelWrite<C, M>,
542    LEN: RawEncode
543        + for<'a> RawDecode<'a>
544        + Copy
545        + Send
546        + Sync
547        + 'static
548        + FromPrimitive
549        + ToPrimitive,
550    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
551{
552    fn get_tunnel_meta(&self) -> Option<Arc<M>> {
553        self.tunnel_meta.clone()
554    }
555
556    fn get_remote_peer_id(&self) -> PeerId {
557        self.remote_id.clone()
558    }
559}
560
561impl<C, M, R, W, LEN, CMD> ClassifiedWorker<CmdClientTunnelClassification<C>>
562    for ClassifiedCmdSend<C, M, R, W, LEN, CMD>
563where
564    C: WorkerClassification,
565    M: CmdTunnelMeta,
566    R: ClassifiedCmdTunnelRead<C, M>,
567    W: ClassifiedCmdTunnelWrite<C, M>,
568    LEN: RawEncode
569        + for<'a> RawDecode<'a>
570        + Copy
571        + Send
572        + Sync
573        + 'static
574        + FromPrimitive
575        + ToPrimitive,
576    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
577{
578    fn is_work(&self) -> bool {
579        self.is_work && !self.recv_handle.is_finished()
580    }
581
582    fn is_valid(&self, c: CmdClientTunnelClassification<C>) -> bool {
583        if c.tunnel_id.is_some() {
584            self.tunnel_id == c.tunnel_id.unwrap()
585        } else {
586            if c.classification.is_some() {
587                self.classification == c.classification.unwrap()
588            } else {
589                true
590            }
591        }
592    }
593
594    fn classification(&self) -> CmdClientTunnelClassification<C> {
595        CmdClientTunnelClassification {
596            tunnel_id: self.pool_tunnel_id,
597            classification: self.pool_classification.clone(),
598        }
599    }
600}
601
602pub struct ClassifiedCmdWriteFactory<
603    C: WorkerClassification,
604    M: CmdTunnelMeta,
605    R: ClassifiedCmdTunnelRead<C, M>,
606    W: ClassifiedCmdTunnelWrite<C, M>,
607    F: ClassifiedCmdTunnelFactory<C, M, R, W>,
608    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
609    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
610> {
611    tunnel_factory: F,
612    cmd_handler: Arc<dyn CmdHandler<LEN, CMD>>,
613    resp_waiter: RespWaiterRef,
614    tunnel_id_generator: TunnelIdGenerator,
615    _p: std::marker::PhantomData<Mutex<(C, M, R, W)>>,
616}
617
618impl<
619    C: WorkerClassification,
620    M: CmdTunnelMeta,
621    R: ClassifiedCmdTunnelRead<C, M>,
622    W: ClassifiedCmdTunnelWrite<C, M>,
623    F: ClassifiedCmdTunnelFactory<C, M, R, W>,
624    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
625    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
626> ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>
627{
628    pub(crate) fn new(
629        tunnel_factory: F,
630        cmd_handler: impl CmdHandler<LEN, CMD>,
631        resp_waiter: RespWaiterRef,
632    ) -> Self {
633        Self {
634            tunnel_factory,
635            cmd_handler: Arc::new(cmd_handler),
636            resp_waiter,
637            tunnel_id_generator: TunnelIdGenerator::new(),
638            _p: Default::default(),
639        }
640    }
641}
642
643#[async_trait::async_trait]
644impl<
645    C: WorkerClassification,
646    M: CmdTunnelMeta,
647    R: ClassifiedCmdTunnelRead<C, M>,
648    W: ClassifiedCmdTunnelWrite<C, M>,
649    F: ClassifiedCmdTunnelFactory<C, M, R, W>,
650    LEN: RawEncode
651        + for<'a> RawDecode<'a>
652        + Copy
653        + Send
654        + Sync
655        + 'static
656        + FromPrimitive
657        + ToPrimitive
658        + RawFixedBytes,
659    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Debug,
660> ClassifiedWorkerFactory<CmdClientTunnelClassification<C>, ClassifiedCmdSend<C, M, R, W, LEN, CMD>>
661    for ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>
662{
663    async fn create(
664        &self,
665        classification: Option<CmdClientTunnelClassification<C>>,
666    ) -> PoolResult<ClassifiedCmdSend<C, M, R, W, LEN, CMD>> {
667        if classification.is_some() && classification.as_ref().unwrap().tunnel_id.is_some() {
668            return Err(pool_err!(
669                PoolErrorCode::Failed,
670                "tunnel {:?} not found",
671                classification.as_ref().unwrap().tunnel_id.unwrap()
672            ));
673        }
674
675        let requested_classification = classification;
676        let classification = requested_classification
677            .as_ref()
678            .and_then(|key| key.classification.clone());
679        let tunnel = self
680            .tunnel_factory
681            .create_tunnel(classification)
682            .await
683            .map_err(into_pool_err!(PoolErrorCode::Failed))?;
684        let classification = tunnel.get_classification();
685        let pool_tunnel_id = requested_classification
686            .as_ref()
687            .and_then(|key| key.tunnel_id);
688        let pool_classification = requested_classification
689            .as_ref()
690            .and_then(|key| key.classification.clone())
691            .or(Some(classification.clone()));
692        let peer_id = tunnel.get_remote_peer_id();
693        let tunnel_id = self.tunnel_id_generator.generate();
694        let (mut recv, write) = tunnel.split();
695        let local_id = recv.get_local_peer_id();
696        let remote_id = peer_id.clone();
697        let tunnel_meta = recv.get_tunnel_meta();
698        let write = ObjectHolder::new(write);
699        let resp_write = write.clone();
700        let cmd_handler = self.cmd_handler.clone();
701        let handle = spawn(async move {
702            let ret: CmdResult<()> = async move {
703                loop {
704                    let header_len = recv
705                        .read_u8()
706                        .await
707                        .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
708                    let mut header = vec![0u8; header_len as usize];
709                    let n = recv
710                        .read_exact(header.as_mut())
711                        .await
712                        .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
713                    if n == 0 {
714                        break;
715                    }
716                    let header = CmdHeader::<LEN, CMD>::clone_from_slice(header.as_slice())
717                        .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
718                    log::trace!(
719                        "recv cmd {:?} from {} len {} tunnel {:?}",
720                        header.cmd_code(),
721                        peer_id,
722                        header.pkg_len().to_u64().unwrap(),
723                        tunnel_id
724                    );
725                    let body_len = header.pkg_len().to_u64().unwrap();
726                    let cmd_read =
727                        CmdBodyRead::new(recv, header.pkg_len().to_u64().unwrap() as usize);
728                    let waiter = cmd_read.get_waiter();
729                    let future = waiter
730                        .create_result_future()
731                        .map_err(into_cmd_err!(CmdErrorCode::Failed))?;
732                    let version = header.version();
733                    let seq = header.seq();
734                    let cmd_code = header.cmd_code();
735                    match cmd_handler
736                        .handle(
737                            local_id.clone(),
738                            peer_id.clone(),
739                            tunnel_id,
740                            header,
741                            CmdBody::from_reader(BufReader::new(cmd_read), body_len),
742                        )
743                        .await
744                    {
745                        Ok(Some(mut body)) => {
746                            let mut write = resp_write.get().await;
747                            let header = CmdHeader::<LEN, CMD>::new(
748                                version,
749                                true,
750                                seq,
751                                cmd_code,
752                                LEN::from_u64(body.len()).unwrap(),
753                            );
754                            let buf = header
755                                .to_vec()
756                                .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
757                            if buf.len() > 255 {
758                                return Err(cmd_err!(
759                                    CmdErrorCode::InvalidParam,
760                                    "header len too large"
761                                ));
762                            }
763                            write
764                                .write_u8(buf.len() as u8)
765                                .await
766                                .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
767                            write
768                                .write_all(buf.as_slice())
769                                .await
770                                .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
771                            tokio::io::copy(&mut body, write.deref_mut().deref_mut())
772                                .await
773                                .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
774                            write
775                                .flush()
776                                .await
777                                .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
778                        }
779                        Err(e) => {
780                            log::error!("handle cmd error: {:?}", e);
781                        }
782                        _ => {}
783                    }
784                    recv = future
785                        .await
786                        .map_err(into_cmd_err!(CmdErrorCode::Failed))??;
787                    log::debug!(
788                        "handle cmd {:?} from {} len {} tunnel {:?} complete",
789                        cmd_code,
790                        peer_id,
791                        body_len,
792                        tunnel_id
793                    );
794                }
795                Ok(())
796            }
797            .await;
798            if ret.is_err() {
799                log::error!("recv cmd error: {:?}", ret.as_ref().err().unwrap());
800            }
801            ret
802        });
803        Ok(ClassifiedCmdSend::new(
804            tunnel_id,
805            classification,
806            None,
807            pool_tunnel_id,
808            pool_classification,
809            handle,
810            write,
811            self.resp_waiter.clone(),
812            remote_id,
813            tunnel_meta,
814        ))
815    }
816}
817
818pub struct DefaultClassifiedCmdClient<
819    C: WorkerClassification,
820    M: CmdTunnelMeta,
821    R: ClassifiedCmdTunnelRead<C, M>,
822    W: ClassifiedCmdTunnelWrite<C, M>,
823    F: ClassifiedCmdTunnelFactory<C, M, R, W>,
824    LEN: RawEncode
825        + for<'a> RawDecode<'a>
826        + Copy
827        + Send
828        + Sync
829        + 'static
830        + FromPrimitive
831        + ToPrimitive
832        + RawFixedBytes,
833    CMD: RawEncode
834        + for<'a> RawDecode<'a>
835        + Copy
836        + Send
837        + Sync
838        + 'static
839        + RawFixedBytes
840        + Eq
841        + Hash
842        + Debug,
843> {
844    tunnel_pool: ClassifiedWorkerPoolRef<
845        CmdClientTunnelClassification<C>,
846        ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
847        ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>,
848    >,
849    cmd_handler_map: Arc<CmdHandlerMap<LEN, CMD>>,
850}
851
852impl<
853    C: WorkerClassification,
854    M: CmdTunnelMeta,
855    R: ClassifiedCmdTunnelRead<C, M>,
856    W: ClassifiedCmdTunnelWrite<C, M>,
857    F: ClassifiedCmdTunnelFactory<C, M, R, W>,
858    LEN: RawEncode
859        + for<'a> RawDecode<'a>
860        + Copy
861        + Send
862        + Sync
863        + 'static
864        + FromPrimitive
865        + ToPrimitive
866        + RawFixedBytes,
867    CMD: RawEncode
868        + for<'a> RawDecode<'a>
869        + Copy
870        + Send
871        + Sync
872        + 'static
873        + RawFixedBytes
874        + Eq
875        + Hash
876        + Debug,
877> DefaultClassifiedCmdClient<C, M, R, W, F, LEN, CMD>
878{
879    pub fn new(factory: F, tunnel_count: u16) -> Arc<Self> {
880        let cmd_handler_map = Arc::new(CmdHandlerMap::new());
881        let resp_waiter = Arc::new(RespWaiter::new());
882        let handler_map = cmd_handler_map.clone();
883        let waiter = resp_waiter.clone();
884        Arc::new(Self {
885            tunnel_pool: ClassifiedWorkerPool::new(
886                tunnel_count,
887                ClassifiedCmdWriteFactory::<C, M, R, W, _, LEN, CMD>::new(
888                    factory,
889                    move |local_id: PeerId,
890                          peer_id: PeerId,
891                          tunnel_id: TunnelId,
892                          header: CmdHeader<LEN, CMD>,
893                          body_read: CmdBody| {
894                        let handler_map = handler_map.clone();
895                        let waiter = waiter.clone();
896                        async move {
897                            if header.is_resp() && header.seq().is_some() {
898                                let resp_id = gen_resp_id(
899                                    tunnel_id,
900                                    header.cmd_code(),
901                                    header.seq().unwrap(),
902                                );
903                                let _ = waiter.set_result(resp_id, body_read);
904                                Ok(None)
905                            } else {
906                                if let Some(handler) = handler_map.get(header.cmd_code()) {
907                                    handler
908                                        .handle(local_id, peer_id, tunnel_id, header, body_read)
909                                        .await
910                                } else {
911                                    Ok(None)
912                                }
913                            }
914                        }
915                    },
916                    resp_waiter.clone(),
917                ),
918            ),
919            cmd_handler_map,
920        })
921    }
922
923    async fn get_send(
924        &self,
925    ) -> CmdResult<
926        ClassifiedWorkerGuard<
927            CmdClientTunnelClassification<C>,
928            ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
929            ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>,
930        >,
931    > {
932        self.tunnel_pool
933            .get_worker()
934            .await
935            .map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
936    }
937
938    async fn get_send_of_tunnel_id(
939        &self,
940        tunnel_id: TunnelId,
941    ) -> CmdResult<
942        ClassifiedWorkerGuard<
943            CmdClientTunnelClassification<C>,
944            ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
945            ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>,
946        >,
947    > {
948        self.tunnel_pool
949            .get_classified_worker(CmdClientTunnelClassification {
950                tunnel_id: Some(tunnel_id),
951                classification: None,
952            })
953            .await
954            .map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
955    }
956
957    async fn get_classified_send(
958        &self,
959        classification: C,
960    ) -> CmdResult<
961        ClassifiedWorkerGuard<
962            CmdClientTunnelClassification<C>,
963            ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
964            ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>,
965        >,
966    > {
967        self.tunnel_pool
968            .get_classified_worker(CmdClientTunnelClassification {
969                tunnel_id: None,
970                classification: Some(classification),
971            })
972            .await
973            .map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
974    }
975}
976
977pub type ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD> = ClassifiedSendGuard<
978    CmdClientTunnelClassification<C>,
979    M,
980    ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
981    ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>,
982>;
983#[async_trait::async_trait]
984impl<
985    C: WorkerClassification,
986    M: CmdTunnelMeta,
987    R: ClassifiedCmdTunnelRead<C, M>,
988    W: ClassifiedCmdTunnelWrite<C, M>,
989    F: ClassifiedCmdTunnelFactory<C, M, R, W>,
990    LEN: RawEncode
991        + for<'a> RawDecode<'a>
992        + Copy
993        + Send
994        + Sync
995        + 'static
996        + FromPrimitive
997        + ToPrimitive
998        + RawFixedBytes,
999    CMD: RawEncode
1000        + for<'a> RawDecode<'a>
1001        + Copy
1002        + Send
1003        + Sync
1004        + 'static
1005        + RawFixedBytes
1006        + Eq
1007        + Hash
1008        + Debug,
1009>
1010    CmdClient<
1011        LEN,
1012        CMD,
1013        M,
1014        ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
1015        ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD>,
1016    > for DefaultClassifiedCmdClient<C, M, R, W, F, LEN, CMD>
1017{
1018    fn register_cmd_handler(&self, cmd: CMD, handler: impl CmdHandler<LEN, CMD>) {
1019        self.cmd_handler_map.insert(cmd, handler);
1020    }
1021
1022    async fn send(&self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
1023        let mut send = self.get_send().await?;
1024        send.send(cmd, version, body).await
1025    }
1026
1027    async fn send_with_resp(
1028        &self,
1029        cmd: CMD,
1030        version: u8,
1031        body: &[u8],
1032        timeout: Duration,
1033    ) -> CmdResult<CmdBody> {
1034        let mut send = self.get_send().await?;
1035        send.send_with_resp(cmd, version, body, timeout).await
1036    }
1037
1038    async fn send_parts(&self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
1039        let mut send = self.get_send().await?;
1040        send.send_parts(cmd, version, body).await
1041    }
1042
1043    async fn send_parts_with_resp(
1044        &self,
1045        cmd: CMD,
1046        version: u8,
1047        body: &[&[u8]],
1048        timeout: Duration,
1049    ) -> CmdResult<CmdBody> {
1050        let mut send = self.get_send().await?;
1051        send.send_parts_with_resp(cmd, version, body, timeout).await
1052    }
1053
1054    async fn send_cmd(&self, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
1055        let mut send = self.get_send().await?;
1056        send.send_cmd(cmd, version, body).await
1057    }
1058
1059    async fn send_cmd_with_resp(
1060        &self,
1061        cmd: CMD,
1062        version: u8,
1063        body: CmdBody,
1064        timeout: Duration,
1065    ) -> CmdResult<CmdBody> {
1066        let mut send = self.get_send().await?;
1067        send.send_cmd_with_resp(cmd, version, body, timeout).await
1068    }
1069
1070    async fn send_by_specify_tunnel(
1071        &self,
1072        tunnel_id: TunnelId,
1073        cmd: CMD,
1074        version: u8,
1075        body: &[u8],
1076    ) -> CmdResult<()> {
1077        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1078        send.send(cmd, version, body).await
1079    }
1080
1081    async fn send_by_specify_tunnel_with_resp(
1082        &self,
1083        tunnel_id: TunnelId,
1084        cmd: CMD,
1085        version: u8,
1086        body: &[u8],
1087        timeout: Duration,
1088    ) -> CmdResult<CmdBody> {
1089        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1090        send.send_with_resp(cmd, version, body, timeout).await
1091    }
1092
1093    async fn send_parts_by_specify_tunnel(
1094        &self,
1095        tunnel_id: TunnelId,
1096        cmd: CMD,
1097        version: u8,
1098        body: &[&[u8]],
1099    ) -> CmdResult<()> {
1100        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1101        send.send_parts(cmd, version, body).await
1102    }
1103
1104    async fn send_parts_by_specify_tunnel_with_resp(
1105        &self,
1106        tunnel_id: TunnelId,
1107        cmd: CMD,
1108        version: u8,
1109        body: &[&[u8]],
1110        timeout: Duration,
1111    ) -> CmdResult<CmdBody> {
1112        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1113        send.send_parts_with_resp(cmd, version, body, timeout).await
1114    }
1115
1116    async fn send_cmd_by_specify_tunnel(
1117        &self,
1118        tunnel_id: TunnelId,
1119        cmd: CMD,
1120        version: u8,
1121        body: CmdBody,
1122    ) -> CmdResult<()> {
1123        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1124        send.send_cmd(cmd, version, body).await
1125    }
1126
1127    async fn send_cmd_by_specify_tunnel_with_resp(
1128        &self,
1129        tunnel_id: TunnelId,
1130        cmd: CMD,
1131        version: u8,
1132        body: CmdBody,
1133        timeout: Duration,
1134    ) -> CmdResult<CmdBody> {
1135        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1136        send.send_cmd_with_resp(cmd, version, body, timeout).await
1137    }
1138
1139    async fn clear_all_tunnel(&self) {
1140        self.tunnel_pool.clear_all_worker().await;
1141    }
1142
1143    async fn get_send(
1144        &self,
1145        tunnel_id: TunnelId,
1146    ) -> CmdResult<ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD>> {
1147        Ok(ClassifiedSendGuard {
1148            worker_guard: self.get_send_of_tunnel_id(tunnel_id).await?,
1149            _p: std::marker::PhantomData,
1150        })
1151    }
1152}
1153
1154#[async_trait::async_trait]
1155impl<
1156    C: WorkerClassification,
1157    M: CmdTunnelMeta,
1158    R: ClassifiedCmdTunnelRead<C, M>,
1159    W: ClassifiedCmdTunnelWrite<C, M>,
1160    F: ClassifiedCmdTunnelFactory<C, M, R, W>,
1161    LEN: RawEncode
1162        + for<'a> RawDecode<'a>
1163        + Copy
1164        + Send
1165        + Sync
1166        + 'static
1167        + FromPrimitive
1168        + ToPrimitive
1169        + RawFixedBytes,
1170    CMD: RawEncode
1171        + for<'a> RawDecode<'a>
1172        + Copy
1173        + Send
1174        + Sync
1175        + 'static
1176        + RawFixedBytes
1177        + Eq
1178        + Hash
1179        + Debug,
1180>
1181    ClassifiedCmdClient<
1182        LEN,
1183        CMD,
1184        C,
1185        M,
1186        ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
1187        ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD>,
1188    > for DefaultClassifiedCmdClient<C, M, R, W, F, LEN, CMD>
1189{
1190    async fn send_by_classified_tunnel(
1191        &self,
1192        classification: C,
1193        cmd: CMD,
1194        version: u8,
1195        body: &[u8],
1196    ) -> CmdResult<()> {
1197        let mut send = self.get_classified_send(classification).await?;
1198        send.send(cmd, version, body).await
1199    }
1200
1201    async fn send_by_classified_tunnel_with_resp(
1202        &self,
1203        classification: C,
1204        cmd: CMD,
1205        version: u8,
1206        body: &[u8],
1207        timeout: Duration,
1208    ) -> CmdResult<CmdBody> {
1209        let mut send = self.get_classified_send(classification).await?;
1210        send.send_with_resp(cmd, version, body, timeout).await
1211    }
1212
1213    async fn send_parts_by_classified_tunnel(
1214        &self,
1215        classification: C,
1216        cmd: CMD,
1217        version: u8,
1218        body: &[&[u8]],
1219    ) -> CmdResult<()> {
1220        let mut send = self.get_classified_send(classification).await?;
1221        send.send_parts(cmd, version, body).await
1222    }
1223
1224    async fn send_parts_by_classified_tunnel_with_resp(
1225        &self,
1226        classification: C,
1227        cmd: CMD,
1228        version: u8,
1229        body: &[&[u8]],
1230        timeout: Duration,
1231    ) -> CmdResult<CmdBody> {
1232        let mut send = self.get_classified_send(classification).await?;
1233        send.send_parts_with_resp(cmd, version, body, timeout).await
1234    }
1235
1236    async fn send_cmd_by_classified_tunnel(
1237        &self,
1238        classification: C,
1239        cmd: CMD,
1240        version: u8,
1241        body: CmdBody,
1242    ) -> CmdResult<()> {
1243        let mut send = self.get_classified_send(classification).await?;
1244        send.send_cmd(cmd, version, body).await
1245    }
1246
1247    async fn send_cmd_by_classified_tunnel_with_resp(
1248        &self,
1249        classification: C,
1250        cmd: CMD,
1251        version: u8,
1252        body: CmdBody,
1253        timeout: Duration,
1254    ) -> CmdResult<CmdBody> {
1255        let mut send = self.get_classified_send(classification).await?;
1256        send.send_cmd_with_resp(cmd, version, body, timeout).await
1257    }
1258
1259    async fn find_tunnel_id_by_classified(&self, classification: C) -> CmdResult<TunnelId> {
1260        let send = self.get_classified_send(classification).await?;
1261        Ok(send.get_tunnel_id())
1262    }
1263
1264    async fn get_send_by_classified(
1265        &self,
1266        classification: C,
1267    ) -> CmdResult<
1268        ClassifiedSendGuard<
1269            CmdClientTunnelClassification<C>,
1270            M,
1271            ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
1272            ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>,
1273        >,
1274    > {
1275        Ok(ClassifiedSendGuard {
1276            worker_guard: self.get_classified_send(classification).await?,
1277            _p: std::marker::PhantomData,
1278        })
1279    }
1280}