Skip to main content

sfo_cmd_server/client/
client.rs

1use crate::client::{
2    CmdClient, CmdSend, RespWaiter, RespWaiterRef, SendGuard, TrackedSendGuard,
3    TunnelReserveResult, TunnelRuntimeRegistry, gen_resp_id, gen_seq,
4};
5use crate::cmd::{CmdBodyRead, CmdHandler, CmdHandlerMap, CmdHeader};
6use crate::errors::{CmdErrorCode, CmdResult, cmd_err, into_cmd_err};
7use crate::peer_id::PeerId;
8use crate::{CmdBody, CmdTunnelMeta, CmdTunnelRead, CmdTunnelWrite, TunnelId, TunnelIdGenerator};
9use async_named_locker::ObjectHolder;
10use bucky_raw_codec::{RawConvertTo, RawDecode, RawEncode, RawFixedBytes, RawFrom};
11use num::{FromPrimitive, ToPrimitive};
12use sfo_pool::{
13    ClassifiedWorker, ClassifiedWorkerFactory, ClassifiedWorkerGuard, ClassifiedWorkerPool,
14    ClassifiedWorkerPoolRef, PoolErrorCode, PoolResult, WorkerClassification, into_pool_err,
15    pool_err,
16};
17use sfo_split::{Splittable, WHalf};
18use std::fmt::Debug;
19use std::hash::Hash;
20use std::marker::PhantomData;
21use std::ops::{Deref, DerefMut};
22use std::sync::{Arc, Mutex};
23use std::time::Duration;
24use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
25use tokio::spawn;
26use tokio::task::JoinHandle;
27use tokio::task::yield_now;
28
29#[async_trait::async_trait]
30pub trait CmdTunnelFactory<M: CmdTunnelMeta, R: CmdTunnelRead<M>, W: CmdTunnelWrite<M>>:
31    Send + Sync + 'static
32{
33    async fn create_tunnel(&self) -> CmdResult<Splittable<R, W>>;
34}
35
36pub struct CommonCmdSend<M: CmdTunnelMeta, R: CmdTunnelRead<M>, W: CmdTunnelWrite<M>, LEN, CMD>
37where
38    LEN: RawEncode
39        + for<'a> RawDecode<'a>
40        + Copy
41        + Send
42        + Sync
43        + 'static
44        + FromPrimitive
45        + ToPrimitive,
46    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
47{
48    pub(crate) recv_handle: JoinHandle<CmdResult<()>>,
49    pub(crate) write: ObjectHolder<WHalf<R, W>>,
50    pub(crate) is_work: bool,
51    pub(crate) tunnel_id: TunnelId,
52    pub(crate) remote_id: PeerId,
53    pub(crate) resp_waiter: RespWaiterRef,
54    pub(crate) tunnel_meta: Option<Arc<M>>,
55    _p: std::marker::PhantomData<(LEN, CMD)>,
56}
57
58// impl<R, W, LEN, CMD> Deref for CmdSend<R, W, LEN, CMD>
59// where R: CmdTunnelRead,
60//       W: CmdTunnelWrite,
61//       LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
62//       CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes {
63//     type Target = W;
64//
65//     fn deref(&self) -> &Self::Target {
66//         self.write.deref()
67//     }
68// }
69
70impl<M, R, W, LEN, CMD> CommonCmdSend<M, R, W, LEN, CMD>
71where
72    M: CmdTunnelMeta,
73    R: CmdTunnelRead<M>,
74    W: CmdTunnelWrite<M>,
75    LEN: RawEncode
76        + for<'a> RawDecode<'a>
77        + Copy
78        + Send
79        + Sync
80        + 'static
81        + FromPrimitive
82        + ToPrimitive,
83    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
84{
85    pub fn new(
86        tunnel_id: TunnelId,
87        recv_handle: JoinHandle<CmdResult<()>>,
88        write: ObjectHolder<WHalf<R, W>>,
89        resp_waiter: RespWaiterRef,
90        remote_id: PeerId,
91        tunnel_meta: Option<Arc<M>>,
92    ) -> Self {
93        Self {
94            recv_handle,
95            write,
96            is_work: true,
97            tunnel_id,
98            remote_id,
99            resp_waiter,
100            tunnel_meta,
101            _p: Default::default(),
102        }
103    }
104
105    pub fn get_tunnel_id(&self) -> TunnelId {
106        self.tunnel_id
107    }
108
109    pub fn set_disable(&mut self) {
110        self.is_work = false;
111        self.recv_handle.abort();
112    }
113
114    pub async fn send(&mut self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
115        log::trace!(
116            "client {:?} send cmd: {:?}, len: {} data:{}",
117            self.tunnel_id,
118            cmd,
119            body.len(),
120            hex::encode(body)
121        );
122        let header = CmdHeader::<LEN, CMD>::new(
123            version,
124            false,
125            None,
126            cmd,
127            LEN::from_u64(body.len() as u64).unwrap(),
128        );
129        let buf = header
130            .to_vec()
131            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
132        let ret = self.send_inner(buf.as_slice(), body).await;
133        if let Err(e) = ret {
134            self.set_disable();
135            return Err(e);
136        }
137        Ok(())
138    }
139
140    pub async fn send_with_resp(
141        &mut self,
142        cmd: CMD,
143        version: u8,
144        body: &[u8],
145        timeout: Duration,
146    ) -> CmdResult<CmdBody> {
147        if let Some(id) = tokio::task::try_id() {
148            if id == self.recv_handle.id() {
149                return Err(cmd_err!(
150                    CmdErrorCode::Failed,
151                    "can't send with resp in recv task"
152                ));
153            }
154        }
155        log::trace!(
156            "client {:?} send cmd: {:?}, len: {}, data: {}",
157            self.tunnel_id,
158            cmd,
159            body.len(),
160            hex::encode(body)
161        );
162        let seq = gen_seq();
163        let header = CmdHeader::<LEN, CMD>::new(
164            version,
165            false,
166            Some(seq),
167            cmd,
168            LEN::from_u64(body.len() as u64).unwrap(),
169        );
170        let buf = header
171            .to_vec()
172            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
173        let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
174        let waiter = self.resp_waiter.clone();
175        let resp_waiter = waiter
176            .create_timeout_result_future(resp_id, timeout)
177            .map_err(into_cmd_err!(
178                CmdErrorCode::Failed,
179                "create timeout result future error"
180            ))?;
181        let ret = self.send_inner(buf.as_slice(), body).await;
182        if let Err(e) = ret {
183            self.set_disable();
184            return Err(e);
185        }
186        let resp = resp_waiter
187            .await
188            .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
189        Ok(resp)
190    }
191
192    pub async fn send_parts(&mut self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
193        let mut len = 0;
194        for b in body.iter() {
195            len += b.len();
196            log::trace!(
197                "client {:?} send2 cmd: {:?}, data {}",
198                self.tunnel_id,
199                cmd,
200                hex::encode(b)
201            );
202        }
203        log::trace!(
204            "client {:?} send2 cmd: {:?}, len {}",
205            self.tunnel_id,
206            cmd,
207            len
208        );
209        let header = CmdHeader::<LEN, CMD>::new(
210            version,
211            false,
212            None,
213            cmd,
214            LEN::from_u64(len as u64).unwrap(),
215        );
216        let buf = header
217            .to_vec()
218            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
219        let ret = self.send_inner2(buf.as_slice(), body).await;
220        if let Err(e) = ret {
221            self.set_disable();
222            return Err(e);
223        }
224        Ok(())
225    }
226
227    pub async fn send_parts_with_resp(
228        &mut self,
229        cmd: CMD,
230        version: u8,
231        body: &[&[u8]],
232        timeout: Duration,
233    ) -> CmdResult<CmdBody> {
234        if let Some(id) = tokio::task::try_id() {
235            if id == self.recv_handle.id() {
236                return Err(cmd_err!(
237                    CmdErrorCode::Failed,
238                    "can't send with resp in recv task"
239                ));
240            }
241        }
242        let mut len = 0;
243        for b in body.iter() {
244            len += b.len();
245            log::trace!(
246                "client {:?} send2 cmd {:?} body: {}",
247                self.tunnel_id,
248                cmd,
249                hex::encode(b)
250            );
251        }
252        log::trace!(
253            "client {:?} send2 cmd: {:?}, len {}",
254            self.tunnel_id,
255            cmd,
256            len
257        );
258        let seq = gen_seq();
259        let header = CmdHeader::<LEN, CMD>::new(
260            version,
261            false,
262            Some(seq),
263            cmd,
264            LEN::from_u64(len as u64).unwrap(),
265        );
266        let buf = header
267            .to_vec()
268            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
269        let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
270        let waiter = self.resp_waiter.clone();
271        let resp_waiter = waiter
272            .create_timeout_result_future(resp_id, timeout)
273            .map_err(into_cmd_err!(
274                CmdErrorCode::Failed,
275                "create timeout result future error"
276            ))?;
277        let ret = self.send_inner2(buf.as_slice(), body).await;
278        if let Err(e) = ret {
279            self.set_disable();
280            return Err(e);
281        }
282        let resp = resp_waiter
283            .await
284            .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
285        Ok(resp)
286    }
287
288    #[allow(deprecated)]
289    #[deprecated(note = "use send_parts instead")]
290    pub async fn send2(&mut self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
291        self.send_parts(cmd, version, body).await
292    }
293
294    #[allow(deprecated)]
295    #[deprecated(note = "use send_parts_with_resp instead")]
296    pub async fn send2_with_resp(
297        &mut self,
298        cmd: CMD,
299        version: u8,
300        body: &[&[u8]],
301        timeout: Duration,
302    ) -> CmdResult<CmdBody> {
303        self.send_parts_with_resp(cmd, version, body, timeout).await
304    }
305
306    pub async fn send_cmd(&mut self, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
307        log::trace!(
308            "client {:?} send cmd: {:?}, len: {}",
309            self.tunnel_id,
310            cmd,
311            body.len()
312        );
313        let header = CmdHeader::<LEN, CMD>::new(
314            version,
315            false,
316            None,
317            cmd,
318            LEN::from_u64(body.len()).unwrap(),
319        );
320        let buf = header
321            .to_vec()
322            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
323        let ret = self.send_inner_cmd(buf.as_slice(), body).await;
324        if let Err(e) = ret {
325            self.set_disable();
326            return Err(e);
327        }
328        Ok(())
329    }
330
331    pub async fn send_cmd_with_resp(
332        &mut self,
333        cmd: CMD,
334        version: u8,
335        body: CmdBody,
336        timeout: Duration,
337    ) -> CmdResult<CmdBody> {
338        if let Some(id) = tokio::task::try_id() {
339            if id == self.recv_handle.id() {
340                return Err(cmd_err!(
341                    CmdErrorCode::Failed,
342                    "can't send with resp in recv task"
343                ));
344            }
345        }
346        log::trace!(
347            "client {:?} send cmd: {:?}, len: {}",
348            self.tunnel_id,
349            cmd,
350            body.len()
351        );
352        let seq = gen_seq();
353        let header = CmdHeader::<LEN, CMD>::new(
354            version,
355            false,
356            Some(seq),
357            cmd,
358            LEN::from_u64(body.len()).unwrap(),
359        );
360        let buf = header
361            .to_vec()
362            .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
363        let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
364        let waiter = self.resp_waiter.clone();
365        let resp_waiter = waiter
366            .create_timeout_result_future(resp_id, timeout)
367            .map_err(into_cmd_err!(
368                CmdErrorCode::Failed,
369                "create timeout result future error"
370            ))?;
371        let ret = self.send_inner_cmd(buf.as_slice(), body).await;
372        if let Err(e) = ret {
373            self.set_disable();
374            return Err(e);
375        }
376        let resp = resp_waiter
377            .await
378            .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
379        Ok(resp)
380    }
381
382    async fn send_inner(&mut self, header: &[u8], body: &[u8]) -> CmdResult<()> {
383        let mut write = self.write.get().await;
384        if header.len() > 255 {
385            return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too long"));
386        }
387        write
388            .write_u8(header.len() as u8)
389            .await
390            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
391        write
392            .write_all(header)
393            .await
394            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
395        write
396            .write_all(body)
397            .await
398            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
399        write
400            .flush()
401            .await
402            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
403        Ok(())
404    }
405
406    async fn send_inner2(&mut self, header: &[u8], body: &[&[u8]]) -> CmdResult<()> {
407        let mut write = self.write.get().await;
408        if header.len() > 255 {
409            return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too long"));
410        }
411        write
412            .write_u8(header.len() as u8)
413            .await
414            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
415        write
416            .write_all(header)
417            .await
418            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
419        for b in body.iter() {
420            write
421                .write_all(b)
422                .await
423                .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
424        }
425        write
426            .flush()
427            .await
428            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
429        Ok(())
430    }
431
432    async fn send_inner_cmd(&mut self, header: &[u8], mut body: CmdBody) -> CmdResult<()> {
433        let mut write = self.write.get().await;
434        if header.len() > 255 {
435            return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
436        }
437        write
438            .write_u8(header.len() as u8)
439            .await
440            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
441        write
442            .write_all(header)
443            .await
444            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
445        tokio::io::copy(&mut body, write.deref_mut().deref_mut())
446            .await
447            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
448        write
449            .flush()
450            .await
451            .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
452        Ok(())
453    }
454}
455
456impl<M, R, W, LEN, CMD> Drop for CommonCmdSend<M, R, W, LEN, CMD>
457where
458    M: CmdTunnelMeta,
459    R: CmdTunnelRead<M>,
460    W: CmdTunnelWrite<M>,
461    LEN: RawEncode
462        + for<'a> RawDecode<'a>
463        + Copy
464        + Send
465        + Sync
466        + 'static
467        + FromPrimitive
468        + ToPrimitive,
469    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
470{
471    fn drop(&mut self) {
472        self.set_disable();
473    }
474}
475
476impl<M, R, W, LEN, CMD> CmdSend<M> for CommonCmdSend<M, R, W, LEN, CMD>
477where
478    M: CmdTunnelMeta,
479    R: CmdTunnelRead<M>,
480    W: CmdTunnelWrite<M>,
481    LEN: RawEncode
482        + for<'a> RawDecode<'a>
483        + Copy
484        + Send
485        + Sync
486        + 'static
487        + FromPrimitive
488        + ToPrimitive,
489    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
490{
491    fn get_tunnel_meta(&self) -> Option<Arc<M>> {
492        self.tunnel_meta.clone()
493    }
494
495    fn get_remote_peer_id(&self) -> PeerId {
496        self.remote_id.clone()
497    }
498}
499
500impl<M, R, W, LEN, CMD> ClassifiedWorker<TunnelId> for CommonCmdSend<M, R, W, LEN, CMD>
501where
502    M: CmdTunnelMeta,
503    R: CmdTunnelRead<M>,
504    W: CmdTunnelWrite<M>,
505    LEN: RawEncode
506        + for<'a> RawDecode<'a>
507        + Copy
508        + Send
509        + Sync
510        + 'static
511        + FromPrimitive
512        + ToPrimitive,
513    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
514{
515    fn is_work(&self) -> bool {
516        self.is_work && !self.recv_handle.is_finished()
517    }
518
519    fn is_valid(&self, c: TunnelId) -> bool {
520        self.tunnel_id == c
521    }
522
523    fn classification(&self) -> TunnelId {
524        self.tunnel_id
525    }
526}
527
528pub struct ClassifiedSendGuard<
529    C: WorkerClassification,
530    M: CmdTunnelMeta,
531    CW: ClassifiedWorker<C> + CmdSend<M>,
532    F: ClassifiedWorkerFactory<C, CW>,
533> {
534    pub(crate) worker_guard: ClassifiedWorkerGuard<C, CW, F>,
535    pub(crate) _p: PhantomData<M>,
536}
537
538impl<
539    C: WorkerClassification,
540    M: CmdTunnelMeta,
541    CW: ClassifiedWorker<C> + CmdSend<M>,
542    F: ClassifiedWorkerFactory<C, CW>,
543> Deref for ClassifiedSendGuard<C, M, CW, F>
544{
545    type Target = CW;
546
547    fn deref(&self) -> &Self::Target {
548        &self.worker_guard.deref()
549    }
550}
551
552impl<
553    C: WorkerClassification,
554    M: CmdTunnelMeta,
555    CW: ClassifiedWorker<C> + CmdSend<M>,
556    F: ClassifiedWorkerFactory<C, CW>,
557> SendGuard<M, CW> for ClassifiedSendGuard<C, M, CW, F>
558{
559}
560
561pub struct CmdWriteFactory<
562    M: CmdTunnelMeta,
563    R: CmdTunnelRead<M>,
564    W: CmdTunnelWrite<M>,
565    F: CmdTunnelFactory<M, R, W>,
566    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
567    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug,
568> {
569    tunnel_factory: F,
570    cmd_handler: Arc<dyn CmdHandler<LEN, CMD>>,
571    resp_waiter: RespWaiterRef,
572    tunnel_id_generator: TunnelIdGenerator,
573    p: std::marker::PhantomData<Mutex<(R, W, M)>>,
574}
575
576impl<
577    M: CmdTunnelMeta,
578    R: CmdTunnelRead<M>,
579    W: CmdTunnelWrite<M>,
580    F: CmdTunnelFactory<M, R, W>,
581    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
582    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug,
583> CmdWriteFactory<M, R, W, F, LEN, CMD>
584{
585    pub(crate) fn new(
586        tunnel_factory: F,
587        cmd_handler: impl CmdHandler<LEN, CMD>,
588        resp_waiter: RespWaiterRef,
589    ) -> Self {
590        Self {
591            tunnel_factory,
592            cmd_handler: Arc::new(cmd_handler),
593            resp_waiter,
594            tunnel_id_generator: TunnelIdGenerator::new(),
595            p: Default::default(),
596        }
597    }
598}
599
600#[async_trait::async_trait]
601impl<
602    M: CmdTunnelMeta,
603    R: CmdTunnelRead<M>,
604    W: CmdTunnelWrite<M>,
605    F: CmdTunnelFactory<M, R, W>,
606    LEN: RawEncode
607        + for<'a> RawDecode<'a>
608        + Copy
609        + Send
610        + Sync
611        + 'static
612        + FromPrimitive
613        + ToPrimitive
614        + RawFixedBytes,
615    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Debug,
616> ClassifiedWorkerFactory<TunnelId, CommonCmdSend<M, R, W, LEN, CMD>>
617    for CmdWriteFactory<M, R, W, F, LEN, CMD>
618{
619    async fn create(&self, c: Option<TunnelId>) -> PoolResult<CommonCmdSend<M, R, W, LEN, CMD>> {
620        if c.is_some() {
621            return Err(pool_err!(
622                PoolErrorCode::Failed,
623                "tunnel {:?} not found",
624                c.unwrap()
625            ));
626        }
627        let tunnel = self
628            .tunnel_factory
629            .create_tunnel()
630            .await
631            .map_err(into_pool_err!(PoolErrorCode::Failed))?;
632        let peer_id = tunnel.get_remote_peer_id();
633        let tunnel_id = self.tunnel_id_generator.generate();
634        let (mut recv, write) = tunnel.split();
635        let local_id = recv.get_local_peer_id();
636        let remote_id = write.get_remote_peer_id();
637        let meta = write.get_tunnel_meta();
638        let write = ObjectHolder::new(write);
639        let resp_write = write.clone();
640        let cmd_handler = self.cmd_handler.clone();
641        let handle = spawn(async move {
642            let ret: CmdResult<()> = async move {
643                loop {
644                    let header_len = recv
645                        .read_u8()
646                        .await
647                        .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
648                    let mut header = vec![0u8; header_len as usize];
649                    let n = recv
650                        .read_exact(header.as_mut())
651                        .await
652                        .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
653                    if n == 0 {
654                        break;
655                    }
656                    let header = CmdHeader::<LEN, CMD>::clone_from_slice(header.as_slice())
657                        .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
658                    log::trace!(
659                        "recv cmd {:?} from {} len {}",
660                        header.cmd_code(),
661                        peer_id.to_base58(),
662                        header.pkg_len().to_u64().unwrap()
663                    );
664                    let body_len = header.pkg_len().to_u64().unwrap();
665                    let cmd_read =
666                        CmdBodyRead::new(recv, header.pkg_len().to_u64().unwrap() as usize);
667                    let waiter = cmd_read.get_waiter();
668                    let future = waiter
669                        .create_result_future()
670                        .map_err(into_cmd_err!(CmdErrorCode::Failed))?;
671                    let version = header.version();
672                    let seq = header.seq();
673                    let cmd_code = header.cmd_code();
674                    match cmd_handler
675                        .handle(
676                            local_id.clone(),
677                            peer_id.clone(),
678                            tunnel_id,
679                            header,
680                            CmdBody::from_reader(BufReader::new(cmd_read), body_len),
681                        )
682                        .await
683                    {
684                        Ok(Some(mut body)) => {
685                            let mut write = resp_write.get().await;
686                            let header = CmdHeader::<LEN, CMD>::new(
687                                version,
688                                true,
689                                seq,
690                                cmd_code,
691                                LEN::from_u64(body.len()).unwrap(),
692                            );
693                            let buf = header
694                                .to_vec()
695                                .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
696                            if buf.len() > 255 {
697                                return Err(cmd_err!(
698                                    CmdErrorCode::InvalidParam,
699                                    "header len too long"
700                                ));
701                            }
702                            write
703                                .write_u8(buf.len() as u8)
704                                .await
705                                .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
706                            write
707                                .write_all(buf.as_slice())
708                                .await
709                                .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
710                            tokio::io::copy(&mut body, write.deref_mut().deref_mut())
711                                .await
712                                .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
713                            write
714                                .flush()
715                                .await
716                                .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
717                        }
718                        Ok(None) => {}
719                        Err(e) => {
720                            log::error!("handle cmd error: {:?}", e);
721                        }
722                    }
723                    recv = future
724                        .await
725                        .map_err(into_cmd_err!(CmdErrorCode::Failed))??;
726                }
727                Ok(())
728            }
729            .await;
730            ret
731        });
732        Ok(CommonCmdSend::new(
733            tunnel_id,
734            handle,
735            write,
736            self.resp_waiter.clone(),
737            remote_id,
738            meta,
739        ))
740    }
741}
742
743pub struct DefaultCmdClient<
744    M: CmdTunnelMeta,
745    R: CmdTunnelRead<M>,
746    W: CmdTunnelWrite<M>,
747    F: CmdTunnelFactory<M, R, W>,
748    LEN: RawEncode
749        + for<'a> RawDecode<'a>
750        + Copy
751        + Send
752        + Sync
753        + 'static
754        + FromPrimitive
755        + ToPrimitive
756        + RawFixedBytes,
757    CMD: RawEncode
758        + for<'a> RawDecode<'a>
759        + Copy
760        + Send
761        + Sync
762        + 'static
763        + RawFixedBytes
764        + Eq
765        + Hash
766        + Debug,
767> {
768    tunnel_pool: ClassifiedWorkerPoolRef<
769        TunnelId,
770        CommonCmdSend<M, R, W, LEN, CMD>,
771        CmdWriteFactory<M, R, W, F, LEN, CMD>,
772    >,
773    runtime_tunnels: Arc<TunnelRuntimeRegistry>,
774    cmd_handler_map: Arc<CmdHandlerMap<LEN, CMD>>,
775}
776
777impl<
778    M: CmdTunnelMeta,
779    R: CmdTunnelRead<M>,
780    W: CmdTunnelWrite<M>,
781    F: CmdTunnelFactory<M, R, W>,
782    LEN: RawEncode
783        + for<'a> RawDecode<'a>
784        + Copy
785        + Send
786        + Sync
787        + 'static
788        + FromPrimitive
789        + ToPrimitive
790        + RawFixedBytes,
791    CMD: RawEncode
792        + for<'a> RawDecode<'a>
793        + Copy
794        + Send
795        + Sync
796        + 'static
797        + RawFixedBytes
798        + Eq
799        + Hash
800        + Debug,
801> DefaultCmdClient<M, R, W, F, LEN, CMD>
802{
803    fn tunnel_not_found(tunnel_id: TunnelId) -> sfo_result::Error<CmdErrorCode> {
804        cmd_err!(CmdErrorCode::Failed, "tunnel {:?} not found", tunnel_id)
805    }
806
807    pub fn new(factory: F, tunnel_count: u16) -> Arc<Self> {
808        let cmd_handler_map = Arc::new(CmdHandlerMap::new());
809        let handler_map = cmd_handler_map.clone();
810        let resp_waiter = Arc::new(RespWaiter::new());
811        let waiter = resp_waiter.clone();
812        Arc::new(Self {
813            tunnel_pool: ClassifiedWorkerPool::new(
814                tunnel_count,
815                CmdWriteFactory::<M, R, W, _, LEN, CMD>::new(
816                    factory,
817                    move |local_id: PeerId,
818                          peer_id: PeerId,
819                          tunnel_id: TunnelId,
820                          header: CmdHeader<LEN, CMD>,
821                          body_read: CmdBody| {
822                        let handler_map = handler_map.clone();
823                        let waiter = waiter.clone();
824                        async move {
825                            if header.is_resp() && header.seq().is_some() {
826                                let resp_id = gen_resp_id(
827                                    tunnel_id,
828                                    header.cmd_code(),
829                                    header.seq().unwrap(),
830                                );
831                                let _ = waiter.set_result(resp_id, body_read);
832                                Ok(None)
833                            } else {
834                                if let Some(handler) = handler_map.get(header.cmd_code()) {
835                                    handler
836                                        .handle(local_id, peer_id, tunnel_id, header, body_read)
837                                        .await
838                                } else {
839                                    Ok(None)
840                                }
841                            }
842                        }
843                    },
844                    resp_waiter.clone(),
845                ),
846            ),
847            runtime_tunnels: TunnelRuntimeRegistry::new(),
848            cmd_handler_map,
849        })
850    }
851
852    fn tracked_send_guard(
853        &self,
854        worker_guard: ClassifiedWorkerGuard<
855            TunnelId,
856            CommonCmdSend<M, R, W, LEN, CMD>,
857            CmdWriteFactory<M, R, W, F, LEN, CMD>,
858        >,
859    ) -> CmdClientSendGuard<M, R, W, F, LEN, CMD> {
860        let tunnel_id = worker_guard.get_tunnel_id();
861        TrackedSendGuard::new(worker_guard, self.runtime_tunnels.clone(), tunnel_id)
862    }
863
864    async fn reserve_tunnel(&self, tunnel_id: TunnelId) -> CmdResult<()> {
865        loop {
866            match self.runtime_tunnels.reserve_existing(tunnel_id) {
867                TunnelReserveResult::Acquired => return Ok(()),
868                TunnelReserveResult::Wait(waiter) => waiter.notified().await,
869                TunnelReserveResult::Missing => return Err(Self::tunnel_not_found(tunnel_id)),
870            }
871        }
872    }
873
874    async fn get_send(&self) -> CmdResult<CmdClientSendGuard<M, R, W, F, LEN, CMD>> {
875        loop {
876            let worker_guard = self
877                .tunnel_pool
878                .get_worker()
879                .await
880                .map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))?;
881            let tunnel_id = worker_guard.get_tunnel_id();
882            if self.runtime_tunnels.mark_borrowed(tunnel_id) {
883                return Ok(self.tracked_send_guard(worker_guard));
884            }
885            drop(worker_guard);
886            yield_now().await;
887        }
888    }
889
890    async fn get_send_of_tunnel_id(
891        &self,
892        tunnel_id: TunnelId,
893    ) -> CmdResult<CmdClientSendGuard<M, R, W, F, LEN, CMD>> {
894        self.reserve_tunnel(tunnel_id).await?;
895        match self.tunnel_pool.get_classified_worker(tunnel_id).await {
896            Ok(worker_guard) => Ok(self.tracked_send_guard(worker_guard)),
897            Err(_) => {
898                self.runtime_tunnels.remove(tunnel_id);
899                Err(Self::tunnel_not_found(tunnel_id))
900            }
901        }
902    }
903}
904
905pub type CmdClientSendGuard<M, R, W, F, LEN, CMD> = TrackedSendGuard<
906    TunnelId,
907    M,
908    CommonCmdSend<M, R, W, LEN, CMD>,
909    CmdWriteFactory<M, R, W, F, LEN, CMD>,
910>;
911#[async_trait::async_trait]
912impl<
913    M: CmdTunnelMeta,
914    R: CmdTunnelRead<M>,
915    W: CmdTunnelWrite<M>,
916    F: CmdTunnelFactory<M, R, W>,
917    LEN: RawEncode
918        + for<'a> RawDecode<'a>
919        + Copy
920        + Send
921        + Sync
922        + 'static
923        + FromPrimitive
924        + ToPrimitive
925        + RawFixedBytes,
926    CMD: RawEncode
927        + for<'a> RawDecode<'a>
928        + Copy
929        + Send
930        + Sync
931        + 'static
932        + RawFixedBytes
933        + Eq
934        + Hash
935        + Debug,
936> CmdClient<LEN, CMD, M, CommonCmdSend<M, R, W, LEN, CMD>, CmdClientSendGuard<M, R, W, F, LEN, CMD>>
937    for DefaultCmdClient<M, R, W, F, LEN, CMD>
938{
939    fn register_cmd_handler(&self, cmd: CMD, handler: impl CmdHandler<LEN, CMD>) {
940        self.cmd_handler_map.insert(cmd, handler);
941    }
942
943    async fn send(&self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
944        let mut send = self.get_send().await?;
945        send.send(cmd, version, body).await
946    }
947
948    async fn send_with_resp(
949        &self,
950        cmd: CMD,
951        version: u8,
952        body: &[u8],
953        timeout: Duration,
954    ) -> CmdResult<CmdBody> {
955        let mut send = self.get_send().await?;
956        send.send_with_resp(cmd, version, body, timeout).await
957    }
958
959    async fn send_parts(&self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
960        let mut send = self.get_send().await?;
961        send.send_parts(cmd, version, body).await
962    }
963
964    async fn send_parts_with_resp(
965        &self,
966        cmd: CMD,
967        version: u8,
968        body: &[&[u8]],
969        timeout: Duration,
970    ) -> CmdResult<CmdBody> {
971        let mut send = self.get_send().await?;
972        send.send_parts_with_resp(cmd, version, body, timeout).await
973    }
974
975    async fn send_cmd(&self, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
976        let mut send = self.get_send().await?;
977        send.send_cmd(cmd, version, body).await
978    }
979
980    async fn send_cmd_with_resp(
981        &self,
982        cmd: CMD,
983        version: u8,
984        body: CmdBody,
985        timeout: Duration,
986    ) -> CmdResult<CmdBody> {
987        let mut send = self.get_send().await?;
988        send.send_cmd_with_resp(cmd, version, body, timeout).await
989    }
990
991    async fn send_by_specify_tunnel(
992        &self,
993        tunnel_id: TunnelId,
994        cmd: CMD,
995        version: u8,
996        body: &[u8],
997    ) -> CmdResult<()> {
998        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
999        send.send(cmd, version, body).await
1000    }
1001
1002    async fn send_by_specify_tunnel_with_resp(
1003        &self,
1004        tunnel_id: TunnelId,
1005        cmd: CMD,
1006        version: u8,
1007        body: &[u8],
1008        timeout: Duration,
1009    ) -> CmdResult<CmdBody> {
1010        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1011        send.send_with_resp(cmd, version, body, timeout).await
1012    }
1013
1014    async fn send_parts_by_specify_tunnel(
1015        &self,
1016        tunnel_id: TunnelId,
1017        cmd: CMD,
1018        version: u8,
1019        body: &[&[u8]],
1020    ) -> CmdResult<()> {
1021        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1022        send.send_parts(cmd, version, body).await
1023    }
1024
1025    async fn send_parts_by_specify_tunnel_with_resp(
1026        &self,
1027        tunnel_id: TunnelId,
1028        cmd: CMD,
1029        version: u8,
1030        body: &[&[u8]],
1031        timeout: Duration,
1032    ) -> CmdResult<CmdBody> {
1033        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1034        send.send_parts_with_resp(cmd, version, body, timeout).await
1035    }
1036
1037    async fn send_cmd_by_specify_tunnel(
1038        &self,
1039        tunnel_id: TunnelId,
1040        cmd: CMD,
1041        version: u8,
1042        body: CmdBody,
1043    ) -> CmdResult<()> {
1044        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1045        send.send_cmd(cmd, version, body).await
1046    }
1047
1048    async fn send_cmd_by_specify_tunnel_with_resp(
1049        &self,
1050        tunnel_id: TunnelId,
1051        cmd: CMD,
1052        version: u8,
1053        body: CmdBody,
1054        timeout: Duration,
1055    ) -> CmdResult<CmdBody> {
1056        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1057        send.send_cmd_with_resp(cmd, version, body, timeout).await
1058    }
1059
1060    async fn clear_all_tunnel(&self) {
1061        self.tunnel_pool.clear_all_worker().await;
1062        self.runtime_tunnels.clear();
1063    }
1064
1065    async fn get_send(
1066        &self,
1067        tunnel_id: TunnelId,
1068    ) -> CmdResult<CmdClientSendGuard<M, R, W, F, LEN, CMD>> {
1069        self.get_send_of_tunnel_id(tunnel_id).await
1070    }
1071}