sfo_cmd_server/client/
classified_client.rs

1use std::hash::Hash;
2use std::sync::{Arc, Mutex};
3use bucky_raw_codec::{RawConvertTo, RawDecode, RawEncode, RawFixedBytes, RawFrom};
4use num::{FromPrimitive, ToPrimitive};
5use sfo_pool::{into_pool_err, pool_err, ClassifiedWorker, ClassifiedWorkerFactory, ClassifiedWorkerGuard, ClassifiedWorkerPool, ClassifiedWorkerPoolRef, PoolErrorCode, PoolResult, WorkerClassification};
6use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
7use tokio::spawn;
8use tokio::task::JoinHandle;
9use crate::{CmdBody, CmdTunnelMeta, CmdTunnelRead, CmdTunnelWrite, TunnelId, TunnelIdGenerator};
10use crate::client::{gen_resp_id, gen_seq, ClassifiedCmdClient, ClassifiedSendGuard, CmdClient, CmdSend, RespWaiter, RespWaiterRef};
11use crate::cmd::{CmdBodyRead, CmdHandler, CmdHandlerMap, CmdHeader};
12use crate::errors::{cmd_err, into_cmd_err, CmdErrorCode, CmdResult};
13use crate::peer_id::PeerId;
14use std::fmt::Debug;
15use std::ops::{DerefMut};
16use std::time::Duration;
17use async_named_locker::ObjectHolder;
18use sfo_split::{RHalf, Splittable, WHalf};
19
20pub trait ClassifiedCmdTunnelRead<C: WorkerClassification, M: CmdTunnelMeta>: CmdTunnelRead<M> + 'static + Send {
21    fn get_classification(&self) -> C;
22}
23
24pub trait ClassifiedCmdTunnelWrite<C: WorkerClassification, M: CmdTunnelMeta>: CmdTunnelWrite<M> + 'static + Send {
25    fn get_classification(&self) -> C;
26}
27
28pub type ClassifiedCmdTunnel<R, W> = Splittable<R, W>;
29pub type ClassifiedCmdTunnelRHalf<R, W> = RHalf<R, W>;
30pub type ClassifiedCmdTunnelWHalf<R, W> = WHalf<R, W>;
31
32#[derive(Debug, Clone, Copy, Eq, Hash)]
33pub struct CmdClientTunnelClassification<C: WorkerClassification> {
34    tunnel_id: Option<TunnelId>,
35    classification: Option<C>,
36}
37
38impl<C: WorkerClassification> PartialEq for CmdClientTunnelClassification<C> {
39    fn eq(&self, other: &Self) -> bool {
40        self.tunnel_id == other.tunnel_id && self.classification == other.classification
41    }
42}
43
44
45#[async_trait::async_trait]
46pub trait ClassifiedCmdTunnelFactory<C: WorkerClassification, M: CmdTunnelMeta, R: ClassifiedCmdTunnelRead<C, M>, W: ClassifiedCmdTunnelWrite<C, M>>: Send + Sync + 'static {
47    async fn create_tunnel(&self, classification: Option<C>) -> CmdResult<Splittable<R, W>>;
48}
49
50pub struct ClassifiedCmdSend<C, M, R, W, LEN, CMD>
51where
52    C: WorkerClassification,
53    M: CmdTunnelMeta,
54    R: ClassifiedCmdTunnelRead<C, M>,
55    W: ClassifiedCmdTunnelWrite<C, M>,
56    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
57    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
58{
59    pub(crate) recv_handle: JoinHandle<CmdResult<()>>,
60    pub(crate) write: ObjectHolder<ClassifiedCmdTunnelWHalf<R, W>>,
61    pub(crate) is_work: bool,
62    pub(crate) classification: C,
63    pub(crate) tunnel_id: TunnelId,
64    pub(crate) resp_waiter: RespWaiterRef,
65    pub(crate) remote_id: PeerId,
66    pub(crate) tunnel_meta: Option<Arc<M>>,
67    _p: std::marker::PhantomData<(LEN, CMD)>,
68
69}
70
71// impl<C, R, W, LEN, CMD> Deref for ClassifiedCmdSend<C, R, W, LEN, CMD>
72// where C: WorkerClassification,
73//       R: ClassifiedCmdTunnelRead<C>,
74//       W: ClassifiedCmdTunnelWrite<C>,
75//       LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
76//       CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes {
77//     type Target = W;
78//
79//     fn deref(&self) -> &Self::Target {
80//         self.write.deref()
81//     }
82// }
83
84impl<C, M, R, W, LEN, CMD> ClassifiedCmdSend<C, M, R, W, LEN, CMD>
85where C: WorkerClassification,
86      M: CmdTunnelMeta,
87      R: ClassifiedCmdTunnelRead<C, M>,
88      W: ClassifiedCmdTunnelWrite<C, M>,
89      LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
90      CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes {
91    pub(crate) fn new(tunnel_id: TunnelId,
92                      classification: C,
93                      recv_handle: JoinHandle<CmdResult<()>>,
94                      write: ObjectHolder<ClassifiedCmdTunnelWHalf<R, W>>,
95                      resp_waiter: RespWaiterRef,
96                      remote_id: PeerId,
97                      tunnel_meta: Option<Arc<M>>) -> Self {
98        Self {
99            recv_handle,
100            write,
101            is_work: true,
102            classification,
103            tunnel_id,
104            resp_waiter,
105            remote_id,
106            tunnel_meta,
107            _p: Default::default(),
108        }
109    }
110
111    pub fn get_tunnel_id(&self) -> TunnelId {
112        self.tunnel_id
113    }
114
115    pub fn set_disable(&mut self) {
116        self.is_work = false;
117        self.recv_handle.abort();
118    }
119
120    pub async fn send(&mut self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
121        log::trace!("client {:?} send cmd: {:?}, len: {}, data: {}", self.tunnel_id, cmd, body.len(), hex::encode(body));
122        let header = CmdHeader::<LEN, CMD>::new(version, false, None, cmd, LEN::from_u64(body.len() as u64).unwrap());
123        let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
124        let ret = self.send_inner(buf.as_slice(), body).await;
125        if let Err(e) = ret {
126            self.set_disable();
127            return Err(e);
128        }
129        Ok(())
130    }
131
132    pub async fn send_with_resp(&mut self, cmd: CMD, version: u8, body: &[u8], timeout: Duration) -> CmdResult<CmdBody> {
133        if let Some(id) = tokio::task::try_id() {
134            if id == self.recv_handle.id() {
135                return Err(cmd_err!(CmdErrorCode::Failed, "can't send with resp in recv task"));
136            }
137        }
138        log::trace!("client {:?} send cmd: {:?}, len: {}, data: {}", self.tunnel_id, cmd, body.len(), hex::encode(body));
139        let seq = gen_seq();
140        let header = CmdHeader::<LEN, CMD>::new(version, false, Some(seq), cmd, LEN::from_u64(body.len() as u64).unwrap());
141        let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
142        let resp_id = gen_resp_id(cmd, seq);
143        let waiter = self.resp_waiter.clone();
144        let resp_waiter = waiter.create_timeout_result_future(resp_id, timeout)
145            .map_err(into_cmd_err!(CmdErrorCode::Failed, "create timeout result future error"))?;
146        let ret = self.send_inner(buf.as_slice(), body).await;
147        if let Err(e) = ret {
148            self.set_disable();
149            return Err(e);
150        }
151        let resp = resp_waiter.await.map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
152        Ok(resp)
153    }
154
155    pub async fn send2(&mut self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
156        let mut len = 0;
157        for b in body.iter() {
158            len += b.len();
159            log::trace!("client {:?} send2 cmd {:?} body: {}", self.tunnel_id, cmd, hex::encode(b));
160        }
161        log::trace!("client {:?} send2 cmd: {:?}, len {}", self.tunnel_id, cmd, len);
162        let header = CmdHeader::<LEN, CMD>::new(version, false, None, cmd, LEN::from_u64(len as u64).unwrap());
163        let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
164        let ret = self.send_inner2(buf.as_slice(), body).await;
165        if let Err(e) = ret {
166            self.set_disable();
167            return Err(e);
168        }
169        Ok(())
170    }
171
172    pub async fn send2_with_resp(&mut self, cmd: CMD, version: u8, body: &[&[u8]], timeout: Duration) -> CmdResult<CmdBody> {
173        if let Some(id) = tokio::task::try_id() {
174            if id == self.recv_handle.id() {
175                return Err(cmd_err!(CmdErrorCode::Failed, "can't send with resp in recv task"));
176            }
177        }
178        let mut len = 0;
179        for b in body.iter() {
180            len += b.len();
181            log::trace!("client {:?} send2 cmd {:?} body: {}", self.tunnel_id, cmd, hex::encode(b));
182        }
183        log::trace!("client {:?} send2 cmd: {:?}, len {}", self.tunnel_id, cmd, len);
184        let seq = gen_seq();
185        let header = CmdHeader::<LEN, CMD>::new(version, false, Some(seq), cmd, LEN::from_u64(len as u64).unwrap());
186        let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
187        let resp_id = gen_resp_id(cmd, seq);
188        let waiter = self.resp_waiter.clone();
189        let resp_waiter = waiter.create_timeout_result_future(resp_id, timeout)
190            .map_err(into_cmd_err!(CmdErrorCode::Failed, "create timeout result future error"))?;
191        let ret = self.send_inner2(buf.as_slice(), body).await;
192        if let Err(e) = ret {
193            self.set_disable();
194            return Err(e);
195        }
196        let resp = resp_waiter.await.map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
197        Ok(resp)
198    }
199
200    pub async fn send_cmd(&mut self, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
201        log::trace!("client {:?} send cmd: {:?}, len: {}", self.tunnel_id, cmd, body.len());
202        let header = CmdHeader::<LEN, CMD>::new(version, false, None, cmd, LEN::from_u64(body.len()).unwrap());
203        let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
204        let ret = self.send_inner_cmd(buf.as_slice(), body).await;
205        if let Err(e) = ret {
206            self.set_disable();
207            return Err(e);
208        }
209        Ok(())
210    }
211
212    pub async fn send_cmd_with_resp(&mut self, cmd: CMD, version: u8, body: CmdBody, timeout: Duration) -> CmdResult<CmdBody> {
213        if let Some(id) = tokio::task::try_id() {
214            if id == self.recv_handle.id() {
215                return Err(cmd_err!(CmdErrorCode::Failed, "can't send with resp in recv task"));
216            }
217        }
218        log::trace!("client {:?} send cmd: {:?}, len: {}", self.tunnel_id, cmd, body.len());
219        let seq = gen_seq();
220        let header = CmdHeader::<LEN, CMD>::new(version, false, Some(seq), cmd, LEN::from_u64(body.len()).unwrap());
221        let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
222        let resp_id = gen_resp_id(cmd, seq);
223        let waiter = self.resp_waiter.clone();
224        let resp_waiter = waiter.create_timeout_result_future(resp_id, timeout)
225           .map_err(into_cmd_err!(CmdErrorCode::Failed, "create timeout result future error"))?;
226        let ret = self.send_inner_cmd(buf.as_slice(), body).await;
227        if let Err(e) = ret {
228            self.set_disable();
229            return Err(e);
230        }
231        let resp = resp_waiter.await.map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
232        Ok(resp)
233    }
234
235    async fn send_inner(&mut self, header: &[u8], body: &[u8]) -> CmdResult<()> {
236        let mut write = self.write.get().await;
237        if header.len() > 255 {
238            return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
239        }
240        write.write_u8(header.len() as u8).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
241        write.write_all(header).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
242        write.write_all(body).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
243        write.flush().await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
244        Ok(())
245    }
246
247    async fn send_inner2(&mut self, header: &[u8], body: &[&[u8]]) -> CmdResult<()> {
248        let mut write = self.write.get().await;
249        if header.len() > 255 {
250            return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
251        }
252        write.write_u8(header.len() as u8).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
253        write.write_all(header).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
254        for b in body.iter() {
255            write.write_all(b).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
256        }
257        write.flush().await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
258        Ok(())
259    }
260
261    async fn send_inner_cmd(&mut self, header: &[u8], mut body: CmdBody) -> CmdResult<()> {
262        let mut write = self.write.get().await;
263        if header.len() > 255 {
264            return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
265        }
266        write.write_u8(header.len() as u8).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
267        write.write_all(header).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
268        tokio::io::copy(&mut body, write.deref_mut().deref_mut()).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
269        write.flush().await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
270        Ok(())
271    }
272}
273
274impl<C, M, R, W, LEN, CMD> Drop for ClassifiedCmdSend<C, M, R, W, LEN, CMD>
275where C: WorkerClassification,
276      M: CmdTunnelMeta,
277      R: ClassifiedCmdTunnelRead<C, M>,
278      W: ClassifiedCmdTunnelWrite<C, M>,
279      LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
280      CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes {
281    fn drop(&mut self) {
282        self.set_disable();
283    }
284}
285
286impl<C, M, R, W, LEN, CMD> CmdSend<M> for ClassifiedCmdSend<C, M, R, W, LEN, CMD>
287where C: WorkerClassification,
288      M: CmdTunnelMeta,
289      R: ClassifiedCmdTunnelRead<C, M>,
290      W: ClassifiedCmdTunnelWrite<C, M>,
291      LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
292      CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes {
293    fn get_tunnel_meta(&self) -> Option<Arc<M>> {
294        self.tunnel_meta.clone()
295    }
296
297    fn get_remote_peer_id(&self) -> PeerId {
298        self.remote_id.clone()
299    }
300}
301
302impl<C, M, R, W, LEN, CMD> ClassifiedWorker<CmdClientTunnelClassification<C>> for ClassifiedCmdSend<C, M, R, W, LEN, CMD>
303where C: WorkerClassification,
304      M: CmdTunnelMeta,
305      R: ClassifiedCmdTunnelRead<C, M>,
306      W: ClassifiedCmdTunnelWrite<C, M>,
307      LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
308      CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes {
309    fn is_work(&self) -> bool {
310        self.is_work && !self.recv_handle.is_finished()
311    }
312
313    fn is_valid(&self, c: CmdClientTunnelClassification<C>) -> bool {
314        if c.tunnel_id.is_some() {
315            self.tunnel_id == c.tunnel_id.unwrap()
316        } else {
317            if c.classification.is_some() {
318                self.classification == c.classification.unwrap()
319            } else {
320                true
321            }
322        }
323    }
324
325    fn classification(&self) -> CmdClientTunnelClassification<C> {
326        CmdClientTunnelClassification {
327            tunnel_id: Some(self.tunnel_id),
328            classification: Some(self.classification.clone()),
329        }
330    }
331}
332
333pub struct ClassifiedCmdWriteFactory<C: WorkerClassification,
334    M: CmdTunnelMeta,
335    R: ClassifiedCmdTunnelRead<C, M>,
336    W: ClassifiedCmdTunnelWrite<C, M>,
337    F: ClassifiedCmdTunnelFactory<C, M, R, W>,
338    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
339    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes> {
340    tunnel_factory: F,
341    cmd_handler: Arc<dyn CmdHandler<LEN, CMD>>,
342    resp_waiter: RespWaiterRef,
343    tunnel_id_generator: TunnelIdGenerator,
344    _p: std::marker::PhantomData<Mutex<(C, M, R, W)>>,
345}
346
347impl<
348    C: WorkerClassification,
349    M: CmdTunnelMeta,
350    R: ClassifiedCmdTunnelRead<C, M>,
351    W: ClassifiedCmdTunnelWrite<C, M>,
352    F: ClassifiedCmdTunnelFactory<C, M, R, W>,
353    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
354    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes
355> ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD> {
356    pub(crate) fn new(tunnel_factory: F,
357                      cmd_handler: impl CmdHandler<LEN, CMD>,
358                      resp_waiter: RespWaiterRef,) -> Self {
359        Self {
360            tunnel_factory,
361            cmd_handler: Arc::new(cmd_handler),
362            resp_waiter,
363            tunnel_id_generator: TunnelIdGenerator::new(),
364            _p: Default::default(),
365        }
366    }
367}
368
369#[async_trait::async_trait]
370impl<C: WorkerClassification,
371    M: CmdTunnelMeta,
372    R: ClassifiedCmdTunnelRead<C, M>,
373    W: ClassifiedCmdTunnelWrite<C, M>,
374    F: ClassifiedCmdTunnelFactory<C, M, R, W>,
375    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
376    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Debug> ClassifiedWorkerFactory<CmdClientTunnelClassification<C>, ClassifiedCmdSend<C, M, R, W, LEN, CMD>
377> for ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD> {
378    async fn create(&self, classification: Option<CmdClientTunnelClassification<C>>) -> PoolResult<ClassifiedCmdSend<C, M, R, W, LEN, CMD>> {
379        if classification.is_some() && classification.as_ref().unwrap().tunnel_id.is_some() {
380            return Err(pool_err!(PoolErrorCode::Failed, "tunnel {:?} not found", classification.as_ref().unwrap().tunnel_id.unwrap()));
381        }
382
383        let classification = if classification.is_some() && classification.as_ref().unwrap().classification.is_some() {
384            classification.unwrap().classification
385        } else {
386            None
387        };
388        let tunnel = self.tunnel_factory.create_tunnel(classification).await.map_err(into_pool_err!(PoolErrorCode::Failed))?;
389        let classification = tunnel.get_classification();
390        let peer_id = tunnel.get_remote_peer_id();
391        let tunnel_id = self.tunnel_id_generator.generate();
392        let (mut recv, write) = tunnel.split();
393        let remote_id = peer_id.clone();
394        let tunnel_meta = recv.get_tunnel_meta();
395        let write = ObjectHolder::new(write);
396        let resp_write = write.clone();
397        let cmd_handler = self.cmd_handler.clone();
398        let handle = spawn(async move {
399            let ret: CmdResult<()> = async move {
400                loop {
401                    let header_len = recv.read_u8().await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
402                    let mut header = vec![0u8; header_len as usize];
403                    let n = recv.read_exact(header.as_mut()).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
404                    if n == 0 {
405                        break;
406                    }
407                    let header = CmdHeader::<LEN, CMD>::clone_from_slice(header.as_slice()).map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
408                    log::trace!("recv cmd {:?} from {} len {} tunnel {:?}", header.cmd_code(), peer_id, header.pkg_len().to_u64().unwrap(), tunnel_id);
409                    let body_len = header.pkg_len().to_u64().unwrap();
410                    let cmd_read = CmdBodyRead::new(recv, header.pkg_len().to_u64().unwrap() as usize);
411                    let waiter = cmd_read.get_waiter();
412                    let future = waiter.create_result_future().map_err(into_cmd_err!(CmdErrorCode::Failed))?;
413                    let version = header.version();
414                    let seq = header.seq();
415                    let cmd_code = header.cmd_code();
416                    match cmd_handler.handle(peer_id.clone(), tunnel_id, header, CmdBody::from_reader(BufReader::new(cmd_read), body_len)).await {
417                        Ok(Some(mut body)) => {
418                            let mut write = resp_write.get().await;
419                            let header = CmdHeader::<LEN, CMD>::new(version, true, seq, cmd_code, LEN::from_u64(body.len()).unwrap());
420                            let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
421                            if buf.len() > 255 {
422                                return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
423                            }
424                            write.write_u8(buf.len() as u8).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
425                            write.write_all(buf.as_slice()).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
426                            tokio::io::copy(&mut body, write.deref_mut().deref_mut()).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
427                            write.flush().await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
428                        }
429                        Err(e) => {
430                            log::error!("handle cmd error: {:?}", e);
431                        }
432                        _ => {}
433                    }
434                    recv = future.await.map_err(into_cmd_err!(CmdErrorCode::Failed))??;
435                }
436                Ok(())
437            }.await;
438            ret
439        });
440        Ok(ClassifiedCmdSend::new(tunnel_id, classification, handle, write, self.resp_waiter.clone(), remote_id, tunnel_meta))
441    }
442}
443
444pub struct DefaultClassifiedCmdClient<C: WorkerClassification,
445    M: CmdTunnelMeta,
446    R: ClassifiedCmdTunnelRead<C, M>,
447    W: ClassifiedCmdTunnelWrite<C, M>,
448    F: ClassifiedCmdTunnelFactory<C, M, R, W>,
449    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
450    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Eq + Hash + Debug> {
451    tunnel_pool: ClassifiedWorkerPoolRef<CmdClientTunnelClassification<C>, ClassifiedCmdSend<C, M, R, W, LEN, CMD>, ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>>,
452    cmd_handler_map: Arc<CmdHandlerMap<LEN, CMD>>,
453}
454
455impl<C: WorkerClassification,
456    M: CmdTunnelMeta,
457    R: ClassifiedCmdTunnelRead<C, M>,
458    W: ClassifiedCmdTunnelWrite<C, M>,
459    F: ClassifiedCmdTunnelFactory<C, M, R, W>,
460    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
461    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Eq + Hash + Debug> DefaultClassifiedCmdClient<C, M, R, W, F, LEN, CMD> {
462    pub fn new(factory: F, tunnel_count: u16) -> Arc<Self> {
463        let cmd_handler_map = Arc::new(CmdHandlerMap::new());
464        let resp_waiter = Arc::new(RespWaiter::new());
465        let handler_map = cmd_handler_map.clone();
466        let waiter = resp_waiter.clone();
467        Arc::new(Self {
468            tunnel_pool: ClassifiedWorkerPool::new(tunnel_count, ClassifiedCmdWriteFactory::<C, M, R, W, _, LEN, CMD>::new(factory, move |peer_id: PeerId, tunnel_id: TunnelId, header: CmdHeader<LEN, CMD>, body_read: CmdBody| {
469                let handler_map = handler_map.clone();
470                let waiter = waiter.clone();
471                async move {
472                    if header.is_resp() && header.seq().is_some() {
473                        let resp_id = gen_resp_id(header.cmd_code(), header.seq().unwrap());
474                        let _ = waiter.set_result(resp_id, body_read);
475                        Ok(None)
476                    } else {
477                        if let Some(handler) = handler_map.get(header.cmd_code()) {
478                            handler.handle(peer_id, tunnel_id, header, body_read).await
479                        } else {
480                            Ok(None)
481                        }
482                    }
483                }
484            }, resp_waiter.clone())),
485            cmd_handler_map,
486        })
487    }
488
489    async fn get_send(&self) -> CmdResult<ClassifiedWorkerGuard<CmdClientTunnelClassification<C>, ClassifiedCmdSend<C, M, R, W, LEN, CMD>, ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>>> {
490        self.tunnel_pool.get_worker().await.map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
491    }
492
493    async fn get_send_of_tunnel_id(&self, tunnel_id: TunnelId) -> CmdResult<ClassifiedWorkerGuard<CmdClientTunnelClassification<C>, ClassifiedCmdSend<C, M, R, W, LEN, CMD>, ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>>> {
494        self.tunnel_pool.get_classified_worker(CmdClientTunnelClassification {
495            tunnel_id: Some(tunnel_id),
496            classification: None,
497        }).await.map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
498    }
499
500    async fn get_classified_send(&self, classification: C) -> CmdResult<ClassifiedWorkerGuard<CmdClientTunnelClassification<C>, ClassifiedCmdSend<C, M, R, W, LEN, CMD>, ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>>> {
501        self.tunnel_pool.get_classified_worker(CmdClientTunnelClassification {
502            tunnel_id: None,
503            classification: Some(classification),
504        }).await.map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
505    }
506}
507
508pub type ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD> = ClassifiedSendGuard<CmdClientTunnelClassification<C>, M, ClassifiedCmdSend<C, M, R, W, LEN, CMD>, ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>>;
509#[async_trait::async_trait]
510impl<C: WorkerClassification,
511    M: CmdTunnelMeta,
512    R: ClassifiedCmdTunnelRead<C, M>,
513    W: ClassifiedCmdTunnelWrite<C, M>,
514    F: ClassifiedCmdTunnelFactory<C, M, R, W>,
515    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
516    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Eq + Hash + Debug,
517> CmdClient<LEN, CMD, M, ClassifiedCmdSend<C, M, R, W, LEN, CMD>, ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD>> for DefaultClassifiedCmdClient<C, M, R, W, F, LEN, CMD> {
518    fn register_cmd_handler(&self, cmd: CMD, handler: impl CmdHandler<LEN, CMD>) {
519        self.cmd_handler_map.insert(cmd, handler);
520    }
521
522    async fn send(&self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
523        let mut send = self.get_send().await?;
524        send.send(cmd, version, body).await
525    }
526
527    async fn send_with_resp(&self, cmd: CMD, version: u8, body: &[u8], timeout: Duration) -> CmdResult<CmdBody> {
528        let mut send = self.get_send().await?;
529        send.send_with_resp(cmd, version, body, timeout).await
530    }
531
532    async fn send2(&self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
533        let mut send = self.get_send().await?;
534        send.send2(cmd, version, body).await
535    }
536
537    async fn send2_with_resp(&self, cmd: CMD, version: u8, body: &[&[u8]], timeout: Duration) -> CmdResult<CmdBody> {
538        let mut send = self.get_send().await?;
539        send.send2_with_resp(cmd, version, body, timeout).await
540    }
541
542    async fn send_cmd(&self, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
543        let mut send = self.get_send().await?;
544        send.send_cmd(cmd, version, body).await
545    }
546
547    async fn send_cmd_with_resp(&self, cmd: CMD, version: u8, body: CmdBody, timeout: Duration) -> CmdResult<CmdBody> {
548        let mut send = self.get_send().await?;
549        send.send_cmd_with_resp(cmd, version, body, timeout).await
550    }
551
552    async fn send_by_specify_tunnel(&self, tunnel_id: TunnelId, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
553        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
554        send.send(cmd, version, body).await
555    }
556
557    async fn send_by_specify_tunnel_with_resp(&self, tunnel_id: TunnelId, cmd: CMD, version: u8, body: &[u8], timeout: Duration) -> CmdResult<CmdBody> {
558        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
559        send.send_with_resp(cmd, version, body, timeout).await
560    }
561
562    async fn send2_by_specify_tunnel(&self, tunnel_id: TunnelId, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
563        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
564        send.send2(cmd, version, body).await
565    }
566
567    async fn send2_by_specify_tunnel_with_resp(&self, tunnel_id: TunnelId, cmd: CMD, version: u8, body: &[&[u8]], timeout: Duration) -> CmdResult<CmdBody> {
568        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
569        send.send2_with_resp(cmd, version, body, timeout).await
570    }
571
572    async fn send_cmd_by_specify_tunnel(&self, tunnel_id: TunnelId, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
573        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
574        send.send_cmd(cmd, version, body).await
575    }
576
577    async fn send_cmd_by_specify_tunnel_with_resp(&self, tunnel_id: TunnelId, cmd: CMD, version: u8, body: CmdBody, timeout: Duration) -> CmdResult<CmdBody> {
578        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
579        send.send_cmd_with_resp(cmd, version, body, timeout).await
580    }
581
582    async fn clear_all_tunnel(&self) {
583        self.tunnel_pool.clear_all_worker().await;
584    }
585
586    async fn get_send(&self, tunnel_id: TunnelId) -> CmdResult<ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD>> {
587        Ok(ClassifiedSendGuard {
588            worker_guard: self.get_send_of_tunnel_id(tunnel_id).await?,
589            _p: std::marker::PhantomData,
590        })
591    }
592}
593
594#[async_trait::async_trait]
595impl<C: WorkerClassification,
596    M: CmdTunnelMeta,
597    R: ClassifiedCmdTunnelRead<C, M>,
598    W: ClassifiedCmdTunnelWrite<C, M>,
599    F: ClassifiedCmdTunnelFactory<C, M, R, W>,
600    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
601    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Eq + Hash + Debug
602> ClassifiedCmdClient<LEN, CMD, C, M, ClassifiedCmdSend<C, M, R, W, LEN, CMD>, ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD>> for DefaultClassifiedCmdClient<C, M, R, W, F, LEN, CMD> {
603    async fn send_by_classified_tunnel(&self, classification: C, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
604        let mut send = self.get_classified_send(classification).await?;
605        send.send(cmd, version, body).await
606    }
607
608    async fn send_by_classified_tunnel_with_resp(&self, classification: C, cmd: CMD, version: u8, body: &[u8], timeout: Duration) -> CmdResult<CmdBody> {
609        let mut send = self.get_classified_send(classification).await?;
610        send.send_with_resp(cmd, version, body, timeout).await
611    }
612
613    async fn send2_by_classified_tunnel(&self, classification: C, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
614        let mut send = self.get_classified_send(classification).await?;
615        send.send2(cmd, version, body).await
616    }
617
618    async fn send2_by_classified_tunnel_with_resp(&self, classification: C, cmd: CMD, version: u8, body: &[&[u8]], timeout: Duration) -> CmdResult<CmdBody> {
619        let mut send = self.get_classified_send(classification).await?;
620        send.send2_with_resp(cmd, version, body, timeout).await
621    }
622
623    async fn send_cmd_by_classified_tunnel(&self, classification: C, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
624        let mut send = self.get_classified_send(classification).await?;
625        send.send_cmd(cmd, version, body).await
626    }
627
628    async fn send_cmd_by_classified_tunnel_with_resp(&self, classification: C, cmd: CMD, version: u8, body: CmdBody, timeout: Duration) -> CmdResult<CmdBody> {
629        let mut send = self.get_classified_send(classification).await?;
630        send.send_cmd_with_resp(cmd, version, body, timeout).await
631    }
632
633    async fn find_tunnel_id_by_classified(&self, classification: C) -> CmdResult<TunnelId> {
634        let send = self.get_classified_send(classification).await?;
635        Ok(send.get_tunnel_id())
636    }
637
638    async fn get_send_by_classified(&self, classification: C) -> CmdResult<ClassifiedSendGuard<CmdClientTunnelClassification<C>, M, ClassifiedCmdSend<C, M, R, W, LEN, CMD>, ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>>> {
639        Ok(ClassifiedSendGuard {
640            worker_guard: self.get_classified_send(classification).await?,
641            _p: std::marker::PhantomData,
642        })
643    }
644}