Skip to main content

sfo_cmd_server/client/
client.rs

1use crate::client::{
2    CmdClient, CmdSend, RespWaiter, RespWaiterRef, SendGuard, gen_resp_id, gen_seq,
3};
4use crate::cmd::{CmdBodyRead, CmdHandler, CmdHandlerMap, CmdHeader};
5use crate::errors::{CmdErrorCode, CmdResult, cmd_err, into_cmd_err};
6use crate::peer_id::PeerId;
7use crate::{CmdBody, CmdTunnelMeta, CmdTunnelRead, CmdTunnelWrite, TunnelId, TunnelIdGenerator};
8use async_named_locker::ObjectHolder;
9use bucky_raw_codec::{RawConvertTo, RawDecode, RawEncode, RawFixedBytes, RawFrom};
10use num::{FromPrimitive, ToPrimitive};
11use sfo_pool::{
12    ClassifiedWorker, ClassifiedWorkerFactory, ClassifiedWorkerGuard, ClassifiedWorkerPool,
13    ClassifiedWorkerPoolRef, PoolErrorCode, PoolResult, WorkerClassification, into_pool_err,
14    pool_err,
15};
16use sfo_split::{Splittable, WHalf};
17use std::fmt::Debug;
18use std::hash::Hash;
19use std::marker::PhantomData;
20use std::ops::{Deref, DerefMut};
21use std::sync::{Arc, Mutex};
22use std::time::Duration;
23use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
24use tokio::spawn;
25use tokio::task::JoinHandle;
26
27#[async_trait::async_trait]
28pub trait CmdTunnelFactory<M: CmdTunnelMeta, R: CmdTunnelRead<M>, W: CmdTunnelWrite<M>>:
29    Send + Sync + 'static
30{
31    async fn create_tunnel(&self) -> CmdResult<Splittable<R, W>>;
32}
33
34pub struct CommonCmdSend<M: CmdTunnelMeta, R: CmdTunnelRead<M>, W: CmdTunnelWrite<M>, LEN, CMD>
35where
36    LEN: RawEncode
37        + for<'a> RawDecode<'a>
38        + Copy
39        + Send
40        + Sync
41        + 'static
42        + FromPrimitive
43        + ToPrimitive,
44    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
45{
46    pub(crate) recv_handle: JoinHandle<CmdResult<()>>,
47    pub(crate) write: ObjectHolder<WHalf<R, W>>,
48    pub(crate) is_work: bool,
49    pub(crate) tunnel_id: TunnelId,
50    pub(crate) remote_id: PeerId,
51    pub(crate) resp_waiter: RespWaiterRef,
52    pub(crate) tunnel_meta: Option<Arc<M>>,
53    _p: std::marker::PhantomData<(LEN, CMD)>,
54}
55
56// impl<R, W, LEN, CMD> Deref for CmdSend<R, W, LEN, CMD>
57// where R: CmdTunnelRead,
58//       W: CmdTunnelWrite,
59//       LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
60//       CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes {
61//     type Target = W;
62//
63//     fn deref(&self) -> &Self::Target {
64//         self.write.deref()
65//     }
66// }
67
68impl<M, R, W, LEN, CMD> CommonCmdSend<M, R, W, LEN, CMD>
69where
70    M: CmdTunnelMeta,
71    R: CmdTunnelRead<M>,
72    W: CmdTunnelWrite<M>,
73    LEN: RawEncode
74        + for<'a> RawDecode<'a>
75        + Copy
76        + Send
77        + Sync
78        + 'static
79        + FromPrimitive
80        + ToPrimitive,
81    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
82{
83    pub fn new(
84        tunnel_id: TunnelId,
85        recv_handle: JoinHandle<CmdResult<()>>,
86        write: ObjectHolder<WHalf<R, W>>,
87        resp_waiter: RespWaiterRef,
88        remote_id: PeerId,
89        tunnel_meta: Option<Arc<M>>,
90    ) -> Self {
91        Self {
92            recv_handle,
93            write,
94            is_work: true,
95            tunnel_id,
96            remote_id,
97            resp_waiter,
98            tunnel_meta,
99            _p: Default::default(),
100        }
101    }
102
103    pub fn get_tunnel_id(&self) -> TunnelId {
104        self.tunnel_id
105    }
106
107    pub fn set_disable(&mut self) {
108        self.is_work = false;
109        self.recv_handle.abort();
110    }
111
112    pub async fn send(&mut self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
113        log::trace!(
114            "client {:?} send cmd: {:?}, len: {} data:{}",
115            self.tunnel_id,
116            cmd,
117            body.len(),
118            hex::encode(body)
119        );
120        let header = CmdHeader::<LEN, CMD>::new(
121            version,
122            false,
123            None,
124            cmd,
125            LEN::from_u64(body.len() as u64).unwrap(),
126        );
127        let buf = header
128            .to_vec()
129            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
130        let ret = self.send_inner(buf.as_slice(), body).await;
131        if let Err(e) = ret {
132            self.set_disable();
133            return Err(e);
134        }
135        Ok(())
136    }
137
138    pub async fn send_with_resp(
139        &mut self,
140        cmd: CMD,
141        version: u8,
142        body: &[u8],
143        timeout: Duration,
144    ) -> CmdResult<CmdBody> {
145        if let Some(id) = tokio::task::try_id() {
146            if id == self.recv_handle.id() {
147                return Err(cmd_err!(
148                    CmdErrorCode::Failed,
149                    "can't send with resp in recv task"
150                ));
151            }
152        }
153        log::trace!(
154            "client {:?} send cmd: {:?}, len: {}, data: {}",
155            self.tunnel_id,
156            cmd,
157            body.len(),
158            hex::encode(body)
159        );
160        let seq = gen_seq();
161        let header = CmdHeader::<LEN, CMD>::new(
162            version,
163            false,
164            Some(seq),
165            cmd,
166            LEN::from_u64(body.len() as u64).unwrap(),
167        );
168        let buf = header
169            .to_vec()
170            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
171        let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
172        let waiter = self.resp_waiter.clone();
173        let resp_waiter = waiter
174            .create_timeout_result_future(resp_id, timeout)
175            .map_err(into_cmd_err!(
176                CmdErrorCode::Failed,
177                "create timeout result future error"
178            ))?;
179        let ret = self.send_inner(buf.as_slice(), body).await;
180        if let Err(e) = ret {
181            self.set_disable();
182            return Err(e);
183        }
184        let resp = resp_waiter
185            .await
186            .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
187        Ok(resp)
188    }
189
190    pub async fn send_parts(&mut self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
191        let mut len = 0;
192        for b in body.iter() {
193            len += b.len();
194            log::trace!(
195                "client {:?} send2 cmd: {:?}, data {}",
196                self.tunnel_id,
197                cmd,
198                hex::encode(b)
199            );
200        }
201        log::trace!(
202            "client {:?} send2 cmd: {:?}, len {}",
203            self.tunnel_id,
204            cmd,
205            len
206        );
207        let header = CmdHeader::<LEN, CMD>::new(
208            version,
209            false,
210            None,
211            cmd,
212            LEN::from_u64(len as u64).unwrap(),
213        );
214        let buf = header
215            .to_vec()
216            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
217        let ret = self.send_inner2(buf.as_slice(), body).await;
218        if let Err(e) = ret {
219            self.set_disable();
220            return Err(e);
221        }
222        Ok(())
223    }
224
225    pub async fn send_parts_with_resp(
226        &mut self,
227        cmd: CMD,
228        version: u8,
229        body: &[&[u8]],
230        timeout: Duration,
231    ) -> CmdResult<CmdBody> {
232        if let Some(id) = tokio::task::try_id() {
233            if id == self.recv_handle.id() {
234                return Err(cmd_err!(
235                    CmdErrorCode::Failed,
236                    "can't send with resp in recv task"
237                ));
238            }
239        }
240        let mut len = 0;
241        for b in body.iter() {
242            len += b.len();
243            log::trace!(
244                "client {:?} send2 cmd {:?} body: {}",
245                self.tunnel_id,
246                cmd,
247                hex::encode(b)
248            );
249        }
250        log::trace!(
251            "client {:?} send2 cmd: {:?}, len {}",
252            self.tunnel_id,
253            cmd,
254            len
255        );
256        let seq = gen_seq();
257        let header = CmdHeader::<LEN, CMD>::new(
258            version,
259            false,
260            Some(seq),
261            cmd,
262            LEN::from_u64(len as u64).unwrap(),
263        );
264        let buf = header
265            .to_vec()
266            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
267        let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
268        let waiter = self.resp_waiter.clone();
269        let resp_waiter = waiter
270            .create_timeout_result_future(resp_id, timeout)
271            .map_err(into_cmd_err!(
272                CmdErrorCode::Failed,
273                "create timeout result future error"
274            ))?;
275        let ret = self.send_inner2(buf.as_slice(), body).await;
276        if let Err(e) = ret {
277            self.set_disable();
278            return Err(e);
279        }
280        let resp = resp_waiter
281            .await
282            .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
283        Ok(resp)
284    }
285
286    #[allow(deprecated)]
287    #[deprecated(note = "use send_parts instead")]
288    pub async fn send2(&mut self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
289        self.send_parts(cmd, version, body).await
290    }
291
292    #[allow(deprecated)]
293    #[deprecated(note = "use send_parts_with_resp instead")]
294    pub async fn send2_with_resp(
295        &mut self,
296        cmd: CMD,
297        version: u8,
298        body: &[&[u8]],
299        timeout: Duration,
300    ) -> CmdResult<CmdBody> {
301        self.send_parts_with_resp(cmd, version, body, timeout).await
302    }
303
304    pub async fn send_cmd(&mut self, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
305        log::trace!(
306            "client {:?} send cmd: {:?}, len: {}",
307            self.tunnel_id,
308            cmd,
309            body.len()
310        );
311        let header = CmdHeader::<LEN, CMD>::new(
312            version,
313            false,
314            None,
315            cmd,
316            LEN::from_u64(body.len()).unwrap(),
317        );
318        let buf = header
319            .to_vec()
320            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
321        let ret = self.send_inner_cmd(buf.as_slice(), body).await;
322        if let Err(e) = ret {
323            self.set_disable();
324            return Err(e);
325        }
326        Ok(())
327    }
328
329    pub async fn send_cmd_with_resp(
330        &mut self,
331        cmd: CMD,
332        version: u8,
333        body: CmdBody,
334        timeout: Duration,
335    ) -> CmdResult<CmdBody> {
336        if let Some(id) = tokio::task::try_id() {
337            if id == self.recv_handle.id() {
338                return Err(cmd_err!(
339                    CmdErrorCode::Failed,
340                    "can't send with resp in recv task"
341                ));
342            }
343        }
344        log::trace!(
345            "client {:?} send cmd: {:?}, len: {}",
346            self.tunnel_id,
347            cmd,
348            body.len()
349        );
350        let seq = gen_seq();
351        let header = CmdHeader::<LEN, CMD>::new(
352            version,
353            false,
354            Some(seq),
355            cmd,
356            LEN::from_u64(body.len()).unwrap(),
357        );
358        let buf = header
359            .to_vec()
360            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
361        let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
362        let waiter = self.resp_waiter.clone();
363        let resp_waiter = waiter
364            .create_timeout_result_future(resp_id, timeout)
365            .map_err(into_cmd_err!(
366                CmdErrorCode::Failed,
367                "create timeout result future error"
368            ))?;
369        let ret = self.send_inner_cmd(buf.as_slice(), body).await;
370        if let Err(e) = ret {
371            self.set_disable();
372            return Err(e);
373        }
374        let resp = resp_waiter
375            .await
376            .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
377        Ok(resp)
378    }
379
380    async fn send_inner(&mut self, header: &[u8], body: &[u8]) -> CmdResult<()> {
381        let mut write = self.write.get().await;
382        if header.len() > 255 {
383            return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too long"));
384        }
385        write
386            .write_u8(header.len() as u8)
387            .await
388            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
389        write
390            .write_all(header)
391            .await
392            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
393        write
394            .write_all(body)
395            .await
396            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
397        write
398            .flush()
399            .await
400            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
401        Ok(())
402    }
403
404    async fn send_inner2(&mut self, header: &[u8], body: &[&[u8]]) -> CmdResult<()> {
405        let mut write = self.write.get().await;
406        if header.len() > 255 {
407            return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too long"));
408        }
409        write
410            .write_u8(header.len() as u8)
411            .await
412            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
413        write
414            .write_all(header)
415            .await
416            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
417        for b in body.iter() {
418            write
419                .write_all(b)
420                .await
421                .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
422        }
423        write
424            .flush()
425            .await
426            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
427        Ok(())
428    }
429
430    async fn send_inner_cmd(&mut self, header: &[u8], mut body: CmdBody) -> CmdResult<()> {
431        let mut write = self.write.get().await;
432        if header.len() > 255 {
433            return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
434        }
435        write
436            .write_u8(header.len() as u8)
437            .await
438            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
439        write
440            .write_all(header)
441            .await
442            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
443        tokio::io::copy(&mut body, write.deref_mut().deref_mut())
444            .await
445            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
446        write
447            .flush()
448            .await
449            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
450        Ok(())
451    }
452}
453
454impl<M, R, W, LEN, CMD> Drop for CommonCmdSend<M, R, W, LEN, CMD>
455where
456    M: CmdTunnelMeta,
457    R: CmdTunnelRead<M>,
458    W: CmdTunnelWrite<M>,
459    LEN: RawEncode
460        + for<'a> RawDecode<'a>
461        + Copy
462        + Send
463        + Sync
464        + 'static
465        + FromPrimitive
466        + ToPrimitive,
467    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
468{
469    fn drop(&mut self) {
470        self.set_disable();
471    }
472}
473
474impl<M, R, W, LEN, CMD> CmdSend<M> for CommonCmdSend<M, R, W, LEN, CMD>
475where
476    M: CmdTunnelMeta,
477    R: CmdTunnelRead<M>,
478    W: CmdTunnelWrite<M>,
479    LEN: RawEncode
480        + for<'a> RawDecode<'a>
481        + Copy
482        + Send
483        + Sync
484        + 'static
485        + FromPrimitive
486        + ToPrimitive,
487    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
488{
489    fn get_tunnel_meta(&self) -> Option<Arc<M>> {
490        self.tunnel_meta.clone()
491    }
492
493    fn get_remote_peer_id(&self) -> PeerId {
494        self.remote_id.clone()
495    }
496}
497
498impl<M, R, W, LEN, CMD> ClassifiedWorker<TunnelId> for CommonCmdSend<M, R, W, LEN, CMD>
499where
500    M: CmdTunnelMeta,
501    R: CmdTunnelRead<M>,
502    W: CmdTunnelWrite<M>,
503    LEN: RawEncode
504        + for<'a> RawDecode<'a>
505        + Copy
506        + Send
507        + Sync
508        + 'static
509        + FromPrimitive
510        + ToPrimitive,
511    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
512{
513    fn is_work(&self) -> bool {
514        self.is_work && !self.recv_handle.is_finished()
515    }
516
517    fn is_valid(&self, c: TunnelId) -> bool {
518        self.tunnel_id == c
519    }
520
521    fn classification(&self) -> TunnelId {
522        self.tunnel_id
523    }
524}
525
526pub struct ClassifiedSendGuard<
527    C: WorkerClassification,
528    M: CmdTunnelMeta,
529    CW: ClassifiedWorker<C> + CmdSend<M>,
530    F: ClassifiedWorkerFactory<C, CW>,
531> {
532    pub(crate) worker_guard: ClassifiedWorkerGuard<C, CW, F>,
533    pub(crate) _p: PhantomData<M>,
534}
535
536impl<
537    C: WorkerClassification,
538    M: CmdTunnelMeta,
539    CW: ClassifiedWorker<C> + CmdSend<M>,
540    F: ClassifiedWorkerFactory<C, CW>,
541> Deref for ClassifiedSendGuard<C, M, CW, F>
542{
543    type Target = CW;
544
545    fn deref(&self) -> &Self::Target {
546        &self.worker_guard.deref()
547    }
548}
549
550impl<
551    C: WorkerClassification,
552    M: CmdTunnelMeta,
553    CW: ClassifiedWorker<C> + CmdSend<M>,
554    F: ClassifiedWorkerFactory<C, CW>,
555> SendGuard<M, CW> for ClassifiedSendGuard<C, M, CW, F>
556{
557}
558
559pub struct CmdWriteFactory<
560    M: CmdTunnelMeta,
561    R: CmdTunnelRead<M>,
562    W: CmdTunnelWrite<M>,
563    F: CmdTunnelFactory<M, R, W>,
564    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
565    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug,
566> {
567    tunnel_factory: F,
568    cmd_handler: Arc<dyn CmdHandler<LEN, CMD>>,
569    resp_waiter: RespWaiterRef,
570    tunnel_id_generator: TunnelIdGenerator,
571    p: std::marker::PhantomData<Mutex<(R, W, M)>>,
572}
573
574impl<
575    M: CmdTunnelMeta,
576    R: CmdTunnelRead<M>,
577    W: CmdTunnelWrite<M>,
578    F: CmdTunnelFactory<M, R, W>,
579    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
580    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug,
581> CmdWriteFactory<M, R, W, F, LEN, CMD>
582{
583    pub(crate) fn new(
584        tunnel_factory: F,
585        cmd_handler: impl CmdHandler<LEN, CMD>,
586        resp_waiter: RespWaiterRef,
587    ) -> Self {
588        Self {
589            tunnel_factory,
590            cmd_handler: Arc::new(cmd_handler),
591            resp_waiter,
592            tunnel_id_generator: TunnelIdGenerator::new(),
593            p: Default::default(),
594        }
595    }
596}
597
598#[async_trait::async_trait]
599impl<
600    M: CmdTunnelMeta,
601    R: CmdTunnelRead<M>,
602    W: CmdTunnelWrite<M>,
603    F: CmdTunnelFactory<M, R, W>,
604    LEN: RawEncode
605        + for<'a> RawDecode<'a>
606        + Copy
607        + Send
608        + Sync
609        + 'static
610        + FromPrimitive
611        + ToPrimitive
612        + RawFixedBytes,
613    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Debug,
614> ClassifiedWorkerFactory<TunnelId, CommonCmdSend<M, R, W, LEN, CMD>>
615    for CmdWriteFactory<M, R, W, F, LEN, CMD>
616{
617    async fn create(&self, c: Option<TunnelId>) -> PoolResult<CommonCmdSend<M, R, W, LEN, CMD>> {
618        if c.is_some() {
619            return Err(pool_err!(
620                PoolErrorCode::Failed,
621                "tunnel {:?} not found",
622                c.unwrap()
623            ));
624        }
625        let tunnel = self
626            .tunnel_factory
627            .create_tunnel()
628            .await
629            .map_err(into_pool_err!(PoolErrorCode::Failed))?;
630        let peer_id = tunnel.get_remote_peer_id();
631        let tunnel_id = self.tunnel_id_generator.generate();
632        let (mut recv, write) = tunnel.split();
633        let local_id = recv.get_local_peer_id();
634        let remote_id = write.get_remote_peer_id();
635        let meta = write.get_tunnel_meta();
636        let write = ObjectHolder::new(write);
637        let resp_write = write.clone();
638        let cmd_handler = self.cmd_handler.clone();
639        let handle = spawn(async move {
640            let ret: CmdResult<()> = async move {
641                loop {
642                    let header_len = recv
643                        .read_u8()
644                        .await
645                        .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
646                    let mut header = vec![0u8; header_len as usize];
647                    let n = recv
648                        .read_exact(header.as_mut())
649                        .await
650                        .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
651                    if n == 0 {
652                        break;
653                    }
654                    let header = CmdHeader::<LEN, CMD>::clone_from_slice(header.as_slice())
655                        .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
656                    log::trace!(
657                        "recv cmd {:?} from {} len {}",
658                        header.cmd_code(),
659                        peer_id.to_base58(),
660                        header.pkg_len().to_u64().unwrap()
661                    );
662                    let body_len = header.pkg_len().to_u64().unwrap();
663                    let cmd_read =
664                        CmdBodyRead::new(recv, header.pkg_len().to_u64().unwrap() as usize);
665                    let waiter = cmd_read.get_waiter();
666                    let future = waiter
667                        .create_result_future()
668                        .map_err(into_cmd_err!(CmdErrorCode::Failed))?;
669                    let version = header.version();
670                    let seq = header.seq();
671                    let cmd_code = header.cmd_code();
672                    match cmd_handler
673                        .handle(
674                            local_id.clone(),
675                            peer_id.clone(),
676                            tunnel_id,
677                            header,
678                            CmdBody::from_reader(BufReader::new(cmd_read), body_len),
679                        )
680                        .await
681                    {
682                        Ok(Some(mut body)) => {
683                            let mut write = resp_write.get().await;
684                            let header = CmdHeader::<LEN, CMD>::new(
685                                version,
686                                true,
687                                seq,
688                                cmd_code,
689                                LEN::from_u64(body.len()).unwrap(),
690                            );
691                            let buf = header
692                                .to_vec()
693                                .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
694                            if buf.len() > 255 {
695                                return Err(cmd_err!(
696                                    CmdErrorCode::InvalidParam,
697                                    "header len too long"
698                                ));
699                            }
700                            write
701                                .write_u8(buf.len() as u8)
702                                .await
703                                .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
704                            write
705                                .write_all(buf.as_slice())
706                                .await
707                                .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
708                            tokio::io::copy(&mut body, write.deref_mut().deref_mut())
709                                .await
710                                .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
711                            write
712                                .flush()
713                                .await
714                                .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
715                        }
716                        Ok(None) => {}
717                        Err(e) => {
718                            log::error!("handle cmd error: {:?}", e);
719                        }
720                    }
721                    recv = future
722                        .await
723                        .map_err(into_cmd_err!(CmdErrorCode::Failed))??;
724                }
725                Ok(())
726            }
727            .await;
728            ret
729        });
730        Ok(CommonCmdSend::new(
731            tunnel_id,
732            handle,
733            write,
734            self.resp_waiter.clone(),
735            remote_id,
736            meta,
737        ))
738    }
739}
740
741pub struct DefaultCmdClient<
742    M: CmdTunnelMeta,
743    R: CmdTunnelRead<M>,
744    W: CmdTunnelWrite<M>,
745    F: CmdTunnelFactory<M, R, W>,
746    LEN: RawEncode
747        + for<'a> RawDecode<'a>
748        + Copy
749        + Send
750        + Sync
751        + 'static
752        + FromPrimitive
753        + ToPrimitive
754        + RawFixedBytes,
755    CMD: RawEncode
756        + for<'a> RawDecode<'a>
757        + Copy
758        + Send
759        + Sync
760        + 'static
761        + RawFixedBytes
762        + Eq
763        + Hash
764        + Debug,
765> {
766    tunnel_pool: ClassifiedWorkerPoolRef<
767        TunnelId,
768        CommonCmdSend<M, R, W, LEN, CMD>,
769        CmdWriteFactory<M, R, W, F, LEN, CMD>,
770    >,
771    cmd_handler_map: Arc<CmdHandlerMap<LEN, CMD>>,
772}
773
774impl<
775    M: CmdTunnelMeta,
776    R: CmdTunnelRead<M>,
777    W: CmdTunnelWrite<M>,
778    F: CmdTunnelFactory<M, R, W>,
779    LEN: RawEncode
780        + for<'a> RawDecode<'a>
781        + Copy
782        + Send
783        + Sync
784        + 'static
785        + FromPrimitive
786        + ToPrimitive
787        + RawFixedBytes,
788    CMD: RawEncode
789        + for<'a> RawDecode<'a>
790        + Copy
791        + Send
792        + Sync
793        + 'static
794        + RawFixedBytes
795        + Eq
796        + Hash
797        + Debug,
798> DefaultCmdClient<M, R, W, F, LEN, CMD>
799{
800    pub fn new(factory: F, tunnel_count: u16) -> Arc<Self> {
801        let cmd_handler_map = Arc::new(CmdHandlerMap::new());
802        let handler_map = cmd_handler_map.clone();
803        let resp_waiter = Arc::new(RespWaiter::new());
804        let waiter = resp_waiter.clone();
805        Arc::new(Self {
806            tunnel_pool: ClassifiedWorkerPool::new(
807                tunnel_count,
808                CmdWriteFactory::<M, R, W, _, LEN, CMD>::new(
809                    factory,
810                    move |local_id: PeerId,
811                          peer_id: PeerId,
812                          tunnel_id: TunnelId,
813                          header: CmdHeader<LEN, CMD>,
814                          body_read: CmdBody| {
815                        let handler_map = handler_map.clone();
816                        let waiter = waiter.clone();
817                        async move {
818                            if header.is_resp() && header.seq().is_some() {
819                                let resp_id = gen_resp_id(
820                                    tunnel_id,
821                                    header.cmd_code(),
822                                    header.seq().unwrap(),
823                                );
824                                let _ = waiter.set_result(resp_id, body_read);
825                                Ok(None)
826                            } else {
827                                if let Some(handler) = handler_map.get(header.cmd_code()) {
828                                    handler
829                                        .handle(local_id, peer_id, tunnel_id, header, body_read)
830                                        .await
831                                } else {
832                                    Ok(None)
833                                }
834                            }
835                        }
836                    },
837                    resp_waiter.clone(),
838                ),
839            ),
840            cmd_handler_map,
841        })
842    }
843
844    async fn get_send(
845        &self,
846    ) -> CmdResult<
847        ClassifiedWorkerGuard<
848            TunnelId,
849            CommonCmdSend<M, R, W, LEN, CMD>,
850            CmdWriteFactory<M, R, W, F, LEN, CMD>,
851        >,
852    > {
853        self.tunnel_pool
854            .get_worker()
855            .await
856            .map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
857    }
858
859    async fn get_send_of_tunnel_id(
860        &self,
861        tunnel_id: TunnelId,
862    ) -> CmdResult<
863        ClassifiedWorkerGuard<
864            TunnelId,
865            CommonCmdSend<M, R, W, LEN, CMD>,
866            CmdWriteFactory<M, R, W, F, LEN, CMD>,
867        >,
868    > {
869        self.tunnel_pool
870            .get_classified_worker(tunnel_id)
871            .await
872            .map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
873    }
874}
875
876pub type CmdClientSendGuard<M, R, W, F, LEN, CMD> = ClassifiedSendGuard<
877    TunnelId,
878    M,
879    CommonCmdSend<M, R, W, LEN, CMD>,
880    CmdWriteFactory<M, R, W, F, LEN, CMD>,
881>;
882#[async_trait::async_trait]
883impl<
884    M: CmdTunnelMeta,
885    R: CmdTunnelRead<M>,
886    W: CmdTunnelWrite<M>,
887    F: CmdTunnelFactory<M, R, W>,
888    LEN: RawEncode
889        + for<'a> RawDecode<'a>
890        + Copy
891        + Send
892        + Sync
893        + 'static
894        + FromPrimitive
895        + ToPrimitive
896        + RawFixedBytes,
897    CMD: RawEncode
898        + for<'a> RawDecode<'a>
899        + Copy
900        + Send
901        + Sync
902        + 'static
903        + RawFixedBytes
904        + Eq
905        + Hash
906        + Debug,
907> CmdClient<LEN, CMD, M, CommonCmdSend<M, R, W, LEN, CMD>, CmdClientSendGuard<M, R, W, F, LEN, CMD>>
908    for DefaultCmdClient<M, R, W, F, LEN, CMD>
909{
910    fn register_cmd_handler(&self, cmd: CMD, handler: impl CmdHandler<LEN, CMD>) {
911        self.cmd_handler_map.insert(cmd, handler);
912    }
913
914    async fn send(&self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
915        let mut send = self.get_send().await?;
916        send.send(cmd, version, body).await
917    }
918
919    async fn send_with_resp(
920        &self,
921        cmd: CMD,
922        version: u8,
923        body: &[u8],
924        timeout: Duration,
925    ) -> CmdResult<CmdBody> {
926        let mut send = self.get_send().await?;
927        send.send_with_resp(cmd, version, body, timeout).await
928    }
929
930    async fn send_parts(&self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
931        let mut send = self.get_send().await?;
932        send.send_parts(cmd, version, body).await
933    }
934
935    async fn send_parts_with_resp(
936        &self,
937        cmd: CMD,
938        version: u8,
939        body: &[&[u8]],
940        timeout: Duration,
941    ) -> CmdResult<CmdBody> {
942        let mut send = self.get_send().await?;
943        send.send_parts_with_resp(cmd, version, body, timeout).await
944    }
945
946    async fn send_cmd(&self, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
947        let mut send = self.get_send().await?;
948        send.send_cmd(cmd, version, body).await
949    }
950
951    async fn send_cmd_with_resp(
952        &self,
953        cmd: CMD,
954        version: u8,
955        body: CmdBody,
956        timeout: Duration,
957    ) -> CmdResult<CmdBody> {
958        let mut send = self.get_send().await?;
959        send.send_cmd_with_resp(cmd, version, body, timeout).await
960    }
961
962    async fn send_by_specify_tunnel(
963        &self,
964        tunnel_id: TunnelId,
965        cmd: CMD,
966        version: u8,
967        body: &[u8],
968    ) -> CmdResult<()> {
969        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
970        send.send(cmd, version, body).await
971    }
972
973    async fn send_by_specify_tunnel_with_resp(
974        &self,
975        tunnel_id: TunnelId,
976        cmd: CMD,
977        version: u8,
978        body: &[u8],
979        timeout: Duration,
980    ) -> CmdResult<CmdBody> {
981        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
982        send.send_with_resp(cmd, version, body, timeout).await
983    }
984
985    async fn send_parts_by_specify_tunnel(
986        &self,
987        tunnel_id: TunnelId,
988        cmd: CMD,
989        version: u8,
990        body: &[&[u8]],
991    ) -> CmdResult<()> {
992        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
993        send.send_parts(cmd, version, body).await
994    }
995
996    async fn send_parts_by_specify_tunnel_with_resp(
997        &self,
998        tunnel_id: TunnelId,
999        cmd: CMD,
1000        version: u8,
1001        body: &[&[u8]],
1002        timeout: Duration,
1003    ) -> CmdResult<CmdBody> {
1004        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1005        send.send_parts_with_resp(cmd, version, body, timeout).await
1006    }
1007
1008    async fn send_cmd_by_specify_tunnel(
1009        &self,
1010        tunnel_id: TunnelId,
1011        cmd: CMD,
1012        version: u8,
1013        body: CmdBody,
1014    ) -> CmdResult<()> {
1015        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1016        send.send_cmd(cmd, version, body).await
1017    }
1018
1019    async fn send_cmd_by_specify_tunnel_with_resp(
1020        &self,
1021        tunnel_id: TunnelId,
1022        cmd: CMD,
1023        version: u8,
1024        body: CmdBody,
1025        timeout: Duration,
1026    ) -> CmdResult<CmdBody> {
1027        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1028        send.send_cmd_with_resp(cmd, version, body, timeout).await
1029    }
1030
1031    async fn clear_all_tunnel(&self) {
1032        self.tunnel_pool.clear_all_worker().await;
1033    }
1034
1035    async fn get_send(
1036        &self,
1037        tunnel_id: TunnelId,
1038    ) -> CmdResult<CmdClientSendGuard<M, R, W, F, LEN, CMD>> {
1039        Ok(ClassifiedSendGuard {
1040            worker_guard: self.get_send_of_tunnel_id(tunnel_id).await?,
1041            _p: Default::default(),
1042        })
1043    }
1044}