Skip to main content

sfo_cmd_server/client/
client.rs

1use crate::client::{
2    CmdClient, CmdSend, RespWaiter, RespWaiterRef, SendGuard, gen_resp_id, gen_seq,
3};
4use crate::cmd::{CmdBodyRead, CmdHandler, CmdHandlerMap, CmdHeader};
5use crate::errors::{CmdErrorCode, CmdResult, cmd_err, into_cmd_err};
6use crate::peer_id::PeerId;
7use crate::{CmdBody, CmdTunnelMeta, CmdTunnelRead, CmdTunnelWrite, TunnelId, TunnelIdGenerator};
8use async_named_locker::ObjectHolder;
9use bucky_raw_codec::{RawConvertTo, RawDecode, RawEncode, RawFixedBytes, RawFrom};
10use num::{FromPrimitive, ToPrimitive};
11use sfo_pool::{
12    ClassifiedWorker, ClassifiedWorkerFactory, ClassifiedWorkerGuard, ClassifiedWorkerPool,
13    ClassifiedWorkerPoolRef, PoolErrorCode, PoolResult, WorkerClassification, into_pool_err,
14    pool_err,
15};
16use sfo_split::{Splittable, WHalf};
17use std::fmt::Debug;
18use std::hash::Hash;
19use std::marker::PhantomData;
20use std::ops::{Deref, DerefMut};
21use std::sync::{Arc, Mutex};
22use std::time::Duration;
23use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
24use tokio::spawn;
25use tokio::task::JoinHandle;
26
27#[async_trait::async_trait]
28pub trait CmdTunnelFactory<M: CmdTunnelMeta, R: CmdTunnelRead<M>, W: CmdTunnelWrite<M>>:
29    Send + Sync + 'static
30{
31    async fn create_tunnel(&self) -> CmdResult<Splittable<R, W>>;
32}
33
34pub struct CommonCmdSend<M: CmdTunnelMeta, R: CmdTunnelRead<M>, W: CmdTunnelWrite<M>, LEN, CMD>
35where
36    LEN: RawEncode
37        + for<'a> RawDecode<'a>
38        + Copy
39        + Send
40        + Sync
41        + 'static
42        + FromPrimitive
43        + ToPrimitive,
44    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
45{
46    pub(crate) recv_handle: JoinHandle<CmdResult<()>>,
47    pub(crate) write: ObjectHolder<WHalf<R, W>>,
48    pub(crate) is_work: bool,
49    pub(crate) tunnel_id: TunnelId,
50    pub(crate) remote_id: PeerId,
51    pub(crate) resp_waiter: RespWaiterRef,
52    pub(crate) tunnel_meta: Option<Arc<M>>,
53    _p: std::marker::PhantomData<(LEN, CMD)>,
54}
55
56// impl<R, W, LEN, CMD> Deref for CmdSend<R, W, LEN, CMD>
57// where R: CmdTunnelRead,
58//       W: CmdTunnelWrite,
59//       LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
60//       CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes {
61//     type Target = W;
62//
63//     fn deref(&self) -> &Self::Target {
64//         self.write.deref()
65//     }
66// }
67
68impl<M, R, W, LEN, CMD> CommonCmdSend<M, R, W, LEN, CMD>
69where
70    M: CmdTunnelMeta,
71    R: CmdTunnelRead<M>,
72    W: CmdTunnelWrite<M>,
73    LEN: RawEncode
74        + for<'a> RawDecode<'a>
75        + Copy
76        + Send
77        + Sync
78        + 'static
79        + FromPrimitive
80        + ToPrimitive,
81    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
82{
83    pub fn new(
84        tunnel_id: TunnelId,
85        recv_handle: JoinHandle<CmdResult<()>>,
86        write: ObjectHolder<WHalf<R, W>>,
87        resp_waiter: RespWaiterRef,
88        remote_id: PeerId,
89        tunnel_meta: Option<Arc<M>>,
90    ) -> Self {
91        Self {
92            recv_handle,
93            write,
94            is_work: true,
95            tunnel_id,
96            remote_id,
97            resp_waiter,
98            tunnel_meta,
99            _p: Default::default(),
100        }
101    }
102
103    pub fn get_tunnel_id(&self) -> TunnelId {
104        self.tunnel_id
105    }
106
107    pub fn set_disable(&mut self) {
108        self.is_work = false;
109        self.recv_handle.abort();
110    }
111
112    pub async fn send(&mut self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
113        log::trace!(
114            "client {:?} send cmd: {:?}, len: {} data:{}",
115            self.tunnel_id,
116            cmd,
117            body.len(),
118            hex::encode(body)
119        );
120        let header = CmdHeader::<LEN, CMD>::new(
121            version,
122            false,
123            None,
124            cmd,
125            LEN::from_u64(body.len() as u64).unwrap(),
126        );
127        let buf = header
128            .to_vec()
129            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
130        let ret = self.send_inner(buf.as_slice(), body).await;
131        if let Err(e) = ret {
132            self.set_disable();
133            return Err(e);
134        }
135        Ok(())
136    }
137
138    pub async fn send_with_resp(
139        &mut self,
140        cmd: CMD,
141        version: u8,
142        body: &[u8],
143        timeout: Duration,
144    ) -> CmdResult<CmdBody> {
145        if let Some(id) = tokio::task::try_id() {
146            if id == self.recv_handle.id() {
147                return Err(cmd_err!(
148                    CmdErrorCode::Failed,
149                    "can't send with resp in recv task"
150                ));
151            }
152        }
153        log::trace!(
154            "client {:?} send cmd: {:?}, len: {}, data: {}",
155            self.tunnel_id,
156            cmd,
157            body.len(),
158            hex::encode(body)
159        );
160        let seq = gen_seq();
161        let header = CmdHeader::<LEN, CMD>::new(
162            version,
163            false,
164            Some(seq),
165            cmd,
166            LEN::from_u64(body.len() as u64).unwrap(),
167        );
168        let buf = header
169            .to_vec()
170            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
171        let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
172        let waiter = self.resp_waiter.clone();
173        let resp_waiter = waiter
174            .create_timeout_result_future(resp_id, timeout)
175            .map_err(into_cmd_err!(
176                CmdErrorCode::Failed,
177                "create timeout result future error"
178            ))?;
179        let ret = self.send_inner(buf.as_slice(), body).await;
180        if let Err(e) = ret {
181            self.set_disable();
182            return Err(e);
183        }
184        let resp = resp_waiter
185            .await
186            .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
187        Ok(resp)
188    }
189
190    pub async fn send2(&mut self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
191        let mut len = 0;
192        for b in body.iter() {
193            len += b.len();
194            log::trace!(
195                "client {:?} send2 cmd: {:?}, data {}",
196                self.tunnel_id,
197                cmd,
198                hex::encode(b)
199            );
200        }
201        log::trace!(
202            "client {:?} send2 cmd: {:?}, len {}",
203            self.tunnel_id,
204            cmd,
205            len
206        );
207        let header = CmdHeader::<LEN, CMD>::new(
208            version,
209            false,
210            None,
211            cmd,
212            LEN::from_u64(len as u64).unwrap(),
213        );
214        let buf = header
215            .to_vec()
216            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
217        let ret = self.send_inner2(buf.as_slice(), body).await;
218        if let Err(e) = ret {
219            self.set_disable();
220            return Err(e);
221        }
222        Ok(())
223    }
224
225    pub async fn send2_with_resp(
226        &mut self,
227        cmd: CMD,
228        version: u8,
229        body: &[&[u8]],
230        timeout: Duration,
231    ) -> CmdResult<CmdBody> {
232        if let Some(id) = tokio::task::try_id() {
233            if id == self.recv_handle.id() {
234                return Err(cmd_err!(
235                    CmdErrorCode::Failed,
236                    "can't send with resp in recv task"
237                ));
238            }
239        }
240        let mut len = 0;
241        for b in body.iter() {
242            len += b.len();
243            log::trace!(
244                "client {:?} send2 cmd {:?} body: {}",
245                self.tunnel_id,
246                cmd,
247                hex::encode(b)
248            );
249        }
250        log::trace!(
251            "client {:?} send2 cmd: {:?}, len {}",
252            self.tunnel_id,
253            cmd,
254            len
255        );
256        let seq = gen_seq();
257        let header = CmdHeader::<LEN, CMD>::new(
258            version,
259            false,
260            Some(seq),
261            cmd,
262            LEN::from_u64(len as u64).unwrap(),
263        );
264        let buf = header
265            .to_vec()
266            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
267        let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
268        let waiter = self.resp_waiter.clone();
269        let resp_waiter = waiter
270            .create_timeout_result_future(resp_id, timeout)
271            .map_err(into_cmd_err!(
272                CmdErrorCode::Failed,
273                "create timeout result future error"
274            ))?;
275        let ret = self.send_inner2(buf.as_slice(), body).await;
276        if let Err(e) = ret {
277            self.set_disable();
278            return Err(e);
279        }
280        let resp = resp_waiter
281            .await
282            .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
283        Ok(resp)
284    }
285
286    pub async fn send_cmd(&mut self, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
287        log::trace!(
288            "client {:?} send cmd: {:?}, len: {}",
289            self.tunnel_id,
290            cmd,
291            body.len()
292        );
293        let header = CmdHeader::<LEN, CMD>::new(
294            version,
295            false,
296            None,
297            cmd,
298            LEN::from_u64(body.len()).unwrap(),
299        );
300        let buf = header
301            .to_vec()
302            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
303        let ret = self.send_inner_cmd(buf.as_slice(), body).await;
304        if let Err(e) = ret {
305            self.set_disable();
306            return Err(e);
307        }
308        Ok(())
309    }
310
311    pub async fn send_cmd_with_resp(
312        &mut self,
313        cmd: CMD,
314        version: u8,
315        body: CmdBody,
316        timeout: Duration,
317    ) -> CmdResult<CmdBody> {
318        if let Some(id) = tokio::task::try_id() {
319            if id == self.recv_handle.id() {
320                return Err(cmd_err!(
321                    CmdErrorCode::Failed,
322                    "can't send with resp in recv task"
323                ));
324            }
325        }
326        log::trace!(
327            "client {:?} send cmd: {:?}, len: {}",
328            self.tunnel_id,
329            cmd,
330            body.len()
331        );
332        let seq = gen_seq();
333        let header = CmdHeader::<LEN, CMD>::new(
334            version,
335            false,
336            Some(seq),
337            cmd,
338            LEN::from_u64(body.len()).unwrap(),
339        );
340        let buf = header
341            .to_vec()
342            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
343        let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
344        let waiter = self.resp_waiter.clone();
345        let resp_waiter = waiter
346            .create_timeout_result_future(resp_id, timeout)
347            .map_err(into_cmd_err!(
348                CmdErrorCode::Failed,
349                "create timeout result future error"
350            ))?;
351        let ret = self.send_inner_cmd(buf.as_slice(), body).await;
352        if let Err(e) = ret {
353            self.set_disable();
354            return Err(e);
355        }
356        let resp = resp_waiter
357            .await
358            .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
359        Ok(resp)
360    }
361
362    async fn send_inner(&mut self, header: &[u8], body: &[u8]) -> CmdResult<()> {
363        let mut write = self.write.get().await;
364        if header.len() > 255 {
365            return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too long"));
366        }
367        write
368            .write_u8(header.len() as u8)
369            .await
370            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
371        write
372            .write_all(header)
373            .await
374            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
375        write
376            .write_all(body)
377            .await
378            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
379        write
380            .flush()
381            .await
382            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
383        Ok(())
384    }
385
386    async fn send_inner2(&mut self, header: &[u8], body: &[&[u8]]) -> CmdResult<()> {
387        let mut write = self.write.get().await;
388        if header.len() > 255 {
389            return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too long"));
390        }
391        write
392            .write_u8(header.len() as u8)
393            .await
394            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
395        write
396            .write_all(header)
397            .await
398            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
399        for b in body.iter() {
400            write
401                .write_all(b)
402                .await
403                .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
404        }
405        write
406            .flush()
407            .await
408            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
409        Ok(())
410    }
411
412    async fn send_inner_cmd(&mut self, header: &[u8], mut body: CmdBody) -> CmdResult<()> {
413        let mut write = self.write.get().await;
414        if header.len() > 255 {
415            return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
416        }
417        write
418            .write_u8(header.len() as u8)
419            .await
420            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
421        write
422            .write_all(header)
423            .await
424            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
425        tokio::io::copy(&mut body, write.deref_mut().deref_mut())
426            .await
427            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
428        write
429            .flush()
430            .await
431            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
432        Ok(())
433    }
434}
435
436impl<M, R, W, LEN, CMD> Drop for CommonCmdSend<M, R, W, LEN, CMD>
437where
438    M: CmdTunnelMeta,
439    R: CmdTunnelRead<M>,
440    W: CmdTunnelWrite<M>,
441    LEN: RawEncode
442        + for<'a> RawDecode<'a>
443        + Copy
444        + Send
445        + Sync
446        + 'static
447        + FromPrimitive
448        + ToPrimitive,
449    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
450{
451    fn drop(&mut self) {
452        self.set_disable();
453    }
454}
455
456impl<M, R, W, LEN, CMD> CmdSend<M> for CommonCmdSend<M, R, W, LEN, CMD>
457where
458    M: CmdTunnelMeta,
459    R: CmdTunnelRead<M>,
460    W: CmdTunnelWrite<M>,
461    LEN: RawEncode
462        + for<'a> RawDecode<'a>
463        + Copy
464        + Send
465        + Sync
466        + 'static
467        + FromPrimitive
468        + ToPrimitive,
469    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
470{
471    fn get_tunnel_meta(&self) -> Option<Arc<M>> {
472        self.tunnel_meta.clone()
473    }
474
475    fn get_remote_peer_id(&self) -> PeerId {
476        self.remote_id.clone()
477    }
478}
479
480impl<M, R, W, LEN, CMD> ClassifiedWorker<TunnelId> for CommonCmdSend<M, R, W, LEN, CMD>
481where
482    M: CmdTunnelMeta,
483    R: CmdTunnelRead<M>,
484    W: CmdTunnelWrite<M>,
485    LEN: RawEncode
486        + for<'a> RawDecode<'a>
487        + Copy
488        + Send
489        + Sync
490        + 'static
491        + FromPrimitive
492        + ToPrimitive,
493    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
494{
495    fn is_work(&self) -> bool {
496        self.is_work && !self.recv_handle.is_finished()
497    }
498
499    fn is_valid(&self, c: TunnelId) -> bool {
500        self.tunnel_id == c
501    }
502
503    fn classification(&self) -> TunnelId {
504        self.tunnel_id
505    }
506}
507
508pub struct ClassifiedSendGuard<
509    C: WorkerClassification,
510    M: CmdTunnelMeta,
511    CW: ClassifiedWorker<C> + CmdSend<M>,
512    F: ClassifiedWorkerFactory<C, CW>,
513> {
514    pub(crate) worker_guard: ClassifiedWorkerGuard<C, CW, F>,
515    pub(crate) _p: PhantomData<M>,
516}
517
518impl<
519    C: WorkerClassification,
520    M: CmdTunnelMeta,
521    CW: ClassifiedWorker<C> + CmdSend<M>,
522    F: ClassifiedWorkerFactory<C, CW>,
523> Deref for ClassifiedSendGuard<C, M, CW, F>
524{
525    type Target = CW;
526
527    fn deref(&self) -> &Self::Target {
528        &self.worker_guard.deref()
529    }
530}
531
532impl<
533    C: WorkerClassification,
534    M: CmdTunnelMeta,
535    CW: ClassifiedWorker<C> + CmdSend<M>,
536    F: ClassifiedWorkerFactory<C, CW>,
537> SendGuard<M, CW> for ClassifiedSendGuard<C, M, CW, F>
538{
539}
540
541pub struct CmdWriteFactory<
542    M: CmdTunnelMeta,
543    R: CmdTunnelRead<M>,
544    W: CmdTunnelWrite<M>,
545    F: CmdTunnelFactory<M, R, W>,
546    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
547    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug,
548> {
549    tunnel_factory: F,
550    cmd_handler: Arc<dyn CmdHandler<LEN, CMD>>,
551    resp_waiter: RespWaiterRef,
552    tunnel_id_generator: TunnelIdGenerator,
553    p: std::marker::PhantomData<Mutex<(R, W, M)>>,
554}
555
556impl<
557    M: CmdTunnelMeta,
558    R: CmdTunnelRead<M>,
559    W: CmdTunnelWrite<M>,
560    F: CmdTunnelFactory<M, R, W>,
561    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
562    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug,
563> CmdWriteFactory<M, R, W, F, LEN, CMD>
564{
565    pub(crate) fn new(
566        tunnel_factory: F,
567        cmd_handler: impl CmdHandler<LEN, CMD>,
568        resp_waiter: RespWaiterRef,
569    ) -> Self {
570        Self {
571            tunnel_factory,
572            cmd_handler: Arc::new(cmd_handler),
573            resp_waiter,
574            tunnel_id_generator: TunnelIdGenerator::new(),
575            p: Default::default(),
576        }
577    }
578}
579
580#[async_trait::async_trait]
581impl<
582    M: CmdTunnelMeta,
583    R: CmdTunnelRead<M>,
584    W: CmdTunnelWrite<M>,
585    F: CmdTunnelFactory<M, R, W>,
586    LEN: RawEncode
587        + for<'a> RawDecode<'a>
588        + Copy
589        + Send
590        + Sync
591        + 'static
592        + FromPrimitive
593        + ToPrimitive
594        + RawFixedBytes,
595    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Debug,
596> ClassifiedWorkerFactory<TunnelId, CommonCmdSend<M, R, W, LEN, CMD>>
597    for CmdWriteFactory<M, R, W, F, LEN, CMD>
598{
599    async fn create(&self, c: Option<TunnelId>) -> PoolResult<CommonCmdSend<M, R, W, LEN, CMD>> {
600        if c.is_some() {
601            return Err(pool_err!(
602                PoolErrorCode::Failed,
603                "tunnel {:?} not found",
604                c.unwrap()
605            ));
606        }
607        let tunnel = self
608            .tunnel_factory
609            .create_tunnel()
610            .await
611            .map_err(into_pool_err!(PoolErrorCode::Failed))?;
612        let peer_id = tunnel.get_remote_peer_id();
613        let tunnel_id = self.tunnel_id_generator.generate();
614        let (mut recv, write) = tunnel.split();
615        let remote_id = write.get_remote_peer_id();
616        let meta = write.get_tunnel_meta();
617        let write = ObjectHolder::new(write);
618        let resp_write = write.clone();
619        let cmd_handler = self.cmd_handler.clone();
620        let handle = spawn(async move {
621            let ret: CmdResult<()> = async move {
622                loop {
623                    let header_len = recv
624                        .read_u8()
625                        .await
626                        .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
627                    let mut header = vec![0u8; header_len as usize];
628                    let n = recv
629                        .read_exact(header.as_mut())
630                        .await
631                        .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
632                    if n == 0 {
633                        break;
634                    }
635                    let header = CmdHeader::<LEN, CMD>::clone_from_slice(header.as_slice())
636                        .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
637                    log::trace!(
638                        "recv cmd {:?} from {} len {}",
639                        header.cmd_code(),
640                        peer_id.to_base58(),
641                        header.pkg_len().to_u64().unwrap()
642                    );
643                    let body_len = header.pkg_len().to_u64().unwrap();
644                    let cmd_read =
645                        CmdBodyRead::new(recv, header.pkg_len().to_u64().unwrap() as usize);
646                    let waiter = cmd_read.get_waiter();
647                    let future = waiter
648                        .create_result_future()
649                        .map_err(into_cmd_err!(CmdErrorCode::Failed))?;
650                    let version = header.version();
651                    let seq = header.seq();
652                    let cmd_code = header.cmd_code();
653                    match cmd_handler
654                        .handle(
655                            peer_id.clone(),
656                            tunnel_id,
657                            header,
658                            CmdBody::from_reader(BufReader::new(cmd_read), body_len),
659                        )
660                        .await
661                    {
662                        Ok(Some(mut body)) => {
663                            let mut write = resp_write.get().await;
664                            let header = CmdHeader::<LEN, CMD>::new(
665                                version,
666                                true,
667                                seq,
668                                cmd_code,
669                                LEN::from_u64(body.len()).unwrap(),
670                            );
671                            let buf = header
672                                .to_vec()
673                                .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
674                            if buf.len() > 255 {
675                                return Err(cmd_err!(
676                                    CmdErrorCode::InvalidParam,
677                                    "header len too long"
678                                ));
679                            }
680                            write
681                                .write_u8(buf.len() as u8)
682                                .await
683                                .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
684                            write
685                                .write_all(buf.as_slice())
686                                .await
687                                .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
688                            tokio::io::copy(&mut body, write.deref_mut().deref_mut())
689                                .await
690                                .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
691                            write
692                                .flush()
693                                .await
694                                .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
695                        }
696                        Ok(None) => {}
697                        Err(e) => {
698                            log::error!("handle cmd error: {:?}", e);
699                        }
700                    }
701                    recv = future
702                        .await
703                        .map_err(into_cmd_err!(CmdErrorCode::Failed))??;
704                }
705                Ok(())
706            }
707            .await;
708            ret
709        });
710        Ok(CommonCmdSend::new(
711            tunnel_id,
712            handle,
713            write,
714            self.resp_waiter.clone(),
715            remote_id,
716            meta,
717        ))
718    }
719}
720
721pub struct DefaultCmdClient<
722    M: CmdTunnelMeta,
723    R: CmdTunnelRead<M>,
724    W: CmdTunnelWrite<M>,
725    F: CmdTunnelFactory<M, R, W>,
726    LEN: RawEncode
727        + for<'a> RawDecode<'a>
728        + Copy
729        + Send
730        + Sync
731        + 'static
732        + FromPrimitive
733        + ToPrimitive
734        + RawFixedBytes,
735    CMD: RawEncode
736        + for<'a> RawDecode<'a>
737        + Copy
738        + Send
739        + Sync
740        + 'static
741        + RawFixedBytes
742        + Eq
743        + Hash
744        + Debug,
745> {
746    tunnel_pool: ClassifiedWorkerPoolRef<
747        TunnelId,
748        CommonCmdSend<M, R, W, LEN, CMD>,
749        CmdWriteFactory<M, R, W, F, LEN, CMD>,
750    >,
751    cmd_handler_map: Arc<CmdHandlerMap<LEN, CMD>>,
752}
753
754impl<
755    M: CmdTunnelMeta,
756    R: CmdTunnelRead<M>,
757    W: CmdTunnelWrite<M>,
758    F: CmdTunnelFactory<M, R, W>,
759    LEN: RawEncode
760        + for<'a> RawDecode<'a>
761        + Copy
762        + Send
763        + Sync
764        + 'static
765        + FromPrimitive
766        + ToPrimitive
767        + RawFixedBytes,
768    CMD: RawEncode
769        + for<'a> RawDecode<'a>
770        + Copy
771        + Send
772        + Sync
773        + 'static
774        + RawFixedBytes
775        + Eq
776        + Hash
777        + Debug,
778> DefaultCmdClient<M, R, W, F, LEN, CMD>
779{
780    pub fn new(factory: F, tunnel_count: u16) -> Arc<Self> {
781        let cmd_handler_map = Arc::new(CmdHandlerMap::new());
782        let handler_map = cmd_handler_map.clone();
783        let resp_waiter = Arc::new(RespWaiter::new());
784        let waiter = resp_waiter.clone();
785        Arc::new(Self {
786            tunnel_pool: ClassifiedWorkerPool::new(
787                tunnel_count,
788                CmdWriteFactory::<M, R, W, _, LEN, CMD>::new(
789                    factory,
790                    move |peer_id: PeerId,
791                          tunnel_id: TunnelId,
792                          header: CmdHeader<LEN, CMD>,
793                          body_read: CmdBody| {
794                        let handler_map = handler_map.clone();
795                        let waiter = waiter.clone();
796                        async move {
797                            if header.is_resp() && header.seq().is_some() {
798                                let resp_id = gen_resp_id(
799                                    tunnel_id,
800                                    header.cmd_code(),
801                                    header.seq().unwrap(),
802                                );
803                                let _ = waiter.set_result(resp_id, body_read);
804                                Ok(None)
805                            } else {
806                                if let Some(handler) = handler_map.get(header.cmd_code()) {
807                                    handler.handle(peer_id, tunnel_id, header, body_read).await
808                                } else {
809                                    Ok(None)
810                                }
811                            }
812                        }
813                    },
814                    resp_waiter.clone(),
815                ),
816            ),
817            cmd_handler_map,
818        })
819    }
820
821    async fn get_send(
822        &self,
823    ) -> CmdResult<
824        ClassifiedWorkerGuard<
825            TunnelId,
826            CommonCmdSend<M, R, W, LEN, CMD>,
827            CmdWriteFactory<M, R, W, F, LEN, CMD>,
828        >,
829    > {
830        self.tunnel_pool
831            .get_worker()
832            .await
833            .map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
834    }
835
836    async fn get_send_of_tunnel_id(
837        &self,
838        tunnel_id: TunnelId,
839    ) -> CmdResult<
840        ClassifiedWorkerGuard<
841            TunnelId,
842            CommonCmdSend<M, R, W, LEN, CMD>,
843            CmdWriteFactory<M, R, W, F, LEN, CMD>,
844        >,
845    > {
846        self.tunnel_pool
847            .get_classified_worker(tunnel_id)
848            .await
849            .map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
850    }
851}
852
853pub type CmdClientSendGuard<M, R, W, F, LEN, CMD> = ClassifiedSendGuard<
854    TunnelId,
855    M,
856    CommonCmdSend<M, R, W, LEN, CMD>,
857    CmdWriteFactory<M, R, W, F, LEN, CMD>,
858>;
859#[async_trait::async_trait]
860impl<
861    M: CmdTunnelMeta,
862    R: CmdTunnelRead<M>,
863    W: CmdTunnelWrite<M>,
864    F: CmdTunnelFactory<M, R, W>,
865    LEN: RawEncode
866        + for<'a> RawDecode<'a>
867        + Copy
868        + Send
869        + Sync
870        + 'static
871        + FromPrimitive
872        + ToPrimitive
873        + RawFixedBytes,
874    CMD: RawEncode
875        + for<'a> RawDecode<'a>
876        + Copy
877        + Send
878        + Sync
879        + 'static
880        + RawFixedBytes
881        + Eq
882        + Hash
883        + Debug,
884> CmdClient<LEN, CMD, M, CommonCmdSend<M, R, W, LEN, CMD>, CmdClientSendGuard<M, R, W, F, LEN, CMD>>
885    for DefaultCmdClient<M, R, W, F, LEN, CMD>
886{
887    fn register_cmd_handler(&self, cmd: CMD, handler: impl CmdHandler<LEN, CMD>) {
888        self.cmd_handler_map.insert(cmd, handler);
889    }
890
891    async fn send(&self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
892        let mut send = self.get_send().await?;
893        send.send(cmd, version, body).await
894    }
895
896    async fn send_with_resp(
897        &self,
898        cmd: CMD,
899        version: u8,
900        body: &[u8],
901        timeout: Duration,
902    ) -> CmdResult<CmdBody> {
903        let mut send = self.get_send().await?;
904        send.send_with_resp(cmd, version, body, timeout).await
905    }
906
907    async fn send2(&self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
908        let mut send = self.get_send().await?;
909        send.send2(cmd, version, body).await
910    }
911
912    async fn send2_with_resp(
913        &self,
914        cmd: CMD,
915        version: u8,
916        body: &[&[u8]],
917        timeout: Duration,
918    ) -> CmdResult<CmdBody> {
919        let mut send = self.get_send().await?;
920        send.send2_with_resp(cmd, version, body, timeout).await
921    }
922
923    async fn send_cmd(&self, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
924        let mut send = self.get_send().await?;
925        send.send_cmd(cmd, version, body).await
926    }
927
928    async fn send_cmd_with_resp(
929        &self,
930        cmd: CMD,
931        version: u8,
932        body: CmdBody,
933        timeout: Duration,
934    ) -> CmdResult<CmdBody> {
935        let mut send = self.get_send().await?;
936        send.send_cmd_with_resp(cmd, version, body, timeout).await
937    }
938
939    async fn send_by_specify_tunnel(
940        &self,
941        tunnel_id: TunnelId,
942        cmd: CMD,
943        version: u8,
944        body: &[u8],
945    ) -> CmdResult<()> {
946        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
947        send.send(cmd, version, body).await
948    }
949
950    async fn send_by_specify_tunnel_with_resp(
951        &self,
952        tunnel_id: TunnelId,
953        cmd: CMD,
954        version: u8,
955        body: &[u8],
956        timeout: Duration,
957    ) -> CmdResult<CmdBody> {
958        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
959        send.send_with_resp(cmd, version, body, timeout).await
960    }
961
962    async fn send2_by_specify_tunnel(
963        &self,
964        tunnel_id: TunnelId,
965        cmd: CMD,
966        version: u8,
967        body: &[&[u8]],
968    ) -> CmdResult<()> {
969        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
970        send.send2(cmd, version, body).await
971    }
972
973    async fn send2_by_specify_tunnel_with_resp(
974        &self,
975        tunnel_id: TunnelId,
976        cmd: CMD,
977        version: u8,
978        body: &[&[u8]],
979        timeout: Duration,
980    ) -> CmdResult<CmdBody> {
981        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
982        send.send2_with_resp(cmd, version, body, timeout).await
983    }
984
985    async fn send_cmd_by_specify_tunnel(
986        &self,
987        tunnel_id: TunnelId,
988        cmd: CMD,
989        version: u8,
990        body: CmdBody,
991    ) -> CmdResult<()> {
992        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
993        send.send_cmd(cmd, version, body).await
994    }
995
996    async fn send_cmd_by_specify_tunnel_with_resp(
997        &self,
998        tunnel_id: TunnelId,
999        cmd: CMD,
1000        version: u8,
1001        body: CmdBody,
1002        timeout: Duration,
1003    ) -> CmdResult<CmdBody> {
1004        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1005        send.send_cmd_with_resp(cmd, version, body, timeout).await
1006    }
1007
1008    async fn clear_all_tunnel(&self) {
1009        self.tunnel_pool.clear_all_worker().await;
1010    }
1011
1012    async fn get_send(
1013        &self,
1014        tunnel_id: TunnelId,
1015    ) -> CmdResult<CmdClientSendGuard<M, R, W, F, LEN, CMD>> {
1016        Ok(ClassifiedSendGuard {
1017            worker_guard: self.get_send_of_tunnel_id(tunnel_id).await?,
1018            _p: Default::default(),
1019        })
1020    }
1021}