1use crate::client::{
2 ClassifiedCmdClient, ClassifiedSendGuard, CmdClient, CmdSend, RespWaiter, RespWaiterRef,
3 gen_resp_id, gen_seq,
4};
5use crate::cmd::{CmdBodyRead, CmdHandler, CmdHandlerMap, CmdHeader};
6use crate::errors::{CmdErrorCode, CmdResult, cmd_err, into_cmd_err};
7use crate::peer_id::PeerId;
8use crate::{CmdBody, CmdTunnelMeta, CmdTunnelRead, CmdTunnelWrite, TunnelId, TunnelIdGenerator};
9use async_named_locker::ObjectHolder;
10use bucky_raw_codec::{RawConvertTo, RawDecode, RawEncode, RawFixedBytes, RawFrom};
11use num::{FromPrimitive, ToPrimitive};
12use sfo_pool::{
13 ClassifiedWorker, ClassifiedWorkerFactory, ClassifiedWorkerGuard, ClassifiedWorkerPool,
14 ClassifiedWorkerPoolRef, PoolErrorCode, PoolResult, WorkerClassification, into_pool_err,
15 pool_err,
16};
17use sfo_split::{RHalf, Splittable, WHalf};
18use std::fmt::Debug;
19use std::hash::Hash;
20use std::ops::DerefMut;
21use std::sync::{Arc, Mutex};
22use std::time::Duration;
23use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
24use tokio::spawn;
25use tokio::task::JoinHandle;
26
27pub trait ClassifiedCmdTunnelRead<C: WorkerClassification, M: CmdTunnelMeta>:
28 CmdTunnelRead<M> + 'static + Send
29{
30 fn get_classification(&self) -> C;
31}
32
33pub trait ClassifiedCmdTunnelWrite<C: WorkerClassification, M: CmdTunnelMeta>:
34 CmdTunnelWrite<M> + 'static + Send
35{
36 fn get_classification(&self) -> C;
37}
38
39pub type ClassifiedCmdTunnel<R, W> = Splittable<R, W>;
40pub type ClassifiedCmdTunnelRHalf<R, W> = RHalf<R, W>;
41pub type ClassifiedCmdTunnelWHalf<R, W> = WHalf<R, W>;
42
43#[derive(Debug, Clone, Copy, Eq, Hash)]
44pub struct CmdClientTunnelClassification<C: WorkerClassification> {
45 tunnel_id: Option<TunnelId>,
46 classification: Option<C>,
47}
48
49impl<C: WorkerClassification> PartialEq for CmdClientTunnelClassification<C> {
50 fn eq(&self, other: &Self) -> bool {
51 self.tunnel_id == other.tunnel_id && self.classification == other.classification
52 }
53}
54
55#[async_trait::async_trait]
56pub trait ClassifiedCmdTunnelFactory<
57 C: WorkerClassification,
58 M: CmdTunnelMeta,
59 R: ClassifiedCmdTunnelRead<C, M>,
60 W: ClassifiedCmdTunnelWrite<C, M>,
61>: Send + Sync + 'static
62{
63 async fn create_tunnel(&self, classification: Option<C>) -> CmdResult<Splittable<R, W>>;
64}
65
66pub struct ClassifiedCmdSend<C, M, R, W, LEN, CMD>
67where
68 C: WorkerClassification,
69 M: CmdTunnelMeta,
70 R: ClassifiedCmdTunnelRead<C, M>,
71 W: ClassifiedCmdTunnelWrite<C, M>,
72 LEN: RawEncode
73 + for<'a> RawDecode<'a>
74 + Copy
75 + Send
76 + Sync
77 + 'static
78 + FromPrimitive
79 + ToPrimitive,
80 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
81{
82 pub(crate) recv_handle: JoinHandle<CmdResult<()>>,
83 pub(crate) write: ObjectHolder<ClassifiedCmdTunnelWHalf<R, W>>,
84 pub(crate) is_work: bool,
85 pub(crate) classification: C,
86 pub(crate) tunnel_id: TunnelId,
87 pub(crate) resp_waiter: RespWaiterRef,
88 pub(crate) remote_id: PeerId,
89 pub(crate) tunnel_meta: Option<Arc<M>>,
90 _p: std::marker::PhantomData<(LEN, CMD)>,
91}
92
93impl<C, M, R, W, LEN, CMD> ClassifiedCmdSend<C, M, R, W, LEN, CMD>
107where
108 C: WorkerClassification,
109 M: CmdTunnelMeta,
110 R: ClassifiedCmdTunnelRead<C, M>,
111 W: ClassifiedCmdTunnelWrite<C, M>,
112 LEN: RawEncode
113 + for<'a> RawDecode<'a>
114 + Copy
115 + Send
116 + Sync
117 + 'static
118 + FromPrimitive
119 + ToPrimitive,
120 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
121{
122 pub(crate) fn new(
123 tunnel_id: TunnelId,
124 classification: C,
125 recv_handle: JoinHandle<CmdResult<()>>,
126 write: ObjectHolder<ClassifiedCmdTunnelWHalf<R, W>>,
127 resp_waiter: RespWaiterRef,
128 remote_id: PeerId,
129 tunnel_meta: Option<Arc<M>>,
130 ) -> Self {
131 Self {
132 recv_handle,
133 write,
134 is_work: true,
135 classification,
136 tunnel_id,
137 resp_waiter,
138 remote_id,
139 tunnel_meta,
140 _p: Default::default(),
141 }
142 }
143
144 pub fn get_tunnel_id(&self) -> TunnelId {
145 self.tunnel_id
146 }
147
148 pub fn set_disable(&mut self) {
149 self.is_work = false;
150 self.recv_handle.abort();
151 }
152
153 pub async fn send(&mut self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
154 log::trace!(
155 "client {:?} send cmd: {:?}, len: {}, data: {}",
156 self.tunnel_id,
157 cmd,
158 body.len(),
159 hex::encode(body)
160 );
161 let header = CmdHeader::<LEN, CMD>::new(
162 version,
163 false,
164 None,
165 cmd,
166 LEN::from_u64(body.len() as u64).unwrap(),
167 );
168 let buf = header
169 .to_vec()
170 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
171 let ret = self.send_inner(buf.as_slice(), body).await;
172 if let Err(e) = ret {
173 self.set_disable();
174 return Err(e);
175 }
176 Ok(())
177 }
178
179 pub async fn send_with_resp(
180 &mut self,
181 cmd: CMD,
182 version: u8,
183 body: &[u8],
184 timeout: Duration,
185 ) -> CmdResult<CmdBody> {
186 if let Some(id) = tokio::task::try_id() {
187 if id == self.recv_handle.id() {
188 return Err(cmd_err!(
189 CmdErrorCode::Failed,
190 "can't send with resp in recv task"
191 ));
192 }
193 }
194 log::trace!(
195 "client {:?} send cmd: {:?}, len: {}, data: {}",
196 self.tunnel_id,
197 cmd,
198 body.len(),
199 hex::encode(body)
200 );
201 let seq = gen_seq();
202 let header = CmdHeader::<LEN, CMD>::new(
203 version,
204 false,
205 Some(seq),
206 cmd,
207 LEN::from_u64(body.len() as u64).unwrap(),
208 );
209 let buf = header
210 .to_vec()
211 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
212 let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
213 let waiter = self.resp_waiter.clone();
214 let resp_waiter = waiter
215 .create_timeout_result_future(resp_id, timeout)
216 .map_err(into_cmd_err!(
217 CmdErrorCode::Failed,
218 "create timeout result future error"
219 ))?;
220 let ret = self.send_inner(buf.as_slice(), body).await;
221 if let Err(e) = ret {
222 self.set_disable();
223 return Err(e);
224 }
225 let resp = resp_waiter
226 .await
227 .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
228 Ok(resp)
229 }
230
231 pub async fn send2(&mut self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
232 let mut len = 0;
233 for b in body.iter() {
234 len += b.len();
235 log::trace!(
236 "client {:?} send2 cmd {:?} body: {}",
237 self.tunnel_id,
238 cmd,
239 hex::encode(b)
240 );
241 }
242 log::trace!(
243 "client {:?} send2 cmd: {:?}, len {}",
244 self.tunnel_id,
245 cmd,
246 len
247 );
248 let header = CmdHeader::<LEN, CMD>::new(
249 version,
250 false,
251 None,
252 cmd,
253 LEN::from_u64(len as u64).unwrap(),
254 );
255 let buf = header
256 .to_vec()
257 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
258 let ret = self.send_inner2(buf.as_slice(), body).await;
259 if let Err(e) = ret {
260 self.set_disable();
261 return Err(e);
262 }
263 Ok(())
264 }
265
266 pub async fn send2_with_resp(
267 &mut self,
268 cmd: CMD,
269 version: u8,
270 body: &[&[u8]],
271 timeout: Duration,
272 ) -> CmdResult<CmdBody> {
273 if let Some(id) = tokio::task::try_id() {
274 if id == self.recv_handle.id() {
275 return Err(cmd_err!(
276 CmdErrorCode::Failed,
277 "can't send with resp in recv task"
278 ));
279 }
280 }
281 let mut len = 0;
282 for b in body.iter() {
283 len += b.len();
284 log::trace!(
285 "client {:?} send2 cmd {:?} body: {}",
286 self.tunnel_id,
287 cmd,
288 hex::encode(b)
289 );
290 }
291 log::trace!(
292 "client {:?} send2 cmd: {:?}, len {}",
293 self.tunnel_id,
294 cmd,
295 len
296 );
297 let seq = gen_seq();
298 let header = CmdHeader::<LEN, CMD>::new(
299 version,
300 false,
301 Some(seq),
302 cmd,
303 LEN::from_u64(len as u64).unwrap(),
304 );
305 let buf = header
306 .to_vec()
307 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
308 let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
309 let waiter = self.resp_waiter.clone();
310 let resp_waiter = waiter
311 .create_timeout_result_future(resp_id, timeout)
312 .map_err(into_cmd_err!(
313 CmdErrorCode::Failed,
314 "create timeout result future error"
315 ))?;
316 let ret = self.send_inner2(buf.as_slice(), body).await;
317 if let Err(e) = ret {
318 self.set_disable();
319 return Err(e);
320 }
321 let resp = resp_waiter
322 .await
323 .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
324 Ok(resp)
325 }
326
327 pub async fn send_cmd(&mut self, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
328 log::trace!(
329 "client {:?} send cmd: {:?}, len: {}",
330 self.tunnel_id,
331 cmd,
332 body.len()
333 );
334 let header = CmdHeader::<LEN, CMD>::new(
335 version,
336 false,
337 None,
338 cmd,
339 LEN::from_u64(body.len()).unwrap(),
340 );
341 let buf = header
342 .to_vec()
343 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
344 let ret = self.send_inner_cmd(buf.as_slice(), body).await;
345 if let Err(e) = ret {
346 self.set_disable();
347 return Err(e);
348 }
349 Ok(())
350 }
351
352 pub async fn send_cmd_with_resp(
353 &mut self,
354 cmd: CMD,
355 version: u8,
356 body: CmdBody,
357 timeout: Duration,
358 ) -> CmdResult<CmdBody> {
359 if let Some(id) = tokio::task::try_id() {
360 if id == self.recv_handle.id() {
361 return Err(cmd_err!(
362 CmdErrorCode::Failed,
363 "can't send with resp in recv task"
364 ));
365 }
366 }
367 log::trace!(
368 "client {:?} send cmd: {:?}, len: {}",
369 self.tunnel_id,
370 cmd,
371 body.len()
372 );
373 let seq = gen_seq();
374 let header = CmdHeader::<LEN, CMD>::new(
375 version,
376 false,
377 Some(seq),
378 cmd,
379 LEN::from_u64(body.len()).unwrap(),
380 );
381 let buf = header
382 .to_vec()
383 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
384 let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
385 let waiter = self.resp_waiter.clone();
386 let resp_waiter = waiter
387 .create_timeout_result_future(resp_id, timeout)
388 .map_err(into_cmd_err!(
389 CmdErrorCode::Failed,
390 "create timeout result future error"
391 ))?;
392 let ret = self.send_inner_cmd(buf.as_slice(), body).await;
393 if let Err(e) = ret {
394 self.set_disable();
395 return Err(e);
396 }
397 let resp = resp_waiter
398 .await
399 .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
400 Ok(resp)
401 }
402
403 async fn send_inner(&mut self, header: &[u8], body: &[u8]) -> CmdResult<()> {
404 let mut write = self.write.get().await;
405 if header.len() > 255 {
406 return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
407 }
408 write
409 .write_u8(header.len() as u8)
410 .await
411 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
412 write
413 .write_all(header)
414 .await
415 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
416 write
417 .write_all(body)
418 .await
419 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
420 write
421 .flush()
422 .await
423 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
424 Ok(())
425 }
426
427 async fn send_inner2(&mut self, header: &[u8], body: &[&[u8]]) -> CmdResult<()> {
428 let mut write = self.write.get().await;
429 if header.len() > 255 {
430 return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
431 }
432 write
433 .write_u8(header.len() as u8)
434 .await
435 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
436 write
437 .write_all(header)
438 .await
439 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
440 for b in body.iter() {
441 write
442 .write_all(b)
443 .await
444 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
445 }
446 write
447 .flush()
448 .await
449 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
450 Ok(())
451 }
452
453 async fn send_inner_cmd(&mut self, header: &[u8], mut body: CmdBody) -> CmdResult<()> {
454 let mut write = self.write.get().await;
455 if header.len() > 255 {
456 return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
457 }
458 write
459 .write_u8(header.len() as u8)
460 .await
461 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
462 write
463 .write_all(header)
464 .await
465 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
466 tokio::io::copy(&mut body, write.deref_mut().deref_mut())
467 .await
468 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
469 write
470 .flush()
471 .await
472 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
473 Ok(())
474 }
475}
476
477impl<C, M, R, W, LEN, CMD> Drop for ClassifiedCmdSend<C, M, R, W, LEN, CMD>
478where
479 C: WorkerClassification,
480 M: CmdTunnelMeta,
481 R: ClassifiedCmdTunnelRead<C, M>,
482 W: ClassifiedCmdTunnelWrite<C, M>,
483 LEN: RawEncode
484 + for<'a> RawDecode<'a>
485 + Copy
486 + Send
487 + Sync
488 + 'static
489 + FromPrimitive
490 + ToPrimitive,
491 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
492{
493 fn drop(&mut self) {
494 self.set_disable();
495 }
496}
497
498impl<C, M, R, W, LEN, CMD> CmdSend<M> for ClassifiedCmdSend<C, M, R, W, LEN, CMD>
499where
500 C: WorkerClassification,
501 M: CmdTunnelMeta,
502 R: ClassifiedCmdTunnelRead<C, M>,
503 W: ClassifiedCmdTunnelWrite<C, M>,
504 LEN: RawEncode
505 + for<'a> RawDecode<'a>
506 + Copy
507 + Send
508 + Sync
509 + 'static
510 + FromPrimitive
511 + ToPrimitive,
512 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
513{
514 fn get_tunnel_meta(&self) -> Option<Arc<M>> {
515 self.tunnel_meta.clone()
516 }
517
518 fn get_remote_peer_id(&self) -> PeerId {
519 self.remote_id.clone()
520 }
521}
522
523impl<C, M, R, W, LEN, CMD> ClassifiedWorker<CmdClientTunnelClassification<C>>
524 for ClassifiedCmdSend<C, M, R, W, LEN, CMD>
525where
526 C: WorkerClassification,
527 M: CmdTunnelMeta,
528 R: ClassifiedCmdTunnelRead<C, M>,
529 W: ClassifiedCmdTunnelWrite<C, M>,
530 LEN: RawEncode
531 + for<'a> RawDecode<'a>
532 + Copy
533 + Send
534 + Sync
535 + 'static
536 + FromPrimitive
537 + ToPrimitive,
538 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
539{
540 fn is_work(&self) -> bool {
541 self.is_work && !self.recv_handle.is_finished()
542 }
543
544 fn is_valid(&self, c: CmdClientTunnelClassification<C>) -> bool {
545 if c.tunnel_id.is_some() {
546 self.tunnel_id == c.tunnel_id.unwrap()
547 } else {
548 if c.classification.is_some() {
549 self.classification == c.classification.unwrap()
550 } else {
551 true
552 }
553 }
554 }
555
556 fn classification(&self) -> CmdClientTunnelClassification<C> {
557 CmdClientTunnelClassification {
558 tunnel_id: Some(self.tunnel_id),
559 classification: Some(self.classification.clone()),
560 }
561 }
562}
563
564pub struct ClassifiedCmdWriteFactory<
565 C: WorkerClassification,
566 M: CmdTunnelMeta,
567 R: ClassifiedCmdTunnelRead<C, M>,
568 W: ClassifiedCmdTunnelWrite<C, M>,
569 F: ClassifiedCmdTunnelFactory<C, M, R, W>,
570 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
571 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
572> {
573 tunnel_factory: F,
574 cmd_handler: Arc<dyn CmdHandler<LEN, CMD>>,
575 resp_waiter: RespWaiterRef,
576 tunnel_id_generator: TunnelIdGenerator,
577 _p: std::marker::PhantomData<Mutex<(C, M, R, W)>>,
578}
579
580impl<
581 C: WorkerClassification,
582 M: CmdTunnelMeta,
583 R: ClassifiedCmdTunnelRead<C, M>,
584 W: ClassifiedCmdTunnelWrite<C, M>,
585 F: ClassifiedCmdTunnelFactory<C, M, R, W>,
586 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
587 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
588> ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>
589{
590 pub(crate) fn new(
591 tunnel_factory: F,
592 cmd_handler: impl CmdHandler<LEN, CMD>,
593 resp_waiter: RespWaiterRef,
594 ) -> Self {
595 Self {
596 tunnel_factory,
597 cmd_handler: Arc::new(cmd_handler),
598 resp_waiter,
599 tunnel_id_generator: TunnelIdGenerator::new(),
600 _p: Default::default(),
601 }
602 }
603}
604
605#[async_trait::async_trait]
606impl<
607 C: WorkerClassification,
608 M: CmdTunnelMeta,
609 R: ClassifiedCmdTunnelRead<C, M>,
610 W: ClassifiedCmdTunnelWrite<C, M>,
611 F: ClassifiedCmdTunnelFactory<C, M, R, W>,
612 LEN: RawEncode
613 + for<'a> RawDecode<'a>
614 + Copy
615 + Send
616 + Sync
617 + 'static
618 + FromPrimitive
619 + ToPrimitive
620 + RawFixedBytes,
621 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Debug,
622> ClassifiedWorkerFactory<CmdClientTunnelClassification<C>, ClassifiedCmdSend<C, M, R, W, LEN, CMD>>
623 for ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>
624{
625 async fn create(
626 &self,
627 classification: Option<CmdClientTunnelClassification<C>>,
628 ) -> PoolResult<ClassifiedCmdSend<C, M, R, W, LEN, CMD>> {
629 if classification.is_some() && classification.as_ref().unwrap().tunnel_id.is_some() {
630 return Err(pool_err!(
631 PoolErrorCode::Failed,
632 "tunnel {:?} not found",
633 classification.as_ref().unwrap().tunnel_id.unwrap()
634 ));
635 }
636
637 let classification = if classification.is_some()
638 && classification.as_ref().unwrap().classification.is_some()
639 {
640 classification.unwrap().classification
641 } else {
642 None
643 };
644 let tunnel = self
645 .tunnel_factory
646 .create_tunnel(classification)
647 .await
648 .map_err(into_pool_err!(PoolErrorCode::Failed))?;
649 let classification = tunnel.get_classification();
650 let peer_id = tunnel.get_remote_peer_id();
651 let tunnel_id = self.tunnel_id_generator.generate();
652 let (mut recv, write) = tunnel.split();
653 let remote_id = peer_id.clone();
654 let tunnel_meta = recv.get_tunnel_meta();
655 let write = ObjectHolder::new(write);
656 let resp_write = write.clone();
657 let cmd_handler = self.cmd_handler.clone();
658 let handle = spawn(async move {
659 let ret: CmdResult<()> = async move {
660 loop {
661 let header_len = recv
662 .read_u8()
663 .await
664 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
665 let mut header = vec![0u8; header_len as usize];
666 let n = recv
667 .read_exact(header.as_mut())
668 .await
669 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
670 if n == 0 {
671 break;
672 }
673 let header = CmdHeader::<LEN, CMD>::clone_from_slice(header.as_slice())
674 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
675 log::trace!(
676 "recv cmd {:?} from {} len {} tunnel {:?}",
677 header.cmd_code(),
678 peer_id,
679 header.pkg_len().to_u64().unwrap(),
680 tunnel_id
681 );
682 let body_len = header.pkg_len().to_u64().unwrap();
683 let cmd_read =
684 CmdBodyRead::new(recv, header.pkg_len().to_u64().unwrap() as usize);
685 let waiter = cmd_read.get_waiter();
686 let future = waiter
687 .create_result_future()
688 .map_err(into_cmd_err!(CmdErrorCode::Failed))?;
689 let version = header.version();
690 let seq = header.seq();
691 let cmd_code = header.cmd_code();
692 match cmd_handler
693 .handle(
694 peer_id.clone(),
695 tunnel_id,
696 header,
697 CmdBody::from_reader(BufReader::new(cmd_read), body_len),
698 )
699 .await
700 {
701 Ok(Some(mut body)) => {
702 let mut write = resp_write.get().await;
703 let header = CmdHeader::<LEN, CMD>::new(
704 version,
705 true,
706 seq,
707 cmd_code,
708 LEN::from_u64(body.len()).unwrap(),
709 );
710 let buf = header
711 .to_vec()
712 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
713 if buf.len() > 255 {
714 return Err(cmd_err!(
715 CmdErrorCode::InvalidParam,
716 "header len too large"
717 ));
718 }
719 write
720 .write_u8(buf.len() as u8)
721 .await
722 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
723 write
724 .write_all(buf.as_slice())
725 .await
726 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
727 tokio::io::copy(&mut body, write.deref_mut().deref_mut())
728 .await
729 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
730 write
731 .flush()
732 .await
733 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
734 }
735 Err(e) => {
736 log::error!("handle cmd error: {:?}", e);
737 }
738 _ => {}
739 }
740 recv = future
741 .await
742 .map_err(into_cmd_err!(CmdErrorCode::Failed))??;
743 log::debug!(
744 "handle cmd {:?} from {} len {} tunnel {:?} complete",
745 cmd_code,
746 peer_id,
747 body_len,
748 tunnel_id
749 );
750 }
751 Ok(())
752 }
753 .await;
754 if ret.is_err() {
755 log::error!("recv cmd error: {:?}", ret.as_ref().err().unwrap());
756 }
757 ret
758 });
759 Ok(ClassifiedCmdSend::new(
760 tunnel_id,
761 classification,
762 handle,
763 write,
764 self.resp_waiter.clone(),
765 remote_id,
766 tunnel_meta,
767 ))
768 }
769}
770
771pub struct DefaultClassifiedCmdClient<
772 C: WorkerClassification,
773 M: CmdTunnelMeta,
774 R: ClassifiedCmdTunnelRead<C, M>,
775 W: ClassifiedCmdTunnelWrite<C, M>,
776 F: ClassifiedCmdTunnelFactory<C, M, R, W>,
777 LEN: RawEncode
778 + for<'a> RawDecode<'a>
779 + Copy
780 + Send
781 + Sync
782 + 'static
783 + FromPrimitive
784 + ToPrimitive
785 + RawFixedBytes,
786 CMD: RawEncode
787 + for<'a> RawDecode<'a>
788 + Copy
789 + Send
790 + Sync
791 + 'static
792 + RawFixedBytes
793 + Eq
794 + Hash
795 + Debug,
796> {
797 tunnel_pool: ClassifiedWorkerPoolRef<
798 CmdClientTunnelClassification<C>,
799 ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
800 ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>,
801 >,
802 cmd_handler_map: Arc<CmdHandlerMap<LEN, CMD>>,
803}
804
805impl<
806 C: WorkerClassification,
807 M: CmdTunnelMeta,
808 R: ClassifiedCmdTunnelRead<C, M>,
809 W: ClassifiedCmdTunnelWrite<C, M>,
810 F: ClassifiedCmdTunnelFactory<C, M, R, W>,
811 LEN: RawEncode
812 + for<'a> RawDecode<'a>
813 + Copy
814 + Send
815 + Sync
816 + 'static
817 + FromPrimitive
818 + ToPrimitive
819 + RawFixedBytes,
820 CMD: RawEncode
821 + for<'a> RawDecode<'a>
822 + Copy
823 + Send
824 + Sync
825 + 'static
826 + RawFixedBytes
827 + Eq
828 + Hash
829 + Debug,
830> DefaultClassifiedCmdClient<C, M, R, W, F, LEN, CMD>
831{
832 pub fn new(factory: F, tunnel_count: u16) -> Arc<Self> {
833 let cmd_handler_map = Arc::new(CmdHandlerMap::new());
834 let resp_waiter = Arc::new(RespWaiter::new());
835 let handler_map = cmd_handler_map.clone();
836 let waiter = resp_waiter.clone();
837 Arc::new(Self {
838 tunnel_pool: ClassifiedWorkerPool::new(
839 tunnel_count,
840 ClassifiedCmdWriteFactory::<C, M, R, W, _, LEN, CMD>::new(
841 factory,
842 move |peer_id: PeerId,
843 tunnel_id: TunnelId,
844 header: CmdHeader<LEN, CMD>,
845 body_read: CmdBody| {
846 let handler_map = handler_map.clone();
847 let waiter = waiter.clone();
848 async move {
849 if header.is_resp() && header.seq().is_some() {
850 let resp_id = gen_resp_id(
851 tunnel_id,
852 header.cmd_code(),
853 header.seq().unwrap(),
854 );
855 let _ = waiter.set_result(resp_id, body_read);
856 Ok(None)
857 } else {
858 if let Some(handler) = handler_map.get(header.cmd_code()) {
859 handler.handle(peer_id, tunnel_id, header, body_read).await
860 } else {
861 Ok(None)
862 }
863 }
864 }
865 },
866 resp_waiter.clone(),
867 ),
868 ),
869 cmd_handler_map,
870 })
871 }
872
873 async fn get_send(
874 &self,
875 ) -> CmdResult<
876 ClassifiedWorkerGuard<
877 CmdClientTunnelClassification<C>,
878 ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
879 ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>,
880 >,
881 > {
882 self.tunnel_pool
883 .get_worker()
884 .await
885 .map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
886 }
887
888 async fn get_send_of_tunnel_id(
889 &self,
890 tunnel_id: TunnelId,
891 ) -> CmdResult<
892 ClassifiedWorkerGuard<
893 CmdClientTunnelClassification<C>,
894 ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
895 ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>,
896 >,
897 > {
898 self.tunnel_pool
899 .get_classified_worker(CmdClientTunnelClassification {
900 tunnel_id: Some(tunnel_id),
901 classification: None,
902 })
903 .await
904 .map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
905 }
906
907 async fn get_classified_send(
908 &self,
909 classification: C,
910 ) -> CmdResult<
911 ClassifiedWorkerGuard<
912 CmdClientTunnelClassification<C>,
913 ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
914 ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>,
915 >,
916 > {
917 self.tunnel_pool
918 .get_classified_worker(CmdClientTunnelClassification {
919 tunnel_id: None,
920 classification: Some(classification),
921 })
922 .await
923 .map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
924 }
925}
926
927pub type ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD> = ClassifiedSendGuard<
928 CmdClientTunnelClassification<C>,
929 M,
930 ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
931 ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>,
932>;
933#[async_trait::async_trait]
934impl<
935 C: WorkerClassification,
936 M: CmdTunnelMeta,
937 R: ClassifiedCmdTunnelRead<C, M>,
938 W: ClassifiedCmdTunnelWrite<C, M>,
939 F: ClassifiedCmdTunnelFactory<C, M, R, W>,
940 LEN: RawEncode
941 + for<'a> RawDecode<'a>
942 + Copy
943 + Send
944 + Sync
945 + 'static
946 + FromPrimitive
947 + ToPrimitive
948 + RawFixedBytes,
949 CMD: RawEncode
950 + for<'a> RawDecode<'a>
951 + Copy
952 + Send
953 + Sync
954 + 'static
955 + RawFixedBytes
956 + Eq
957 + Hash
958 + Debug,
959>
960 CmdClient<
961 LEN,
962 CMD,
963 M,
964 ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
965 ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD>,
966 > for DefaultClassifiedCmdClient<C, M, R, W, F, LEN, CMD>
967{
968 fn register_cmd_handler(&self, cmd: CMD, handler: impl CmdHandler<LEN, CMD>) {
969 self.cmd_handler_map.insert(cmd, handler);
970 }
971
972 async fn send(&self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
973 let mut send = self.get_send().await?;
974 send.send(cmd, version, body).await
975 }
976
977 async fn send_with_resp(
978 &self,
979 cmd: CMD,
980 version: u8,
981 body: &[u8],
982 timeout: Duration,
983 ) -> CmdResult<CmdBody> {
984 let mut send = self.get_send().await?;
985 send.send_with_resp(cmd, version, body, timeout).await
986 }
987
988 async fn send2(&self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
989 let mut send = self.get_send().await?;
990 send.send2(cmd, version, body).await
991 }
992
993 async fn send2_with_resp(
994 &self,
995 cmd: CMD,
996 version: u8,
997 body: &[&[u8]],
998 timeout: Duration,
999 ) -> CmdResult<CmdBody> {
1000 let mut send = self.get_send().await?;
1001 send.send2_with_resp(cmd, version, body, timeout).await
1002 }
1003
1004 async fn send_cmd(&self, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
1005 let mut send = self.get_send().await?;
1006 send.send_cmd(cmd, version, body).await
1007 }
1008
1009 async fn send_cmd_with_resp(
1010 &self,
1011 cmd: CMD,
1012 version: u8,
1013 body: CmdBody,
1014 timeout: Duration,
1015 ) -> CmdResult<CmdBody> {
1016 let mut send = self.get_send().await?;
1017 send.send_cmd_with_resp(cmd, version, body, timeout).await
1018 }
1019
1020 async fn send_by_specify_tunnel(
1021 &self,
1022 tunnel_id: TunnelId,
1023 cmd: CMD,
1024 version: u8,
1025 body: &[u8],
1026 ) -> CmdResult<()> {
1027 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1028 send.send(cmd, version, body).await
1029 }
1030
1031 async fn send_by_specify_tunnel_with_resp(
1032 &self,
1033 tunnel_id: TunnelId,
1034 cmd: CMD,
1035 version: u8,
1036 body: &[u8],
1037 timeout: Duration,
1038 ) -> CmdResult<CmdBody> {
1039 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1040 send.send_with_resp(cmd, version, body, timeout).await
1041 }
1042
1043 async fn send2_by_specify_tunnel(
1044 &self,
1045 tunnel_id: TunnelId,
1046 cmd: CMD,
1047 version: u8,
1048 body: &[&[u8]],
1049 ) -> CmdResult<()> {
1050 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1051 send.send2(cmd, version, body).await
1052 }
1053
1054 async fn send2_by_specify_tunnel_with_resp(
1055 &self,
1056 tunnel_id: TunnelId,
1057 cmd: CMD,
1058 version: u8,
1059 body: &[&[u8]],
1060 timeout: Duration,
1061 ) -> CmdResult<CmdBody> {
1062 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1063 send.send2_with_resp(cmd, version, body, timeout).await
1064 }
1065
1066 async fn send_cmd_by_specify_tunnel(
1067 &self,
1068 tunnel_id: TunnelId,
1069 cmd: CMD,
1070 version: u8,
1071 body: CmdBody,
1072 ) -> CmdResult<()> {
1073 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1074 send.send_cmd(cmd, version, body).await
1075 }
1076
1077 async fn send_cmd_by_specify_tunnel_with_resp(
1078 &self,
1079 tunnel_id: TunnelId,
1080 cmd: CMD,
1081 version: u8,
1082 body: CmdBody,
1083 timeout: Duration,
1084 ) -> CmdResult<CmdBody> {
1085 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1086 send.send_cmd_with_resp(cmd, version, body, timeout).await
1087 }
1088
1089 async fn clear_all_tunnel(&self) {
1090 self.tunnel_pool.clear_all_worker().await;
1091 }
1092
1093 async fn get_send(
1094 &self,
1095 tunnel_id: TunnelId,
1096 ) -> CmdResult<ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD>> {
1097 Ok(ClassifiedSendGuard {
1098 worker_guard: self.get_send_of_tunnel_id(tunnel_id).await?,
1099 _p: std::marker::PhantomData,
1100 })
1101 }
1102}
1103
1104#[async_trait::async_trait]
1105impl<
1106 C: WorkerClassification,
1107 M: CmdTunnelMeta,
1108 R: ClassifiedCmdTunnelRead<C, M>,
1109 W: ClassifiedCmdTunnelWrite<C, M>,
1110 F: ClassifiedCmdTunnelFactory<C, M, R, W>,
1111 LEN: RawEncode
1112 + for<'a> RawDecode<'a>
1113 + Copy
1114 + Send
1115 + Sync
1116 + 'static
1117 + FromPrimitive
1118 + ToPrimitive
1119 + RawFixedBytes,
1120 CMD: RawEncode
1121 + for<'a> RawDecode<'a>
1122 + Copy
1123 + Send
1124 + Sync
1125 + 'static
1126 + RawFixedBytes
1127 + Eq
1128 + Hash
1129 + Debug,
1130>
1131 ClassifiedCmdClient<
1132 LEN,
1133 CMD,
1134 C,
1135 M,
1136 ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
1137 ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD>,
1138 > for DefaultClassifiedCmdClient<C, M, R, W, F, LEN, CMD>
1139{
1140 async fn send_by_classified_tunnel(
1141 &self,
1142 classification: C,
1143 cmd: CMD,
1144 version: u8,
1145 body: &[u8],
1146 ) -> CmdResult<()> {
1147 let mut send = self.get_classified_send(classification).await?;
1148 send.send(cmd, version, body).await
1149 }
1150
1151 async fn send_by_classified_tunnel_with_resp(
1152 &self,
1153 classification: C,
1154 cmd: CMD,
1155 version: u8,
1156 body: &[u8],
1157 timeout: Duration,
1158 ) -> CmdResult<CmdBody> {
1159 let mut send = self.get_classified_send(classification).await?;
1160 send.send_with_resp(cmd, version, body, timeout).await
1161 }
1162
1163 async fn send2_by_classified_tunnel(
1164 &self,
1165 classification: C,
1166 cmd: CMD,
1167 version: u8,
1168 body: &[&[u8]],
1169 ) -> CmdResult<()> {
1170 let mut send = self.get_classified_send(classification).await?;
1171 send.send2(cmd, version, body).await
1172 }
1173
1174 async fn send2_by_classified_tunnel_with_resp(
1175 &self,
1176 classification: C,
1177 cmd: CMD,
1178 version: u8,
1179 body: &[&[u8]],
1180 timeout: Duration,
1181 ) -> CmdResult<CmdBody> {
1182 let mut send = self.get_classified_send(classification).await?;
1183 send.send2_with_resp(cmd, version, body, timeout).await
1184 }
1185
1186 async fn send_cmd_by_classified_tunnel(
1187 &self,
1188 classification: C,
1189 cmd: CMD,
1190 version: u8,
1191 body: CmdBody,
1192 ) -> CmdResult<()> {
1193 let mut send = self.get_classified_send(classification).await?;
1194 send.send_cmd(cmd, version, body).await
1195 }
1196
1197 async fn send_cmd_by_classified_tunnel_with_resp(
1198 &self,
1199 classification: C,
1200 cmd: CMD,
1201 version: u8,
1202 body: CmdBody,
1203 timeout: Duration,
1204 ) -> CmdResult<CmdBody> {
1205 let mut send = self.get_classified_send(classification).await?;
1206 send.send_cmd_with_resp(cmd, version, body, timeout).await
1207 }
1208
1209 async fn find_tunnel_id_by_classified(&self, classification: C) -> CmdResult<TunnelId> {
1210 let send = self.get_classified_send(classification).await?;
1211 Ok(send.get_tunnel_id())
1212 }
1213
1214 async fn get_send_by_classified(
1215 &self,
1216 classification: C,
1217 ) -> CmdResult<
1218 ClassifiedSendGuard<
1219 CmdClientTunnelClassification<C>,
1220 M,
1221 ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
1222 ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>,
1223 >,
1224 > {
1225 Ok(ClassifiedSendGuard {
1226 worker_guard: self.get_classified_send(classification).await?,
1227 _p: std::marker::PhantomData,
1228 })
1229 }
1230}