Skip to main content

sfo_cmd_server/client/
classified_client.rs

1use crate::client::{
2    ClassifiedCmdClient, CmdClient, CmdSend, RespWaiter, RespWaiterRef, TrackedSendGuard,
3    TunnelReserveResult, TunnelRuntimeRegistry, gen_resp_id, gen_seq,
4};
5use crate::cmd::{CmdBodyRead, CmdHandler, CmdHandlerMap, CmdHeader};
6use crate::errors::{CmdErrorCode, CmdResult, cmd_err, into_cmd_err};
7use crate::peer_id::PeerId;
8use crate::{CmdBody, CmdTunnelMeta, CmdTunnelRead, CmdTunnelWrite, TunnelId, TunnelIdGenerator};
9use async_named_locker::ObjectHolder;
10use bucky_raw_codec::{RawConvertTo, RawDecode, RawEncode, RawFixedBytes, RawFrom};
11use num::{FromPrimitive, ToPrimitive};
12use sfo_pool::{
13    ClassifiedWorker, ClassifiedWorkerFactory, ClassifiedWorkerGuard, ClassifiedWorkerPool,
14    ClassifiedWorkerPoolRef, PoolErrorCode, PoolResult, WorkerClassification, into_pool_err,
15    pool_err,
16};
17use sfo_split::{RHalf, Splittable, WHalf};
18use std::fmt::Debug;
19use std::hash::Hash;
20use std::ops::DerefMut;
21use std::sync::{Arc, Mutex};
22use std::time::Duration;
23use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
24use tokio::spawn;
25use tokio::task::JoinHandle;
26use tokio::task::yield_now;
27
28pub trait ClassifiedCmdTunnelRead<C: WorkerClassification, M: CmdTunnelMeta>:
29    CmdTunnelRead<M> + 'static + Send
30{
31    fn get_classification(&self) -> C;
32}
33
34pub trait ClassifiedCmdTunnelWrite<C: WorkerClassification, M: CmdTunnelMeta>:
35    CmdTunnelWrite<M> + 'static + Send
36{
37    fn get_classification(&self) -> C;
38}
39
40pub type ClassifiedCmdTunnel<R, W> = Splittable<R, W>;
41pub type ClassifiedCmdTunnelRHalf<R, W> = RHalf<R, W>;
42pub type ClassifiedCmdTunnelWHalf<R, W> = WHalf<R, W>;
43
44#[derive(Debug, Clone, Copy, Eq, Hash)]
45pub struct CmdClientTunnelClassification<C: WorkerClassification> {
46    tunnel_id: Option<TunnelId>,
47    classification: Option<C>,
48}
49
50impl<C: WorkerClassification> PartialEq for CmdClientTunnelClassification<C> {
51    fn eq(&self, other: &Self) -> bool {
52        self.tunnel_id == other.tunnel_id && self.classification == other.classification
53    }
54}
55
56#[async_trait::async_trait]
57pub trait ClassifiedCmdTunnelFactory<
58    C: WorkerClassification,
59    M: CmdTunnelMeta,
60    R: ClassifiedCmdTunnelRead<C, M>,
61    W: ClassifiedCmdTunnelWrite<C, M>,
62>: Send + Sync + 'static
63{
64    async fn create_tunnel(&self, classification: Option<C>) -> CmdResult<Splittable<R, W>>;
65}
66
67pub struct ClassifiedCmdSend<C, M, R, W, LEN, CMD>
68where
69    C: WorkerClassification,
70    M: CmdTunnelMeta,
71    R: ClassifiedCmdTunnelRead<C, M>,
72    W: ClassifiedCmdTunnelWrite<C, M>,
73    LEN: RawEncode
74        + for<'a> RawDecode<'a>
75        + Copy
76        + Send
77        + Sync
78        + 'static
79        + FromPrimitive
80        + ToPrimitive,
81    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
82{
83    pub(crate) recv_handle: JoinHandle<CmdResult<()>>,
84    pub(crate) write: ObjectHolder<ClassifiedCmdTunnelWHalf<R, W>>,
85    pub(crate) is_work: bool,
86    pub(crate) classification: C,
87    pub(crate) pool_tunnel_id: Option<TunnelId>,
88    pub(crate) pool_classification: Option<C>,
89    pub(crate) tunnel_id: TunnelId,
90    pub(crate) resp_waiter: RespWaiterRef,
91    pub(crate) remote_id: PeerId,
92    pub(crate) pool_peer_id: Option<PeerId>,
93    pub(crate) tunnel_meta: Option<Arc<M>>,
94    _p: std::marker::PhantomData<(LEN, CMD)>,
95}
96
97// impl<C, R, W, LEN, CMD> Deref for ClassifiedCmdSend<C, R, W, LEN, CMD>
98// where C: WorkerClassification,
99//       R: ClassifiedCmdTunnelRead<C>,
100//       W: ClassifiedCmdTunnelWrite<C>,
101//       LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
102//       CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes {
103//     type Target = W;
104//
105//     fn deref(&self) -> &Self::Target {
106//         self.write.deref()
107//     }
108// }
109
110impl<C, M, R, W, LEN, CMD> ClassifiedCmdSend<C, M, R, W, LEN, CMD>
111where
112    C: WorkerClassification,
113    M: CmdTunnelMeta,
114    R: ClassifiedCmdTunnelRead<C, M>,
115    W: ClassifiedCmdTunnelWrite<C, M>,
116    LEN: RawEncode
117        + for<'a> RawDecode<'a>
118        + Copy
119        + Send
120        + Sync
121        + 'static
122        + FromPrimitive
123        + ToPrimitive,
124    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
125{
126    pub(crate) fn new(
127        tunnel_id: TunnelId,
128        classification: C,
129        pool_peer_id: Option<PeerId>,
130        pool_tunnel_id: Option<TunnelId>,
131        pool_classification: Option<C>,
132        recv_handle: JoinHandle<CmdResult<()>>,
133        write: ObjectHolder<ClassifiedCmdTunnelWHalf<R, W>>,
134        resp_waiter: RespWaiterRef,
135        remote_id: PeerId,
136        tunnel_meta: Option<Arc<M>>,
137    ) -> Self {
138        Self {
139            recv_handle,
140            write,
141            is_work: true,
142            classification,
143            pool_tunnel_id,
144            pool_classification,
145            tunnel_id,
146            resp_waiter,
147            remote_id,
148            pool_peer_id,
149            tunnel_meta,
150            _p: Default::default(),
151        }
152    }
153
154    pub fn get_tunnel_id(&self) -> TunnelId {
155        self.tunnel_id
156    }
157
158    pub fn set_disable(&mut self) {
159        self.is_work = false;
160        self.recv_handle.abort();
161    }
162
163    pub(crate) fn set_pool_key(
164        &mut self,
165        peer_id: Option<PeerId>,
166        tunnel_id: Option<TunnelId>,
167        classification: Option<C>,
168    ) {
169        self.pool_peer_id = peer_id;
170        self.pool_tunnel_id = tunnel_id;
171        self.pool_classification = classification;
172    }
173
174    pub async fn send(&mut self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
175        log::trace!(
176            "client {:?} send cmd: {:?}, len: {}, data: {}",
177            self.tunnel_id,
178            cmd,
179            body.len(),
180            hex::encode(body)
181        );
182        let header = CmdHeader::<LEN, CMD>::new(
183            version,
184            false,
185            None,
186            cmd,
187            LEN::from_u64(body.len() as u64).unwrap(),
188        );
189        let buf = header
190            .to_vec()
191            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
192        let ret = self.send_inner(buf.as_slice(), body).await;
193        if let Err(e) = ret {
194            self.set_disable();
195            return Err(e);
196        }
197        Ok(())
198    }
199
200    pub async fn send_with_resp(
201        &mut self,
202        cmd: CMD,
203        version: u8,
204        body: &[u8],
205        timeout: Duration,
206    ) -> CmdResult<CmdBody> {
207        if let Some(id) = tokio::task::try_id() {
208            if id == self.recv_handle.id() {
209                return Err(cmd_err!(
210                    CmdErrorCode::Failed,
211                    "can't send with resp in recv task"
212                ));
213            }
214        }
215        log::trace!(
216            "client {:?} send cmd: {:?}, len: {}, data: {}",
217            self.tunnel_id,
218            cmd,
219            body.len(),
220            hex::encode(body)
221        );
222        let seq = gen_seq();
223        let header = CmdHeader::<LEN, CMD>::new(
224            version,
225            false,
226            Some(seq),
227            cmd,
228            LEN::from_u64(body.len() as u64).unwrap(),
229        );
230        let buf = header
231            .to_vec()
232            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
233        let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
234        let waiter = self.resp_waiter.clone();
235        let resp_waiter = waiter
236            .create_timeout_result_future(resp_id, timeout)
237            .map_err(into_cmd_err!(
238                CmdErrorCode::Failed,
239                "create timeout result future error"
240            ))?;
241        let ret = self.send_inner(buf.as_slice(), body).await;
242        if let Err(e) = ret {
243            self.set_disable();
244            return Err(e);
245        }
246        let resp = resp_waiter
247            .await
248            .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
249        Ok(resp)
250    }
251
252    pub async fn send_parts(&mut self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
253        let mut len = 0;
254        for b in body.iter() {
255            len += b.len();
256            log::trace!(
257                "client {:?} send2 cmd {:?} body: {}",
258                self.tunnel_id,
259                cmd,
260                hex::encode(b)
261            );
262        }
263        log::trace!(
264            "client {:?} send2 cmd: {:?}, len {}",
265            self.tunnel_id,
266            cmd,
267            len
268        );
269        let header = CmdHeader::<LEN, CMD>::new(
270            version,
271            false,
272            None,
273            cmd,
274            LEN::from_u64(len as u64).unwrap(),
275        );
276        let buf = header
277            .to_vec()
278            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
279        let ret = self.send_inner2(buf.as_slice(), body).await;
280        if let Err(e) = ret {
281            self.set_disable();
282            return Err(e);
283        }
284        Ok(())
285    }
286
287    pub async fn send_parts_with_resp(
288        &mut self,
289        cmd: CMD,
290        version: u8,
291        body: &[&[u8]],
292        timeout: Duration,
293    ) -> CmdResult<CmdBody> {
294        if let Some(id) = tokio::task::try_id() {
295            if id == self.recv_handle.id() {
296                return Err(cmd_err!(
297                    CmdErrorCode::Failed,
298                    "can't send with resp in recv task"
299                ));
300            }
301        }
302        let mut len = 0;
303        for b in body.iter() {
304            len += b.len();
305            log::trace!(
306                "client {:?} send2 cmd {:?} body: {}",
307                self.tunnel_id,
308                cmd,
309                hex::encode(b)
310            );
311        }
312        log::trace!(
313            "client {:?} send2 cmd: {:?}, len {}",
314            self.tunnel_id,
315            cmd,
316            len
317        );
318        let seq = gen_seq();
319        let header = CmdHeader::<LEN, CMD>::new(
320            version,
321            false,
322            Some(seq),
323            cmd,
324            LEN::from_u64(len as u64).unwrap(),
325        );
326        let buf = header
327            .to_vec()
328            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
329        let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
330        let waiter = self.resp_waiter.clone();
331        let resp_waiter = waiter
332            .create_timeout_result_future(resp_id, timeout)
333            .map_err(into_cmd_err!(
334                CmdErrorCode::Failed,
335                "create timeout result future error"
336            ))?;
337        let ret = self.send_inner2(buf.as_slice(), body).await;
338        if let Err(e) = ret {
339            self.set_disable();
340            return Err(e);
341        }
342        let resp = resp_waiter
343            .await
344            .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
345        Ok(resp)
346    }
347
348    #[allow(deprecated)]
349    #[deprecated(note = "use send_parts instead")]
350    pub async fn send2(&mut self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
351        self.send_parts(cmd, version, body).await
352    }
353
354    #[allow(deprecated)]
355    #[deprecated(note = "use send_parts_with_resp instead")]
356    pub async fn send2_with_resp(
357        &mut self,
358        cmd: CMD,
359        version: u8,
360        body: &[&[u8]],
361        timeout: Duration,
362    ) -> CmdResult<CmdBody> {
363        self.send_parts_with_resp(cmd, version, body, timeout).await
364    }
365
366    pub async fn send_cmd(&mut self, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
367        log::trace!(
368            "client {:?} send cmd: {:?}, len: {}",
369            self.tunnel_id,
370            cmd,
371            body.len()
372        );
373        let header = CmdHeader::<LEN, CMD>::new(
374            version,
375            false,
376            None,
377            cmd,
378            LEN::from_u64(body.len()).unwrap(),
379        );
380        let buf = header
381            .to_vec()
382            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
383        let ret = self.send_inner_cmd(buf.as_slice(), body).await;
384        if let Err(e) = ret {
385            self.set_disable();
386            return Err(e);
387        }
388        Ok(())
389    }
390
391    pub async fn send_cmd_with_resp(
392        &mut self,
393        cmd: CMD,
394        version: u8,
395        body: CmdBody,
396        timeout: Duration,
397    ) -> CmdResult<CmdBody> {
398        if let Some(id) = tokio::task::try_id() {
399            if id == self.recv_handle.id() {
400                return Err(cmd_err!(
401                    CmdErrorCode::Failed,
402                    "can't send with resp in recv task"
403                ));
404            }
405        }
406        log::trace!(
407            "client {:?} send cmd: {:?}, len: {}",
408            self.tunnel_id,
409            cmd,
410            body.len()
411        );
412        let seq = gen_seq();
413        let header = CmdHeader::<LEN, CMD>::new(
414            version,
415            false,
416            Some(seq),
417            cmd,
418            LEN::from_u64(body.len()).unwrap(),
419        );
420        let buf = header
421            .to_vec()
422            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
423        let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
424        let waiter = self.resp_waiter.clone();
425        let resp_waiter = waiter
426            .create_timeout_result_future(resp_id, timeout)
427            .map_err(into_cmd_err!(
428                CmdErrorCode::Failed,
429                "create timeout result future error"
430            ))?;
431        let ret = self.send_inner_cmd(buf.as_slice(), body).await;
432        if let Err(e) = ret {
433            self.set_disable();
434            return Err(e);
435        }
436        let resp = resp_waiter
437            .await
438            .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
439        Ok(resp)
440    }
441
442    async fn send_inner(&mut self, header: &[u8], body: &[u8]) -> CmdResult<()> {
443        let mut write = self.write.get().await;
444        if header.len() > 255 {
445            return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
446        }
447        write
448            .write_u8(header.len() as u8)
449            .await
450            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
451        write
452            .write_all(header)
453            .await
454            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
455        write
456            .write_all(body)
457            .await
458            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
459        write
460            .flush()
461            .await
462            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
463        Ok(())
464    }
465
466    async fn send_inner2(&mut self, header: &[u8], body: &[&[u8]]) -> CmdResult<()> {
467        let mut write = self.write.get().await;
468        if header.len() > 255 {
469            return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
470        }
471        write
472            .write_u8(header.len() as u8)
473            .await
474            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
475        write
476            .write_all(header)
477            .await
478            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
479        for b in body.iter() {
480            write
481                .write_all(b)
482                .await
483                .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
484        }
485        write
486            .flush()
487            .await
488            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
489        Ok(())
490    }
491
492    async fn send_inner_cmd(&mut self, header: &[u8], mut body: CmdBody) -> CmdResult<()> {
493        let mut write = self.write.get().await;
494        if header.len() > 255 {
495            return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
496        }
497        write
498            .write_u8(header.len() as u8)
499            .await
500            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
501        write
502            .write_all(header)
503            .await
504            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
505        tokio::io::copy(&mut body, write.deref_mut().deref_mut())
506            .await
507            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
508        write
509            .flush()
510            .await
511            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
512        Ok(())
513    }
514}
515
516impl<C, M, R, W, LEN, CMD> Drop for ClassifiedCmdSend<C, M, R, W, LEN, CMD>
517where
518    C: WorkerClassification,
519    M: CmdTunnelMeta,
520    R: ClassifiedCmdTunnelRead<C, M>,
521    W: ClassifiedCmdTunnelWrite<C, M>,
522    LEN: RawEncode
523        + for<'a> RawDecode<'a>
524        + Copy
525        + Send
526        + Sync
527        + 'static
528        + FromPrimitive
529        + ToPrimitive,
530    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
531{
532    fn drop(&mut self) {
533        self.set_disable();
534    }
535}
536
537impl<C, M, R, W, LEN, CMD> CmdSend<M> for ClassifiedCmdSend<C, M, R, W, LEN, CMD>
538where
539    C: WorkerClassification,
540    M: CmdTunnelMeta,
541    R: ClassifiedCmdTunnelRead<C, M>,
542    W: ClassifiedCmdTunnelWrite<C, M>,
543    LEN: RawEncode
544        + for<'a> RawDecode<'a>
545        + Copy
546        + Send
547        + Sync
548        + 'static
549        + FromPrimitive
550        + ToPrimitive,
551    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
552{
553    fn get_tunnel_meta(&self) -> Option<Arc<M>> {
554        self.tunnel_meta.clone()
555    }
556
557    fn get_remote_peer_id(&self) -> PeerId {
558        self.remote_id.clone()
559    }
560}
561
562impl<C, M, R, W, LEN, CMD> ClassifiedWorker<CmdClientTunnelClassification<C>>
563    for ClassifiedCmdSend<C, M, R, W, LEN, CMD>
564where
565    C: WorkerClassification,
566    M: CmdTunnelMeta,
567    R: ClassifiedCmdTunnelRead<C, M>,
568    W: ClassifiedCmdTunnelWrite<C, M>,
569    LEN: RawEncode
570        + for<'a> RawDecode<'a>
571        + Copy
572        + Send
573        + Sync
574        + 'static
575        + FromPrimitive
576        + ToPrimitive,
577    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
578{
579    fn is_work(&self) -> bool {
580        self.is_work && !self.recv_handle.is_finished()
581    }
582
583    fn is_valid(&self, c: CmdClientTunnelClassification<C>) -> bool {
584        if c.tunnel_id.is_some() {
585            self.tunnel_id == c.tunnel_id.unwrap()
586        } else {
587            if c.classification.is_some() {
588                self.classification == c.classification.unwrap()
589            } else {
590                true
591            }
592        }
593    }
594
595    fn classification(&self) -> CmdClientTunnelClassification<C> {
596        CmdClientTunnelClassification {
597            tunnel_id: self.pool_tunnel_id,
598            classification: self.pool_classification.clone(),
599        }
600    }
601}
602
603pub struct ClassifiedCmdWriteFactory<
604    C: WorkerClassification,
605    M: CmdTunnelMeta,
606    R: ClassifiedCmdTunnelRead<C, M>,
607    W: ClassifiedCmdTunnelWrite<C, M>,
608    F: ClassifiedCmdTunnelFactory<C, M, R, W>,
609    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
610    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
611> {
612    tunnel_factory: F,
613    cmd_handler: Arc<dyn CmdHandler<LEN, CMD>>,
614    resp_waiter: RespWaiterRef,
615    tunnel_id_generator: TunnelIdGenerator,
616    _p: std::marker::PhantomData<Mutex<(C, M, R, W)>>,
617}
618
619impl<
620    C: WorkerClassification,
621    M: CmdTunnelMeta,
622    R: ClassifiedCmdTunnelRead<C, M>,
623    W: ClassifiedCmdTunnelWrite<C, M>,
624    F: ClassifiedCmdTunnelFactory<C, M, R, W>,
625    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
626    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
627> ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>
628{
629    pub(crate) fn new(
630        tunnel_factory: F,
631        cmd_handler: impl CmdHandler<LEN, CMD>,
632        resp_waiter: RespWaiterRef,
633    ) -> Self {
634        Self {
635            tunnel_factory,
636            cmd_handler: Arc::new(cmd_handler),
637            resp_waiter,
638            tunnel_id_generator: TunnelIdGenerator::new(),
639            _p: Default::default(),
640        }
641    }
642}
643
644#[async_trait::async_trait]
645impl<
646    C: WorkerClassification,
647    M: CmdTunnelMeta,
648    R: ClassifiedCmdTunnelRead<C, M>,
649    W: ClassifiedCmdTunnelWrite<C, M>,
650    F: ClassifiedCmdTunnelFactory<C, M, R, W>,
651    LEN: RawEncode
652        + for<'a> RawDecode<'a>
653        + Copy
654        + Send
655        + Sync
656        + 'static
657        + FromPrimitive
658        + ToPrimitive
659        + RawFixedBytes,
660    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Debug,
661> ClassifiedWorkerFactory<CmdClientTunnelClassification<C>, ClassifiedCmdSend<C, M, R, W, LEN, CMD>>
662    for ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>
663{
664    async fn create(
665        &self,
666        classification: Option<CmdClientTunnelClassification<C>>,
667    ) -> PoolResult<ClassifiedCmdSend<C, M, R, W, LEN, CMD>> {
668        if classification.is_some() && classification.as_ref().unwrap().tunnel_id.is_some() {
669            return Err(pool_err!(
670                PoolErrorCode::Failed,
671                "tunnel {:?} not found",
672                classification.as_ref().unwrap().tunnel_id.unwrap()
673            ));
674        }
675
676        let requested_classification = classification;
677        let classification = requested_classification
678            .as_ref()
679            .and_then(|key| key.classification.clone());
680        let tunnel = self
681            .tunnel_factory
682            .create_tunnel(classification)
683            .await
684            .map_err(into_pool_err!(PoolErrorCode::Failed))?;
685        let classification = tunnel.get_classification();
686        let pool_tunnel_id = requested_classification
687            .as_ref()
688            .and_then(|key| key.tunnel_id);
689        let pool_classification = requested_classification
690            .as_ref()
691            .and_then(|key| key.classification.clone())
692            .or(Some(classification.clone()));
693        let peer_id = tunnel.get_remote_peer_id();
694        let tunnel_id = self.tunnel_id_generator.generate();
695        let (mut recv, write) = tunnel.split();
696        let local_id = recv.get_local_peer_id();
697        let remote_id = peer_id.clone();
698        let tunnel_meta = recv.get_tunnel_meta();
699        let write = ObjectHolder::new(write);
700        let resp_write = write.clone();
701        let cmd_handler = self.cmd_handler.clone();
702        let handle = spawn(async move {
703            let ret: CmdResult<()> = async move {
704                loop {
705                    let header_len = recv
706                        .read_u8()
707                        .await
708                        .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
709                    let mut header = vec![0u8; header_len as usize];
710                    let n = recv
711                        .read_exact(header.as_mut())
712                        .await
713                        .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
714                    if n == 0 {
715                        break;
716                    }
717                    let header = CmdHeader::<LEN, CMD>::clone_from_slice(header.as_slice())
718                        .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
719                    log::trace!(
720                        "recv cmd {:?} from {} len {} tunnel {:?}",
721                        header.cmd_code(),
722                        peer_id,
723                        header.pkg_len().to_u64().unwrap(),
724                        tunnel_id
725                    );
726                    let body_len = header.pkg_len().to_u64().unwrap();
727                    let cmd_read =
728                        CmdBodyRead::new(recv, header.pkg_len().to_u64().unwrap() as usize);
729                    let waiter = cmd_read.get_waiter();
730                    let future = waiter
731                        .create_result_future()
732                        .map_err(into_cmd_err!(CmdErrorCode::Failed))?;
733                    let version = header.version();
734                    let seq = header.seq();
735                    let cmd_code = header.cmd_code();
736                    match cmd_handler
737                        .handle(
738                            local_id.clone(),
739                            peer_id.clone(),
740                            tunnel_id,
741                            header,
742                            CmdBody::from_reader(BufReader::new(cmd_read), body_len),
743                        )
744                        .await
745                    {
746                        Ok(Some(mut body)) => {
747                            let mut write = resp_write.get().await;
748                            let header = CmdHeader::<LEN, CMD>::new(
749                                version,
750                                true,
751                                seq,
752                                cmd_code,
753                                LEN::from_u64(body.len()).unwrap(),
754                            );
755                            let buf = header
756                                .to_vec()
757                                .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
758                            if buf.len() > 255 {
759                                return Err(cmd_err!(
760                                    CmdErrorCode::InvalidParam,
761                                    "header len too large"
762                                ));
763                            }
764                            write
765                                .write_u8(buf.len() as u8)
766                                .await
767                                .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
768                            write
769                                .write_all(buf.as_slice())
770                                .await
771                                .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
772                            tokio::io::copy(&mut body, write.deref_mut().deref_mut())
773                                .await
774                                .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
775                            write
776                                .flush()
777                                .await
778                                .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
779                        }
780                        Err(e) => {
781                            log::error!("handle cmd error: {:?}", e);
782                        }
783                        _ => {}
784                    }
785                    recv = future
786                        .await
787                        .map_err(into_cmd_err!(CmdErrorCode::Failed))??;
788                    log::debug!(
789                        "handle cmd {:?} from {} len {} tunnel {:?} complete",
790                        cmd_code,
791                        peer_id,
792                        body_len,
793                        tunnel_id
794                    );
795                }
796                Ok(())
797            }
798            .await;
799            if ret.is_err() {
800                log::error!("recv cmd error: {:?}", ret.as_ref().err().unwrap());
801            }
802            ret
803        });
804        Ok(ClassifiedCmdSend::new(
805            tunnel_id,
806            classification,
807            None,
808            pool_tunnel_id,
809            pool_classification,
810            handle,
811            write,
812            self.resp_waiter.clone(),
813            remote_id,
814            tunnel_meta,
815        ))
816    }
817}
818
819pub struct DefaultClassifiedCmdClient<
820    C: WorkerClassification,
821    M: CmdTunnelMeta,
822    R: ClassifiedCmdTunnelRead<C, M>,
823    W: ClassifiedCmdTunnelWrite<C, M>,
824    F: ClassifiedCmdTunnelFactory<C, M, R, W>,
825    LEN: RawEncode
826        + for<'a> RawDecode<'a>
827        + Copy
828        + Send
829        + Sync
830        + 'static
831        + FromPrimitive
832        + ToPrimitive
833        + RawFixedBytes,
834    CMD: RawEncode
835        + for<'a> RawDecode<'a>
836        + Copy
837        + Send
838        + Sync
839        + 'static
840        + RawFixedBytes
841        + Eq
842        + Hash
843        + Debug,
844> {
845    tunnel_pool: ClassifiedWorkerPoolRef<
846        CmdClientTunnelClassification<C>,
847        ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
848        ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>,
849    >,
850    runtime_tunnels: Arc<TunnelRuntimeRegistry>,
851    cmd_handler_map: Arc<CmdHandlerMap<LEN, CMD>>,
852}
853
854impl<
855    C: WorkerClassification,
856    M: CmdTunnelMeta,
857    R: ClassifiedCmdTunnelRead<C, M>,
858    W: ClassifiedCmdTunnelWrite<C, M>,
859    F: ClassifiedCmdTunnelFactory<C, M, R, W>,
860    LEN: RawEncode
861        + for<'a> RawDecode<'a>
862        + Copy
863        + Send
864        + Sync
865        + 'static
866        + FromPrimitive
867        + ToPrimitive
868        + RawFixedBytes,
869    CMD: RawEncode
870        + for<'a> RawDecode<'a>
871        + Copy
872        + Send
873        + Sync
874        + 'static
875        + RawFixedBytes
876        + Eq
877        + Hash
878        + Debug,
879> DefaultClassifiedCmdClient<C, M, R, W, F, LEN, CMD>
880{
881    fn tunnel_not_found(tunnel_id: TunnelId) -> sfo_result::Error<CmdErrorCode> {
882        cmd_err!(CmdErrorCode::Failed, "tunnel {:?} not found", tunnel_id)
883    }
884
885    pub fn new(factory: F, tunnel_count: u16) -> Arc<Self> {
886        let cmd_handler_map = Arc::new(CmdHandlerMap::new());
887        let resp_waiter = Arc::new(RespWaiter::new());
888        let handler_map = cmd_handler_map.clone();
889        let waiter = resp_waiter.clone();
890        Arc::new(Self {
891            tunnel_pool: ClassifiedWorkerPool::new(
892                tunnel_count,
893                ClassifiedCmdWriteFactory::<C, M, R, W, _, LEN, CMD>::new(
894                    factory,
895                    move |local_id: PeerId,
896                          peer_id: PeerId,
897                          tunnel_id: TunnelId,
898                          header: CmdHeader<LEN, CMD>,
899                          body_read: CmdBody| {
900                        let handler_map = handler_map.clone();
901                        let waiter = waiter.clone();
902                        async move {
903                            if header.is_resp() && header.seq().is_some() {
904                                let resp_id = gen_resp_id(
905                                    tunnel_id,
906                                    header.cmd_code(),
907                                    header.seq().unwrap(),
908                                );
909                                let _ = waiter.set_result(resp_id, body_read);
910                                Ok(None)
911                            } else {
912                                if let Some(handler) = handler_map.get(header.cmd_code()) {
913                                    handler
914                                        .handle(local_id, peer_id, tunnel_id, header, body_read)
915                                        .await
916                                } else {
917                                    Ok(None)
918                                }
919                            }
920                        }
921                    },
922                    resp_waiter.clone(),
923                ),
924            ),
925            runtime_tunnels: TunnelRuntimeRegistry::new(),
926            cmd_handler_map,
927        })
928    }
929
930    fn tracked_send_guard(
931        &self,
932        worker_guard: ClassifiedWorkerGuard<
933            CmdClientTunnelClassification<C>,
934            ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
935            ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>,
936        >,
937    ) -> ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD> {
938        let tunnel_id = worker_guard.get_tunnel_id();
939        TrackedSendGuard::new(worker_guard, self.runtime_tunnels.clone(), tunnel_id)
940    }
941
942    async fn reserve_tunnel(&self, tunnel_id: TunnelId) -> CmdResult<()> {
943        loop {
944            match self.runtime_tunnels.reserve_existing(tunnel_id) {
945                TunnelReserveResult::Acquired => return Ok(()),
946                TunnelReserveResult::Wait(waiter) => waiter.notified().await,
947                TunnelReserveResult::Missing => return Err(Self::tunnel_not_found(tunnel_id)),
948            }
949        }
950    }
951
952    async fn get_send(&self) -> CmdResult<ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD>> {
953        loop {
954            let worker_guard = self
955                .tunnel_pool
956                .get_worker()
957                .await
958                .map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))?;
959            let tunnel_id = worker_guard.get_tunnel_id();
960            if self.runtime_tunnels.mark_borrowed(tunnel_id) {
961                return Ok(self.tracked_send_guard(worker_guard));
962            }
963            drop(worker_guard);
964            yield_now().await;
965        }
966    }
967
968    async fn get_send_of_tunnel_id(
969        &self,
970        tunnel_id: TunnelId,
971    ) -> CmdResult<ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD>> {
972        self.reserve_tunnel(tunnel_id).await?;
973        match self
974            .tunnel_pool
975            .get_classified_worker(CmdClientTunnelClassification {
976                tunnel_id: Some(tunnel_id),
977                classification: None,
978            })
979            .await
980        {
981            Ok(worker_guard) => Ok(self.tracked_send_guard(worker_guard)),
982            Err(_) => {
983                self.runtime_tunnels.remove(tunnel_id);
984                Err(Self::tunnel_not_found(tunnel_id))
985            }
986        }
987    }
988
989    async fn get_classified_send(
990        &self,
991        classification: C,
992    ) -> CmdResult<ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD>> {
993        loop {
994            let worker_guard = self
995                .tunnel_pool
996                .get_classified_worker(CmdClientTunnelClassification {
997                    tunnel_id: None,
998                    classification: Some(classification.clone()),
999                })
1000                .await
1001                .map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))?;
1002            let tunnel_id = worker_guard.get_tunnel_id();
1003            if self.runtime_tunnels.mark_borrowed(tunnel_id) {
1004                return Ok(self.tracked_send_guard(worker_guard));
1005            }
1006            drop(worker_guard);
1007            yield_now().await;
1008        }
1009    }
1010}
1011
1012pub type ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD> = TrackedSendGuard<
1013    CmdClientTunnelClassification<C>,
1014    M,
1015    ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
1016    ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>,
1017>;
1018#[async_trait::async_trait]
1019impl<
1020    C: WorkerClassification,
1021    M: CmdTunnelMeta,
1022    R: ClassifiedCmdTunnelRead<C, M>,
1023    W: ClassifiedCmdTunnelWrite<C, M>,
1024    F: ClassifiedCmdTunnelFactory<C, M, R, W>,
1025    LEN: RawEncode
1026        + for<'a> RawDecode<'a>
1027        + Copy
1028        + Send
1029        + Sync
1030        + 'static
1031        + FromPrimitive
1032        + ToPrimitive
1033        + RawFixedBytes,
1034    CMD: RawEncode
1035        + for<'a> RawDecode<'a>
1036        + Copy
1037        + Send
1038        + Sync
1039        + 'static
1040        + RawFixedBytes
1041        + Eq
1042        + Hash
1043        + Debug,
1044>
1045    CmdClient<
1046        LEN,
1047        CMD,
1048        M,
1049        ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
1050        ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD>,
1051    > for DefaultClassifiedCmdClient<C, M, R, W, F, LEN, CMD>
1052{
1053    fn register_cmd_handler(&self, cmd: CMD, handler: impl CmdHandler<LEN, CMD>) {
1054        self.cmd_handler_map.insert(cmd, handler);
1055    }
1056
1057    async fn send(&self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
1058        let mut send = self.get_send().await?;
1059        send.send(cmd, version, body).await
1060    }
1061
1062    async fn send_with_resp(
1063        &self,
1064        cmd: CMD,
1065        version: u8,
1066        body: &[u8],
1067        timeout: Duration,
1068    ) -> CmdResult<CmdBody> {
1069        let mut send = self.get_send().await?;
1070        send.send_with_resp(cmd, version, body, timeout).await
1071    }
1072
1073    async fn send_parts(&self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
1074        let mut send = self.get_send().await?;
1075        send.send_parts(cmd, version, body).await
1076    }
1077
1078    async fn send_parts_with_resp(
1079        &self,
1080        cmd: CMD,
1081        version: u8,
1082        body: &[&[u8]],
1083        timeout: Duration,
1084    ) -> CmdResult<CmdBody> {
1085        let mut send = self.get_send().await?;
1086        send.send_parts_with_resp(cmd, version, body, timeout).await
1087    }
1088
1089    async fn send_cmd(&self, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
1090        let mut send = self.get_send().await?;
1091        send.send_cmd(cmd, version, body).await
1092    }
1093
1094    async fn send_cmd_with_resp(
1095        &self,
1096        cmd: CMD,
1097        version: u8,
1098        body: CmdBody,
1099        timeout: Duration,
1100    ) -> CmdResult<CmdBody> {
1101        let mut send = self.get_send().await?;
1102        send.send_cmd_with_resp(cmd, version, body, timeout).await
1103    }
1104
1105    async fn send_by_specify_tunnel(
1106        &self,
1107        tunnel_id: TunnelId,
1108        cmd: CMD,
1109        version: u8,
1110        body: &[u8],
1111    ) -> CmdResult<()> {
1112        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1113        send.send(cmd, version, body).await
1114    }
1115
1116    async fn send_by_specify_tunnel_with_resp(
1117        &self,
1118        tunnel_id: TunnelId,
1119        cmd: CMD,
1120        version: u8,
1121        body: &[u8],
1122        timeout: Duration,
1123    ) -> CmdResult<CmdBody> {
1124        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1125        send.send_with_resp(cmd, version, body, timeout).await
1126    }
1127
1128    async fn send_parts_by_specify_tunnel(
1129        &self,
1130        tunnel_id: TunnelId,
1131        cmd: CMD,
1132        version: u8,
1133        body: &[&[u8]],
1134    ) -> CmdResult<()> {
1135        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1136        send.send_parts(cmd, version, body).await
1137    }
1138
1139    async fn send_parts_by_specify_tunnel_with_resp(
1140        &self,
1141        tunnel_id: TunnelId,
1142        cmd: CMD,
1143        version: u8,
1144        body: &[&[u8]],
1145        timeout: Duration,
1146    ) -> CmdResult<CmdBody> {
1147        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1148        send.send_parts_with_resp(cmd, version, body, timeout).await
1149    }
1150
1151    async fn send_cmd_by_specify_tunnel(
1152        &self,
1153        tunnel_id: TunnelId,
1154        cmd: CMD,
1155        version: u8,
1156        body: CmdBody,
1157    ) -> CmdResult<()> {
1158        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1159        send.send_cmd(cmd, version, body).await
1160    }
1161
1162    async fn send_cmd_by_specify_tunnel_with_resp(
1163        &self,
1164        tunnel_id: TunnelId,
1165        cmd: CMD,
1166        version: u8,
1167        body: CmdBody,
1168        timeout: Duration,
1169    ) -> CmdResult<CmdBody> {
1170        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1171        send.send_cmd_with_resp(cmd, version, body, timeout).await
1172    }
1173
1174    async fn clear_all_tunnel(&self) {
1175        self.tunnel_pool.clear_all_worker().await;
1176        self.runtime_tunnels.clear();
1177    }
1178
1179    async fn get_send(
1180        &self,
1181        tunnel_id: TunnelId,
1182    ) -> CmdResult<ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD>> {
1183        self.get_send_of_tunnel_id(tunnel_id).await
1184    }
1185}
1186
1187#[async_trait::async_trait]
1188impl<
1189    C: WorkerClassification,
1190    M: CmdTunnelMeta,
1191    R: ClassifiedCmdTunnelRead<C, M>,
1192    W: ClassifiedCmdTunnelWrite<C, M>,
1193    F: ClassifiedCmdTunnelFactory<C, M, R, W>,
1194    LEN: RawEncode
1195        + for<'a> RawDecode<'a>
1196        + Copy
1197        + Send
1198        + Sync
1199        + 'static
1200        + FromPrimitive
1201        + ToPrimitive
1202        + RawFixedBytes,
1203    CMD: RawEncode
1204        + for<'a> RawDecode<'a>
1205        + Copy
1206        + Send
1207        + Sync
1208        + 'static
1209        + RawFixedBytes
1210        + Eq
1211        + Hash
1212        + Debug,
1213>
1214    ClassifiedCmdClient<
1215        LEN,
1216        CMD,
1217        C,
1218        M,
1219        ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
1220        ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD>,
1221    > for DefaultClassifiedCmdClient<C, M, R, W, F, LEN, CMD>
1222{
1223    async fn send_by_classified_tunnel(
1224        &self,
1225        classification: C,
1226        cmd: CMD,
1227        version: u8,
1228        body: &[u8],
1229    ) -> CmdResult<()> {
1230        let mut send = self.get_classified_send(classification).await?;
1231        send.send(cmd, version, body).await
1232    }
1233
1234    async fn send_by_classified_tunnel_with_resp(
1235        &self,
1236        classification: C,
1237        cmd: CMD,
1238        version: u8,
1239        body: &[u8],
1240        timeout: Duration,
1241    ) -> CmdResult<CmdBody> {
1242        let mut send = self.get_classified_send(classification).await?;
1243        send.send_with_resp(cmd, version, body, timeout).await
1244    }
1245
1246    async fn send_parts_by_classified_tunnel(
1247        &self,
1248        classification: C,
1249        cmd: CMD,
1250        version: u8,
1251        body: &[&[u8]],
1252    ) -> CmdResult<()> {
1253        let mut send = self.get_classified_send(classification).await?;
1254        send.send_parts(cmd, version, body).await
1255    }
1256
1257    async fn send_parts_by_classified_tunnel_with_resp(
1258        &self,
1259        classification: C,
1260        cmd: CMD,
1261        version: u8,
1262        body: &[&[u8]],
1263        timeout: Duration,
1264    ) -> CmdResult<CmdBody> {
1265        let mut send = self.get_classified_send(classification).await?;
1266        send.send_parts_with_resp(cmd, version, body, timeout).await
1267    }
1268
1269    async fn send_cmd_by_classified_tunnel(
1270        &self,
1271        classification: C,
1272        cmd: CMD,
1273        version: u8,
1274        body: CmdBody,
1275    ) -> CmdResult<()> {
1276        let mut send = self.get_classified_send(classification).await?;
1277        send.send_cmd(cmd, version, body).await
1278    }
1279
1280    async fn send_cmd_by_classified_tunnel_with_resp(
1281        &self,
1282        classification: C,
1283        cmd: CMD,
1284        version: u8,
1285        body: CmdBody,
1286        timeout: Duration,
1287    ) -> CmdResult<CmdBody> {
1288        let mut send = self.get_classified_send(classification).await?;
1289        send.send_cmd_with_resp(cmd, version, body, timeout).await
1290    }
1291
1292    async fn find_tunnel_id_by_classified(&self, classification: C) -> CmdResult<TunnelId> {
1293        let send = self.get_classified_send(classification).await?;
1294        Ok(send.get_tunnel_id())
1295    }
1296
1297    async fn get_send_by_classified(
1298        &self,
1299        classification: C,
1300    ) -> CmdResult<ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD>> {
1301        self.get_classified_send(classification).await
1302    }
1303}