1use crate::client::{
2 ClassifiedCmdClient, ClassifiedSendGuard, CmdClient, CmdSend, RespWaiter, RespWaiterRef,
3 gen_resp_id, gen_seq,
4};
5use crate::cmd::{CmdBodyRead, CmdHandler, CmdHandlerMap, CmdHeader};
6use crate::errors::{CmdErrorCode, CmdResult, cmd_err, into_cmd_err};
7use crate::peer_id::PeerId;
8use crate::{CmdBody, CmdTunnelMeta, CmdTunnelRead, CmdTunnelWrite, TunnelId, TunnelIdGenerator};
9use async_named_locker::ObjectHolder;
10use bucky_raw_codec::{RawConvertTo, RawDecode, RawEncode, RawFixedBytes, RawFrom};
11use num::{FromPrimitive, ToPrimitive};
12use sfo_pool::{
13 ClassifiedWorker, ClassifiedWorkerFactory, ClassifiedWorkerGuard, ClassifiedWorkerPool,
14 ClassifiedWorkerPoolRef, PoolErrorCode, PoolResult, WorkerClassification, into_pool_err,
15 pool_err,
16};
17use sfo_split::{RHalf, Splittable, WHalf};
18use std::fmt::Debug;
19use std::hash::Hash;
20use std::ops::DerefMut;
21use std::sync::{Arc, Mutex};
22use std::time::Duration;
23use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
24use tokio::spawn;
25use tokio::task::JoinHandle;
26
27pub trait ClassifiedCmdTunnelRead<C: WorkerClassification, M: CmdTunnelMeta>:
28 CmdTunnelRead<M> + 'static + Send
29{
30 fn get_classification(&self) -> C;
31}
32
33pub trait ClassifiedCmdTunnelWrite<C: WorkerClassification, M: CmdTunnelMeta>:
34 CmdTunnelWrite<M> + 'static + Send
35{
36 fn get_classification(&self) -> C;
37}
38
39pub type ClassifiedCmdTunnel<R, W> = Splittable<R, W>;
40pub type ClassifiedCmdTunnelRHalf<R, W> = RHalf<R, W>;
41pub type ClassifiedCmdTunnelWHalf<R, W> = WHalf<R, W>;
42
43#[derive(Debug, Clone, Copy, Eq, Hash)]
44pub struct CmdClientTunnelClassification<C: WorkerClassification> {
45 tunnel_id: Option<TunnelId>,
46 classification: Option<C>,
47}
48
49impl<C: WorkerClassification> PartialEq for CmdClientTunnelClassification<C> {
50 fn eq(&self, other: &Self) -> bool {
51 self.tunnel_id == other.tunnel_id && self.classification == other.classification
52 }
53}
54
55#[async_trait::async_trait]
56pub trait ClassifiedCmdTunnelFactory<
57 C: WorkerClassification,
58 M: CmdTunnelMeta,
59 R: ClassifiedCmdTunnelRead<C, M>,
60 W: ClassifiedCmdTunnelWrite<C, M>,
61>: Send + Sync + 'static
62{
63 async fn create_tunnel(&self, classification: Option<C>) -> CmdResult<Splittable<R, W>>;
64}
65
66pub struct ClassifiedCmdSend<C, M, R, W, LEN, CMD>
67where
68 C: WorkerClassification,
69 M: CmdTunnelMeta,
70 R: ClassifiedCmdTunnelRead<C, M>,
71 W: ClassifiedCmdTunnelWrite<C, M>,
72 LEN: RawEncode
73 + for<'a> RawDecode<'a>
74 + Copy
75 + Send
76 + Sync
77 + 'static
78 + FromPrimitive
79 + ToPrimitive,
80 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
81{
82 pub(crate) recv_handle: JoinHandle<CmdResult<()>>,
83 pub(crate) write: ObjectHolder<ClassifiedCmdTunnelWHalf<R, W>>,
84 pub(crate) is_work: bool,
85 pub(crate) classification: C,
86 pub(crate) tunnel_id: TunnelId,
87 pub(crate) resp_waiter: RespWaiterRef,
88 pub(crate) remote_id: PeerId,
89 pub(crate) tunnel_meta: Option<Arc<M>>,
90 _p: std::marker::PhantomData<(LEN, CMD)>,
91}
92
93impl<C, M, R, W, LEN, CMD> ClassifiedCmdSend<C, M, R, W, LEN, CMD>
107where
108 C: WorkerClassification,
109 M: CmdTunnelMeta,
110 R: ClassifiedCmdTunnelRead<C, M>,
111 W: ClassifiedCmdTunnelWrite<C, M>,
112 LEN: RawEncode
113 + for<'a> RawDecode<'a>
114 + Copy
115 + Send
116 + Sync
117 + 'static
118 + FromPrimitive
119 + ToPrimitive,
120 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
121{
122 pub(crate) fn new(
123 tunnel_id: TunnelId,
124 classification: C,
125 recv_handle: JoinHandle<CmdResult<()>>,
126 write: ObjectHolder<ClassifiedCmdTunnelWHalf<R, W>>,
127 resp_waiter: RespWaiterRef,
128 remote_id: PeerId,
129 tunnel_meta: Option<Arc<M>>,
130 ) -> Self {
131 Self {
132 recv_handle,
133 write,
134 is_work: true,
135 classification,
136 tunnel_id,
137 resp_waiter,
138 remote_id,
139 tunnel_meta,
140 _p: Default::default(),
141 }
142 }
143
144 pub fn get_tunnel_id(&self) -> TunnelId {
145 self.tunnel_id
146 }
147
148 pub fn set_disable(&mut self) {
149 self.is_work = false;
150 self.recv_handle.abort();
151 }
152
153 pub async fn send(&mut self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
154 log::trace!(
155 "client {:?} send cmd: {:?}, len: {}, data: {}",
156 self.tunnel_id,
157 cmd,
158 body.len(),
159 hex::encode(body)
160 );
161 let header = CmdHeader::<LEN, CMD>::new(
162 version,
163 false,
164 None,
165 cmd,
166 LEN::from_u64(body.len() as u64).unwrap(),
167 );
168 let buf = header
169 .to_vec()
170 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
171 let ret = self.send_inner(buf.as_slice(), body).await;
172 if let Err(e) = ret {
173 self.set_disable();
174 return Err(e);
175 }
176 Ok(())
177 }
178
179 pub async fn send_with_resp(
180 &mut self,
181 cmd: CMD,
182 version: u8,
183 body: &[u8],
184 timeout: Duration,
185 ) -> CmdResult<CmdBody> {
186 if let Some(id) = tokio::task::try_id() {
187 if id == self.recv_handle.id() {
188 return Err(cmd_err!(
189 CmdErrorCode::Failed,
190 "can't send with resp in recv task"
191 ));
192 }
193 }
194 log::trace!(
195 "client {:?} send cmd: {:?}, len: {}, data: {}",
196 self.tunnel_id,
197 cmd,
198 body.len(),
199 hex::encode(body)
200 );
201 let seq = gen_seq();
202 let header = CmdHeader::<LEN, CMD>::new(
203 version,
204 false,
205 Some(seq),
206 cmd,
207 LEN::from_u64(body.len() as u64).unwrap(),
208 );
209 let buf = header
210 .to_vec()
211 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
212 let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
213 let waiter = self.resp_waiter.clone();
214 let resp_waiter = waiter
215 .create_timeout_result_future(resp_id, timeout)
216 .map_err(into_cmd_err!(
217 CmdErrorCode::Failed,
218 "create timeout result future error"
219 ))?;
220 let ret = self.send_inner(buf.as_slice(), body).await;
221 if let Err(e) = ret {
222 self.set_disable();
223 return Err(e);
224 }
225 let resp = resp_waiter
226 .await
227 .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
228 Ok(resp)
229 }
230
231 pub async fn send2(&mut self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
232 let mut len = 0;
233 for b in body.iter() {
234 len += b.len();
235 log::trace!(
236 "client {:?} send2 cmd {:?} body: {}",
237 self.tunnel_id,
238 cmd,
239 hex::encode(b)
240 );
241 }
242 log::trace!(
243 "client {:?} send2 cmd: {:?}, len {}",
244 self.tunnel_id,
245 cmd,
246 len
247 );
248 let header = CmdHeader::<LEN, CMD>::new(
249 version,
250 false,
251 None,
252 cmd,
253 LEN::from_u64(len as u64).unwrap(),
254 );
255 let buf = header
256 .to_vec()
257 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
258 let ret = self.send_inner2(buf.as_slice(), body).await;
259 if let Err(e) = ret {
260 self.set_disable();
261 return Err(e);
262 }
263 Ok(())
264 }
265
266 pub async fn send2_with_resp(
267 &mut self,
268 cmd: CMD,
269 version: u8,
270 body: &[&[u8]],
271 timeout: Duration,
272 ) -> CmdResult<CmdBody> {
273 if let Some(id) = tokio::task::try_id() {
274 if id == self.recv_handle.id() {
275 return Err(cmd_err!(
276 CmdErrorCode::Failed,
277 "can't send with resp in recv task"
278 ));
279 }
280 }
281 let mut len = 0;
282 for b in body.iter() {
283 len += b.len();
284 log::trace!(
285 "client {:?} send2 cmd {:?} body: {}",
286 self.tunnel_id,
287 cmd,
288 hex::encode(b)
289 );
290 }
291 log::trace!(
292 "client {:?} send2 cmd: {:?}, len {}",
293 self.tunnel_id,
294 cmd,
295 len
296 );
297 let seq = gen_seq();
298 let header = CmdHeader::<LEN, CMD>::new(
299 version,
300 false,
301 Some(seq),
302 cmd,
303 LEN::from_u64(len as u64).unwrap(),
304 );
305 let buf = header
306 .to_vec()
307 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
308 let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
309 let waiter = self.resp_waiter.clone();
310 let resp_waiter = waiter
311 .create_timeout_result_future(resp_id, timeout)
312 .map_err(into_cmd_err!(
313 CmdErrorCode::Failed,
314 "create timeout result future error"
315 ))?;
316 let ret = self.send_inner2(buf.as_slice(), body).await;
317 if let Err(e) = ret {
318 self.set_disable();
319 return Err(e);
320 }
321 let resp = resp_waiter
322 .await
323 .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
324 Ok(resp)
325 }
326
327 pub async fn send_cmd(&mut self, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
328 log::trace!(
329 "client {:?} send cmd: {:?}, len: {}",
330 self.tunnel_id,
331 cmd,
332 body.len()
333 );
334 let header = CmdHeader::<LEN, CMD>::new(
335 version,
336 false,
337 None,
338 cmd,
339 LEN::from_u64(body.len()).unwrap(),
340 );
341 let buf = header
342 .to_vec()
343 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
344 let ret = self.send_inner_cmd(buf.as_slice(), body).await;
345 if let Err(e) = ret {
346 self.set_disable();
347 return Err(e);
348 }
349 Ok(())
350 }
351
352 pub async fn send_cmd_with_resp(
353 &mut self,
354 cmd: CMD,
355 version: u8,
356 body: CmdBody,
357 timeout: Duration,
358 ) -> CmdResult<CmdBody> {
359 if let Some(id) = tokio::task::try_id() {
360 if id == self.recv_handle.id() {
361 return Err(cmd_err!(
362 CmdErrorCode::Failed,
363 "can't send with resp in recv task"
364 ));
365 }
366 }
367 log::trace!(
368 "client {:?} send cmd: {:?}, len: {}",
369 self.tunnel_id,
370 cmd,
371 body.len()
372 );
373 let seq = gen_seq();
374 let header = CmdHeader::<LEN, CMD>::new(
375 version,
376 false,
377 Some(seq),
378 cmd,
379 LEN::from_u64(body.len()).unwrap(),
380 );
381 let buf = header
382 .to_vec()
383 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
384 let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
385 let waiter = self.resp_waiter.clone();
386 let resp_waiter = waiter
387 .create_timeout_result_future(resp_id, timeout)
388 .map_err(into_cmd_err!(
389 CmdErrorCode::Failed,
390 "create timeout result future error"
391 ))?;
392 let ret = self.send_inner_cmd(buf.as_slice(), body).await;
393 if let Err(e) = ret {
394 self.set_disable();
395 return Err(e);
396 }
397 let resp = resp_waiter
398 .await
399 .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
400 Ok(resp)
401 }
402
403 async fn send_inner(&mut self, header: &[u8], body: &[u8]) -> CmdResult<()> {
404 let mut write = self.write.get().await;
405 if header.len() > 255 {
406 return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
407 }
408 write
409 .write_u8(header.len() as u8)
410 .await
411 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
412 write
413 .write_all(header)
414 .await
415 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
416 write
417 .write_all(body)
418 .await
419 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
420 write
421 .flush()
422 .await
423 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
424 Ok(())
425 }
426
427 async fn send_inner2(&mut self, header: &[u8], body: &[&[u8]]) -> CmdResult<()> {
428 let mut write = self.write.get().await;
429 if header.len() > 255 {
430 return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
431 }
432 write
433 .write_u8(header.len() as u8)
434 .await
435 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
436 write
437 .write_all(header)
438 .await
439 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
440 for b in body.iter() {
441 write
442 .write_all(b)
443 .await
444 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
445 }
446 write
447 .flush()
448 .await
449 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
450 Ok(())
451 }
452
453 async fn send_inner_cmd(&mut self, header: &[u8], mut body: CmdBody) -> CmdResult<()> {
454 let mut write = self.write.get().await;
455 if header.len() > 255 {
456 return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
457 }
458 write
459 .write_u8(header.len() as u8)
460 .await
461 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
462 write
463 .write_all(header)
464 .await
465 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
466 tokio::io::copy(&mut body, write.deref_mut().deref_mut())
467 .await
468 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
469 write
470 .flush()
471 .await
472 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
473 Ok(())
474 }
475}
476
477impl<C, M, R, W, LEN, CMD> Drop for ClassifiedCmdSend<C, M, R, W, LEN, CMD>
478where
479 C: WorkerClassification,
480 M: CmdTunnelMeta,
481 R: ClassifiedCmdTunnelRead<C, M>,
482 W: ClassifiedCmdTunnelWrite<C, M>,
483 LEN: RawEncode
484 + for<'a> RawDecode<'a>
485 + Copy
486 + Send
487 + Sync
488 + 'static
489 + FromPrimitive
490 + ToPrimitive,
491 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
492{
493 fn drop(&mut self) {
494 self.set_disable();
495 }
496}
497
498impl<C, M, R, W, LEN, CMD> CmdSend<M> for ClassifiedCmdSend<C, M, R, W, LEN, CMD>
499where
500 C: WorkerClassification,
501 M: CmdTunnelMeta,
502 R: ClassifiedCmdTunnelRead<C, M>,
503 W: ClassifiedCmdTunnelWrite<C, M>,
504 LEN: RawEncode
505 + for<'a> RawDecode<'a>
506 + Copy
507 + Send
508 + Sync
509 + 'static
510 + FromPrimitive
511 + ToPrimitive,
512 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
513{
514 fn get_tunnel_meta(&self) -> Option<Arc<M>> {
515 self.tunnel_meta.clone()
516 }
517
518 fn get_remote_peer_id(&self) -> PeerId {
519 self.remote_id.clone()
520 }
521}
522
523impl<C, M, R, W, LEN, CMD> ClassifiedWorker<CmdClientTunnelClassification<C>>
524 for ClassifiedCmdSend<C, M, R, W, LEN, CMD>
525where
526 C: WorkerClassification,
527 M: CmdTunnelMeta,
528 R: ClassifiedCmdTunnelRead<C, M>,
529 W: ClassifiedCmdTunnelWrite<C, M>,
530 LEN: RawEncode
531 + for<'a> RawDecode<'a>
532 + Copy
533 + Send
534 + Sync
535 + 'static
536 + FromPrimitive
537 + ToPrimitive,
538 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
539{
540 fn is_work(&self) -> bool {
541 self.is_work && !self.recv_handle.is_finished()
542 }
543
544 fn is_valid(&self, c: CmdClientTunnelClassification<C>) -> bool {
545 if c.tunnel_id.is_some() {
546 self.tunnel_id == c.tunnel_id.unwrap()
547 } else {
548 if c.classification.is_some() {
549 self.classification == c.classification.unwrap()
550 } else {
551 true
552 }
553 }
554 }
555
556 fn classification(&self) -> CmdClientTunnelClassification<C> {
557 CmdClientTunnelClassification {
558 tunnel_id: Some(self.tunnel_id),
559 classification: Some(self.classification.clone()),
560 }
561 }
562}
563
564pub struct ClassifiedCmdWriteFactory<
565 C: WorkerClassification,
566 M: CmdTunnelMeta,
567 R: ClassifiedCmdTunnelRead<C, M>,
568 W: ClassifiedCmdTunnelWrite<C, M>,
569 F: ClassifiedCmdTunnelFactory<C, M, R, W>,
570 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
571 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
572> {
573 tunnel_factory: F,
574 cmd_handler: Arc<dyn CmdHandler<LEN, CMD>>,
575 resp_waiter: RespWaiterRef,
576 tunnel_id_generator: TunnelIdGenerator,
577 _p: std::marker::PhantomData<Mutex<(C, M, R, W)>>,
578}
579
580impl<
581 C: WorkerClassification,
582 M: CmdTunnelMeta,
583 R: ClassifiedCmdTunnelRead<C, M>,
584 W: ClassifiedCmdTunnelWrite<C, M>,
585 F: ClassifiedCmdTunnelFactory<C, M, R, W>,
586 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
587 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
588> ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>
589{
590 pub(crate) fn new(
591 tunnel_factory: F,
592 cmd_handler: impl CmdHandler<LEN, CMD>,
593 resp_waiter: RespWaiterRef,
594 ) -> Self {
595 Self {
596 tunnel_factory,
597 cmd_handler: Arc::new(cmd_handler),
598 resp_waiter,
599 tunnel_id_generator: TunnelIdGenerator::new(),
600 _p: Default::default(),
601 }
602 }
603}
604
605#[async_trait::async_trait]
606impl<
607 C: WorkerClassification,
608 M: CmdTunnelMeta,
609 R: ClassifiedCmdTunnelRead<C, M>,
610 W: ClassifiedCmdTunnelWrite<C, M>,
611 F: ClassifiedCmdTunnelFactory<C, M, R, W>,
612 LEN: RawEncode
613 + for<'a> RawDecode<'a>
614 + Copy
615 + Send
616 + Sync
617 + 'static
618 + FromPrimitive
619 + ToPrimitive
620 + RawFixedBytes,
621 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Debug,
622> ClassifiedWorkerFactory<CmdClientTunnelClassification<C>, ClassifiedCmdSend<C, M, R, W, LEN, CMD>>
623 for ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>
624{
625 async fn create(
626 &self,
627 classification: Option<CmdClientTunnelClassification<C>>,
628 ) -> PoolResult<ClassifiedCmdSend<C, M, R, W, LEN, CMD>> {
629 if classification.is_some() && classification.as_ref().unwrap().tunnel_id.is_some() {
630 return Err(pool_err!(
631 PoolErrorCode::Failed,
632 "tunnel {:?} not found",
633 classification.as_ref().unwrap().tunnel_id.unwrap()
634 ));
635 }
636
637 let classification = if classification.is_some()
638 && classification.as_ref().unwrap().classification.is_some()
639 {
640 classification.unwrap().classification
641 } else {
642 None
643 };
644 let tunnel = self
645 .tunnel_factory
646 .create_tunnel(classification)
647 .await
648 .map_err(into_pool_err!(PoolErrorCode::Failed))?;
649 let classification = tunnel.get_classification();
650 let peer_id = tunnel.get_remote_peer_id();
651 let tunnel_id = self.tunnel_id_generator.generate();
652 let (mut recv, write) = tunnel.split();
653 let local_id = recv.get_local_peer_id();
654 let remote_id = peer_id.clone();
655 let tunnel_meta = recv.get_tunnel_meta();
656 let write = ObjectHolder::new(write);
657 let resp_write = write.clone();
658 let cmd_handler = self.cmd_handler.clone();
659 let handle = spawn(async move {
660 let ret: CmdResult<()> = async move {
661 loop {
662 let header_len = recv
663 .read_u8()
664 .await
665 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
666 let mut header = vec![0u8; header_len as usize];
667 let n = recv
668 .read_exact(header.as_mut())
669 .await
670 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
671 if n == 0 {
672 break;
673 }
674 let header = CmdHeader::<LEN, CMD>::clone_from_slice(header.as_slice())
675 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
676 log::trace!(
677 "recv cmd {:?} from {} len {} tunnel {:?}",
678 header.cmd_code(),
679 peer_id,
680 header.pkg_len().to_u64().unwrap(),
681 tunnel_id
682 );
683 let body_len = header.pkg_len().to_u64().unwrap();
684 let cmd_read =
685 CmdBodyRead::new(recv, header.pkg_len().to_u64().unwrap() as usize);
686 let waiter = cmd_read.get_waiter();
687 let future = waiter
688 .create_result_future()
689 .map_err(into_cmd_err!(CmdErrorCode::Failed))?;
690 let version = header.version();
691 let seq = header.seq();
692 let cmd_code = header.cmd_code();
693 match cmd_handler
694 .handle(
695 local_id.clone(),
696 peer_id.clone(),
697 tunnel_id,
698 header,
699 CmdBody::from_reader(BufReader::new(cmd_read), body_len),
700 )
701 .await
702 {
703 Ok(Some(mut body)) => {
704 let mut write = resp_write.get().await;
705 let header = CmdHeader::<LEN, CMD>::new(
706 version,
707 true,
708 seq,
709 cmd_code,
710 LEN::from_u64(body.len()).unwrap(),
711 );
712 let buf = header
713 .to_vec()
714 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
715 if buf.len() > 255 {
716 return Err(cmd_err!(
717 CmdErrorCode::InvalidParam,
718 "header len too large"
719 ));
720 }
721 write
722 .write_u8(buf.len() as u8)
723 .await
724 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
725 write
726 .write_all(buf.as_slice())
727 .await
728 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
729 tokio::io::copy(&mut body, write.deref_mut().deref_mut())
730 .await
731 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
732 write
733 .flush()
734 .await
735 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
736 }
737 Err(e) => {
738 log::error!("handle cmd error: {:?}", e);
739 }
740 _ => {}
741 }
742 recv = future
743 .await
744 .map_err(into_cmd_err!(CmdErrorCode::Failed))??;
745 log::debug!(
746 "handle cmd {:?} from {} len {} tunnel {:?} complete",
747 cmd_code,
748 peer_id,
749 body_len,
750 tunnel_id
751 );
752 }
753 Ok(())
754 }
755 .await;
756 if ret.is_err() {
757 log::error!("recv cmd error: {:?}", ret.as_ref().err().unwrap());
758 }
759 ret
760 });
761 Ok(ClassifiedCmdSend::new(
762 tunnel_id,
763 classification,
764 handle,
765 write,
766 self.resp_waiter.clone(),
767 remote_id,
768 tunnel_meta,
769 ))
770 }
771}
772
773pub struct DefaultClassifiedCmdClient<
774 C: WorkerClassification,
775 M: CmdTunnelMeta,
776 R: ClassifiedCmdTunnelRead<C, M>,
777 W: ClassifiedCmdTunnelWrite<C, M>,
778 F: ClassifiedCmdTunnelFactory<C, M, R, W>,
779 LEN: RawEncode
780 + for<'a> RawDecode<'a>
781 + Copy
782 + Send
783 + Sync
784 + 'static
785 + FromPrimitive
786 + ToPrimitive
787 + RawFixedBytes,
788 CMD: RawEncode
789 + for<'a> RawDecode<'a>
790 + Copy
791 + Send
792 + Sync
793 + 'static
794 + RawFixedBytes
795 + Eq
796 + Hash
797 + Debug,
798> {
799 tunnel_pool: ClassifiedWorkerPoolRef<
800 CmdClientTunnelClassification<C>,
801 ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
802 ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>,
803 >,
804 cmd_handler_map: Arc<CmdHandlerMap<LEN, CMD>>,
805}
806
807impl<
808 C: WorkerClassification,
809 M: CmdTunnelMeta,
810 R: ClassifiedCmdTunnelRead<C, M>,
811 W: ClassifiedCmdTunnelWrite<C, M>,
812 F: ClassifiedCmdTunnelFactory<C, M, R, W>,
813 LEN: RawEncode
814 + for<'a> RawDecode<'a>
815 + Copy
816 + Send
817 + Sync
818 + 'static
819 + FromPrimitive
820 + ToPrimitive
821 + RawFixedBytes,
822 CMD: RawEncode
823 + for<'a> RawDecode<'a>
824 + Copy
825 + Send
826 + Sync
827 + 'static
828 + RawFixedBytes
829 + Eq
830 + Hash
831 + Debug,
832> DefaultClassifiedCmdClient<C, M, R, W, F, LEN, CMD>
833{
834 pub fn new(factory: F, tunnel_count: u16) -> Arc<Self> {
835 let cmd_handler_map = Arc::new(CmdHandlerMap::new());
836 let resp_waiter = Arc::new(RespWaiter::new());
837 let handler_map = cmd_handler_map.clone();
838 let waiter = resp_waiter.clone();
839 Arc::new(Self {
840 tunnel_pool: ClassifiedWorkerPool::new(
841 tunnel_count,
842 ClassifiedCmdWriteFactory::<C, M, R, W, _, LEN, CMD>::new(
843 factory,
844 move |local_id: PeerId,
845 peer_id: PeerId,
846 tunnel_id: TunnelId,
847 header: CmdHeader<LEN, CMD>,
848 body_read: CmdBody| {
849 let handler_map = handler_map.clone();
850 let waiter = waiter.clone();
851 async move {
852 if header.is_resp() && header.seq().is_some() {
853 let resp_id = gen_resp_id(
854 tunnel_id,
855 header.cmd_code(),
856 header.seq().unwrap(),
857 );
858 let _ = waiter.set_result(resp_id, body_read);
859 Ok(None)
860 } else {
861 if let Some(handler) = handler_map.get(header.cmd_code()) {
862 handler
863 .handle(local_id, peer_id, tunnel_id, header, body_read)
864 .await
865 } else {
866 Ok(None)
867 }
868 }
869 }
870 },
871 resp_waiter.clone(),
872 ),
873 ),
874 cmd_handler_map,
875 })
876 }
877
878 async fn get_send(
879 &self,
880 ) -> CmdResult<
881 ClassifiedWorkerGuard<
882 CmdClientTunnelClassification<C>,
883 ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
884 ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>,
885 >,
886 > {
887 self.tunnel_pool
888 .get_worker()
889 .await
890 .map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
891 }
892
893 async fn get_send_of_tunnel_id(
894 &self,
895 tunnel_id: TunnelId,
896 ) -> CmdResult<
897 ClassifiedWorkerGuard<
898 CmdClientTunnelClassification<C>,
899 ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
900 ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>,
901 >,
902 > {
903 self.tunnel_pool
904 .get_classified_worker(CmdClientTunnelClassification {
905 tunnel_id: Some(tunnel_id),
906 classification: None,
907 })
908 .await
909 .map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
910 }
911
912 async fn get_classified_send(
913 &self,
914 classification: C,
915 ) -> CmdResult<
916 ClassifiedWorkerGuard<
917 CmdClientTunnelClassification<C>,
918 ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
919 ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>,
920 >,
921 > {
922 self.tunnel_pool
923 .get_classified_worker(CmdClientTunnelClassification {
924 tunnel_id: None,
925 classification: Some(classification),
926 })
927 .await
928 .map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
929 }
930}
931
932pub type ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD> = ClassifiedSendGuard<
933 CmdClientTunnelClassification<C>,
934 M,
935 ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
936 ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>,
937>;
938#[async_trait::async_trait]
939impl<
940 C: WorkerClassification,
941 M: CmdTunnelMeta,
942 R: ClassifiedCmdTunnelRead<C, M>,
943 W: ClassifiedCmdTunnelWrite<C, M>,
944 F: ClassifiedCmdTunnelFactory<C, M, R, W>,
945 LEN: RawEncode
946 + for<'a> RawDecode<'a>
947 + Copy
948 + Send
949 + Sync
950 + 'static
951 + FromPrimitive
952 + ToPrimitive
953 + RawFixedBytes,
954 CMD: RawEncode
955 + for<'a> RawDecode<'a>
956 + Copy
957 + Send
958 + Sync
959 + 'static
960 + RawFixedBytes
961 + Eq
962 + Hash
963 + Debug,
964>
965 CmdClient<
966 LEN,
967 CMD,
968 M,
969 ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
970 ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD>,
971 > for DefaultClassifiedCmdClient<C, M, R, W, F, LEN, CMD>
972{
973 fn register_cmd_handler(&self, cmd: CMD, handler: impl CmdHandler<LEN, CMD>) {
974 self.cmd_handler_map.insert(cmd, handler);
975 }
976
977 async fn send(&self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
978 let mut send = self.get_send().await?;
979 send.send(cmd, version, body).await
980 }
981
982 async fn send_with_resp(
983 &self,
984 cmd: CMD,
985 version: u8,
986 body: &[u8],
987 timeout: Duration,
988 ) -> CmdResult<CmdBody> {
989 let mut send = self.get_send().await?;
990 send.send_with_resp(cmd, version, body, timeout).await
991 }
992
993 async fn send2(&self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
994 let mut send = self.get_send().await?;
995 send.send2(cmd, version, body).await
996 }
997
998 async fn send2_with_resp(
999 &self,
1000 cmd: CMD,
1001 version: u8,
1002 body: &[&[u8]],
1003 timeout: Duration,
1004 ) -> CmdResult<CmdBody> {
1005 let mut send = self.get_send().await?;
1006 send.send2_with_resp(cmd, version, body, timeout).await
1007 }
1008
1009 async fn send_cmd(&self, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
1010 let mut send = self.get_send().await?;
1011 send.send_cmd(cmd, version, body).await
1012 }
1013
1014 async fn send_cmd_with_resp(
1015 &self,
1016 cmd: CMD,
1017 version: u8,
1018 body: CmdBody,
1019 timeout: Duration,
1020 ) -> CmdResult<CmdBody> {
1021 let mut send = self.get_send().await?;
1022 send.send_cmd_with_resp(cmd, version, body, timeout).await
1023 }
1024
1025 async fn send_by_specify_tunnel(
1026 &self,
1027 tunnel_id: TunnelId,
1028 cmd: CMD,
1029 version: u8,
1030 body: &[u8],
1031 ) -> CmdResult<()> {
1032 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1033 send.send(cmd, version, body).await
1034 }
1035
1036 async fn send_by_specify_tunnel_with_resp(
1037 &self,
1038 tunnel_id: TunnelId,
1039 cmd: CMD,
1040 version: u8,
1041 body: &[u8],
1042 timeout: Duration,
1043 ) -> CmdResult<CmdBody> {
1044 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1045 send.send_with_resp(cmd, version, body, timeout).await
1046 }
1047
1048 async fn send2_by_specify_tunnel(
1049 &self,
1050 tunnel_id: TunnelId,
1051 cmd: CMD,
1052 version: u8,
1053 body: &[&[u8]],
1054 ) -> CmdResult<()> {
1055 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1056 send.send2(cmd, version, body).await
1057 }
1058
1059 async fn send2_by_specify_tunnel_with_resp(
1060 &self,
1061 tunnel_id: TunnelId,
1062 cmd: CMD,
1063 version: u8,
1064 body: &[&[u8]],
1065 timeout: Duration,
1066 ) -> CmdResult<CmdBody> {
1067 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1068 send.send2_with_resp(cmd, version, body, timeout).await
1069 }
1070
1071 async fn send_cmd_by_specify_tunnel(
1072 &self,
1073 tunnel_id: TunnelId,
1074 cmd: CMD,
1075 version: u8,
1076 body: CmdBody,
1077 ) -> CmdResult<()> {
1078 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1079 send.send_cmd(cmd, version, body).await
1080 }
1081
1082 async fn send_cmd_by_specify_tunnel_with_resp(
1083 &self,
1084 tunnel_id: TunnelId,
1085 cmd: CMD,
1086 version: u8,
1087 body: CmdBody,
1088 timeout: Duration,
1089 ) -> CmdResult<CmdBody> {
1090 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1091 send.send_cmd_with_resp(cmd, version, body, timeout).await
1092 }
1093
1094 async fn clear_all_tunnel(&self) {
1095 self.tunnel_pool.clear_all_worker().await;
1096 }
1097
1098 async fn get_send(
1099 &self,
1100 tunnel_id: TunnelId,
1101 ) -> CmdResult<ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD>> {
1102 Ok(ClassifiedSendGuard {
1103 worker_guard: self.get_send_of_tunnel_id(tunnel_id).await?,
1104 _p: std::marker::PhantomData,
1105 })
1106 }
1107}
1108
1109#[async_trait::async_trait]
1110impl<
1111 C: WorkerClassification,
1112 M: CmdTunnelMeta,
1113 R: ClassifiedCmdTunnelRead<C, M>,
1114 W: ClassifiedCmdTunnelWrite<C, M>,
1115 F: ClassifiedCmdTunnelFactory<C, M, R, W>,
1116 LEN: RawEncode
1117 + for<'a> RawDecode<'a>
1118 + Copy
1119 + Send
1120 + Sync
1121 + 'static
1122 + FromPrimitive
1123 + ToPrimitive
1124 + RawFixedBytes,
1125 CMD: RawEncode
1126 + for<'a> RawDecode<'a>
1127 + Copy
1128 + Send
1129 + Sync
1130 + 'static
1131 + RawFixedBytes
1132 + Eq
1133 + Hash
1134 + Debug,
1135>
1136 ClassifiedCmdClient<
1137 LEN,
1138 CMD,
1139 C,
1140 M,
1141 ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
1142 ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD>,
1143 > for DefaultClassifiedCmdClient<C, M, R, W, F, LEN, CMD>
1144{
1145 async fn send_by_classified_tunnel(
1146 &self,
1147 classification: C,
1148 cmd: CMD,
1149 version: u8,
1150 body: &[u8],
1151 ) -> CmdResult<()> {
1152 let mut send = self.get_classified_send(classification).await?;
1153 send.send(cmd, version, body).await
1154 }
1155
1156 async fn send_by_classified_tunnel_with_resp(
1157 &self,
1158 classification: C,
1159 cmd: CMD,
1160 version: u8,
1161 body: &[u8],
1162 timeout: Duration,
1163 ) -> CmdResult<CmdBody> {
1164 let mut send = self.get_classified_send(classification).await?;
1165 send.send_with_resp(cmd, version, body, timeout).await
1166 }
1167
1168 async fn send2_by_classified_tunnel(
1169 &self,
1170 classification: C,
1171 cmd: CMD,
1172 version: u8,
1173 body: &[&[u8]],
1174 ) -> CmdResult<()> {
1175 let mut send = self.get_classified_send(classification).await?;
1176 send.send2(cmd, version, body).await
1177 }
1178
1179 async fn send2_by_classified_tunnel_with_resp(
1180 &self,
1181 classification: C,
1182 cmd: CMD,
1183 version: u8,
1184 body: &[&[u8]],
1185 timeout: Duration,
1186 ) -> CmdResult<CmdBody> {
1187 let mut send = self.get_classified_send(classification).await?;
1188 send.send2_with_resp(cmd, version, body, timeout).await
1189 }
1190
1191 async fn send_cmd_by_classified_tunnel(
1192 &self,
1193 classification: C,
1194 cmd: CMD,
1195 version: u8,
1196 body: CmdBody,
1197 ) -> CmdResult<()> {
1198 let mut send = self.get_classified_send(classification).await?;
1199 send.send_cmd(cmd, version, body).await
1200 }
1201
1202 async fn send_cmd_by_classified_tunnel_with_resp(
1203 &self,
1204 classification: C,
1205 cmd: CMD,
1206 version: u8,
1207 body: CmdBody,
1208 timeout: Duration,
1209 ) -> CmdResult<CmdBody> {
1210 let mut send = self.get_classified_send(classification).await?;
1211 send.send_cmd_with_resp(cmd, version, body, timeout).await
1212 }
1213
1214 async fn find_tunnel_id_by_classified(&self, classification: C) -> CmdResult<TunnelId> {
1215 let send = self.get_classified_send(classification).await?;
1216 Ok(send.get_tunnel_id())
1217 }
1218
1219 async fn get_send_by_classified(
1220 &self,
1221 classification: C,
1222 ) -> CmdResult<
1223 ClassifiedSendGuard<
1224 CmdClientTunnelClassification<C>,
1225 M,
1226 ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
1227 ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>,
1228 >,
1229 > {
1230 Ok(ClassifiedSendGuard {
1231 worker_guard: self.get_classified_send(classification).await?,
1232 _p: std::marker::PhantomData,
1233 })
1234 }
1235}