sfo_cmd_server/client/
classified_client.rs

1use std::hash::Hash;
2use std::sync::{Arc, Mutex};
3use bucky_raw_codec::{RawConvertTo, RawDecode, RawEncode, RawFixedBytes, RawFrom};
4use num::{FromPrimitive, ToPrimitive};
5use sfo_pool::{into_pool_err, pool_err, ClassifiedWorker, ClassifiedWorkerFactory, ClassifiedWorkerGuard, ClassifiedWorkerPool, ClassifiedWorkerPoolRef, PoolErrorCode, PoolResult, WorkerClassification};
6use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
7use tokio::spawn;
8use tokio::task::JoinHandle;
9use crate::{CmdBody, CmdTunnelMeta, CmdTunnelRead, CmdTunnelWrite, TunnelId, TunnelIdGenerator};
10use crate::client::{gen_resp_id, gen_seq, ClassifiedCmdClient, ClassifiedSendGuard, CmdClient, CmdSend, RespWaiter, RespWaiterRef};
11use crate::cmd::{CmdBodyRead, CmdHandler, CmdHandlerMap, CmdHeader};
12use crate::errors::{cmd_err, into_cmd_err, CmdErrorCode, CmdResult};
13use crate::peer_id::PeerId;
14use std::fmt::Debug;
15use std::ops::{DerefMut};
16use std::time::Duration;
17use async_named_locker::ObjectHolder;
18use sfo_split::{RHalf, Splittable, WHalf};
19
20pub trait ClassifiedCmdTunnelRead<C: WorkerClassification, M: CmdTunnelMeta>: CmdTunnelRead<M> + 'static + Send {
21    fn get_classification(&self) -> C;
22}
23
24pub trait ClassifiedCmdTunnelWrite<C: WorkerClassification, M: CmdTunnelMeta>: CmdTunnelWrite<M> + 'static + Send {
25    fn get_classification(&self) -> C;
26}
27
28pub type ClassifiedCmdTunnel<R, W> = Splittable<R, W>;
29pub type ClassifiedCmdTunnelRHalf<R, W> = RHalf<R, W>;
30pub type ClassifiedCmdTunnelWHalf<R, W> = WHalf<R, W>;
31
32#[derive(Debug, Clone, Copy, Eq, Hash)]
33pub struct CmdClientTunnelClassification<C: WorkerClassification> {
34    tunnel_id: Option<TunnelId>,
35    classification: Option<C>,
36}
37
38impl<C: WorkerClassification> PartialEq for CmdClientTunnelClassification<C> {
39    fn eq(&self, other: &Self) -> bool {
40        self.tunnel_id == other.tunnel_id && self.classification == other.classification
41    }
42}
43
44
45#[async_trait::async_trait]
46pub trait ClassifiedCmdTunnelFactory<C: WorkerClassification, M: CmdTunnelMeta, R: ClassifiedCmdTunnelRead<C, M>, W: ClassifiedCmdTunnelWrite<C, M>>: Send + Sync + 'static {
47    async fn create_tunnel(&self, classification: Option<C>) -> CmdResult<Splittable<R, W>>;
48}
49
50pub struct ClassifiedCmdSend<C, M, R, W, LEN, CMD>
51where
52    C: WorkerClassification,
53    M: CmdTunnelMeta,
54    R: ClassifiedCmdTunnelRead<C, M>,
55    W: ClassifiedCmdTunnelWrite<C, M>,
56    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
57    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
58{
59    pub(crate) recv_handle: JoinHandle<CmdResult<()>>,
60    pub(crate) write: ObjectHolder<ClassifiedCmdTunnelWHalf<R, W>>,
61    pub(crate) is_work: bool,
62    pub(crate) classification: C,
63    pub(crate) tunnel_id: TunnelId,
64    pub(crate) resp_waiter: RespWaiterRef,
65    pub(crate) remote_id: PeerId,
66    pub(crate) tunnel_meta: Option<Arc<M>>,
67    _p: std::marker::PhantomData<(LEN, CMD)>,
68
69}
70
71// impl<C, R, W, LEN, CMD> Deref for ClassifiedCmdSend<C, R, W, LEN, CMD>
72// where C: WorkerClassification,
73//       R: ClassifiedCmdTunnelRead<C>,
74//       W: ClassifiedCmdTunnelWrite<C>,
75//       LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
76//       CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes {
77//     type Target = W;
78//
79//     fn deref(&self) -> &Self::Target {
80//         self.write.deref()
81//     }
82// }
83
84impl<C, M, R, W, LEN, CMD> ClassifiedCmdSend<C, M, R, W, LEN, CMD>
85where C: WorkerClassification,
86      M: CmdTunnelMeta,
87      R: ClassifiedCmdTunnelRead<C, M>,
88      W: ClassifiedCmdTunnelWrite<C, M>,
89      LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
90      CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes {
91    pub(crate) fn new(tunnel_id: TunnelId,
92                      classification: C,
93                      recv_handle: JoinHandle<CmdResult<()>>,
94                      write: ObjectHolder<ClassifiedCmdTunnelWHalf<R, W>>,
95                      resp_waiter: RespWaiterRef,
96                      remote_id: PeerId,
97                      tunnel_meta: Option<Arc<M>>) -> Self {
98        Self {
99            recv_handle,
100            write,
101            is_work: true,
102            classification,
103            tunnel_id,
104            resp_waiter,
105            remote_id,
106            tunnel_meta,
107            _p: Default::default(),
108        }
109    }
110
111    pub fn get_tunnel_id(&self) -> TunnelId {
112        self.tunnel_id
113    }
114
115    pub fn set_disable(&mut self) {
116        self.is_work = false;
117        self.recv_handle.abort();
118    }
119
120    pub async fn send(&mut self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
121        log::trace!("client {:?} send cmd: {:?}, len: {}, data: {}", self.tunnel_id, cmd, body.len(), hex::encode(body));
122        let header = CmdHeader::<LEN, CMD>::new(version, false, None, cmd, LEN::from_u64(body.len() as u64).unwrap());
123        let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
124        let ret = self.send_inner(buf.as_slice(), body).await;
125        if let Err(e) = ret {
126            self.set_disable();
127            return Err(e);
128        }
129        Ok(())
130    }
131
132    pub async fn send_with_resp(&mut self, cmd: CMD, version: u8, body: &[u8], timeout: Duration) -> CmdResult<CmdBody> {
133        if let Some(id) = tokio::task::try_id() {
134            if id == self.recv_handle.id() {
135                return Err(cmd_err!(CmdErrorCode::Failed, "can't send with resp in recv task"));
136            }
137        }
138        log::trace!("client {:?} send cmd: {:?}, len: {}, data: {}", self.tunnel_id, cmd, body.len(), hex::encode(body));
139        let seq = gen_seq();
140        let header = CmdHeader::<LEN, CMD>::new(version, false, Some(seq), cmd, LEN::from_u64(body.len() as u64).unwrap());
141        let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
142        let resp_id = gen_resp_id(cmd, seq);
143        let waiter = self.resp_waiter.clone();
144        let resp_waiter = waiter.create_timeout_result_future(resp_id, timeout)
145            .map_err(into_cmd_err!(CmdErrorCode::Failed, "create timeout result future error"))?;
146        let ret = self.send_inner(buf.as_slice(), body).await;
147        if let Err(e) = ret {
148            self.set_disable();
149            return Err(e);
150        }
151        let resp = resp_waiter.await.map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
152        Ok(resp)
153    }
154
155    pub async fn send2(&mut self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
156        let mut len = 0;
157        for b in body.iter() {
158            len += b.len();
159            log::trace!("client {:?} send2 cmd {:?} body: {}", self.tunnel_id, cmd, hex::encode(b));
160        }
161        log::trace!("client {:?} send2 cmd: {:?}, len {}", self.tunnel_id, cmd, len);
162        let header = CmdHeader::<LEN, CMD>::new(version, false, None, cmd, LEN::from_u64(len as u64).unwrap());
163        let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
164        let ret = self.send_inner2(buf.as_slice(), body).await;
165        if let Err(e) = ret {
166            self.set_disable();
167            return Err(e);
168        }
169        Ok(())
170    }
171
172    pub async fn send2_with_resp(&mut self, cmd: CMD, version: u8, body: &[&[u8]], timeout: Duration) -> CmdResult<CmdBody> {
173        if let Some(id) = tokio::task::try_id() {
174            if id == self.recv_handle.id() {
175                return Err(cmd_err!(CmdErrorCode::Failed, "can't send with resp in recv task"));
176            }
177        }
178        let mut len = 0;
179        for b in body.iter() {
180            len += b.len();
181            log::trace!("client {:?} send2 cmd {:?} body: {}", self.tunnel_id, cmd, hex::encode(b));
182        }
183        log::trace!("client {:?} send2 cmd: {:?}, len {}", self.tunnel_id, cmd, len);
184        let seq = gen_seq();
185        let header = CmdHeader::<LEN, CMD>::new(version, false, Some(seq), cmd, LEN::from_u64(len as u64).unwrap());
186        let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
187        let resp_id = gen_resp_id(cmd, seq);
188        let waiter = self.resp_waiter.clone();
189        let resp_waiter = waiter.create_timeout_result_future(resp_id, timeout)
190            .map_err(into_cmd_err!(CmdErrorCode::Failed, "create timeout result future error"))?;
191        let ret = self.send_inner2(buf.as_slice(), body).await;
192        if let Err(e) = ret {
193            self.set_disable();
194            return Err(e);
195        }
196        let resp = resp_waiter.await.map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
197        Ok(resp)
198    }
199
200    pub async fn send_cmd(&mut self, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
201        log::trace!("client {:?} send cmd: {:?}, len: {}", self.tunnel_id, cmd, body.len());
202        let header = CmdHeader::<LEN, CMD>::new(version, false, None, cmd, LEN::from_u64(body.len()).unwrap());
203        let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
204        let ret = self.send_inner_cmd(buf.as_slice(), body).await;
205        if let Err(e) = ret {
206            self.set_disable();
207            return Err(e);
208        }
209        Ok(())
210    }
211
212    pub async fn send_cmd_with_resp(&mut self, cmd: CMD, version: u8, body: CmdBody, timeout: Duration) -> CmdResult<CmdBody> {
213        if let Some(id) = tokio::task::try_id() {
214            if id == self.recv_handle.id() {
215                return Err(cmd_err!(CmdErrorCode::Failed, "can't send with resp in recv task"));
216            }
217        }
218        log::trace!("client {:?} send cmd: {:?}, len: {}", self.tunnel_id, cmd, body.len());
219        let seq = gen_seq();
220        let header = CmdHeader::<LEN, CMD>::new(version, false, Some(seq), cmd, LEN::from_u64(body.len()).unwrap());
221        let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
222        let resp_id = gen_resp_id(cmd, seq);
223        let waiter = self.resp_waiter.clone();
224        let resp_waiter = waiter.create_timeout_result_future(resp_id, timeout)
225           .map_err(into_cmd_err!(CmdErrorCode::Failed, "create timeout result future error"))?;
226        let ret = self.send_inner_cmd(buf.as_slice(), body).await;
227        if let Err(e) = ret {
228            self.set_disable();
229            return Err(e);
230        }
231        let resp = resp_waiter.await.map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
232        Ok(resp)
233    }
234
235    async fn send_inner(&mut self, header: &[u8], body: &[u8]) -> CmdResult<()> {
236        let mut write = self.write.get().await;
237        if header.len() > 255 {
238            return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
239        }
240        write.write_u8(header.len() as u8).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
241        write.write_all(header).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
242        write.write_all(body).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
243        write.flush().await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
244        Ok(())
245    }
246
247    async fn send_inner2(&mut self, header: &[u8], body: &[&[u8]]) -> CmdResult<()> {
248        let mut write = self.write.get().await;
249        if header.len() > 255 {
250            return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
251        }
252        write.write_u8(header.len() as u8).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
253        write.write_all(header).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
254        for b in body.iter() {
255            write.write_all(b).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
256        }
257        write.flush().await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
258        Ok(())
259    }
260
261    async fn send_inner_cmd(&mut self, header: &[u8], mut body: CmdBody) -> CmdResult<()> {
262        let mut write = self.write.get().await;
263        if header.len() > 255 {
264            return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
265        }
266        write.write_u8(header.len() as u8).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
267        write.write_all(header).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
268        tokio::io::copy(&mut body, write.deref_mut().deref_mut()).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
269        write.flush().await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
270        Ok(())
271    }
272}
273
274impl<C, M, R, W, LEN, CMD> Drop for ClassifiedCmdSend<C, M, R, W, LEN, CMD>
275where C: WorkerClassification,
276      M: CmdTunnelMeta,
277      R: ClassifiedCmdTunnelRead<C, M>,
278      W: ClassifiedCmdTunnelWrite<C, M>,
279      LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
280      CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes {
281    fn drop(&mut self) {
282        self.set_disable();
283    }
284}
285
286impl<C, M, R, W, LEN, CMD> CmdSend<M> for ClassifiedCmdSend<C, M, R, W, LEN, CMD>
287where C: WorkerClassification,
288      M: CmdTunnelMeta,
289      R: ClassifiedCmdTunnelRead<C, M>,
290      W: ClassifiedCmdTunnelWrite<C, M>,
291      LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
292      CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes {
293    fn get_tunnel_meta(&self) -> Option<Arc<M>> {
294        self.tunnel_meta.clone()
295    }
296
297    fn get_remote_peer_id(&self) -> PeerId {
298        self.remote_id.clone()
299    }
300}
301
302impl<C, M, R, W, LEN, CMD> ClassifiedWorker<CmdClientTunnelClassification<C>> for ClassifiedCmdSend<C, M, R, W, LEN, CMD>
303where C: WorkerClassification,
304      M: CmdTunnelMeta,
305      R: ClassifiedCmdTunnelRead<C, M>,
306      W: ClassifiedCmdTunnelWrite<C, M>,
307      LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
308      CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes {
309    fn is_work(&self) -> bool {
310        self.is_work && !self.recv_handle.is_finished()
311    }
312
313    fn is_valid(&self, c: CmdClientTunnelClassification<C>) -> bool {
314        if c.tunnel_id.is_some() {
315            self.tunnel_id == c.tunnel_id.unwrap()
316        } else {
317            if c.classification.is_some() {
318                self.classification == c.classification.unwrap()
319            } else {
320                true
321            }
322        }
323    }
324
325    fn classification(&self) -> CmdClientTunnelClassification<C> {
326        CmdClientTunnelClassification {
327            tunnel_id: Some(self.tunnel_id),
328            classification: Some(self.classification.clone()),
329        }
330    }
331}
332
333pub struct ClassifiedCmdWriteFactory<C: WorkerClassification,
334    M: CmdTunnelMeta,
335    R: ClassifiedCmdTunnelRead<C, M>,
336    W: ClassifiedCmdTunnelWrite<C, M>,
337    F: ClassifiedCmdTunnelFactory<C, M, R, W>,
338    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
339    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes> {
340    tunnel_factory: F,
341    cmd_handler: Arc<dyn CmdHandler<LEN, CMD>>,
342    resp_waiter: RespWaiterRef,
343    tunnel_id_generator: TunnelIdGenerator,
344    _p: std::marker::PhantomData<Mutex<(C, M, R, W)>>,
345}
346
347impl<
348    C: WorkerClassification,
349    M: CmdTunnelMeta,
350    R: ClassifiedCmdTunnelRead<C, M>,
351    W: ClassifiedCmdTunnelWrite<C, M>,
352    F: ClassifiedCmdTunnelFactory<C, M, R, W>,
353    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
354    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes
355> ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD> {
356    pub(crate) fn new(tunnel_factory: F,
357                      cmd_handler: impl CmdHandler<LEN, CMD>,
358                      resp_waiter: RespWaiterRef,) -> Self {
359        Self {
360            tunnel_factory,
361            cmd_handler: Arc::new(cmd_handler),
362            resp_waiter,
363            tunnel_id_generator: TunnelIdGenerator::new(),
364            _p: Default::default(),
365        }
366    }
367}
368
369#[async_trait::async_trait]
370impl<C: WorkerClassification,
371    M: CmdTunnelMeta,
372    R: ClassifiedCmdTunnelRead<C, M>,
373    W: ClassifiedCmdTunnelWrite<C, M>,
374    F: ClassifiedCmdTunnelFactory<C, M, R, W>,
375    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
376    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Debug> ClassifiedWorkerFactory<CmdClientTunnelClassification<C>, ClassifiedCmdSend<C, M, R, W, LEN, CMD>
377> for ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD> {
378    async fn create(&self, classification: Option<CmdClientTunnelClassification<C>>) -> PoolResult<ClassifiedCmdSend<C, M, R, W, LEN, CMD>> {
379        if classification.is_some() && classification.as_ref().unwrap().tunnel_id.is_some() {
380            return Err(pool_err!(PoolErrorCode::Failed, "tunnel {:?} not found", classification.as_ref().unwrap().tunnel_id.unwrap()));
381        }
382
383        let classification = if classification.is_some() && classification.as_ref().unwrap().classification.is_some() {
384            classification.unwrap().classification
385        } else {
386            None
387        };
388        let tunnel = self.tunnel_factory.create_tunnel(classification).await.map_err(into_pool_err!(PoolErrorCode::Failed))?;
389        let classification = tunnel.get_classification();
390        let peer_id = tunnel.get_remote_peer_id();
391        let tunnel_id = self.tunnel_id_generator.generate();
392        let (mut recv, write) = tunnel.split();
393        let remote_id = peer_id.clone();
394        let tunnel_meta = recv.get_tunnel_meta();
395        let write = ObjectHolder::new(write);
396        let resp_write = write.clone();
397        let cmd_handler = self.cmd_handler.clone();
398        let handle = spawn(async move {
399            let ret: CmdResult<()> = async move {
400                loop {
401                    let header_len = recv.read_u8().await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
402                    let mut header = vec![0u8; header_len as usize];
403                    let n = recv.read_exact(header.as_mut()).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
404                    if n == 0 {
405                        break;
406                    }
407                    let header = CmdHeader::<LEN, CMD>::clone_from_slice(header.as_slice()).map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
408                    log::trace!("recv cmd {:?} from {} len {} tunnel {:?}", header.cmd_code(), peer_id, header.pkg_len().to_u64().unwrap(), tunnel_id);
409                    let body_len = header.pkg_len().to_u64().unwrap();
410                    let cmd_read = CmdBodyRead::new(recv, header.pkg_len().to_u64().unwrap() as usize);
411                    let waiter = cmd_read.get_waiter();
412                    let future = waiter.create_result_future().map_err(into_cmd_err!(CmdErrorCode::Failed))?;
413                    let version = header.version();
414                    let seq = header.seq();
415                    let cmd_code = header.cmd_code();
416                    match cmd_handler.handle(peer_id.clone(), tunnel_id, header, CmdBody::from_reader(BufReader::new(cmd_read), body_len)).await {
417                        Ok(Some(mut body)) => {
418                            let mut write = resp_write.get().await;
419                            let header = CmdHeader::<LEN, CMD>::new(version, true, seq, cmd_code, LEN::from_u64(body.len()).unwrap());
420                            let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
421                            if buf.len() > 255 {
422                                return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
423                            }
424                            write.write_u8(buf.len() as u8).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
425                            write.write_all(buf.as_slice()).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
426                            tokio::io::copy(&mut body, write.deref_mut().deref_mut()).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
427                            write.flush().await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
428                        }
429                        Err(e) => {
430                            log::error!("handle cmd error: {:?}", e);
431                        }
432                        _ => {}
433                    }
434                    recv = future.await.map_err(into_cmd_err!(CmdErrorCode::Failed))??;
435                    log::debug!("handle cmd {:?} from {} len {} tunnel {:?} complete", cmd_code, peer_id, body_len, tunnel_id);
436                }
437                Ok(())
438            }.await;
439            if ret.is_err() {
440                log::error!("recv cmd error: {:?}", ret.as_ref().err().unwrap());
441            }
442            ret
443        });
444        Ok(ClassifiedCmdSend::new(tunnel_id, classification, handle, write, self.resp_waiter.clone(), remote_id, tunnel_meta))
445    }
446}
447
448pub struct DefaultClassifiedCmdClient<C: WorkerClassification,
449    M: CmdTunnelMeta,
450    R: ClassifiedCmdTunnelRead<C, M>,
451    W: ClassifiedCmdTunnelWrite<C, M>,
452    F: ClassifiedCmdTunnelFactory<C, M, R, W>,
453    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
454    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Eq + Hash + Debug> {
455    tunnel_pool: ClassifiedWorkerPoolRef<CmdClientTunnelClassification<C>, ClassifiedCmdSend<C, M, R, W, LEN, CMD>, ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>>,
456    cmd_handler_map: Arc<CmdHandlerMap<LEN, CMD>>,
457}
458
459impl<C: WorkerClassification,
460    M: CmdTunnelMeta,
461    R: ClassifiedCmdTunnelRead<C, M>,
462    W: ClassifiedCmdTunnelWrite<C, M>,
463    F: ClassifiedCmdTunnelFactory<C, M, R, W>,
464    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
465    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Eq + Hash + Debug> DefaultClassifiedCmdClient<C, M, R, W, F, LEN, CMD> {
466    pub fn new(factory: F, tunnel_count: u16) -> Arc<Self> {
467        let cmd_handler_map = Arc::new(CmdHandlerMap::new());
468        let resp_waiter = Arc::new(RespWaiter::new());
469        let handler_map = cmd_handler_map.clone();
470        let waiter = resp_waiter.clone();
471        Arc::new(Self {
472            tunnel_pool: ClassifiedWorkerPool::new(tunnel_count, ClassifiedCmdWriteFactory::<C, M, R, W, _, LEN, CMD>::new(factory, move |peer_id: PeerId, tunnel_id: TunnelId, header: CmdHeader<LEN, CMD>, body_read: CmdBody| {
473                let handler_map = handler_map.clone();
474                let waiter = waiter.clone();
475                async move {
476                    if header.is_resp() && header.seq().is_some() {
477                        let resp_id = gen_resp_id(header.cmd_code(), header.seq().unwrap());
478                        let _ = waiter.set_result(resp_id, body_read);
479                        Ok(None)
480                    } else {
481                        if let Some(handler) = handler_map.get(header.cmd_code()) {
482                            handler.handle(peer_id, tunnel_id, header, body_read).await
483                        } else {
484                            Ok(None)
485                        }
486                    }
487                }
488            }, resp_waiter.clone())),
489            cmd_handler_map,
490        })
491    }
492
493    async fn get_send(&self) -> CmdResult<ClassifiedWorkerGuard<CmdClientTunnelClassification<C>, ClassifiedCmdSend<C, M, R, W, LEN, CMD>, ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>>> {
494        self.tunnel_pool.get_worker().await.map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
495    }
496
497    async fn get_send_of_tunnel_id(&self, tunnel_id: TunnelId) -> CmdResult<ClassifiedWorkerGuard<CmdClientTunnelClassification<C>, ClassifiedCmdSend<C, M, R, W, LEN, CMD>, ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>>> {
498        self.tunnel_pool.get_classified_worker(CmdClientTunnelClassification {
499            tunnel_id: Some(tunnel_id),
500            classification: None,
501        }).await.map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
502    }
503
504    async fn get_classified_send(&self, classification: C) -> CmdResult<ClassifiedWorkerGuard<CmdClientTunnelClassification<C>, ClassifiedCmdSend<C, M, R, W, LEN, CMD>, ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>>> {
505        self.tunnel_pool.get_classified_worker(CmdClientTunnelClassification {
506            tunnel_id: None,
507            classification: Some(classification),
508        }).await.map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
509    }
510}
511
512pub type ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD> = ClassifiedSendGuard<CmdClientTunnelClassification<C>, M, ClassifiedCmdSend<C, M, R, W, LEN, CMD>, ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>>;
513#[async_trait::async_trait]
514impl<C: WorkerClassification,
515    M: CmdTunnelMeta,
516    R: ClassifiedCmdTunnelRead<C, M>,
517    W: ClassifiedCmdTunnelWrite<C, M>,
518    F: ClassifiedCmdTunnelFactory<C, M, R, W>,
519    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
520    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Eq + Hash + Debug,
521> CmdClient<LEN, CMD, M, ClassifiedCmdSend<C, M, R, W, LEN, CMD>, ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD>> for DefaultClassifiedCmdClient<C, M, R, W, F, LEN, CMD> {
522    fn register_cmd_handler(&self, cmd: CMD, handler: impl CmdHandler<LEN, CMD>) {
523        self.cmd_handler_map.insert(cmd, handler);
524    }
525
526    async fn send(&self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
527        let mut send = self.get_send().await?;
528        send.send(cmd, version, body).await
529    }
530
531    async fn send_with_resp(&self, cmd: CMD, version: u8, body: &[u8], timeout: Duration) -> CmdResult<CmdBody> {
532        let mut send = self.get_send().await?;
533        send.send_with_resp(cmd, version, body, timeout).await
534    }
535
536    async fn send2(&self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
537        let mut send = self.get_send().await?;
538        send.send2(cmd, version, body).await
539    }
540
541    async fn send2_with_resp(&self, cmd: CMD, version: u8, body: &[&[u8]], timeout: Duration) -> CmdResult<CmdBody> {
542        let mut send = self.get_send().await?;
543        send.send2_with_resp(cmd, version, body, timeout).await
544    }
545
546    async fn send_cmd(&self, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
547        let mut send = self.get_send().await?;
548        send.send_cmd(cmd, version, body).await
549    }
550
551    async fn send_cmd_with_resp(&self, cmd: CMD, version: u8, body: CmdBody, timeout: Duration) -> CmdResult<CmdBody> {
552        let mut send = self.get_send().await?;
553        send.send_cmd_with_resp(cmd, version, body, timeout).await
554    }
555
556    async fn send_by_specify_tunnel(&self, tunnel_id: TunnelId, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
557        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
558        send.send(cmd, version, body).await
559    }
560
561    async fn send_by_specify_tunnel_with_resp(&self, tunnel_id: TunnelId, cmd: CMD, version: u8, body: &[u8], timeout: Duration) -> CmdResult<CmdBody> {
562        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
563        send.send_with_resp(cmd, version, body, timeout).await
564    }
565
566    async fn send2_by_specify_tunnel(&self, tunnel_id: TunnelId, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
567        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
568        send.send2(cmd, version, body).await
569    }
570
571    async fn send2_by_specify_tunnel_with_resp(&self, tunnel_id: TunnelId, cmd: CMD, version: u8, body: &[&[u8]], timeout: Duration) -> CmdResult<CmdBody> {
572        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
573        send.send2_with_resp(cmd, version, body, timeout).await
574    }
575
576    async fn send_cmd_by_specify_tunnel(&self, tunnel_id: TunnelId, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
577        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
578        send.send_cmd(cmd, version, body).await
579    }
580
581    async fn send_cmd_by_specify_tunnel_with_resp(&self, tunnel_id: TunnelId, cmd: CMD, version: u8, body: CmdBody, timeout: Duration) -> CmdResult<CmdBody> {
582        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
583        send.send_cmd_with_resp(cmd, version, body, timeout).await
584    }
585
586    async fn clear_all_tunnel(&self) {
587        self.tunnel_pool.clear_all_worker().await;
588    }
589
590    async fn get_send(&self, tunnel_id: TunnelId) -> CmdResult<ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD>> {
591        Ok(ClassifiedSendGuard {
592            worker_guard: self.get_send_of_tunnel_id(tunnel_id).await?,
593            _p: std::marker::PhantomData,
594        })
595    }
596}
597
598#[async_trait::async_trait]
599impl<C: WorkerClassification,
600    M: CmdTunnelMeta,
601    R: ClassifiedCmdTunnelRead<C, M>,
602    W: ClassifiedCmdTunnelWrite<C, M>,
603    F: ClassifiedCmdTunnelFactory<C, M, R, W>,
604    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
605    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Eq + Hash + Debug
606> ClassifiedCmdClient<LEN, CMD, C, M, ClassifiedCmdSend<C, M, R, W, LEN, CMD>, ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD>> for DefaultClassifiedCmdClient<C, M, R, W, F, LEN, CMD> {
607    async fn send_by_classified_tunnel(&self, classification: C, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
608        let mut send = self.get_classified_send(classification).await?;
609        send.send(cmd, version, body).await
610    }
611
612    async fn send_by_classified_tunnel_with_resp(&self, classification: C, cmd: CMD, version: u8, body: &[u8], timeout: Duration) -> CmdResult<CmdBody> {
613        let mut send = self.get_classified_send(classification).await?;
614        send.send_with_resp(cmd, version, body, timeout).await
615    }
616
617    async fn send2_by_classified_tunnel(&self, classification: C, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
618        let mut send = self.get_classified_send(classification).await?;
619        send.send2(cmd, version, body).await
620    }
621
622    async fn send2_by_classified_tunnel_with_resp(&self, classification: C, cmd: CMD, version: u8, body: &[&[u8]], timeout: Duration) -> CmdResult<CmdBody> {
623        let mut send = self.get_classified_send(classification).await?;
624        send.send2_with_resp(cmd, version, body, timeout).await
625    }
626
627    async fn send_cmd_by_classified_tunnel(&self, classification: C, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
628        let mut send = self.get_classified_send(classification).await?;
629        send.send_cmd(cmd, version, body).await
630    }
631
632    async fn send_cmd_by_classified_tunnel_with_resp(&self, classification: C, cmd: CMD, version: u8, body: CmdBody, timeout: Duration) -> CmdResult<CmdBody> {
633        let mut send = self.get_classified_send(classification).await?;
634        send.send_cmd_with_resp(cmd, version, body, timeout).await
635    }
636
637    async fn find_tunnel_id_by_classified(&self, classification: C) -> CmdResult<TunnelId> {
638        let send = self.get_classified_send(classification).await?;
639        Ok(send.get_tunnel_id())
640    }
641
642    async fn get_send_by_classified(&self, classification: C) -> CmdResult<ClassifiedSendGuard<CmdClientTunnelClassification<C>, M, ClassifiedCmdSend<C, M, R, W, LEN, CMD>, ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>>> {
643        Ok(ClassifiedSendGuard {
644            worker_guard: self.get_classified_send(classification).await?,
645            _p: std::marker::PhantomData,
646        })
647    }
648}