1use crate::client::{
2 CmdClient, CmdSend, RespWaiter, RespWaiterRef, SendGuard, gen_resp_id, gen_seq,
3};
4use crate::cmd::{CmdBodyRead, CmdHandler, CmdHandlerMap, CmdHeader};
5use crate::errors::{CmdErrorCode, CmdResult, cmd_err, into_cmd_err};
6use crate::peer_id::PeerId;
7use crate::{CmdBody, CmdTunnelMeta, CmdTunnelRead, CmdTunnelWrite, TunnelId, TunnelIdGenerator};
8use async_named_locker::ObjectHolder;
9use bucky_raw_codec::{RawConvertTo, RawDecode, RawEncode, RawFixedBytes, RawFrom};
10use num::{FromPrimitive, ToPrimitive};
11use sfo_pool::{
12 ClassifiedWorker, ClassifiedWorkerFactory, ClassifiedWorkerGuard, ClassifiedWorkerPool,
13 ClassifiedWorkerPoolRef, PoolErrorCode, PoolResult, WorkerClassification, into_pool_err,
14 pool_err,
15};
16use sfo_split::{Splittable, WHalf};
17use std::fmt::Debug;
18use std::hash::Hash;
19use std::marker::PhantomData;
20use std::ops::{Deref, DerefMut};
21use std::sync::{Arc, Mutex};
22use std::time::Duration;
23use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
24use tokio::spawn;
25use tokio::task::JoinHandle;
26
27#[async_trait::async_trait]
28pub trait CmdTunnelFactory<M: CmdTunnelMeta, R: CmdTunnelRead<M>, W: CmdTunnelWrite<M>>:
29 Send + Sync + 'static
30{
31 async fn create_tunnel(&self) -> CmdResult<Splittable<R, W>>;
32}
33
34pub struct CommonCmdSend<M: CmdTunnelMeta, R: CmdTunnelRead<M>, W: CmdTunnelWrite<M>, LEN, CMD>
35where
36 LEN: RawEncode
37 + for<'a> RawDecode<'a>
38 + Copy
39 + Send
40 + Sync
41 + 'static
42 + FromPrimitive
43 + ToPrimitive,
44 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
45{
46 pub(crate) recv_handle: JoinHandle<CmdResult<()>>,
47 pub(crate) write: ObjectHolder<WHalf<R, W>>,
48 pub(crate) is_work: bool,
49 pub(crate) tunnel_id: TunnelId,
50 pub(crate) remote_id: PeerId,
51 pub(crate) resp_waiter: RespWaiterRef,
52 pub(crate) tunnel_meta: Option<Arc<M>>,
53 _p: std::marker::PhantomData<(LEN, CMD)>,
54}
55
56impl<M, R, W, LEN, CMD> CommonCmdSend<M, R, W, LEN, CMD>
69where
70 M: CmdTunnelMeta,
71 R: CmdTunnelRead<M>,
72 W: CmdTunnelWrite<M>,
73 LEN: RawEncode
74 + for<'a> RawDecode<'a>
75 + Copy
76 + Send
77 + Sync
78 + 'static
79 + FromPrimitive
80 + ToPrimitive,
81 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
82{
83 pub fn new(
84 tunnel_id: TunnelId,
85 recv_handle: JoinHandle<CmdResult<()>>,
86 write: ObjectHolder<WHalf<R, W>>,
87 resp_waiter: RespWaiterRef,
88 remote_id: PeerId,
89 tunnel_meta: Option<Arc<M>>,
90 ) -> Self {
91 Self {
92 recv_handle,
93 write,
94 is_work: true,
95 tunnel_id,
96 remote_id,
97 resp_waiter,
98 tunnel_meta,
99 _p: Default::default(),
100 }
101 }
102
103 pub fn get_tunnel_id(&self) -> TunnelId {
104 self.tunnel_id
105 }
106
107 pub fn set_disable(&mut self) {
108 self.is_work = false;
109 self.recv_handle.abort();
110 }
111
112 pub async fn send(&mut self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
113 log::trace!(
114 "client {:?} send cmd: {:?}, len: {} data:{}",
115 self.tunnel_id,
116 cmd,
117 body.len(),
118 hex::encode(body)
119 );
120 let header = CmdHeader::<LEN, CMD>::new(
121 version,
122 false,
123 None,
124 cmd,
125 LEN::from_u64(body.len() as u64).unwrap(),
126 );
127 let buf = header
128 .to_vec()
129 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
130 let ret = self.send_inner(buf.as_slice(), body).await;
131 if let Err(e) = ret {
132 self.set_disable();
133 return Err(e);
134 }
135 Ok(())
136 }
137
138 pub async fn send_with_resp(
139 &mut self,
140 cmd: CMD,
141 version: u8,
142 body: &[u8],
143 timeout: Duration,
144 ) -> CmdResult<CmdBody> {
145 if let Some(id) = tokio::task::try_id() {
146 if id == self.recv_handle.id() {
147 return Err(cmd_err!(
148 CmdErrorCode::Failed,
149 "can't send with resp in recv task"
150 ));
151 }
152 }
153 log::trace!(
154 "client {:?} send cmd: {:?}, len: {}, data: {}",
155 self.tunnel_id,
156 cmd,
157 body.len(),
158 hex::encode(body)
159 );
160 let seq = gen_seq();
161 let header = CmdHeader::<LEN, CMD>::new(
162 version,
163 false,
164 Some(seq),
165 cmd,
166 LEN::from_u64(body.len() as u64).unwrap(),
167 );
168 let buf = header
169 .to_vec()
170 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
171 let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
172 let waiter = self.resp_waiter.clone();
173 let resp_waiter = waiter
174 .create_timeout_result_future(resp_id, timeout)
175 .map_err(into_cmd_err!(
176 CmdErrorCode::Failed,
177 "create timeout result future error"
178 ))?;
179 let ret = self.send_inner(buf.as_slice(), body).await;
180 if let Err(e) = ret {
181 self.set_disable();
182 return Err(e);
183 }
184 let resp = resp_waiter
185 .await
186 .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
187 Ok(resp)
188 }
189
190 pub async fn send2(&mut self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
191 let mut len = 0;
192 for b in body.iter() {
193 len += b.len();
194 log::trace!(
195 "client {:?} send2 cmd: {:?}, data {}",
196 self.tunnel_id,
197 cmd,
198 hex::encode(b)
199 );
200 }
201 log::trace!(
202 "client {:?} send2 cmd: {:?}, len {}",
203 self.tunnel_id,
204 cmd,
205 len
206 );
207 let header = CmdHeader::<LEN, CMD>::new(
208 version,
209 false,
210 None,
211 cmd,
212 LEN::from_u64(len as u64).unwrap(),
213 );
214 let buf = header
215 .to_vec()
216 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
217 let ret = self.send_inner2(buf.as_slice(), body).await;
218 if let Err(e) = ret {
219 self.set_disable();
220 return Err(e);
221 }
222 Ok(())
223 }
224
225 pub async fn send2_with_resp(
226 &mut self,
227 cmd: CMD,
228 version: u8,
229 body: &[&[u8]],
230 timeout: Duration,
231 ) -> CmdResult<CmdBody> {
232 if let Some(id) = tokio::task::try_id() {
233 if id == self.recv_handle.id() {
234 return Err(cmd_err!(
235 CmdErrorCode::Failed,
236 "can't send with resp in recv task"
237 ));
238 }
239 }
240 let mut len = 0;
241 for b in body.iter() {
242 len += b.len();
243 log::trace!(
244 "client {:?} send2 cmd {:?} body: {}",
245 self.tunnel_id,
246 cmd,
247 hex::encode(b)
248 );
249 }
250 log::trace!(
251 "client {:?} send2 cmd: {:?}, len {}",
252 self.tunnel_id,
253 cmd,
254 len
255 );
256 let seq = gen_seq();
257 let header = CmdHeader::<LEN, CMD>::new(
258 version,
259 false,
260 Some(seq),
261 cmd,
262 LEN::from_u64(len as u64).unwrap(),
263 );
264 let buf = header
265 .to_vec()
266 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
267 let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
268 let waiter = self.resp_waiter.clone();
269 let resp_waiter = waiter
270 .create_timeout_result_future(resp_id, timeout)
271 .map_err(into_cmd_err!(
272 CmdErrorCode::Failed,
273 "create timeout result future error"
274 ))?;
275 let ret = self.send_inner2(buf.as_slice(), body).await;
276 if let Err(e) = ret {
277 self.set_disable();
278 return Err(e);
279 }
280 let resp = resp_waiter
281 .await
282 .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
283 Ok(resp)
284 }
285
286 pub async fn send_cmd(&mut self, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
287 log::trace!(
288 "client {:?} send cmd: {:?}, len: {}",
289 self.tunnel_id,
290 cmd,
291 body.len()
292 );
293 let header = CmdHeader::<LEN, CMD>::new(
294 version,
295 false,
296 None,
297 cmd,
298 LEN::from_u64(body.len()).unwrap(),
299 );
300 let buf = header
301 .to_vec()
302 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
303 let ret = self.send_inner_cmd(buf.as_slice(), body).await;
304 if let Err(e) = ret {
305 self.set_disable();
306 return Err(e);
307 }
308 Ok(())
309 }
310
311 pub async fn send_cmd_with_resp(
312 &mut self,
313 cmd: CMD,
314 version: u8,
315 body: CmdBody,
316 timeout: Duration,
317 ) -> CmdResult<CmdBody> {
318 if let Some(id) = tokio::task::try_id() {
319 if id == self.recv_handle.id() {
320 return Err(cmd_err!(
321 CmdErrorCode::Failed,
322 "can't send with resp in recv task"
323 ));
324 }
325 }
326 log::trace!(
327 "client {:?} send cmd: {:?}, len: {}",
328 self.tunnel_id,
329 cmd,
330 body.len()
331 );
332 let seq = gen_seq();
333 let header = CmdHeader::<LEN, CMD>::new(
334 version,
335 false,
336 Some(seq),
337 cmd,
338 LEN::from_u64(body.len()).unwrap(),
339 );
340 let buf = header
341 .to_vec()
342 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
343 let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
344 let waiter = self.resp_waiter.clone();
345 let resp_waiter = waiter
346 .create_timeout_result_future(resp_id, timeout)
347 .map_err(into_cmd_err!(
348 CmdErrorCode::Failed,
349 "create timeout result future error"
350 ))?;
351 let ret = self.send_inner_cmd(buf.as_slice(), body).await;
352 if let Err(e) = ret {
353 self.set_disable();
354 return Err(e);
355 }
356 let resp = resp_waiter
357 .await
358 .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
359 Ok(resp)
360 }
361
362 async fn send_inner(&mut self, header: &[u8], body: &[u8]) -> CmdResult<()> {
363 let mut write = self.write.get().await;
364 if header.len() > 255 {
365 return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too long"));
366 }
367 write
368 .write_u8(header.len() as u8)
369 .await
370 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
371 write
372 .write_all(header)
373 .await
374 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
375 write
376 .write_all(body)
377 .await
378 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
379 write
380 .flush()
381 .await
382 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
383 Ok(())
384 }
385
386 async fn send_inner2(&mut self, header: &[u8], body: &[&[u8]]) -> CmdResult<()> {
387 let mut write = self.write.get().await;
388 if header.len() > 255 {
389 return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too long"));
390 }
391 write
392 .write_u8(header.len() as u8)
393 .await
394 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
395 write
396 .write_all(header)
397 .await
398 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
399 for b in body.iter() {
400 write
401 .write_all(b)
402 .await
403 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
404 }
405 write
406 .flush()
407 .await
408 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
409 Ok(())
410 }
411
412 async fn send_inner_cmd(&mut self, header: &[u8], mut body: CmdBody) -> CmdResult<()> {
413 let mut write = self.write.get().await;
414 if header.len() > 255 {
415 return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
416 }
417 write
418 .write_u8(header.len() as u8)
419 .await
420 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
421 write
422 .write_all(header)
423 .await
424 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
425 tokio::io::copy(&mut body, write.deref_mut().deref_mut())
426 .await
427 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
428 write
429 .flush()
430 .await
431 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
432 Ok(())
433 }
434}
435
436impl<M, R, W, LEN, CMD> Drop for CommonCmdSend<M, R, W, LEN, CMD>
437where
438 M: CmdTunnelMeta,
439 R: CmdTunnelRead<M>,
440 W: CmdTunnelWrite<M>,
441 LEN: RawEncode
442 + for<'a> RawDecode<'a>
443 + Copy
444 + Send
445 + Sync
446 + 'static
447 + FromPrimitive
448 + ToPrimitive,
449 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
450{
451 fn drop(&mut self) {
452 self.set_disable();
453 }
454}
455
456impl<M, R, W, LEN, CMD> CmdSend<M> for CommonCmdSend<M, R, W, LEN, CMD>
457where
458 M: CmdTunnelMeta,
459 R: CmdTunnelRead<M>,
460 W: CmdTunnelWrite<M>,
461 LEN: RawEncode
462 + for<'a> RawDecode<'a>
463 + Copy
464 + Send
465 + Sync
466 + 'static
467 + FromPrimitive
468 + ToPrimitive,
469 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
470{
471 fn get_tunnel_meta(&self) -> Option<Arc<M>> {
472 self.tunnel_meta.clone()
473 }
474
475 fn get_remote_peer_id(&self) -> PeerId {
476 self.remote_id.clone()
477 }
478}
479
480impl<M, R, W, LEN, CMD> ClassifiedWorker<TunnelId> for CommonCmdSend<M, R, W, LEN, CMD>
481where
482 M: CmdTunnelMeta,
483 R: CmdTunnelRead<M>,
484 W: CmdTunnelWrite<M>,
485 LEN: RawEncode
486 + for<'a> RawDecode<'a>
487 + Copy
488 + Send
489 + Sync
490 + 'static
491 + FromPrimitive
492 + ToPrimitive,
493 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
494{
495 fn is_work(&self) -> bool {
496 self.is_work && !self.recv_handle.is_finished()
497 }
498
499 fn is_valid(&self, c: TunnelId) -> bool {
500 self.tunnel_id == c
501 }
502
503 fn classification(&self) -> TunnelId {
504 self.tunnel_id
505 }
506}
507
508pub struct ClassifiedSendGuard<
509 C: WorkerClassification,
510 M: CmdTunnelMeta,
511 CW: ClassifiedWorker<C> + CmdSend<M>,
512 F: ClassifiedWorkerFactory<C, CW>,
513> {
514 pub(crate) worker_guard: ClassifiedWorkerGuard<C, CW, F>,
515 pub(crate) _p: PhantomData<M>,
516}
517
518impl<
519 C: WorkerClassification,
520 M: CmdTunnelMeta,
521 CW: ClassifiedWorker<C> + CmdSend<M>,
522 F: ClassifiedWorkerFactory<C, CW>,
523> Deref for ClassifiedSendGuard<C, M, CW, F>
524{
525 type Target = CW;
526
527 fn deref(&self) -> &Self::Target {
528 &self.worker_guard.deref()
529 }
530}
531
532impl<
533 C: WorkerClassification,
534 M: CmdTunnelMeta,
535 CW: ClassifiedWorker<C> + CmdSend<M>,
536 F: ClassifiedWorkerFactory<C, CW>,
537> SendGuard<M, CW> for ClassifiedSendGuard<C, M, CW, F>
538{
539}
540
541pub struct CmdWriteFactory<
542 M: CmdTunnelMeta,
543 R: CmdTunnelRead<M>,
544 W: CmdTunnelWrite<M>,
545 F: CmdTunnelFactory<M, R, W>,
546 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
547 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug,
548> {
549 tunnel_factory: F,
550 cmd_handler: Arc<dyn CmdHandler<LEN, CMD>>,
551 resp_waiter: RespWaiterRef,
552 tunnel_id_generator: TunnelIdGenerator,
553 p: std::marker::PhantomData<Mutex<(R, W, M)>>,
554}
555
556impl<
557 M: CmdTunnelMeta,
558 R: CmdTunnelRead<M>,
559 W: CmdTunnelWrite<M>,
560 F: CmdTunnelFactory<M, R, W>,
561 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
562 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug,
563> CmdWriteFactory<M, R, W, F, LEN, CMD>
564{
565 pub(crate) fn new(
566 tunnel_factory: F,
567 cmd_handler: impl CmdHandler<LEN, CMD>,
568 resp_waiter: RespWaiterRef,
569 ) -> Self {
570 Self {
571 tunnel_factory,
572 cmd_handler: Arc::new(cmd_handler),
573 resp_waiter,
574 tunnel_id_generator: TunnelIdGenerator::new(),
575 p: Default::default(),
576 }
577 }
578}
579
580#[async_trait::async_trait]
581impl<
582 M: CmdTunnelMeta,
583 R: CmdTunnelRead<M>,
584 W: CmdTunnelWrite<M>,
585 F: CmdTunnelFactory<M, R, W>,
586 LEN: RawEncode
587 + for<'a> RawDecode<'a>
588 + Copy
589 + Send
590 + Sync
591 + 'static
592 + FromPrimitive
593 + ToPrimitive
594 + RawFixedBytes,
595 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Debug,
596> ClassifiedWorkerFactory<TunnelId, CommonCmdSend<M, R, W, LEN, CMD>>
597 for CmdWriteFactory<M, R, W, F, LEN, CMD>
598{
599 async fn create(&self, c: Option<TunnelId>) -> PoolResult<CommonCmdSend<M, R, W, LEN, CMD>> {
600 if c.is_some() {
601 return Err(pool_err!(
602 PoolErrorCode::Failed,
603 "tunnel {:?} not found",
604 c.unwrap()
605 ));
606 }
607 let tunnel = self
608 .tunnel_factory
609 .create_tunnel()
610 .await
611 .map_err(into_pool_err!(PoolErrorCode::Failed))?;
612 let peer_id = tunnel.get_remote_peer_id();
613 let tunnel_id = self.tunnel_id_generator.generate();
614 let (mut recv, write) = tunnel.split();
615 let local_id = recv.get_local_peer_id();
616 let remote_id = write.get_remote_peer_id();
617 let meta = write.get_tunnel_meta();
618 let write = ObjectHolder::new(write);
619 let resp_write = write.clone();
620 let cmd_handler = self.cmd_handler.clone();
621 let handle = spawn(async move {
622 let ret: CmdResult<()> = async move {
623 loop {
624 let header_len = recv
625 .read_u8()
626 .await
627 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
628 let mut header = vec![0u8; header_len as usize];
629 let n = recv
630 .read_exact(header.as_mut())
631 .await
632 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
633 if n == 0 {
634 break;
635 }
636 let header = CmdHeader::<LEN, CMD>::clone_from_slice(header.as_slice())
637 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
638 log::trace!(
639 "recv cmd {:?} from {} len {}",
640 header.cmd_code(),
641 peer_id.to_base58(),
642 header.pkg_len().to_u64().unwrap()
643 );
644 let body_len = header.pkg_len().to_u64().unwrap();
645 let cmd_read =
646 CmdBodyRead::new(recv, header.pkg_len().to_u64().unwrap() as usize);
647 let waiter = cmd_read.get_waiter();
648 let future = waiter
649 .create_result_future()
650 .map_err(into_cmd_err!(CmdErrorCode::Failed))?;
651 let version = header.version();
652 let seq = header.seq();
653 let cmd_code = header.cmd_code();
654 match cmd_handler
655 .handle(
656 local_id.clone(),
657 peer_id.clone(),
658 tunnel_id,
659 header,
660 CmdBody::from_reader(BufReader::new(cmd_read), body_len),
661 )
662 .await
663 {
664 Ok(Some(mut body)) => {
665 let mut write = resp_write.get().await;
666 let header = CmdHeader::<LEN, CMD>::new(
667 version,
668 true,
669 seq,
670 cmd_code,
671 LEN::from_u64(body.len()).unwrap(),
672 );
673 let buf = header
674 .to_vec()
675 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
676 if buf.len() > 255 {
677 return Err(cmd_err!(
678 CmdErrorCode::InvalidParam,
679 "header len too long"
680 ));
681 }
682 write
683 .write_u8(buf.len() as u8)
684 .await
685 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
686 write
687 .write_all(buf.as_slice())
688 .await
689 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
690 tokio::io::copy(&mut body, write.deref_mut().deref_mut())
691 .await
692 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
693 write
694 .flush()
695 .await
696 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
697 }
698 Ok(None) => {}
699 Err(e) => {
700 log::error!("handle cmd error: {:?}", e);
701 }
702 }
703 recv = future
704 .await
705 .map_err(into_cmd_err!(CmdErrorCode::Failed))??;
706 }
707 Ok(())
708 }
709 .await;
710 ret
711 });
712 Ok(CommonCmdSend::new(
713 tunnel_id,
714 handle,
715 write,
716 self.resp_waiter.clone(),
717 remote_id,
718 meta,
719 ))
720 }
721}
722
723pub struct DefaultCmdClient<
724 M: CmdTunnelMeta,
725 R: CmdTunnelRead<M>,
726 W: CmdTunnelWrite<M>,
727 F: CmdTunnelFactory<M, R, W>,
728 LEN: RawEncode
729 + for<'a> RawDecode<'a>
730 + Copy
731 + Send
732 + Sync
733 + 'static
734 + FromPrimitive
735 + ToPrimitive
736 + RawFixedBytes,
737 CMD: RawEncode
738 + for<'a> RawDecode<'a>
739 + Copy
740 + Send
741 + Sync
742 + 'static
743 + RawFixedBytes
744 + Eq
745 + Hash
746 + Debug,
747> {
748 tunnel_pool: ClassifiedWorkerPoolRef<
749 TunnelId,
750 CommonCmdSend<M, R, W, LEN, CMD>,
751 CmdWriteFactory<M, R, W, F, LEN, CMD>,
752 >,
753 cmd_handler_map: Arc<CmdHandlerMap<LEN, CMD>>,
754}
755
756impl<
757 M: CmdTunnelMeta,
758 R: CmdTunnelRead<M>,
759 W: CmdTunnelWrite<M>,
760 F: CmdTunnelFactory<M, R, W>,
761 LEN: RawEncode
762 + for<'a> RawDecode<'a>
763 + Copy
764 + Send
765 + Sync
766 + 'static
767 + FromPrimitive
768 + ToPrimitive
769 + RawFixedBytes,
770 CMD: RawEncode
771 + for<'a> RawDecode<'a>
772 + Copy
773 + Send
774 + Sync
775 + 'static
776 + RawFixedBytes
777 + Eq
778 + Hash
779 + Debug,
780> DefaultCmdClient<M, R, W, F, LEN, CMD>
781{
782 pub fn new(factory: F, tunnel_count: u16) -> Arc<Self> {
783 let cmd_handler_map = Arc::new(CmdHandlerMap::new());
784 let handler_map = cmd_handler_map.clone();
785 let resp_waiter = Arc::new(RespWaiter::new());
786 let waiter = resp_waiter.clone();
787 Arc::new(Self {
788 tunnel_pool: ClassifiedWorkerPool::new(
789 tunnel_count,
790 CmdWriteFactory::<M, R, W, _, LEN, CMD>::new(
791 factory,
792 move |local_id: PeerId,
793 peer_id: PeerId,
794 tunnel_id: TunnelId,
795 header: CmdHeader<LEN, CMD>,
796 body_read: CmdBody| {
797 let handler_map = handler_map.clone();
798 let waiter = waiter.clone();
799 async move {
800 if header.is_resp() && header.seq().is_some() {
801 let resp_id = gen_resp_id(
802 tunnel_id,
803 header.cmd_code(),
804 header.seq().unwrap(),
805 );
806 let _ = waiter.set_result(resp_id, body_read);
807 Ok(None)
808 } else {
809 if let Some(handler) = handler_map.get(header.cmd_code()) {
810 handler
811 .handle(local_id, peer_id, tunnel_id, header, body_read)
812 .await
813 } else {
814 Ok(None)
815 }
816 }
817 }
818 },
819 resp_waiter.clone(),
820 ),
821 ),
822 cmd_handler_map,
823 })
824 }
825
826 async fn get_send(
827 &self,
828 ) -> CmdResult<
829 ClassifiedWorkerGuard<
830 TunnelId,
831 CommonCmdSend<M, R, W, LEN, CMD>,
832 CmdWriteFactory<M, R, W, F, LEN, CMD>,
833 >,
834 > {
835 self.tunnel_pool
836 .get_worker()
837 .await
838 .map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
839 }
840
841 async fn get_send_of_tunnel_id(
842 &self,
843 tunnel_id: TunnelId,
844 ) -> CmdResult<
845 ClassifiedWorkerGuard<
846 TunnelId,
847 CommonCmdSend<M, R, W, LEN, CMD>,
848 CmdWriteFactory<M, R, W, F, LEN, CMD>,
849 >,
850 > {
851 self.tunnel_pool
852 .get_classified_worker(tunnel_id)
853 .await
854 .map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
855 }
856}
857
858pub type CmdClientSendGuard<M, R, W, F, LEN, CMD> = ClassifiedSendGuard<
859 TunnelId,
860 M,
861 CommonCmdSend<M, R, W, LEN, CMD>,
862 CmdWriteFactory<M, R, W, F, LEN, CMD>,
863>;
864#[async_trait::async_trait]
865impl<
866 M: CmdTunnelMeta,
867 R: CmdTunnelRead<M>,
868 W: CmdTunnelWrite<M>,
869 F: CmdTunnelFactory<M, R, W>,
870 LEN: RawEncode
871 + for<'a> RawDecode<'a>
872 + Copy
873 + Send
874 + Sync
875 + 'static
876 + FromPrimitive
877 + ToPrimitive
878 + RawFixedBytes,
879 CMD: RawEncode
880 + for<'a> RawDecode<'a>
881 + Copy
882 + Send
883 + Sync
884 + 'static
885 + RawFixedBytes
886 + Eq
887 + Hash
888 + Debug,
889> CmdClient<LEN, CMD, M, CommonCmdSend<M, R, W, LEN, CMD>, CmdClientSendGuard<M, R, W, F, LEN, CMD>>
890 for DefaultCmdClient<M, R, W, F, LEN, CMD>
891{
892 fn register_cmd_handler(&self, cmd: CMD, handler: impl CmdHandler<LEN, CMD>) {
893 self.cmd_handler_map.insert(cmd, handler);
894 }
895
896 async fn send(&self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
897 let mut send = self.get_send().await?;
898 send.send(cmd, version, body).await
899 }
900
901 async fn send_with_resp(
902 &self,
903 cmd: CMD,
904 version: u8,
905 body: &[u8],
906 timeout: Duration,
907 ) -> CmdResult<CmdBody> {
908 let mut send = self.get_send().await?;
909 send.send_with_resp(cmd, version, body, timeout).await
910 }
911
912 async fn send2(&self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
913 let mut send = self.get_send().await?;
914 send.send2(cmd, version, body).await
915 }
916
917 async fn send2_with_resp(
918 &self,
919 cmd: CMD,
920 version: u8,
921 body: &[&[u8]],
922 timeout: Duration,
923 ) -> CmdResult<CmdBody> {
924 let mut send = self.get_send().await?;
925 send.send2_with_resp(cmd, version, body, timeout).await
926 }
927
928 async fn send_cmd(&self, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
929 let mut send = self.get_send().await?;
930 send.send_cmd(cmd, version, body).await
931 }
932
933 async fn send_cmd_with_resp(
934 &self,
935 cmd: CMD,
936 version: u8,
937 body: CmdBody,
938 timeout: Duration,
939 ) -> CmdResult<CmdBody> {
940 let mut send = self.get_send().await?;
941 send.send_cmd_with_resp(cmd, version, body, timeout).await
942 }
943
944 async fn send_by_specify_tunnel(
945 &self,
946 tunnel_id: TunnelId,
947 cmd: CMD,
948 version: u8,
949 body: &[u8],
950 ) -> CmdResult<()> {
951 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
952 send.send(cmd, version, body).await
953 }
954
955 async fn send_by_specify_tunnel_with_resp(
956 &self,
957 tunnel_id: TunnelId,
958 cmd: CMD,
959 version: u8,
960 body: &[u8],
961 timeout: Duration,
962 ) -> CmdResult<CmdBody> {
963 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
964 send.send_with_resp(cmd, version, body, timeout).await
965 }
966
967 async fn send2_by_specify_tunnel(
968 &self,
969 tunnel_id: TunnelId,
970 cmd: CMD,
971 version: u8,
972 body: &[&[u8]],
973 ) -> CmdResult<()> {
974 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
975 send.send2(cmd, version, body).await
976 }
977
978 async fn send2_by_specify_tunnel_with_resp(
979 &self,
980 tunnel_id: TunnelId,
981 cmd: CMD,
982 version: u8,
983 body: &[&[u8]],
984 timeout: Duration,
985 ) -> CmdResult<CmdBody> {
986 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
987 send.send2_with_resp(cmd, version, body, timeout).await
988 }
989
990 async fn send_cmd_by_specify_tunnel(
991 &self,
992 tunnel_id: TunnelId,
993 cmd: CMD,
994 version: u8,
995 body: CmdBody,
996 ) -> CmdResult<()> {
997 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
998 send.send_cmd(cmd, version, body).await
999 }
1000
1001 async fn send_cmd_by_specify_tunnel_with_resp(
1002 &self,
1003 tunnel_id: TunnelId,
1004 cmd: CMD,
1005 version: u8,
1006 body: CmdBody,
1007 timeout: Duration,
1008 ) -> CmdResult<CmdBody> {
1009 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1010 send.send_cmd_with_resp(cmd, version, body, timeout).await
1011 }
1012
1013 async fn clear_all_tunnel(&self) {
1014 self.tunnel_pool.clear_all_worker().await;
1015 }
1016
1017 async fn get_send(
1018 &self,
1019 tunnel_id: TunnelId,
1020 ) -> CmdResult<CmdClientSendGuard<M, R, W, F, LEN, CMD>> {
1021 Ok(ClassifiedSendGuard {
1022 worker_guard: self.get_send_of_tunnel_id(tunnel_id).await?,
1023 _p: Default::default(),
1024 })
1025 }
1026}