1use crate::client::{
2 CmdClient, CmdSend, RespWaiter, RespWaiterRef, SendGuard, gen_resp_id, gen_seq,
3};
4use crate::cmd::{CmdBodyRead, CmdHandler, CmdHandlerMap, CmdHeader};
5use crate::errors::{CmdErrorCode, CmdResult, cmd_err, into_cmd_err};
6use crate::peer_id::PeerId;
7use crate::{CmdBody, CmdTunnelMeta, CmdTunnelRead, CmdTunnelWrite, TunnelId, TunnelIdGenerator};
8use async_named_locker::ObjectHolder;
9use bucky_raw_codec::{RawConvertTo, RawDecode, RawEncode, RawFixedBytes, RawFrom};
10use num::{FromPrimitive, ToPrimitive};
11use sfo_pool::{
12 ClassifiedWorker, ClassifiedWorkerFactory, ClassifiedWorkerGuard, ClassifiedWorkerPool,
13 ClassifiedWorkerPoolRef, PoolErrorCode, PoolResult, WorkerClassification, into_pool_err,
14 pool_err,
15};
16use sfo_split::{Splittable, WHalf};
17use std::fmt::Debug;
18use std::hash::Hash;
19use std::marker::PhantomData;
20use std::ops::{Deref, DerefMut};
21use std::sync::{Arc, Mutex};
22use std::time::Duration;
23use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
24use tokio::spawn;
25use tokio::task::JoinHandle;
26
27#[async_trait::async_trait]
28pub trait CmdTunnelFactory<M: CmdTunnelMeta, R: CmdTunnelRead<M>, W: CmdTunnelWrite<M>>:
29 Send + Sync + 'static
30{
31 async fn create_tunnel(&self) -> CmdResult<Splittable<R, W>>;
32}
33
34pub struct CommonCmdSend<M: CmdTunnelMeta, R: CmdTunnelRead<M>, W: CmdTunnelWrite<M>, LEN, CMD>
35where
36 LEN: RawEncode
37 + for<'a> RawDecode<'a>
38 + Copy
39 + Send
40 + Sync
41 + 'static
42 + FromPrimitive
43 + ToPrimitive,
44 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
45{
46 pub(crate) recv_handle: JoinHandle<CmdResult<()>>,
47 pub(crate) write: ObjectHolder<WHalf<R, W>>,
48 pub(crate) is_work: bool,
49 pub(crate) tunnel_id: TunnelId,
50 pub(crate) remote_id: PeerId,
51 pub(crate) resp_waiter: RespWaiterRef,
52 pub(crate) tunnel_meta: Option<Arc<M>>,
53 _p: std::marker::PhantomData<(LEN, CMD)>,
54}
55
56impl<M, R, W, LEN, CMD> CommonCmdSend<M, R, W, LEN, CMD>
69where
70 M: CmdTunnelMeta,
71 R: CmdTunnelRead<M>,
72 W: CmdTunnelWrite<M>,
73 LEN: RawEncode
74 + for<'a> RawDecode<'a>
75 + Copy
76 + Send
77 + Sync
78 + 'static
79 + FromPrimitive
80 + ToPrimitive,
81 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
82{
83 pub fn new(
84 tunnel_id: TunnelId,
85 recv_handle: JoinHandle<CmdResult<()>>,
86 write: ObjectHolder<WHalf<R, W>>,
87 resp_waiter: RespWaiterRef,
88 remote_id: PeerId,
89 tunnel_meta: Option<Arc<M>>,
90 ) -> Self {
91 Self {
92 recv_handle,
93 write,
94 is_work: true,
95 tunnel_id,
96 remote_id,
97 resp_waiter,
98 tunnel_meta,
99 _p: Default::default(),
100 }
101 }
102
103 pub fn get_tunnel_id(&self) -> TunnelId {
104 self.tunnel_id
105 }
106
107 pub fn set_disable(&mut self) {
108 self.is_work = false;
109 self.recv_handle.abort();
110 }
111
112 pub async fn send(&mut self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
113 log::trace!(
114 "client {:?} send cmd: {:?}, len: {} data:{}",
115 self.tunnel_id,
116 cmd,
117 body.len(),
118 hex::encode(body)
119 );
120 let header = CmdHeader::<LEN, CMD>::new(
121 version,
122 false,
123 None,
124 cmd,
125 LEN::from_u64(body.len() as u64).unwrap(),
126 );
127 let buf = header
128 .to_vec()
129 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
130 let ret = self.send_inner(buf.as_slice(), body).await;
131 if let Err(e) = ret {
132 self.set_disable();
133 return Err(e);
134 }
135 Ok(())
136 }
137
138 pub async fn send_with_resp(
139 &mut self,
140 cmd: CMD,
141 version: u8,
142 body: &[u8],
143 timeout: Duration,
144 ) -> CmdResult<CmdBody> {
145 if let Some(id) = tokio::task::try_id() {
146 if id == self.recv_handle.id() {
147 return Err(cmd_err!(
148 CmdErrorCode::Failed,
149 "can't send with resp in recv task"
150 ));
151 }
152 }
153 log::trace!(
154 "client {:?} send cmd: {:?}, len: {}, data: {}",
155 self.tunnel_id,
156 cmd,
157 body.len(),
158 hex::encode(body)
159 );
160 let seq = gen_seq();
161 let header = CmdHeader::<LEN, CMD>::new(
162 version,
163 false,
164 Some(seq),
165 cmd,
166 LEN::from_u64(body.len() as u64).unwrap(),
167 );
168 let buf = header
169 .to_vec()
170 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
171 let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
172 let waiter = self.resp_waiter.clone();
173 let resp_waiter = waiter
174 .create_timeout_result_future(resp_id, timeout)
175 .map_err(into_cmd_err!(
176 CmdErrorCode::Failed,
177 "create timeout result future error"
178 ))?;
179 let ret = self.send_inner(buf.as_slice(), body).await;
180 if let Err(e) = ret {
181 self.set_disable();
182 return Err(e);
183 }
184 let resp = resp_waiter
185 .await
186 .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
187 Ok(resp)
188 }
189
190 pub async fn send_parts(&mut self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
191 let mut len = 0;
192 for b in body.iter() {
193 len += b.len();
194 log::trace!(
195 "client {:?} send2 cmd: {:?}, data {}",
196 self.tunnel_id,
197 cmd,
198 hex::encode(b)
199 );
200 }
201 log::trace!(
202 "client {:?} send2 cmd: {:?}, len {}",
203 self.tunnel_id,
204 cmd,
205 len
206 );
207 let header = CmdHeader::<LEN, CMD>::new(
208 version,
209 false,
210 None,
211 cmd,
212 LEN::from_u64(len as u64).unwrap(),
213 );
214 let buf = header
215 .to_vec()
216 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
217 let ret = self.send_inner2(buf.as_slice(), body).await;
218 if let Err(e) = ret {
219 self.set_disable();
220 return Err(e);
221 }
222 Ok(())
223 }
224
225 pub async fn send_parts_with_resp(
226 &mut self,
227 cmd: CMD,
228 version: u8,
229 body: &[&[u8]],
230 timeout: Duration,
231 ) -> CmdResult<CmdBody> {
232 if let Some(id) = tokio::task::try_id() {
233 if id == self.recv_handle.id() {
234 return Err(cmd_err!(
235 CmdErrorCode::Failed,
236 "can't send with resp in recv task"
237 ));
238 }
239 }
240 let mut len = 0;
241 for b in body.iter() {
242 len += b.len();
243 log::trace!(
244 "client {:?} send2 cmd {:?} body: {}",
245 self.tunnel_id,
246 cmd,
247 hex::encode(b)
248 );
249 }
250 log::trace!(
251 "client {:?} send2 cmd: {:?}, len {}",
252 self.tunnel_id,
253 cmd,
254 len
255 );
256 let seq = gen_seq();
257 let header = CmdHeader::<LEN, CMD>::new(
258 version,
259 false,
260 Some(seq),
261 cmd,
262 LEN::from_u64(len as u64).unwrap(),
263 );
264 let buf = header
265 .to_vec()
266 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
267 let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
268 let waiter = self.resp_waiter.clone();
269 let resp_waiter = waiter
270 .create_timeout_result_future(resp_id, timeout)
271 .map_err(into_cmd_err!(
272 CmdErrorCode::Failed,
273 "create timeout result future error"
274 ))?;
275 let ret = self.send_inner2(buf.as_slice(), body).await;
276 if let Err(e) = ret {
277 self.set_disable();
278 return Err(e);
279 }
280 let resp = resp_waiter
281 .await
282 .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
283 Ok(resp)
284 }
285
286 #[allow(deprecated)]
287 #[deprecated(note = "use send_parts instead")]
288 pub async fn send2(&mut self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
289 self.send_parts(cmd, version, body).await
290 }
291
292 #[allow(deprecated)]
293 #[deprecated(note = "use send_parts_with_resp instead")]
294 pub async fn send2_with_resp(
295 &mut self,
296 cmd: CMD,
297 version: u8,
298 body: &[&[u8]],
299 timeout: Duration,
300 ) -> CmdResult<CmdBody> {
301 self.send_parts_with_resp(cmd, version, body, timeout).await
302 }
303
304 pub async fn send_cmd(&mut self, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
305 log::trace!(
306 "client {:?} send cmd: {:?}, len: {}",
307 self.tunnel_id,
308 cmd,
309 body.len()
310 );
311 let header = CmdHeader::<LEN, CMD>::new(
312 version,
313 false,
314 None,
315 cmd,
316 LEN::from_u64(body.len()).unwrap(),
317 );
318 let buf = header
319 .to_vec()
320 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
321 let ret = self.send_inner_cmd(buf.as_slice(), body).await;
322 if let Err(e) = ret {
323 self.set_disable();
324 return Err(e);
325 }
326 Ok(())
327 }
328
329 pub async fn send_cmd_with_resp(
330 &mut self,
331 cmd: CMD,
332 version: u8,
333 body: CmdBody,
334 timeout: Duration,
335 ) -> CmdResult<CmdBody> {
336 if let Some(id) = tokio::task::try_id() {
337 if id == self.recv_handle.id() {
338 return Err(cmd_err!(
339 CmdErrorCode::Failed,
340 "can't send with resp in recv task"
341 ));
342 }
343 }
344 log::trace!(
345 "client {:?} send cmd: {:?}, len: {}",
346 self.tunnel_id,
347 cmd,
348 body.len()
349 );
350 let seq = gen_seq();
351 let header = CmdHeader::<LEN, CMD>::new(
352 version,
353 false,
354 Some(seq),
355 cmd,
356 LEN::from_u64(body.len()).unwrap(),
357 );
358 let buf = header
359 .to_vec()
360 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
361 let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
362 let waiter = self.resp_waiter.clone();
363 let resp_waiter = waiter
364 .create_timeout_result_future(resp_id, timeout)
365 .map_err(into_cmd_err!(
366 CmdErrorCode::Failed,
367 "create timeout result future error"
368 ))?;
369 let ret = self.send_inner_cmd(buf.as_slice(), body).await;
370 if let Err(e) = ret {
371 self.set_disable();
372 return Err(e);
373 }
374 let resp = resp_waiter
375 .await
376 .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
377 Ok(resp)
378 }
379
380 async fn send_inner(&mut self, header: &[u8], body: &[u8]) -> CmdResult<()> {
381 let mut write = self.write.get().await;
382 if header.len() > 255 {
383 return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too long"));
384 }
385 write
386 .write_u8(header.len() as u8)
387 .await
388 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
389 write
390 .write_all(header)
391 .await
392 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
393 write
394 .write_all(body)
395 .await
396 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
397 write
398 .flush()
399 .await
400 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
401 Ok(())
402 }
403
404 async fn send_inner2(&mut self, header: &[u8], body: &[&[u8]]) -> CmdResult<()> {
405 let mut write = self.write.get().await;
406 if header.len() > 255 {
407 return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too long"));
408 }
409 write
410 .write_u8(header.len() as u8)
411 .await
412 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
413 write
414 .write_all(header)
415 .await
416 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
417 for b in body.iter() {
418 write
419 .write_all(b)
420 .await
421 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
422 }
423 write
424 .flush()
425 .await
426 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
427 Ok(())
428 }
429
430 async fn send_inner_cmd(&mut self, header: &[u8], mut body: CmdBody) -> CmdResult<()> {
431 let mut write = self.write.get().await;
432 if header.len() > 255 {
433 return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
434 }
435 write
436 .write_u8(header.len() as u8)
437 .await
438 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
439 write
440 .write_all(header)
441 .await
442 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
443 tokio::io::copy(&mut body, write.deref_mut().deref_mut())
444 .await
445 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
446 write
447 .flush()
448 .await
449 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
450 Ok(())
451 }
452}
453
454impl<M, R, W, LEN, CMD> Drop for CommonCmdSend<M, R, W, LEN, CMD>
455where
456 M: CmdTunnelMeta,
457 R: CmdTunnelRead<M>,
458 W: CmdTunnelWrite<M>,
459 LEN: RawEncode
460 + for<'a> RawDecode<'a>
461 + Copy
462 + Send
463 + Sync
464 + 'static
465 + FromPrimitive
466 + ToPrimitive,
467 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
468{
469 fn drop(&mut self) {
470 self.set_disable();
471 }
472}
473
474impl<M, R, W, LEN, CMD> CmdSend<M> for CommonCmdSend<M, R, W, LEN, CMD>
475where
476 M: CmdTunnelMeta,
477 R: CmdTunnelRead<M>,
478 W: CmdTunnelWrite<M>,
479 LEN: RawEncode
480 + for<'a> RawDecode<'a>
481 + Copy
482 + Send
483 + Sync
484 + 'static
485 + FromPrimitive
486 + ToPrimitive,
487 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
488{
489 fn get_tunnel_meta(&self) -> Option<Arc<M>> {
490 self.tunnel_meta.clone()
491 }
492
493 fn get_remote_peer_id(&self) -> PeerId {
494 self.remote_id.clone()
495 }
496}
497
498impl<M, R, W, LEN, CMD> ClassifiedWorker<TunnelId> for CommonCmdSend<M, R, W, LEN, CMD>
499where
500 M: CmdTunnelMeta,
501 R: CmdTunnelRead<M>,
502 W: CmdTunnelWrite<M>,
503 LEN: RawEncode
504 + for<'a> RawDecode<'a>
505 + Copy
506 + Send
507 + Sync
508 + 'static
509 + FromPrimitive
510 + ToPrimitive,
511 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
512{
513 fn is_work(&self) -> bool {
514 self.is_work && !self.recv_handle.is_finished()
515 }
516
517 fn is_valid(&self, c: TunnelId) -> bool {
518 self.tunnel_id == c
519 }
520
521 fn classification(&self) -> TunnelId {
522 self.tunnel_id
523 }
524}
525
526pub struct ClassifiedSendGuard<
527 C: WorkerClassification,
528 M: CmdTunnelMeta,
529 CW: ClassifiedWorker<C> + CmdSend<M>,
530 F: ClassifiedWorkerFactory<C, CW>,
531> {
532 pub(crate) worker_guard: ClassifiedWorkerGuard<C, CW, F>,
533 pub(crate) _p: PhantomData<M>,
534}
535
536impl<
537 C: WorkerClassification,
538 M: CmdTunnelMeta,
539 CW: ClassifiedWorker<C> + CmdSend<M>,
540 F: ClassifiedWorkerFactory<C, CW>,
541> Deref for ClassifiedSendGuard<C, M, CW, F>
542{
543 type Target = CW;
544
545 fn deref(&self) -> &Self::Target {
546 &self.worker_guard.deref()
547 }
548}
549
550impl<
551 C: WorkerClassification,
552 M: CmdTunnelMeta,
553 CW: ClassifiedWorker<C> + CmdSend<M>,
554 F: ClassifiedWorkerFactory<C, CW>,
555> SendGuard<M, CW> for ClassifiedSendGuard<C, M, CW, F>
556{
557}
558
559pub struct CmdWriteFactory<
560 M: CmdTunnelMeta,
561 R: CmdTunnelRead<M>,
562 W: CmdTunnelWrite<M>,
563 F: CmdTunnelFactory<M, R, W>,
564 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
565 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug,
566> {
567 tunnel_factory: F,
568 cmd_handler: Arc<dyn CmdHandler<LEN, CMD>>,
569 resp_waiter: RespWaiterRef,
570 tunnel_id_generator: TunnelIdGenerator,
571 p: std::marker::PhantomData<Mutex<(R, W, M)>>,
572}
573
574impl<
575 M: CmdTunnelMeta,
576 R: CmdTunnelRead<M>,
577 W: CmdTunnelWrite<M>,
578 F: CmdTunnelFactory<M, R, W>,
579 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
580 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug,
581> CmdWriteFactory<M, R, W, F, LEN, CMD>
582{
583 pub(crate) fn new(
584 tunnel_factory: F,
585 cmd_handler: impl CmdHandler<LEN, CMD>,
586 resp_waiter: RespWaiterRef,
587 ) -> Self {
588 Self {
589 tunnel_factory,
590 cmd_handler: Arc::new(cmd_handler),
591 resp_waiter,
592 tunnel_id_generator: TunnelIdGenerator::new(),
593 p: Default::default(),
594 }
595 }
596}
597
598#[async_trait::async_trait]
599impl<
600 M: CmdTunnelMeta,
601 R: CmdTunnelRead<M>,
602 W: CmdTunnelWrite<M>,
603 F: CmdTunnelFactory<M, R, W>,
604 LEN: RawEncode
605 + for<'a> RawDecode<'a>
606 + Copy
607 + Send
608 + Sync
609 + 'static
610 + FromPrimitive
611 + ToPrimitive
612 + RawFixedBytes,
613 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Debug,
614> ClassifiedWorkerFactory<TunnelId, CommonCmdSend<M, R, W, LEN, CMD>>
615 for CmdWriteFactory<M, R, W, F, LEN, CMD>
616{
617 async fn create(&self, c: Option<TunnelId>) -> PoolResult<CommonCmdSend<M, R, W, LEN, CMD>> {
618 if c.is_some() {
619 return Err(pool_err!(
620 PoolErrorCode::Failed,
621 "tunnel {:?} not found",
622 c.unwrap()
623 ));
624 }
625 let tunnel = self
626 .tunnel_factory
627 .create_tunnel()
628 .await
629 .map_err(into_pool_err!(PoolErrorCode::Failed))?;
630 let peer_id = tunnel.get_remote_peer_id();
631 let tunnel_id = self.tunnel_id_generator.generate();
632 let (mut recv, write) = tunnel.split();
633 let local_id = recv.get_local_peer_id();
634 let remote_id = write.get_remote_peer_id();
635 let meta = write.get_tunnel_meta();
636 let write = ObjectHolder::new(write);
637 let resp_write = write.clone();
638 let cmd_handler = self.cmd_handler.clone();
639 let handle = spawn(async move {
640 let ret: CmdResult<()> = async move {
641 loop {
642 let header_len = recv
643 .read_u8()
644 .await
645 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
646 let mut header = vec![0u8; header_len as usize];
647 let n = recv
648 .read_exact(header.as_mut())
649 .await
650 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
651 if n == 0 {
652 break;
653 }
654 let header = CmdHeader::<LEN, CMD>::clone_from_slice(header.as_slice())
655 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
656 log::trace!(
657 "recv cmd {:?} from {} len {}",
658 header.cmd_code(),
659 peer_id.to_base58(),
660 header.pkg_len().to_u64().unwrap()
661 );
662 let body_len = header.pkg_len().to_u64().unwrap();
663 let cmd_read =
664 CmdBodyRead::new(recv, header.pkg_len().to_u64().unwrap() as usize);
665 let waiter = cmd_read.get_waiter();
666 let future = waiter
667 .create_result_future()
668 .map_err(into_cmd_err!(CmdErrorCode::Failed))?;
669 let version = header.version();
670 let seq = header.seq();
671 let cmd_code = header.cmd_code();
672 match cmd_handler
673 .handle(
674 local_id.clone(),
675 peer_id.clone(),
676 tunnel_id,
677 header,
678 CmdBody::from_reader(BufReader::new(cmd_read), body_len),
679 )
680 .await
681 {
682 Ok(Some(mut body)) => {
683 let mut write = resp_write.get().await;
684 let header = CmdHeader::<LEN, CMD>::new(
685 version,
686 true,
687 seq,
688 cmd_code,
689 LEN::from_u64(body.len()).unwrap(),
690 );
691 let buf = header
692 .to_vec()
693 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
694 if buf.len() > 255 {
695 return Err(cmd_err!(
696 CmdErrorCode::InvalidParam,
697 "header len too long"
698 ));
699 }
700 write
701 .write_u8(buf.len() as u8)
702 .await
703 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
704 write
705 .write_all(buf.as_slice())
706 .await
707 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
708 tokio::io::copy(&mut body, write.deref_mut().deref_mut())
709 .await
710 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
711 write
712 .flush()
713 .await
714 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
715 }
716 Ok(None) => {}
717 Err(e) => {
718 log::error!("handle cmd error: {:?}", e);
719 }
720 }
721 recv = future
722 .await
723 .map_err(into_cmd_err!(CmdErrorCode::Failed))??;
724 }
725 Ok(())
726 }
727 .await;
728 ret
729 });
730 Ok(CommonCmdSend::new(
731 tunnel_id,
732 handle,
733 write,
734 self.resp_waiter.clone(),
735 remote_id,
736 meta,
737 ))
738 }
739}
740
741pub struct DefaultCmdClient<
742 M: CmdTunnelMeta,
743 R: CmdTunnelRead<M>,
744 W: CmdTunnelWrite<M>,
745 F: CmdTunnelFactory<M, R, W>,
746 LEN: RawEncode
747 + for<'a> RawDecode<'a>
748 + Copy
749 + Send
750 + Sync
751 + 'static
752 + FromPrimitive
753 + ToPrimitive
754 + RawFixedBytes,
755 CMD: RawEncode
756 + for<'a> RawDecode<'a>
757 + Copy
758 + Send
759 + Sync
760 + 'static
761 + RawFixedBytes
762 + Eq
763 + Hash
764 + Debug,
765> {
766 tunnel_pool: ClassifiedWorkerPoolRef<
767 TunnelId,
768 CommonCmdSend<M, R, W, LEN, CMD>,
769 CmdWriteFactory<M, R, W, F, LEN, CMD>,
770 >,
771 cmd_handler_map: Arc<CmdHandlerMap<LEN, CMD>>,
772}
773
774impl<
775 M: CmdTunnelMeta,
776 R: CmdTunnelRead<M>,
777 W: CmdTunnelWrite<M>,
778 F: CmdTunnelFactory<M, R, W>,
779 LEN: RawEncode
780 + for<'a> RawDecode<'a>
781 + Copy
782 + Send
783 + Sync
784 + 'static
785 + FromPrimitive
786 + ToPrimitive
787 + RawFixedBytes,
788 CMD: RawEncode
789 + for<'a> RawDecode<'a>
790 + Copy
791 + Send
792 + Sync
793 + 'static
794 + RawFixedBytes
795 + Eq
796 + Hash
797 + Debug,
798> DefaultCmdClient<M, R, W, F, LEN, CMD>
799{
800 pub fn new(factory: F, tunnel_count: u16) -> Arc<Self> {
801 let cmd_handler_map = Arc::new(CmdHandlerMap::new());
802 let handler_map = cmd_handler_map.clone();
803 let resp_waiter = Arc::new(RespWaiter::new());
804 let waiter = resp_waiter.clone();
805 Arc::new(Self {
806 tunnel_pool: ClassifiedWorkerPool::new(
807 tunnel_count,
808 CmdWriteFactory::<M, R, W, _, LEN, CMD>::new(
809 factory,
810 move |local_id: PeerId,
811 peer_id: PeerId,
812 tunnel_id: TunnelId,
813 header: CmdHeader<LEN, CMD>,
814 body_read: CmdBody| {
815 let handler_map = handler_map.clone();
816 let waiter = waiter.clone();
817 async move {
818 if header.is_resp() && header.seq().is_some() {
819 let resp_id = gen_resp_id(
820 tunnel_id,
821 header.cmd_code(),
822 header.seq().unwrap(),
823 );
824 let _ = waiter.set_result(resp_id, body_read);
825 Ok(None)
826 } else {
827 if let Some(handler) = handler_map.get(header.cmd_code()) {
828 handler
829 .handle(local_id, peer_id, tunnel_id, header, body_read)
830 .await
831 } else {
832 Ok(None)
833 }
834 }
835 }
836 },
837 resp_waiter.clone(),
838 ),
839 ),
840 cmd_handler_map,
841 })
842 }
843
844 async fn get_send(
845 &self,
846 ) -> CmdResult<
847 ClassifiedWorkerGuard<
848 TunnelId,
849 CommonCmdSend<M, R, W, LEN, CMD>,
850 CmdWriteFactory<M, R, W, F, LEN, CMD>,
851 >,
852 > {
853 self.tunnel_pool
854 .get_worker()
855 .await
856 .map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
857 }
858
859 async fn get_send_of_tunnel_id(
860 &self,
861 tunnel_id: TunnelId,
862 ) -> CmdResult<
863 ClassifiedWorkerGuard<
864 TunnelId,
865 CommonCmdSend<M, R, W, LEN, CMD>,
866 CmdWriteFactory<M, R, W, F, LEN, CMD>,
867 >,
868 > {
869 self.tunnel_pool
870 .get_classified_worker(tunnel_id)
871 .await
872 .map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
873 }
874}
875
876pub type CmdClientSendGuard<M, R, W, F, LEN, CMD> = ClassifiedSendGuard<
877 TunnelId,
878 M,
879 CommonCmdSend<M, R, W, LEN, CMD>,
880 CmdWriteFactory<M, R, W, F, LEN, CMD>,
881>;
882#[async_trait::async_trait]
883impl<
884 M: CmdTunnelMeta,
885 R: CmdTunnelRead<M>,
886 W: CmdTunnelWrite<M>,
887 F: CmdTunnelFactory<M, R, W>,
888 LEN: RawEncode
889 + for<'a> RawDecode<'a>
890 + Copy
891 + Send
892 + Sync
893 + 'static
894 + FromPrimitive
895 + ToPrimitive
896 + RawFixedBytes,
897 CMD: RawEncode
898 + for<'a> RawDecode<'a>
899 + Copy
900 + Send
901 + Sync
902 + 'static
903 + RawFixedBytes
904 + Eq
905 + Hash
906 + Debug,
907> CmdClient<LEN, CMD, M, CommonCmdSend<M, R, W, LEN, CMD>, CmdClientSendGuard<M, R, W, F, LEN, CMD>>
908 for DefaultCmdClient<M, R, W, F, LEN, CMD>
909{
910 fn register_cmd_handler(&self, cmd: CMD, handler: impl CmdHandler<LEN, CMD>) {
911 self.cmd_handler_map.insert(cmd, handler);
912 }
913
914 async fn send(&self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
915 let mut send = self.get_send().await?;
916 send.send(cmd, version, body).await
917 }
918
919 async fn send_with_resp(
920 &self,
921 cmd: CMD,
922 version: u8,
923 body: &[u8],
924 timeout: Duration,
925 ) -> CmdResult<CmdBody> {
926 let mut send = self.get_send().await?;
927 send.send_with_resp(cmd, version, body, timeout).await
928 }
929
930 async fn send_parts(&self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
931 let mut send = self.get_send().await?;
932 send.send_parts(cmd, version, body).await
933 }
934
935 async fn send_parts_with_resp(
936 &self,
937 cmd: CMD,
938 version: u8,
939 body: &[&[u8]],
940 timeout: Duration,
941 ) -> CmdResult<CmdBody> {
942 let mut send = self.get_send().await?;
943 send.send_parts_with_resp(cmd, version, body, timeout).await
944 }
945
946 async fn send_cmd(&self, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
947 let mut send = self.get_send().await?;
948 send.send_cmd(cmd, version, body).await
949 }
950
951 async fn send_cmd_with_resp(
952 &self,
953 cmd: CMD,
954 version: u8,
955 body: CmdBody,
956 timeout: Duration,
957 ) -> CmdResult<CmdBody> {
958 let mut send = self.get_send().await?;
959 send.send_cmd_with_resp(cmd, version, body, timeout).await
960 }
961
962 async fn send_by_specify_tunnel(
963 &self,
964 tunnel_id: TunnelId,
965 cmd: CMD,
966 version: u8,
967 body: &[u8],
968 ) -> CmdResult<()> {
969 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
970 send.send(cmd, version, body).await
971 }
972
973 async fn send_by_specify_tunnel_with_resp(
974 &self,
975 tunnel_id: TunnelId,
976 cmd: CMD,
977 version: u8,
978 body: &[u8],
979 timeout: Duration,
980 ) -> CmdResult<CmdBody> {
981 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
982 send.send_with_resp(cmd, version, body, timeout).await
983 }
984
985 async fn send_parts_by_specify_tunnel(
986 &self,
987 tunnel_id: TunnelId,
988 cmd: CMD,
989 version: u8,
990 body: &[&[u8]],
991 ) -> CmdResult<()> {
992 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
993 send.send_parts(cmd, version, body).await
994 }
995
996 async fn send_parts_by_specify_tunnel_with_resp(
997 &self,
998 tunnel_id: TunnelId,
999 cmd: CMD,
1000 version: u8,
1001 body: &[&[u8]],
1002 timeout: Duration,
1003 ) -> CmdResult<CmdBody> {
1004 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1005 send.send_parts_with_resp(cmd, version, body, timeout).await
1006 }
1007
1008 async fn send_cmd_by_specify_tunnel(
1009 &self,
1010 tunnel_id: TunnelId,
1011 cmd: CMD,
1012 version: u8,
1013 body: CmdBody,
1014 ) -> CmdResult<()> {
1015 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1016 send.send_cmd(cmd, version, body).await
1017 }
1018
1019 async fn send_cmd_by_specify_tunnel_with_resp(
1020 &self,
1021 tunnel_id: TunnelId,
1022 cmd: CMD,
1023 version: u8,
1024 body: CmdBody,
1025 timeout: Duration,
1026 ) -> CmdResult<CmdBody> {
1027 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1028 send.send_cmd_with_resp(cmd, version, body, timeout).await
1029 }
1030
1031 async fn clear_all_tunnel(&self) {
1032 self.tunnel_pool.clear_all_worker().await;
1033 }
1034
1035 async fn get_send(
1036 &self,
1037 tunnel_id: TunnelId,
1038 ) -> CmdResult<CmdClientSendGuard<M, R, W, F, LEN, CMD>> {
1039 Ok(ClassifiedSendGuard {
1040 worker_guard: self.get_send_of_tunnel_id(tunnel_id).await?,
1041 _p: Default::default(),
1042 })
1043 }
1044}