Skip to main content

sfo_cmd_server/client/
client.rs

1use crate::client::{
2    CmdClient, CmdSend, RespWaiter, RespWaiterRef, SendGuard, gen_resp_id, gen_seq,
3};
4use crate::cmd::{CmdBodyRead, CmdHandler, CmdHandlerMap, CmdHeader};
5use crate::errors::{CmdErrorCode, CmdResult, cmd_err, into_cmd_err};
6use crate::peer_id::PeerId;
7use crate::{CmdBody, CmdTunnelMeta, CmdTunnelRead, CmdTunnelWrite, TunnelId, TunnelIdGenerator};
8use async_named_locker::ObjectHolder;
9use bucky_raw_codec::{RawConvertTo, RawDecode, RawEncode, RawFixedBytes, RawFrom};
10use num::{FromPrimitive, ToPrimitive};
11use sfo_pool::{
12    ClassifiedWorker, ClassifiedWorkerFactory, ClassifiedWorkerGuard, ClassifiedWorkerPool,
13    ClassifiedWorkerPoolRef, PoolErrorCode, PoolResult, WorkerClassification, into_pool_err,
14    pool_err,
15};
16use sfo_split::{Splittable, WHalf};
17use std::fmt::Debug;
18use std::hash::Hash;
19use std::marker::PhantomData;
20use std::ops::{Deref, DerefMut};
21use std::sync::{Arc, Mutex};
22use std::time::Duration;
23use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
24use tokio::spawn;
25use tokio::task::JoinHandle;
26
27#[async_trait::async_trait]
28pub trait CmdTunnelFactory<M: CmdTunnelMeta, R: CmdTunnelRead<M>, W: CmdTunnelWrite<M>>:
29    Send + Sync + 'static
30{
31    async fn create_tunnel(&self) -> CmdResult<Splittable<R, W>>;
32}
33
34pub struct CommonCmdSend<M: CmdTunnelMeta, R: CmdTunnelRead<M>, W: CmdTunnelWrite<M>, LEN, CMD>
35where
36    LEN: RawEncode
37        + for<'a> RawDecode<'a>
38        + Copy
39        + Send
40        + Sync
41        + 'static
42        + FromPrimitive
43        + ToPrimitive,
44    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
45{
46    pub(crate) recv_handle: JoinHandle<CmdResult<()>>,
47    pub(crate) write: ObjectHolder<WHalf<R, W>>,
48    pub(crate) is_work: bool,
49    pub(crate) tunnel_id: TunnelId,
50    pub(crate) remote_id: PeerId,
51    pub(crate) resp_waiter: RespWaiterRef,
52    pub(crate) tunnel_meta: Option<Arc<M>>,
53    _p: std::marker::PhantomData<(LEN, CMD)>,
54}
55
56// impl<R, W, LEN, CMD> Deref for CmdSend<R, W, LEN, CMD>
57// where R: CmdTunnelRead,
58//       W: CmdTunnelWrite,
59//       LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
60//       CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes {
61//     type Target = W;
62//
63//     fn deref(&self) -> &Self::Target {
64//         self.write.deref()
65//     }
66// }
67
68impl<M, R, W, LEN, CMD> CommonCmdSend<M, R, W, LEN, CMD>
69where
70    M: CmdTunnelMeta,
71    R: CmdTunnelRead<M>,
72    W: CmdTunnelWrite<M>,
73    LEN: RawEncode
74        + for<'a> RawDecode<'a>
75        + Copy
76        + Send
77        + Sync
78        + 'static
79        + FromPrimitive
80        + ToPrimitive,
81    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
82{
83    pub fn new(
84        tunnel_id: TunnelId,
85        recv_handle: JoinHandle<CmdResult<()>>,
86        write: ObjectHolder<WHalf<R, W>>,
87        resp_waiter: RespWaiterRef,
88        remote_id: PeerId,
89        tunnel_meta: Option<Arc<M>>,
90    ) -> Self {
91        Self {
92            recv_handle,
93            write,
94            is_work: true,
95            tunnel_id,
96            remote_id,
97            resp_waiter,
98            tunnel_meta,
99            _p: Default::default(),
100        }
101    }
102
103    pub fn get_tunnel_id(&self) -> TunnelId {
104        self.tunnel_id
105    }
106
107    pub fn set_disable(&mut self) {
108        self.is_work = false;
109        self.recv_handle.abort();
110    }
111
112    pub async fn send(&mut self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
113        log::trace!(
114            "client {:?} send cmd: {:?}, len: {} data:{}",
115            self.tunnel_id,
116            cmd,
117            body.len(),
118            hex::encode(body)
119        );
120        let header = CmdHeader::<LEN, CMD>::new(
121            version,
122            false,
123            None,
124            cmd,
125            LEN::from_u64(body.len() as u64).unwrap(),
126        );
127        let buf = header
128            .to_vec()
129            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
130        let ret = self.send_inner(buf.as_slice(), body).await;
131        if let Err(e) = ret {
132            self.set_disable();
133            return Err(e);
134        }
135        Ok(())
136    }
137
138    pub async fn send_with_resp(
139        &mut self,
140        cmd: CMD,
141        version: u8,
142        body: &[u8],
143        timeout: Duration,
144    ) -> CmdResult<CmdBody> {
145        if let Some(id) = tokio::task::try_id() {
146            if id == self.recv_handle.id() {
147                return Err(cmd_err!(
148                    CmdErrorCode::Failed,
149                    "can't send with resp in recv task"
150                ));
151            }
152        }
153        log::trace!(
154            "client {:?} send cmd: {:?}, len: {}, data: {}",
155            self.tunnel_id,
156            cmd,
157            body.len(),
158            hex::encode(body)
159        );
160        let seq = gen_seq();
161        let header = CmdHeader::<LEN, CMD>::new(
162            version,
163            false,
164            Some(seq),
165            cmd,
166            LEN::from_u64(body.len() as u64).unwrap(),
167        );
168        let buf = header
169            .to_vec()
170            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
171        let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
172        let waiter = self.resp_waiter.clone();
173        let resp_waiter = waiter
174            .create_timeout_result_future(resp_id, timeout)
175            .map_err(into_cmd_err!(
176                CmdErrorCode::Failed,
177                "create timeout result future error"
178            ))?;
179        let ret = self.send_inner(buf.as_slice(), body).await;
180        if let Err(e) = ret {
181            self.set_disable();
182            return Err(e);
183        }
184        let resp = resp_waiter
185            .await
186            .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
187        Ok(resp)
188    }
189
190    pub async fn send2(&mut self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
191        let mut len = 0;
192        for b in body.iter() {
193            len += b.len();
194            log::trace!(
195                "client {:?} send2 cmd: {:?}, data {}",
196                self.tunnel_id,
197                cmd,
198                hex::encode(b)
199            );
200        }
201        log::trace!(
202            "client {:?} send2 cmd: {:?}, len {}",
203            self.tunnel_id,
204            cmd,
205            len
206        );
207        let header = CmdHeader::<LEN, CMD>::new(
208            version,
209            false,
210            None,
211            cmd,
212            LEN::from_u64(len as u64).unwrap(),
213        );
214        let buf = header
215            .to_vec()
216            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
217        let ret = self.send_inner2(buf.as_slice(), body).await;
218        if let Err(e) = ret {
219            self.set_disable();
220            return Err(e);
221        }
222        Ok(())
223    }
224
225    pub async fn send2_with_resp(
226        &mut self,
227        cmd: CMD,
228        version: u8,
229        body: &[&[u8]],
230        timeout: Duration,
231    ) -> CmdResult<CmdBody> {
232        if let Some(id) = tokio::task::try_id() {
233            if id == self.recv_handle.id() {
234                return Err(cmd_err!(
235                    CmdErrorCode::Failed,
236                    "can't send with resp in recv task"
237                ));
238            }
239        }
240        let mut len = 0;
241        for b in body.iter() {
242            len += b.len();
243            log::trace!(
244                "client {:?} send2 cmd {:?} body: {}",
245                self.tunnel_id,
246                cmd,
247                hex::encode(b)
248            );
249        }
250        log::trace!(
251            "client {:?} send2 cmd: {:?}, len {}",
252            self.tunnel_id,
253            cmd,
254            len
255        );
256        let seq = gen_seq();
257        let header = CmdHeader::<LEN, CMD>::new(
258            version,
259            false,
260            Some(seq),
261            cmd,
262            LEN::from_u64(len as u64).unwrap(),
263        );
264        let buf = header
265            .to_vec()
266            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
267        let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
268        let waiter = self.resp_waiter.clone();
269        let resp_waiter = waiter
270            .create_timeout_result_future(resp_id, timeout)
271            .map_err(into_cmd_err!(
272                CmdErrorCode::Failed,
273                "create timeout result future error"
274            ))?;
275        let ret = self.send_inner2(buf.as_slice(), body).await;
276        if let Err(e) = ret {
277            self.set_disable();
278            return Err(e);
279        }
280        let resp = resp_waiter
281            .await
282            .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
283        Ok(resp)
284    }
285
286    pub async fn send_cmd(&mut self, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
287        log::trace!(
288            "client {:?} send cmd: {:?}, len: {}",
289            self.tunnel_id,
290            cmd,
291            body.len()
292        );
293        let header = CmdHeader::<LEN, CMD>::new(
294            version,
295            false,
296            None,
297            cmd,
298            LEN::from_u64(body.len()).unwrap(),
299        );
300        let buf = header
301            .to_vec()
302            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
303        let ret = self.send_inner_cmd(buf.as_slice(), body).await;
304        if let Err(e) = ret {
305            self.set_disable();
306            return Err(e);
307        }
308        Ok(())
309    }
310
311    pub async fn send_cmd_with_resp(
312        &mut self,
313        cmd: CMD,
314        version: u8,
315        body: CmdBody,
316        timeout: Duration,
317    ) -> CmdResult<CmdBody> {
318        if let Some(id) = tokio::task::try_id() {
319            if id == self.recv_handle.id() {
320                return Err(cmd_err!(
321                    CmdErrorCode::Failed,
322                    "can't send with resp in recv task"
323                ));
324            }
325        }
326        log::trace!(
327            "client {:?} send cmd: {:?}, len: {}",
328            self.tunnel_id,
329            cmd,
330            body.len()
331        );
332        let seq = gen_seq();
333        let header = CmdHeader::<LEN, CMD>::new(
334            version,
335            false,
336            Some(seq),
337            cmd,
338            LEN::from_u64(body.len()).unwrap(),
339        );
340        let buf = header
341            .to_vec()
342            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
343        let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
344        let waiter = self.resp_waiter.clone();
345        let resp_waiter = waiter
346            .create_timeout_result_future(resp_id, timeout)
347            .map_err(into_cmd_err!(
348                CmdErrorCode::Failed,
349                "create timeout result future error"
350            ))?;
351        let ret = self.send_inner_cmd(buf.as_slice(), body).await;
352        if let Err(e) = ret {
353            self.set_disable();
354            return Err(e);
355        }
356        let resp = resp_waiter
357            .await
358            .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
359        Ok(resp)
360    }
361
362    async fn send_inner(&mut self, header: &[u8], body: &[u8]) -> CmdResult<()> {
363        let mut write = self.write.get().await;
364        if header.len() > 255 {
365            return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too long"));
366        }
367        write
368            .write_u8(header.len() as u8)
369            .await
370            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
371        write
372            .write_all(header)
373            .await
374            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
375        write
376            .write_all(body)
377            .await
378            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
379        write
380            .flush()
381            .await
382            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
383        Ok(())
384    }
385
386    async fn send_inner2(&mut self, header: &[u8], body: &[&[u8]]) -> CmdResult<()> {
387        let mut write = self.write.get().await;
388        if header.len() > 255 {
389            return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too long"));
390        }
391        write
392            .write_u8(header.len() as u8)
393            .await
394            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
395        write
396            .write_all(header)
397            .await
398            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
399        for b in body.iter() {
400            write
401                .write_all(b)
402                .await
403                .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
404        }
405        write
406            .flush()
407            .await
408            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
409        Ok(())
410    }
411
412    async fn send_inner_cmd(&mut self, header: &[u8], mut body: CmdBody) -> CmdResult<()> {
413        let mut write = self.write.get().await;
414        if header.len() > 255 {
415            return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
416        }
417        write
418            .write_u8(header.len() as u8)
419            .await
420            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
421        write
422            .write_all(header)
423            .await
424            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
425        tokio::io::copy(&mut body, write.deref_mut().deref_mut())
426            .await
427            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
428        write
429            .flush()
430            .await
431            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
432        Ok(())
433    }
434}
435
436impl<M, R, W, LEN, CMD> Drop for CommonCmdSend<M, R, W, LEN, CMD>
437where
438    M: CmdTunnelMeta,
439    R: CmdTunnelRead<M>,
440    W: CmdTunnelWrite<M>,
441    LEN: RawEncode
442        + for<'a> RawDecode<'a>
443        + Copy
444        + Send
445        + Sync
446        + 'static
447        + FromPrimitive
448        + ToPrimitive,
449    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
450{
451    fn drop(&mut self) {
452        self.set_disable();
453    }
454}
455
456impl<M, R, W, LEN, CMD> CmdSend<M> for CommonCmdSend<M, R, W, LEN, CMD>
457where
458    M: CmdTunnelMeta,
459    R: CmdTunnelRead<M>,
460    W: CmdTunnelWrite<M>,
461    LEN: RawEncode
462        + for<'a> RawDecode<'a>
463        + Copy
464        + Send
465        + Sync
466        + 'static
467        + FromPrimitive
468        + ToPrimitive,
469    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
470{
471    fn get_tunnel_meta(&self) -> Option<Arc<M>> {
472        self.tunnel_meta.clone()
473    }
474
475    fn get_remote_peer_id(&self) -> PeerId {
476        self.remote_id.clone()
477    }
478}
479
480impl<M, R, W, LEN, CMD> ClassifiedWorker<TunnelId> for CommonCmdSend<M, R, W, LEN, CMD>
481where
482    M: CmdTunnelMeta,
483    R: CmdTunnelRead<M>,
484    W: CmdTunnelWrite<M>,
485    LEN: RawEncode
486        + for<'a> RawDecode<'a>
487        + Copy
488        + Send
489        + Sync
490        + 'static
491        + FromPrimitive
492        + ToPrimitive,
493    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
494{
495    fn is_work(&self) -> bool {
496        self.is_work && !self.recv_handle.is_finished()
497    }
498
499    fn is_valid(&self, c: TunnelId) -> bool {
500        self.tunnel_id == c
501    }
502
503    fn classification(&self) -> TunnelId {
504        self.tunnel_id
505    }
506}
507
508pub struct ClassifiedSendGuard<
509    C: WorkerClassification,
510    M: CmdTunnelMeta,
511    CW: ClassifiedWorker<C> + CmdSend<M>,
512    F: ClassifiedWorkerFactory<C, CW>,
513> {
514    pub(crate) worker_guard: ClassifiedWorkerGuard<C, CW, F>,
515    pub(crate) _p: PhantomData<M>,
516}
517
518impl<
519    C: WorkerClassification,
520    M: CmdTunnelMeta,
521    CW: ClassifiedWorker<C> + CmdSend<M>,
522    F: ClassifiedWorkerFactory<C, CW>,
523> Deref for ClassifiedSendGuard<C, M, CW, F>
524{
525    type Target = CW;
526
527    fn deref(&self) -> &Self::Target {
528        &self.worker_guard.deref()
529    }
530}
531
532impl<
533    C: WorkerClassification,
534    M: CmdTunnelMeta,
535    CW: ClassifiedWorker<C> + CmdSend<M>,
536    F: ClassifiedWorkerFactory<C, CW>,
537> SendGuard<M, CW> for ClassifiedSendGuard<C, M, CW, F>
538{
539}
540
541pub struct CmdWriteFactory<
542    M: CmdTunnelMeta,
543    R: CmdTunnelRead<M>,
544    W: CmdTunnelWrite<M>,
545    F: CmdTunnelFactory<M, R, W>,
546    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
547    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug,
548> {
549    tunnel_factory: F,
550    cmd_handler: Arc<dyn CmdHandler<LEN, CMD>>,
551    resp_waiter: RespWaiterRef,
552    tunnel_id_generator: TunnelIdGenerator,
553    p: std::marker::PhantomData<Mutex<(R, W, M)>>,
554}
555
556impl<
557    M: CmdTunnelMeta,
558    R: CmdTunnelRead<M>,
559    W: CmdTunnelWrite<M>,
560    F: CmdTunnelFactory<M, R, W>,
561    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
562    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug,
563> CmdWriteFactory<M, R, W, F, LEN, CMD>
564{
565    pub(crate) fn new(
566        tunnel_factory: F,
567        cmd_handler: impl CmdHandler<LEN, CMD>,
568        resp_waiter: RespWaiterRef,
569    ) -> Self {
570        Self {
571            tunnel_factory,
572            cmd_handler: Arc::new(cmd_handler),
573            resp_waiter,
574            tunnel_id_generator: TunnelIdGenerator::new(),
575            p: Default::default(),
576        }
577    }
578}
579
580#[async_trait::async_trait]
581impl<
582    M: CmdTunnelMeta,
583    R: CmdTunnelRead<M>,
584    W: CmdTunnelWrite<M>,
585    F: CmdTunnelFactory<M, R, W>,
586    LEN: RawEncode
587        + for<'a> RawDecode<'a>
588        + Copy
589        + Send
590        + Sync
591        + 'static
592        + FromPrimitive
593        + ToPrimitive
594        + RawFixedBytes,
595    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Debug,
596> ClassifiedWorkerFactory<TunnelId, CommonCmdSend<M, R, W, LEN, CMD>>
597    for CmdWriteFactory<M, R, W, F, LEN, CMD>
598{
599    async fn create(&self, c: Option<TunnelId>) -> PoolResult<CommonCmdSend<M, R, W, LEN, CMD>> {
600        if c.is_some() {
601            return Err(pool_err!(
602                PoolErrorCode::Failed,
603                "tunnel {:?} not found",
604                c.unwrap()
605            ));
606        }
607        let tunnel = self
608            .tunnel_factory
609            .create_tunnel()
610            .await
611            .map_err(into_pool_err!(PoolErrorCode::Failed))?;
612        let peer_id = tunnel.get_remote_peer_id();
613        let tunnel_id = self.tunnel_id_generator.generate();
614        let (mut recv, write) = tunnel.split();
615        let local_id = recv.get_local_peer_id();
616        let remote_id = write.get_remote_peer_id();
617        let meta = write.get_tunnel_meta();
618        let write = ObjectHolder::new(write);
619        let resp_write = write.clone();
620        let cmd_handler = self.cmd_handler.clone();
621        let handle = spawn(async move {
622            let ret: CmdResult<()> = async move {
623                loop {
624                    let header_len = recv
625                        .read_u8()
626                        .await
627                        .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
628                    let mut header = vec![0u8; header_len as usize];
629                    let n = recv
630                        .read_exact(header.as_mut())
631                        .await
632                        .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
633                    if n == 0 {
634                        break;
635                    }
636                    let header = CmdHeader::<LEN, CMD>::clone_from_slice(header.as_slice())
637                        .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
638                    log::trace!(
639                        "recv cmd {:?} from {} len {}",
640                        header.cmd_code(),
641                        peer_id.to_base58(),
642                        header.pkg_len().to_u64().unwrap()
643                    );
644                    let body_len = header.pkg_len().to_u64().unwrap();
645                    let cmd_read =
646                        CmdBodyRead::new(recv, header.pkg_len().to_u64().unwrap() as usize);
647                    let waiter = cmd_read.get_waiter();
648                    let future = waiter
649                        .create_result_future()
650                        .map_err(into_cmd_err!(CmdErrorCode::Failed))?;
651                    let version = header.version();
652                    let seq = header.seq();
653                    let cmd_code = header.cmd_code();
654                    match cmd_handler
655                        .handle(
656                            local_id.clone(),
657                            peer_id.clone(),
658                            tunnel_id,
659                            header,
660                            CmdBody::from_reader(BufReader::new(cmd_read), body_len),
661                        )
662                        .await
663                    {
664                        Ok(Some(mut body)) => {
665                            let mut write = resp_write.get().await;
666                            let header = CmdHeader::<LEN, CMD>::new(
667                                version,
668                                true,
669                                seq,
670                                cmd_code,
671                                LEN::from_u64(body.len()).unwrap(),
672                            );
673                            let buf = header
674                                .to_vec()
675                                .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
676                            if buf.len() > 255 {
677                                return Err(cmd_err!(
678                                    CmdErrorCode::InvalidParam,
679                                    "header len too long"
680                                ));
681                            }
682                            write
683                                .write_u8(buf.len() as u8)
684                                .await
685                                .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
686                            write
687                                .write_all(buf.as_slice())
688                                .await
689                                .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
690                            tokio::io::copy(&mut body, write.deref_mut().deref_mut())
691                                .await
692                                .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
693                            write
694                                .flush()
695                                .await
696                                .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
697                        }
698                        Ok(None) => {}
699                        Err(e) => {
700                            log::error!("handle cmd error: {:?}", e);
701                        }
702                    }
703                    recv = future
704                        .await
705                        .map_err(into_cmd_err!(CmdErrorCode::Failed))??;
706                }
707                Ok(())
708            }
709            .await;
710            ret
711        });
712        Ok(CommonCmdSend::new(
713            tunnel_id,
714            handle,
715            write,
716            self.resp_waiter.clone(),
717            remote_id,
718            meta,
719        ))
720    }
721}
722
723pub struct DefaultCmdClient<
724    M: CmdTunnelMeta,
725    R: CmdTunnelRead<M>,
726    W: CmdTunnelWrite<M>,
727    F: CmdTunnelFactory<M, R, W>,
728    LEN: RawEncode
729        + for<'a> RawDecode<'a>
730        + Copy
731        + Send
732        + Sync
733        + 'static
734        + FromPrimitive
735        + ToPrimitive
736        + RawFixedBytes,
737    CMD: RawEncode
738        + for<'a> RawDecode<'a>
739        + Copy
740        + Send
741        + Sync
742        + 'static
743        + RawFixedBytes
744        + Eq
745        + Hash
746        + Debug,
747> {
748    tunnel_pool: ClassifiedWorkerPoolRef<
749        TunnelId,
750        CommonCmdSend<M, R, W, LEN, CMD>,
751        CmdWriteFactory<M, R, W, F, LEN, CMD>,
752    >,
753    cmd_handler_map: Arc<CmdHandlerMap<LEN, CMD>>,
754}
755
756impl<
757    M: CmdTunnelMeta,
758    R: CmdTunnelRead<M>,
759    W: CmdTunnelWrite<M>,
760    F: CmdTunnelFactory<M, R, W>,
761    LEN: RawEncode
762        + for<'a> RawDecode<'a>
763        + Copy
764        + Send
765        + Sync
766        + 'static
767        + FromPrimitive
768        + ToPrimitive
769        + RawFixedBytes,
770    CMD: RawEncode
771        + for<'a> RawDecode<'a>
772        + Copy
773        + Send
774        + Sync
775        + 'static
776        + RawFixedBytes
777        + Eq
778        + Hash
779        + Debug,
780> DefaultCmdClient<M, R, W, F, LEN, CMD>
781{
782    pub fn new(factory: F, tunnel_count: u16) -> Arc<Self> {
783        let cmd_handler_map = Arc::new(CmdHandlerMap::new());
784        let handler_map = cmd_handler_map.clone();
785        let resp_waiter = Arc::new(RespWaiter::new());
786        let waiter = resp_waiter.clone();
787        Arc::new(Self {
788            tunnel_pool: ClassifiedWorkerPool::new(
789                tunnel_count,
790                CmdWriteFactory::<M, R, W, _, LEN, CMD>::new(
791                    factory,
792                    move |local_id: PeerId,
793                          peer_id: PeerId,
794                          tunnel_id: TunnelId,
795                          header: CmdHeader<LEN, CMD>,
796                          body_read: CmdBody| {
797                        let handler_map = handler_map.clone();
798                        let waiter = waiter.clone();
799                        async move {
800                            if header.is_resp() && header.seq().is_some() {
801                                let resp_id = gen_resp_id(
802                                    tunnel_id,
803                                    header.cmd_code(),
804                                    header.seq().unwrap(),
805                                );
806                                let _ = waiter.set_result(resp_id, body_read);
807                                Ok(None)
808                            } else {
809                                if let Some(handler) = handler_map.get(header.cmd_code()) {
810                                    handler
811                                        .handle(local_id, peer_id, tunnel_id, header, body_read)
812                                        .await
813                                } else {
814                                    Ok(None)
815                                }
816                            }
817                        }
818                    },
819                    resp_waiter.clone(),
820                ),
821            ),
822            cmd_handler_map,
823        })
824    }
825
826    async fn get_send(
827        &self,
828    ) -> CmdResult<
829        ClassifiedWorkerGuard<
830            TunnelId,
831            CommonCmdSend<M, R, W, LEN, CMD>,
832            CmdWriteFactory<M, R, W, F, LEN, CMD>,
833        >,
834    > {
835        self.tunnel_pool
836            .get_worker()
837            .await
838            .map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
839    }
840
841    async fn get_send_of_tunnel_id(
842        &self,
843        tunnel_id: TunnelId,
844    ) -> CmdResult<
845        ClassifiedWorkerGuard<
846            TunnelId,
847            CommonCmdSend<M, R, W, LEN, CMD>,
848            CmdWriteFactory<M, R, W, F, LEN, CMD>,
849        >,
850    > {
851        self.tunnel_pool
852            .get_classified_worker(tunnel_id)
853            .await
854            .map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
855    }
856}
857
858pub type CmdClientSendGuard<M, R, W, F, LEN, CMD> = ClassifiedSendGuard<
859    TunnelId,
860    M,
861    CommonCmdSend<M, R, W, LEN, CMD>,
862    CmdWriteFactory<M, R, W, F, LEN, CMD>,
863>;
864#[async_trait::async_trait]
865impl<
866    M: CmdTunnelMeta,
867    R: CmdTunnelRead<M>,
868    W: CmdTunnelWrite<M>,
869    F: CmdTunnelFactory<M, R, W>,
870    LEN: RawEncode
871        + for<'a> RawDecode<'a>
872        + Copy
873        + Send
874        + Sync
875        + 'static
876        + FromPrimitive
877        + ToPrimitive
878        + RawFixedBytes,
879    CMD: RawEncode
880        + for<'a> RawDecode<'a>
881        + Copy
882        + Send
883        + Sync
884        + 'static
885        + RawFixedBytes
886        + Eq
887        + Hash
888        + Debug,
889> CmdClient<LEN, CMD, M, CommonCmdSend<M, R, W, LEN, CMD>, CmdClientSendGuard<M, R, W, F, LEN, CMD>>
890    for DefaultCmdClient<M, R, W, F, LEN, CMD>
891{
892    fn register_cmd_handler(&self, cmd: CMD, handler: impl CmdHandler<LEN, CMD>) {
893        self.cmd_handler_map.insert(cmd, handler);
894    }
895
896    async fn send(&self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
897        let mut send = self.get_send().await?;
898        send.send(cmd, version, body).await
899    }
900
901    async fn send_with_resp(
902        &self,
903        cmd: CMD,
904        version: u8,
905        body: &[u8],
906        timeout: Duration,
907    ) -> CmdResult<CmdBody> {
908        let mut send = self.get_send().await?;
909        send.send_with_resp(cmd, version, body, timeout).await
910    }
911
912    async fn send2(&self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
913        let mut send = self.get_send().await?;
914        send.send2(cmd, version, body).await
915    }
916
917    async fn send2_with_resp(
918        &self,
919        cmd: CMD,
920        version: u8,
921        body: &[&[u8]],
922        timeout: Duration,
923    ) -> CmdResult<CmdBody> {
924        let mut send = self.get_send().await?;
925        send.send2_with_resp(cmd, version, body, timeout).await
926    }
927
928    async fn send_cmd(&self, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
929        let mut send = self.get_send().await?;
930        send.send_cmd(cmd, version, body).await
931    }
932
933    async fn send_cmd_with_resp(
934        &self,
935        cmd: CMD,
936        version: u8,
937        body: CmdBody,
938        timeout: Duration,
939    ) -> CmdResult<CmdBody> {
940        let mut send = self.get_send().await?;
941        send.send_cmd_with_resp(cmd, version, body, timeout).await
942    }
943
944    async fn send_by_specify_tunnel(
945        &self,
946        tunnel_id: TunnelId,
947        cmd: CMD,
948        version: u8,
949        body: &[u8],
950    ) -> CmdResult<()> {
951        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
952        send.send(cmd, version, body).await
953    }
954
955    async fn send_by_specify_tunnel_with_resp(
956        &self,
957        tunnel_id: TunnelId,
958        cmd: CMD,
959        version: u8,
960        body: &[u8],
961        timeout: Duration,
962    ) -> CmdResult<CmdBody> {
963        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
964        send.send_with_resp(cmd, version, body, timeout).await
965    }
966
967    async fn send2_by_specify_tunnel(
968        &self,
969        tunnel_id: TunnelId,
970        cmd: CMD,
971        version: u8,
972        body: &[&[u8]],
973    ) -> CmdResult<()> {
974        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
975        send.send2(cmd, version, body).await
976    }
977
978    async fn send2_by_specify_tunnel_with_resp(
979        &self,
980        tunnel_id: TunnelId,
981        cmd: CMD,
982        version: u8,
983        body: &[&[u8]],
984        timeout: Duration,
985    ) -> CmdResult<CmdBody> {
986        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
987        send.send2_with_resp(cmd, version, body, timeout).await
988    }
989
990    async fn send_cmd_by_specify_tunnel(
991        &self,
992        tunnel_id: TunnelId,
993        cmd: CMD,
994        version: u8,
995        body: CmdBody,
996    ) -> CmdResult<()> {
997        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
998        send.send_cmd(cmd, version, body).await
999    }
1000
1001    async fn send_cmd_by_specify_tunnel_with_resp(
1002        &self,
1003        tunnel_id: TunnelId,
1004        cmd: CMD,
1005        version: u8,
1006        body: CmdBody,
1007        timeout: Duration,
1008    ) -> CmdResult<CmdBody> {
1009        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1010        send.send_cmd_with_resp(cmd, version, body, timeout).await
1011    }
1012
1013    async fn clear_all_tunnel(&self) {
1014        self.tunnel_pool.clear_all_worker().await;
1015    }
1016
1017    async fn get_send(
1018        &self,
1019        tunnel_id: TunnelId,
1020    ) -> CmdResult<CmdClientSendGuard<M, R, W, F, LEN, CMD>> {
1021        Ok(ClassifiedSendGuard {
1022            worker_guard: self.get_send_of_tunnel_id(tunnel_id).await?,
1023            _p: Default::default(),
1024        })
1025    }
1026}