1use crate::client::{
2 CmdClient, CmdSend, RespWaiter, RespWaiterRef, SendGuard, gen_resp_id, gen_seq,
3};
4use crate::cmd::{CmdBodyRead, CmdHandler, CmdHandlerMap, CmdHeader};
5use crate::errors::{CmdErrorCode, CmdResult, cmd_err, into_cmd_err};
6use crate::peer_id::PeerId;
7use crate::{CmdBody, CmdTunnelMeta, CmdTunnelRead, CmdTunnelWrite, TunnelId, TunnelIdGenerator};
8use async_named_locker::ObjectHolder;
9use bucky_raw_codec::{RawConvertTo, RawDecode, RawEncode, RawFixedBytes, RawFrom};
10use num::{FromPrimitive, ToPrimitive};
11use sfo_pool::{
12 ClassifiedWorker, ClassifiedWorkerFactory, ClassifiedWorkerGuard, ClassifiedWorkerPool,
13 ClassifiedWorkerPoolRef, PoolErrorCode, PoolResult, WorkerClassification, into_pool_err,
14 pool_err,
15};
16use sfo_split::{Splittable, WHalf};
17use std::fmt::Debug;
18use std::hash::Hash;
19use std::marker::PhantomData;
20use std::ops::{Deref, DerefMut};
21use std::sync::{Arc, Mutex};
22use std::time::Duration;
23use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
24use tokio::spawn;
25use tokio::task::JoinHandle;
26
27#[async_trait::async_trait]
28pub trait CmdTunnelFactory<M: CmdTunnelMeta, R: CmdTunnelRead<M>, W: CmdTunnelWrite<M>>:
29 Send + Sync + 'static
30{
31 async fn create_tunnel(&self) -> CmdResult<Splittable<R, W>>;
32}
33
34pub struct CommonCmdSend<M: CmdTunnelMeta, R: CmdTunnelRead<M>, W: CmdTunnelWrite<M>, LEN, CMD>
35where
36 LEN: RawEncode
37 + for<'a> RawDecode<'a>
38 + Copy
39 + Send
40 + Sync
41 + 'static
42 + FromPrimitive
43 + ToPrimitive,
44 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
45{
46 pub(crate) recv_handle: JoinHandle<CmdResult<()>>,
47 pub(crate) write: ObjectHolder<WHalf<R, W>>,
48 pub(crate) is_work: bool,
49 pub(crate) tunnel_id: TunnelId,
50 pub(crate) remote_id: PeerId,
51 pub(crate) resp_waiter: RespWaiterRef,
52 pub(crate) tunnel_meta: Option<Arc<M>>,
53 _p: std::marker::PhantomData<(LEN, CMD)>,
54}
55
56impl<M, R, W, LEN, CMD> CommonCmdSend<M, R, W, LEN, CMD>
69where
70 M: CmdTunnelMeta,
71 R: CmdTunnelRead<M>,
72 W: CmdTunnelWrite<M>,
73 LEN: RawEncode
74 + for<'a> RawDecode<'a>
75 + Copy
76 + Send
77 + Sync
78 + 'static
79 + FromPrimitive
80 + ToPrimitive,
81 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
82{
83 pub fn new(
84 tunnel_id: TunnelId,
85 recv_handle: JoinHandle<CmdResult<()>>,
86 write: ObjectHolder<WHalf<R, W>>,
87 resp_waiter: RespWaiterRef,
88 remote_id: PeerId,
89 tunnel_meta: Option<Arc<M>>,
90 ) -> Self {
91 Self {
92 recv_handle,
93 write,
94 is_work: true,
95 tunnel_id,
96 remote_id,
97 resp_waiter,
98 tunnel_meta,
99 _p: Default::default(),
100 }
101 }
102
103 pub fn get_tunnel_id(&self) -> TunnelId {
104 self.tunnel_id
105 }
106
107 pub fn set_disable(&mut self) {
108 self.is_work = false;
109 self.recv_handle.abort();
110 }
111
112 pub async fn send(&mut self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
113 log::trace!(
114 "client {:?} send cmd: {:?}, len: {} data:{}",
115 self.tunnel_id,
116 cmd,
117 body.len(),
118 hex::encode(body)
119 );
120 let header = CmdHeader::<LEN, CMD>::new(
121 version,
122 false,
123 None,
124 cmd,
125 LEN::from_u64(body.len() as u64).unwrap(),
126 );
127 let buf = header
128 .to_vec()
129 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
130 let ret = self.send_inner(buf.as_slice(), body).await;
131 if let Err(e) = ret {
132 self.set_disable();
133 return Err(e);
134 }
135 Ok(())
136 }
137
138 pub async fn send_with_resp(
139 &mut self,
140 cmd: CMD,
141 version: u8,
142 body: &[u8],
143 timeout: Duration,
144 ) -> CmdResult<CmdBody> {
145 if let Some(id) = tokio::task::try_id() {
146 if id == self.recv_handle.id() {
147 return Err(cmd_err!(
148 CmdErrorCode::Failed,
149 "can't send with resp in recv task"
150 ));
151 }
152 }
153 log::trace!(
154 "client {:?} send cmd: {:?}, len: {}, data: {}",
155 self.tunnel_id,
156 cmd,
157 body.len(),
158 hex::encode(body)
159 );
160 let seq = gen_seq();
161 let header = CmdHeader::<LEN, CMD>::new(
162 version,
163 false,
164 Some(seq),
165 cmd,
166 LEN::from_u64(body.len() as u64).unwrap(),
167 );
168 let buf = header
169 .to_vec()
170 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
171 let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
172 let waiter = self.resp_waiter.clone();
173 let resp_waiter = waiter
174 .create_timeout_result_future(resp_id, timeout)
175 .map_err(into_cmd_err!(
176 CmdErrorCode::Failed,
177 "create timeout result future error"
178 ))?;
179 let ret = self.send_inner(buf.as_slice(), body).await;
180 if let Err(e) = ret {
181 self.set_disable();
182 return Err(e);
183 }
184 let resp = resp_waiter
185 .await
186 .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
187 Ok(resp)
188 }
189
190 pub async fn send2(&mut self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
191 let mut len = 0;
192 for b in body.iter() {
193 len += b.len();
194 log::trace!(
195 "client {:?} send2 cmd: {:?}, data {}",
196 self.tunnel_id,
197 cmd,
198 hex::encode(b)
199 );
200 }
201 log::trace!(
202 "client {:?} send2 cmd: {:?}, len {}",
203 self.tunnel_id,
204 cmd,
205 len
206 );
207 let header = CmdHeader::<LEN, CMD>::new(
208 version,
209 false,
210 None,
211 cmd,
212 LEN::from_u64(len as u64).unwrap(),
213 );
214 let buf = header
215 .to_vec()
216 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
217 let ret = self.send_inner2(buf.as_slice(), body).await;
218 if let Err(e) = ret {
219 self.set_disable();
220 return Err(e);
221 }
222 Ok(())
223 }
224
225 pub async fn send2_with_resp(
226 &mut self,
227 cmd: CMD,
228 version: u8,
229 body: &[&[u8]],
230 timeout: Duration,
231 ) -> CmdResult<CmdBody> {
232 if let Some(id) = tokio::task::try_id() {
233 if id == self.recv_handle.id() {
234 return Err(cmd_err!(
235 CmdErrorCode::Failed,
236 "can't send with resp in recv task"
237 ));
238 }
239 }
240 let mut len = 0;
241 for b in body.iter() {
242 len += b.len();
243 log::trace!(
244 "client {:?} send2 cmd {:?} body: {}",
245 self.tunnel_id,
246 cmd,
247 hex::encode(b)
248 );
249 }
250 log::trace!(
251 "client {:?} send2 cmd: {:?}, len {}",
252 self.tunnel_id,
253 cmd,
254 len
255 );
256 let seq = gen_seq();
257 let header = CmdHeader::<LEN, CMD>::new(
258 version,
259 false,
260 Some(seq),
261 cmd,
262 LEN::from_u64(len as u64).unwrap(),
263 );
264 let buf = header
265 .to_vec()
266 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
267 let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
268 let waiter = self.resp_waiter.clone();
269 let resp_waiter = waiter
270 .create_timeout_result_future(resp_id, timeout)
271 .map_err(into_cmd_err!(
272 CmdErrorCode::Failed,
273 "create timeout result future error"
274 ))?;
275 let ret = self.send_inner2(buf.as_slice(), body).await;
276 if let Err(e) = ret {
277 self.set_disable();
278 return Err(e);
279 }
280 let resp = resp_waiter
281 .await
282 .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
283 Ok(resp)
284 }
285
286 pub async fn send_cmd(&mut self, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
287 log::trace!(
288 "client {:?} send cmd: {:?}, len: {}",
289 self.tunnel_id,
290 cmd,
291 body.len()
292 );
293 let header = CmdHeader::<LEN, CMD>::new(
294 version,
295 false,
296 None,
297 cmd,
298 LEN::from_u64(body.len()).unwrap(),
299 );
300 let buf = header
301 .to_vec()
302 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
303 let ret = self.send_inner_cmd(buf.as_slice(), body).await;
304 if let Err(e) = ret {
305 self.set_disable();
306 return Err(e);
307 }
308 Ok(())
309 }
310
311 pub async fn send_cmd_with_resp(
312 &mut self,
313 cmd: CMD,
314 version: u8,
315 body: CmdBody,
316 timeout: Duration,
317 ) -> CmdResult<CmdBody> {
318 if let Some(id) = tokio::task::try_id() {
319 if id == self.recv_handle.id() {
320 return Err(cmd_err!(
321 CmdErrorCode::Failed,
322 "can't send with resp in recv task"
323 ));
324 }
325 }
326 log::trace!(
327 "client {:?} send cmd: {:?}, len: {}",
328 self.tunnel_id,
329 cmd,
330 body.len()
331 );
332 let seq = gen_seq();
333 let header = CmdHeader::<LEN, CMD>::new(
334 version,
335 false,
336 Some(seq),
337 cmd,
338 LEN::from_u64(body.len()).unwrap(),
339 );
340 let buf = header
341 .to_vec()
342 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
343 let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
344 let waiter = self.resp_waiter.clone();
345 let resp_waiter = waiter
346 .create_timeout_result_future(resp_id, timeout)
347 .map_err(into_cmd_err!(
348 CmdErrorCode::Failed,
349 "create timeout result future error"
350 ))?;
351 let ret = self.send_inner_cmd(buf.as_slice(), body).await;
352 if let Err(e) = ret {
353 self.set_disable();
354 return Err(e);
355 }
356 let resp = resp_waiter
357 .await
358 .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
359 Ok(resp)
360 }
361
362 async fn send_inner(&mut self, header: &[u8], body: &[u8]) -> CmdResult<()> {
363 let mut write = self.write.get().await;
364 if header.len() > 255 {
365 return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too long"));
366 }
367 write
368 .write_u8(header.len() as u8)
369 .await
370 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
371 write
372 .write_all(header)
373 .await
374 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
375 write
376 .write_all(body)
377 .await
378 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
379 write
380 .flush()
381 .await
382 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
383 Ok(())
384 }
385
386 async fn send_inner2(&mut self, header: &[u8], body: &[&[u8]]) -> CmdResult<()> {
387 let mut write = self.write.get().await;
388 if header.len() > 255 {
389 return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too long"));
390 }
391 write
392 .write_u8(header.len() as u8)
393 .await
394 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
395 write
396 .write_all(header)
397 .await
398 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
399 for b in body.iter() {
400 write
401 .write_all(b)
402 .await
403 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
404 }
405 write
406 .flush()
407 .await
408 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
409 Ok(())
410 }
411
412 async fn send_inner_cmd(&mut self, header: &[u8], mut body: CmdBody) -> CmdResult<()> {
413 let mut write = self.write.get().await;
414 if header.len() > 255 {
415 return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
416 }
417 write
418 .write_u8(header.len() as u8)
419 .await
420 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
421 write
422 .write_all(header)
423 .await
424 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
425 tokio::io::copy(&mut body, write.deref_mut().deref_mut())
426 .await
427 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
428 write
429 .flush()
430 .await
431 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
432 Ok(())
433 }
434}
435
436impl<M, R, W, LEN, CMD> Drop for CommonCmdSend<M, R, W, LEN, CMD>
437where
438 M: CmdTunnelMeta,
439 R: CmdTunnelRead<M>,
440 W: CmdTunnelWrite<M>,
441 LEN: RawEncode
442 + for<'a> RawDecode<'a>
443 + Copy
444 + Send
445 + Sync
446 + 'static
447 + FromPrimitive
448 + ToPrimitive,
449 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
450{
451 fn drop(&mut self) {
452 self.set_disable();
453 }
454}
455
456impl<M, R, W, LEN, CMD> CmdSend<M> for CommonCmdSend<M, R, W, LEN, CMD>
457where
458 M: CmdTunnelMeta,
459 R: CmdTunnelRead<M>,
460 W: CmdTunnelWrite<M>,
461 LEN: RawEncode
462 + for<'a> RawDecode<'a>
463 + Copy
464 + Send
465 + Sync
466 + 'static
467 + FromPrimitive
468 + ToPrimitive,
469 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
470{
471 fn get_tunnel_meta(&self) -> Option<Arc<M>> {
472 self.tunnel_meta.clone()
473 }
474
475 fn get_remote_peer_id(&self) -> PeerId {
476 self.remote_id.clone()
477 }
478}
479
480impl<M, R, W, LEN, CMD> ClassifiedWorker<TunnelId> for CommonCmdSend<M, R, W, LEN, CMD>
481where
482 M: CmdTunnelMeta,
483 R: CmdTunnelRead<M>,
484 W: CmdTunnelWrite<M>,
485 LEN: RawEncode
486 + for<'a> RawDecode<'a>
487 + Copy
488 + Send
489 + Sync
490 + 'static
491 + FromPrimitive
492 + ToPrimitive,
493 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
494{
495 fn is_work(&self) -> bool {
496 self.is_work && !self.recv_handle.is_finished()
497 }
498
499 fn is_valid(&self, c: TunnelId) -> bool {
500 self.tunnel_id == c
501 }
502
503 fn classification(&self) -> TunnelId {
504 self.tunnel_id
505 }
506}
507
508pub struct ClassifiedSendGuard<
509 C: WorkerClassification,
510 M: CmdTunnelMeta,
511 CW: ClassifiedWorker<C> + CmdSend<M>,
512 F: ClassifiedWorkerFactory<C, CW>,
513> {
514 pub(crate) worker_guard: ClassifiedWorkerGuard<C, CW, F>,
515 pub(crate) _p: PhantomData<M>,
516}
517
518impl<
519 C: WorkerClassification,
520 M: CmdTunnelMeta,
521 CW: ClassifiedWorker<C> + CmdSend<M>,
522 F: ClassifiedWorkerFactory<C, CW>,
523> Deref for ClassifiedSendGuard<C, M, CW, F>
524{
525 type Target = CW;
526
527 fn deref(&self) -> &Self::Target {
528 &self.worker_guard.deref()
529 }
530}
531
532impl<
533 C: WorkerClassification,
534 M: CmdTunnelMeta,
535 CW: ClassifiedWorker<C> + CmdSend<M>,
536 F: ClassifiedWorkerFactory<C, CW>,
537> SendGuard<M, CW> for ClassifiedSendGuard<C, M, CW, F>
538{
539}
540
541pub struct CmdWriteFactory<
542 M: CmdTunnelMeta,
543 R: CmdTunnelRead<M>,
544 W: CmdTunnelWrite<M>,
545 F: CmdTunnelFactory<M, R, W>,
546 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
547 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug,
548> {
549 tunnel_factory: F,
550 cmd_handler: Arc<dyn CmdHandler<LEN, CMD>>,
551 resp_waiter: RespWaiterRef,
552 tunnel_id_generator: TunnelIdGenerator,
553 p: std::marker::PhantomData<Mutex<(R, W, M)>>,
554}
555
556impl<
557 M: CmdTunnelMeta,
558 R: CmdTunnelRead<M>,
559 W: CmdTunnelWrite<M>,
560 F: CmdTunnelFactory<M, R, W>,
561 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
562 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug,
563> CmdWriteFactory<M, R, W, F, LEN, CMD>
564{
565 pub(crate) fn new(
566 tunnel_factory: F,
567 cmd_handler: impl CmdHandler<LEN, CMD>,
568 resp_waiter: RespWaiterRef,
569 ) -> Self {
570 Self {
571 tunnel_factory,
572 cmd_handler: Arc::new(cmd_handler),
573 resp_waiter,
574 tunnel_id_generator: TunnelIdGenerator::new(),
575 p: Default::default(),
576 }
577 }
578}
579
580#[async_trait::async_trait]
581impl<
582 M: CmdTunnelMeta,
583 R: CmdTunnelRead<M>,
584 W: CmdTunnelWrite<M>,
585 F: CmdTunnelFactory<M, R, W>,
586 LEN: RawEncode
587 + for<'a> RawDecode<'a>
588 + Copy
589 + Send
590 + Sync
591 + 'static
592 + FromPrimitive
593 + ToPrimitive
594 + RawFixedBytes,
595 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Debug,
596> ClassifiedWorkerFactory<TunnelId, CommonCmdSend<M, R, W, LEN, CMD>>
597 for CmdWriteFactory<M, R, W, F, LEN, CMD>
598{
599 async fn create(&self, c: Option<TunnelId>) -> PoolResult<CommonCmdSend<M, R, W, LEN, CMD>> {
600 if c.is_some() {
601 return Err(pool_err!(
602 PoolErrorCode::Failed,
603 "tunnel {:?} not found",
604 c.unwrap()
605 ));
606 }
607 let tunnel = self
608 .tunnel_factory
609 .create_tunnel()
610 .await
611 .map_err(into_pool_err!(PoolErrorCode::Failed))?;
612 let peer_id = tunnel.get_remote_peer_id();
613 let tunnel_id = self.tunnel_id_generator.generate();
614 let (mut recv, write) = tunnel.split();
615 let remote_id = write.get_remote_peer_id();
616 let meta = write.get_tunnel_meta();
617 let write = ObjectHolder::new(write);
618 let resp_write = write.clone();
619 let cmd_handler = self.cmd_handler.clone();
620 let handle = spawn(async move {
621 let ret: CmdResult<()> = async move {
622 loop {
623 let header_len = recv
624 .read_u8()
625 .await
626 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
627 let mut header = vec![0u8; header_len as usize];
628 let n = recv
629 .read_exact(header.as_mut())
630 .await
631 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
632 if n == 0 {
633 break;
634 }
635 let header = CmdHeader::<LEN, CMD>::clone_from_slice(header.as_slice())
636 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
637 log::trace!(
638 "recv cmd {:?} from {} len {}",
639 header.cmd_code(),
640 peer_id.to_base58(),
641 header.pkg_len().to_u64().unwrap()
642 );
643 let body_len = header.pkg_len().to_u64().unwrap();
644 let cmd_read =
645 CmdBodyRead::new(recv, header.pkg_len().to_u64().unwrap() as usize);
646 let waiter = cmd_read.get_waiter();
647 let future = waiter
648 .create_result_future()
649 .map_err(into_cmd_err!(CmdErrorCode::Failed))?;
650 let version = header.version();
651 let seq = header.seq();
652 let cmd_code = header.cmd_code();
653 match cmd_handler
654 .handle(
655 peer_id.clone(),
656 tunnel_id,
657 header,
658 CmdBody::from_reader(BufReader::new(cmd_read), body_len),
659 )
660 .await
661 {
662 Ok(Some(mut body)) => {
663 let mut write = resp_write.get().await;
664 let header = CmdHeader::<LEN, CMD>::new(
665 version,
666 true,
667 seq,
668 cmd_code,
669 LEN::from_u64(body.len()).unwrap(),
670 );
671 let buf = header
672 .to_vec()
673 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
674 if buf.len() > 255 {
675 return Err(cmd_err!(
676 CmdErrorCode::InvalidParam,
677 "header len too long"
678 ));
679 }
680 write
681 .write_u8(buf.len() as u8)
682 .await
683 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
684 write
685 .write_all(buf.as_slice())
686 .await
687 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
688 tokio::io::copy(&mut body, write.deref_mut().deref_mut())
689 .await
690 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
691 write
692 .flush()
693 .await
694 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
695 }
696 Ok(None) => {}
697 Err(e) => {
698 log::error!("handle cmd error: {:?}", e);
699 }
700 }
701 recv = future
702 .await
703 .map_err(into_cmd_err!(CmdErrorCode::Failed))??;
704 }
705 Ok(())
706 }
707 .await;
708 ret
709 });
710 Ok(CommonCmdSend::new(
711 tunnel_id,
712 handle,
713 write,
714 self.resp_waiter.clone(),
715 remote_id,
716 meta,
717 ))
718 }
719}
720
721pub struct DefaultCmdClient<
722 M: CmdTunnelMeta,
723 R: CmdTunnelRead<M>,
724 W: CmdTunnelWrite<M>,
725 F: CmdTunnelFactory<M, R, W>,
726 LEN: RawEncode
727 + for<'a> RawDecode<'a>
728 + Copy
729 + Send
730 + Sync
731 + 'static
732 + FromPrimitive
733 + ToPrimitive
734 + RawFixedBytes,
735 CMD: RawEncode
736 + for<'a> RawDecode<'a>
737 + Copy
738 + Send
739 + Sync
740 + 'static
741 + RawFixedBytes
742 + Eq
743 + Hash
744 + Debug,
745> {
746 tunnel_pool: ClassifiedWorkerPoolRef<
747 TunnelId,
748 CommonCmdSend<M, R, W, LEN, CMD>,
749 CmdWriteFactory<M, R, W, F, LEN, CMD>,
750 >,
751 cmd_handler_map: Arc<CmdHandlerMap<LEN, CMD>>,
752}
753
754impl<
755 M: CmdTunnelMeta,
756 R: CmdTunnelRead<M>,
757 W: CmdTunnelWrite<M>,
758 F: CmdTunnelFactory<M, R, W>,
759 LEN: RawEncode
760 + for<'a> RawDecode<'a>
761 + Copy
762 + Send
763 + Sync
764 + 'static
765 + FromPrimitive
766 + ToPrimitive
767 + RawFixedBytes,
768 CMD: RawEncode
769 + for<'a> RawDecode<'a>
770 + Copy
771 + Send
772 + Sync
773 + 'static
774 + RawFixedBytes
775 + Eq
776 + Hash
777 + Debug,
778> DefaultCmdClient<M, R, W, F, LEN, CMD>
779{
780 pub fn new(factory: F, tunnel_count: u16) -> Arc<Self> {
781 let cmd_handler_map = Arc::new(CmdHandlerMap::new());
782 let handler_map = cmd_handler_map.clone();
783 let resp_waiter = Arc::new(RespWaiter::new());
784 let waiter = resp_waiter.clone();
785 Arc::new(Self {
786 tunnel_pool: ClassifiedWorkerPool::new(
787 tunnel_count,
788 CmdWriteFactory::<M, R, W, _, LEN, CMD>::new(
789 factory,
790 move |peer_id: PeerId,
791 tunnel_id: TunnelId,
792 header: CmdHeader<LEN, CMD>,
793 body_read: CmdBody| {
794 let handler_map = handler_map.clone();
795 let waiter = waiter.clone();
796 async move {
797 if header.is_resp() && header.seq().is_some() {
798 let resp_id = gen_resp_id(
799 tunnel_id,
800 header.cmd_code(),
801 header.seq().unwrap(),
802 );
803 let _ = waiter.set_result(resp_id, body_read);
804 Ok(None)
805 } else {
806 if let Some(handler) = handler_map.get(header.cmd_code()) {
807 handler.handle(peer_id, tunnel_id, header, body_read).await
808 } else {
809 Ok(None)
810 }
811 }
812 }
813 },
814 resp_waiter.clone(),
815 ),
816 ),
817 cmd_handler_map,
818 })
819 }
820
821 async fn get_send(
822 &self,
823 ) -> CmdResult<
824 ClassifiedWorkerGuard<
825 TunnelId,
826 CommonCmdSend<M, R, W, LEN, CMD>,
827 CmdWriteFactory<M, R, W, F, LEN, CMD>,
828 >,
829 > {
830 self.tunnel_pool
831 .get_worker()
832 .await
833 .map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
834 }
835
836 async fn get_send_of_tunnel_id(
837 &self,
838 tunnel_id: TunnelId,
839 ) -> CmdResult<
840 ClassifiedWorkerGuard<
841 TunnelId,
842 CommonCmdSend<M, R, W, LEN, CMD>,
843 CmdWriteFactory<M, R, W, F, LEN, CMD>,
844 >,
845 > {
846 self.tunnel_pool
847 .get_classified_worker(tunnel_id)
848 .await
849 .map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
850 }
851}
852
853pub type CmdClientSendGuard<M, R, W, F, LEN, CMD> = ClassifiedSendGuard<
854 TunnelId,
855 M,
856 CommonCmdSend<M, R, W, LEN, CMD>,
857 CmdWriteFactory<M, R, W, F, LEN, CMD>,
858>;
859#[async_trait::async_trait]
860impl<
861 M: CmdTunnelMeta,
862 R: CmdTunnelRead<M>,
863 W: CmdTunnelWrite<M>,
864 F: CmdTunnelFactory<M, R, W>,
865 LEN: RawEncode
866 + for<'a> RawDecode<'a>
867 + Copy
868 + Send
869 + Sync
870 + 'static
871 + FromPrimitive
872 + ToPrimitive
873 + RawFixedBytes,
874 CMD: RawEncode
875 + for<'a> RawDecode<'a>
876 + Copy
877 + Send
878 + Sync
879 + 'static
880 + RawFixedBytes
881 + Eq
882 + Hash
883 + Debug,
884> CmdClient<LEN, CMD, M, CommonCmdSend<M, R, W, LEN, CMD>, CmdClientSendGuard<M, R, W, F, LEN, CMD>>
885 for DefaultCmdClient<M, R, W, F, LEN, CMD>
886{
887 fn register_cmd_handler(&self, cmd: CMD, handler: impl CmdHandler<LEN, CMD>) {
888 self.cmd_handler_map.insert(cmd, handler);
889 }
890
891 async fn send(&self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
892 let mut send = self.get_send().await?;
893 send.send(cmd, version, body).await
894 }
895
896 async fn send_with_resp(
897 &self,
898 cmd: CMD,
899 version: u8,
900 body: &[u8],
901 timeout: Duration,
902 ) -> CmdResult<CmdBody> {
903 let mut send = self.get_send().await?;
904 send.send_with_resp(cmd, version, body, timeout).await
905 }
906
907 async fn send2(&self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
908 let mut send = self.get_send().await?;
909 send.send2(cmd, version, body).await
910 }
911
912 async fn send2_with_resp(
913 &self,
914 cmd: CMD,
915 version: u8,
916 body: &[&[u8]],
917 timeout: Duration,
918 ) -> CmdResult<CmdBody> {
919 let mut send = self.get_send().await?;
920 send.send2_with_resp(cmd, version, body, timeout).await
921 }
922
923 async fn send_cmd(&self, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
924 let mut send = self.get_send().await?;
925 send.send_cmd(cmd, version, body).await
926 }
927
928 async fn send_cmd_with_resp(
929 &self,
930 cmd: CMD,
931 version: u8,
932 body: CmdBody,
933 timeout: Duration,
934 ) -> CmdResult<CmdBody> {
935 let mut send = self.get_send().await?;
936 send.send_cmd_with_resp(cmd, version, body, timeout).await
937 }
938
939 async fn send_by_specify_tunnel(
940 &self,
941 tunnel_id: TunnelId,
942 cmd: CMD,
943 version: u8,
944 body: &[u8],
945 ) -> CmdResult<()> {
946 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
947 send.send(cmd, version, body).await
948 }
949
950 async fn send_by_specify_tunnel_with_resp(
951 &self,
952 tunnel_id: TunnelId,
953 cmd: CMD,
954 version: u8,
955 body: &[u8],
956 timeout: Duration,
957 ) -> CmdResult<CmdBody> {
958 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
959 send.send_with_resp(cmd, version, body, timeout).await
960 }
961
962 async fn send2_by_specify_tunnel(
963 &self,
964 tunnel_id: TunnelId,
965 cmd: CMD,
966 version: u8,
967 body: &[&[u8]],
968 ) -> CmdResult<()> {
969 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
970 send.send2(cmd, version, body).await
971 }
972
973 async fn send2_by_specify_tunnel_with_resp(
974 &self,
975 tunnel_id: TunnelId,
976 cmd: CMD,
977 version: u8,
978 body: &[&[u8]],
979 timeout: Duration,
980 ) -> CmdResult<CmdBody> {
981 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
982 send.send2_with_resp(cmd, version, body, timeout).await
983 }
984
985 async fn send_cmd_by_specify_tunnel(
986 &self,
987 tunnel_id: TunnelId,
988 cmd: CMD,
989 version: u8,
990 body: CmdBody,
991 ) -> CmdResult<()> {
992 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
993 send.send_cmd(cmd, version, body).await
994 }
995
996 async fn send_cmd_by_specify_tunnel_with_resp(
997 &self,
998 tunnel_id: TunnelId,
999 cmd: CMD,
1000 version: u8,
1001 body: CmdBody,
1002 timeout: Duration,
1003 ) -> CmdResult<CmdBody> {
1004 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1005 send.send_cmd_with_resp(cmd, version, body, timeout).await
1006 }
1007
1008 async fn clear_all_tunnel(&self) {
1009 self.tunnel_pool.clear_all_worker().await;
1010 }
1011
1012 async fn get_send(
1013 &self,
1014 tunnel_id: TunnelId,
1015 ) -> CmdResult<CmdClientSendGuard<M, R, W, F, LEN, CMD>> {
1016 Ok(ClassifiedSendGuard {
1017 worker_guard: self.get_send_of_tunnel_id(tunnel_id).await?,
1018 _p: Default::default(),
1019 })
1020 }
1021}