1use crate::client::{
2 ClassifiedCmdClient, ClassifiedSendGuard, CmdClient, CmdSend, RespWaiter, RespWaiterRef,
3 gen_resp_id, gen_seq,
4};
5use crate::cmd::{CmdBodyRead, CmdHandler, CmdHandlerMap, CmdHeader};
6use crate::errors::{CmdErrorCode, CmdResult, cmd_err, into_cmd_err};
7use crate::peer_id::PeerId;
8use crate::{CmdBody, CmdTunnelMeta, CmdTunnelRead, CmdTunnelWrite, TunnelId, TunnelIdGenerator};
9use async_named_locker::ObjectHolder;
10use bucky_raw_codec::{RawConvertTo, RawDecode, RawEncode, RawFixedBytes, RawFrom};
11use num::{FromPrimitive, ToPrimitive};
12use sfo_pool::{
13 ClassifiedWorker, ClassifiedWorkerFactory, ClassifiedWorkerGuard, ClassifiedWorkerPool,
14 ClassifiedWorkerPoolRef, PoolErrorCode, PoolResult, WorkerClassification, into_pool_err,
15 pool_err,
16};
17use sfo_split::{RHalf, Splittable, WHalf};
18use std::fmt::Debug;
19use std::hash::Hash;
20use std::ops::DerefMut;
21use std::sync::{Arc, Mutex};
22use std::time::Duration;
23use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
24use tokio::spawn;
25use tokio::task::JoinHandle;
26
27pub trait ClassifiedCmdTunnelRead<C: WorkerClassification, M: CmdTunnelMeta>:
28 CmdTunnelRead<M> + 'static + Send
29{
30 fn get_classification(&self) -> C;
31}
32
33pub trait ClassifiedCmdTunnelWrite<C: WorkerClassification, M: CmdTunnelMeta>:
34 CmdTunnelWrite<M> + 'static + Send
35{
36 fn get_classification(&self) -> C;
37}
38
39pub type ClassifiedCmdTunnel<R, W> = Splittable<R, W>;
40pub type ClassifiedCmdTunnelRHalf<R, W> = RHalf<R, W>;
41pub type ClassifiedCmdTunnelWHalf<R, W> = WHalf<R, W>;
42
43#[derive(Debug, Clone, Copy, Eq, Hash)]
44pub struct CmdClientTunnelClassification<C: WorkerClassification> {
45 tunnel_id: Option<TunnelId>,
46 classification: Option<C>,
47}
48
49impl<C: WorkerClassification> PartialEq for CmdClientTunnelClassification<C> {
50 fn eq(&self, other: &Self) -> bool {
51 self.tunnel_id == other.tunnel_id && self.classification == other.classification
52 }
53}
54
55#[async_trait::async_trait]
56pub trait ClassifiedCmdTunnelFactory<
57 C: WorkerClassification,
58 M: CmdTunnelMeta,
59 R: ClassifiedCmdTunnelRead<C, M>,
60 W: ClassifiedCmdTunnelWrite<C, M>,
61>: Send + Sync + 'static
62{
63 async fn create_tunnel(&self, classification: Option<C>) -> CmdResult<Splittable<R, W>>;
64}
65
66pub struct ClassifiedCmdSend<C, M, R, W, LEN, CMD>
67where
68 C: WorkerClassification,
69 M: CmdTunnelMeta,
70 R: ClassifiedCmdTunnelRead<C, M>,
71 W: ClassifiedCmdTunnelWrite<C, M>,
72 LEN: RawEncode
73 + for<'a> RawDecode<'a>
74 + Copy
75 + Send
76 + Sync
77 + 'static
78 + FromPrimitive
79 + ToPrimitive,
80 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
81{
82 pub(crate) recv_handle: JoinHandle<CmdResult<()>>,
83 pub(crate) write: ObjectHolder<ClassifiedCmdTunnelWHalf<R, W>>,
84 pub(crate) is_work: bool,
85 pub(crate) classification: C,
86 pub(crate) pool_tunnel_id: Option<TunnelId>,
87 pub(crate) pool_classification: Option<C>,
88 pub(crate) tunnel_id: TunnelId,
89 pub(crate) resp_waiter: RespWaiterRef,
90 pub(crate) remote_id: PeerId,
91 pub(crate) pool_peer_id: Option<PeerId>,
92 pub(crate) tunnel_meta: Option<Arc<M>>,
93 _p: std::marker::PhantomData<(LEN, CMD)>,
94}
95
96impl<C, M, R, W, LEN, CMD> ClassifiedCmdSend<C, M, R, W, LEN, CMD>
110where
111 C: WorkerClassification,
112 M: CmdTunnelMeta,
113 R: ClassifiedCmdTunnelRead<C, M>,
114 W: ClassifiedCmdTunnelWrite<C, M>,
115 LEN: RawEncode
116 + for<'a> RawDecode<'a>
117 + Copy
118 + Send
119 + Sync
120 + 'static
121 + FromPrimitive
122 + ToPrimitive,
123 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
124{
125 pub(crate) fn new(
126 tunnel_id: TunnelId,
127 classification: C,
128 pool_peer_id: Option<PeerId>,
129 pool_tunnel_id: Option<TunnelId>,
130 pool_classification: Option<C>,
131 recv_handle: JoinHandle<CmdResult<()>>,
132 write: ObjectHolder<ClassifiedCmdTunnelWHalf<R, W>>,
133 resp_waiter: RespWaiterRef,
134 remote_id: PeerId,
135 tunnel_meta: Option<Arc<M>>,
136 ) -> Self {
137 Self {
138 recv_handle,
139 write,
140 is_work: true,
141 classification,
142 pool_tunnel_id,
143 pool_classification,
144 tunnel_id,
145 resp_waiter,
146 remote_id,
147 pool_peer_id,
148 tunnel_meta,
149 _p: Default::default(),
150 }
151 }
152
153 pub fn get_tunnel_id(&self) -> TunnelId {
154 self.tunnel_id
155 }
156
157 pub fn set_disable(&mut self) {
158 self.is_work = false;
159 self.recv_handle.abort();
160 }
161
162 pub(crate) fn set_pool_key(
163 &mut self,
164 peer_id: Option<PeerId>,
165 tunnel_id: Option<TunnelId>,
166 classification: Option<C>,
167 ) {
168 self.pool_peer_id = peer_id;
169 self.pool_tunnel_id = tunnel_id;
170 self.pool_classification = classification;
171 }
172
173 pub async fn send(&mut self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
174 log::trace!(
175 "client {:?} send cmd: {:?}, len: {}, data: {}",
176 self.tunnel_id,
177 cmd,
178 body.len(),
179 hex::encode(body)
180 );
181 let header = CmdHeader::<LEN, CMD>::new(
182 version,
183 false,
184 None,
185 cmd,
186 LEN::from_u64(body.len() as u64).unwrap(),
187 );
188 let buf = header
189 .to_vec()
190 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
191 let ret = self.send_inner(buf.as_slice(), body).await;
192 if let Err(e) = ret {
193 self.set_disable();
194 return Err(e);
195 }
196 Ok(())
197 }
198
199 pub async fn send_with_resp(
200 &mut self,
201 cmd: CMD,
202 version: u8,
203 body: &[u8],
204 timeout: Duration,
205 ) -> CmdResult<CmdBody> {
206 if let Some(id) = tokio::task::try_id() {
207 if id == self.recv_handle.id() {
208 return Err(cmd_err!(
209 CmdErrorCode::Failed,
210 "can't send with resp in recv task"
211 ));
212 }
213 }
214 log::trace!(
215 "client {:?} send cmd: {:?}, len: {}, data: {}",
216 self.tunnel_id,
217 cmd,
218 body.len(),
219 hex::encode(body)
220 );
221 let seq = gen_seq();
222 let header = CmdHeader::<LEN, CMD>::new(
223 version,
224 false,
225 Some(seq),
226 cmd,
227 LEN::from_u64(body.len() as u64).unwrap(),
228 );
229 let buf = header
230 .to_vec()
231 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
232 let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
233 let waiter = self.resp_waiter.clone();
234 let resp_waiter = waiter
235 .create_timeout_result_future(resp_id, timeout)
236 .map_err(into_cmd_err!(
237 CmdErrorCode::Failed,
238 "create timeout result future error"
239 ))?;
240 let ret = self.send_inner(buf.as_slice(), body).await;
241 if let Err(e) = ret {
242 self.set_disable();
243 return Err(e);
244 }
245 let resp = resp_waiter
246 .await
247 .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
248 Ok(resp)
249 }
250
251 pub async fn send_parts(&mut self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
252 let mut len = 0;
253 for b in body.iter() {
254 len += b.len();
255 log::trace!(
256 "client {:?} send2 cmd {:?} body: {}",
257 self.tunnel_id,
258 cmd,
259 hex::encode(b)
260 );
261 }
262 log::trace!(
263 "client {:?} send2 cmd: {:?}, len {}",
264 self.tunnel_id,
265 cmd,
266 len
267 );
268 let header = CmdHeader::<LEN, CMD>::new(
269 version,
270 false,
271 None,
272 cmd,
273 LEN::from_u64(len as u64).unwrap(),
274 );
275 let buf = header
276 .to_vec()
277 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
278 let ret = self.send_inner2(buf.as_slice(), body).await;
279 if let Err(e) = ret {
280 self.set_disable();
281 return Err(e);
282 }
283 Ok(())
284 }
285
286 pub async fn send_parts_with_resp(
287 &mut self,
288 cmd: CMD,
289 version: u8,
290 body: &[&[u8]],
291 timeout: Duration,
292 ) -> CmdResult<CmdBody> {
293 if let Some(id) = tokio::task::try_id() {
294 if id == self.recv_handle.id() {
295 return Err(cmd_err!(
296 CmdErrorCode::Failed,
297 "can't send with resp in recv task"
298 ));
299 }
300 }
301 let mut len = 0;
302 for b in body.iter() {
303 len += b.len();
304 log::trace!(
305 "client {:?} send2 cmd {:?} body: {}",
306 self.tunnel_id,
307 cmd,
308 hex::encode(b)
309 );
310 }
311 log::trace!(
312 "client {:?} send2 cmd: {:?}, len {}",
313 self.tunnel_id,
314 cmd,
315 len
316 );
317 let seq = gen_seq();
318 let header = CmdHeader::<LEN, CMD>::new(
319 version,
320 false,
321 Some(seq),
322 cmd,
323 LEN::from_u64(len as u64).unwrap(),
324 );
325 let buf = header
326 .to_vec()
327 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
328 let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
329 let waiter = self.resp_waiter.clone();
330 let resp_waiter = waiter
331 .create_timeout_result_future(resp_id, timeout)
332 .map_err(into_cmd_err!(
333 CmdErrorCode::Failed,
334 "create timeout result future error"
335 ))?;
336 let ret = self.send_inner2(buf.as_slice(), body).await;
337 if let Err(e) = ret {
338 self.set_disable();
339 return Err(e);
340 }
341 let resp = resp_waiter
342 .await
343 .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
344 Ok(resp)
345 }
346
347 #[allow(deprecated)]
348 #[deprecated(note = "use send_parts instead")]
349 pub async fn send2(&mut self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
350 self.send_parts(cmd, version, body).await
351 }
352
353 #[allow(deprecated)]
354 #[deprecated(note = "use send_parts_with_resp instead")]
355 pub async fn send2_with_resp(
356 &mut self,
357 cmd: CMD,
358 version: u8,
359 body: &[&[u8]],
360 timeout: Duration,
361 ) -> CmdResult<CmdBody> {
362 self.send_parts_with_resp(cmd, version, body, timeout).await
363 }
364
365 pub async fn send_cmd(&mut self, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
366 log::trace!(
367 "client {:?} send cmd: {:?}, len: {}",
368 self.tunnel_id,
369 cmd,
370 body.len()
371 );
372 let header = CmdHeader::<LEN, CMD>::new(
373 version,
374 false,
375 None,
376 cmd,
377 LEN::from_u64(body.len()).unwrap(),
378 );
379 let buf = header
380 .to_vec()
381 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
382 let ret = self.send_inner_cmd(buf.as_slice(), body).await;
383 if let Err(e) = ret {
384 self.set_disable();
385 return Err(e);
386 }
387 Ok(())
388 }
389
390 pub async fn send_cmd_with_resp(
391 &mut self,
392 cmd: CMD,
393 version: u8,
394 body: CmdBody,
395 timeout: Duration,
396 ) -> CmdResult<CmdBody> {
397 if let Some(id) = tokio::task::try_id() {
398 if id == self.recv_handle.id() {
399 return Err(cmd_err!(
400 CmdErrorCode::Failed,
401 "can't send with resp in recv task"
402 ));
403 }
404 }
405 log::trace!(
406 "client {:?} send cmd: {:?}, len: {}",
407 self.tunnel_id,
408 cmd,
409 body.len()
410 );
411 let seq = gen_seq();
412 let header = CmdHeader::<LEN, CMD>::new(
413 version,
414 false,
415 Some(seq),
416 cmd,
417 LEN::from_u64(body.len()).unwrap(),
418 );
419 let buf = header
420 .to_vec()
421 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
422 let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
423 let waiter = self.resp_waiter.clone();
424 let resp_waiter = waiter
425 .create_timeout_result_future(resp_id, timeout)
426 .map_err(into_cmd_err!(
427 CmdErrorCode::Failed,
428 "create timeout result future error"
429 ))?;
430 let ret = self.send_inner_cmd(buf.as_slice(), body).await;
431 if let Err(e) = ret {
432 self.set_disable();
433 return Err(e);
434 }
435 let resp = resp_waiter
436 .await
437 .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
438 Ok(resp)
439 }
440
441 async fn send_inner(&mut self, header: &[u8], body: &[u8]) -> CmdResult<()> {
442 let mut write = self.write.get().await;
443 if header.len() > 255 {
444 return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
445 }
446 write
447 .write_u8(header.len() as u8)
448 .await
449 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
450 write
451 .write_all(header)
452 .await
453 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
454 write
455 .write_all(body)
456 .await
457 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
458 write
459 .flush()
460 .await
461 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
462 Ok(())
463 }
464
465 async fn send_inner2(&mut self, header: &[u8], body: &[&[u8]]) -> CmdResult<()> {
466 let mut write = self.write.get().await;
467 if header.len() > 255 {
468 return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
469 }
470 write
471 .write_u8(header.len() as u8)
472 .await
473 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
474 write
475 .write_all(header)
476 .await
477 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
478 for b in body.iter() {
479 write
480 .write_all(b)
481 .await
482 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
483 }
484 write
485 .flush()
486 .await
487 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
488 Ok(())
489 }
490
491 async fn send_inner_cmd(&mut self, header: &[u8], mut body: CmdBody) -> CmdResult<()> {
492 let mut write = self.write.get().await;
493 if header.len() > 255 {
494 return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
495 }
496 write
497 .write_u8(header.len() as u8)
498 .await
499 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
500 write
501 .write_all(header)
502 .await
503 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
504 tokio::io::copy(&mut body, write.deref_mut().deref_mut())
505 .await
506 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
507 write
508 .flush()
509 .await
510 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
511 Ok(())
512 }
513}
514
515impl<C, M, R, W, LEN, CMD> Drop for ClassifiedCmdSend<C, M, R, W, LEN, CMD>
516where
517 C: WorkerClassification,
518 M: CmdTunnelMeta,
519 R: ClassifiedCmdTunnelRead<C, M>,
520 W: ClassifiedCmdTunnelWrite<C, M>,
521 LEN: RawEncode
522 + for<'a> RawDecode<'a>
523 + Copy
524 + Send
525 + Sync
526 + 'static
527 + FromPrimitive
528 + ToPrimitive,
529 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
530{
531 fn drop(&mut self) {
532 self.set_disable();
533 }
534}
535
536impl<C, M, R, W, LEN, CMD> CmdSend<M> for ClassifiedCmdSend<C, M, R, W, LEN, CMD>
537where
538 C: WorkerClassification,
539 M: CmdTunnelMeta,
540 R: ClassifiedCmdTunnelRead<C, M>,
541 W: ClassifiedCmdTunnelWrite<C, M>,
542 LEN: RawEncode
543 + for<'a> RawDecode<'a>
544 + Copy
545 + Send
546 + Sync
547 + 'static
548 + FromPrimitive
549 + ToPrimitive,
550 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
551{
552 fn get_tunnel_meta(&self) -> Option<Arc<M>> {
553 self.tunnel_meta.clone()
554 }
555
556 fn get_remote_peer_id(&self) -> PeerId {
557 self.remote_id.clone()
558 }
559}
560
561impl<C, M, R, W, LEN, CMD> ClassifiedWorker<CmdClientTunnelClassification<C>>
562 for ClassifiedCmdSend<C, M, R, W, LEN, CMD>
563where
564 C: WorkerClassification,
565 M: CmdTunnelMeta,
566 R: ClassifiedCmdTunnelRead<C, M>,
567 W: ClassifiedCmdTunnelWrite<C, M>,
568 LEN: RawEncode
569 + for<'a> RawDecode<'a>
570 + Copy
571 + Send
572 + Sync
573 + 'static
574 + FromPrimitive
575 + ToPrimitive,
576 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
577{
578 fn is_work(&self) -> bool {
579 self.is_work && !self.recv_handle.is_finished()
580 }
581
582 fn is_valid(&self, c: CmdClientTunnelClassification<C>) -> bool {
583 if c.tunnel_id.is_some() {
584 self.tunnel_id == c.tunnel_id.unwrap()
585 } else {
586 if c.classification.is_some() {
587 self.classification == c.classification.unwrap()
588 } else {
589 true
590 }
591 }
592 }
593
594 fn classification(&self) -> CmdClientTunnelClassification<C> {
595 CmdClientTunnelClassification {
596 tunnel_id: self.pool_tunnel_id,
597 classification: self.pool_classification.clone(),
598 }
599 }
600}
601
602pub struct ClassifiedCmdWriteFactory<
603 C: WorkerClassification,
604 M: CmdTunnelMeta,
605 R: ClassifiedCmdTunnelRead<C, M>,
606 W: ClassifiedCmdTunnelWrite<C, M>,
607 F: ClassifiedCmdTunnelFactory<C, M, R, W>,
608 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
609 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
610> {
611 tunnel_factory: F,
612 cmd_handler: Arc<dyn CmdHandler<LEN, CMD>>,
613 resp_waiter: RespWaiterRef,
614 tunnel_id_generator: TunnelIdGenerator,
615 _p: std::marker::PhantomData<Mutex<(C, M, R, W)>>,
616}
617
618impl<
619 C: WorkerClassification,
620 M: CmdTunnelMeta,
621 R: ClassifiedCmdTunnelRead<C, M>,
622 W: ClassifiedCmdTunnelWrite<C, M>,
623 F: ClassifiedCmdTunnelFactory<C, M, R, W>,
624 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
625 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
626> ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>
627{
628 pub(crate) fn new(
629 tunnel_factory: F,
630 cmd_handler: impl CmdHandler<LEN, CMD>,
631 resp_waiter: RespWaiterRef,
632 ) -> Self {
633 Self {
634 tunnel_factory,
635 cmd_handler: Arc::new(cmd_handler),
636 resp_waiter,
637 tunnel_id_generator: TunnelIdGenerator::new(),
638 _p: Default::default(),
639 }
640 }
641}
642
643#[async_trait::async_trait]
644impl<
645 C: WorkerClassification,
646 M: CmdTunnelMeta,
647 R: ClassifiedCmdTunnelRead<C, M>,
648 W: ClassifiedCmdTunnelWrite<C, M>,
649 F: ClassifiedCmdTunnelFactory<C, M, R, W>,
650 LEN: RawEncode
651 + for<'a> RawDecode<'a>
652 + Copy
653 + Send
654 + Sync
655 + 'static
656 + FromPrimitive
657 + ToPrimitive
658 + RawFixedBytes,
659 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Debug,
660> ClassifiedWorkerFactory<CmdClientTunnelClassification<C>, ClassifiedCmdSend<C, M, R, W, LEN, CMD>>
661 for ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>
662{
663 async fn create(
664 &self,
665 classification: Option<CmdClientTunnelClassification<C>>,
666 ) -> PoolResult<ClassifiedCmdSend<C, M, R, W, LEN, CMD>> {
667 if classification.is_some() && classification.as_ref().unwrap().tunnel_id.is_some() {
668 return Err(pool_err!(
669 PoolErrorCode::Failed,
670 "tunnel {:?} not found",
671 classification.as_ref().unwrap().tunnel_id.unwrap()
672 ));
673 }
674
675 let requested_classification = classification;
676 let classification = requested_classification
677 .as_ref()
678 .and_then(|key| key.classification.clone());
679 let tunnel = self
680 .tunnel_factory
681 .create_tunnel(classification)
682 .await
683 .map_err(into_pool_err!(PoolErrorCode::Failed))?;
684 let classification = tunnel.get_classification();
685 let pool_tunnel_id = requested_classification
686 .as_ref()
687 .and_then(|key| key.tunnel_id);
688 let pool_classification = requested_classification
689 .as_ref()
690 .and_then(|key| key.classification.clone())
691 .or(Some(classification.clone()));
692 let peer_id = tunnel.get_remote_peer_id();
693 let tunnel_id = self.tunnel_id_generator.generate();
694 let (mut recv, write) = tunnel.split();
695 let local_id = recv.get_local_peer_id();
696 let remote_id = peer_id.clone();
697 let tunnel_meta = recv.get_tunnel_meta();
698 let write = ObjectHolder::new(write);
699 let resp_write = write.clone();
700 let cmd_handler = self.cmd_handler.clone();
701 let handle = spawn(async move {
702 let ret: CmdResult<()> = async move {
703 loop {
704 let header_len = recv
705 .read_u8()
706 .await
707 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
708 let mut header = vec![0u8; header_len as usize];
709 let n = recv
710 .read_exact(header.as_mut())
711 .await
712 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
713 if n == 0 {
714 break;
715 }
716 let header = CmdHeader::<LEN, CMD>::clone_from_slice(header.as_slice())
717 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
718 log::trace!(
719 "recv cmd {:?} from {} len {} tunnel {:?}",
720 header.cmd_code(),
721 peer_id,
722 header.pkg_len().to_u64().unwrap(),
723 tunnel_id
724 );
725 let body_len = header.pkg_len().to_u64().unwrap();
726 let cmd_read =
727 CmdBodyRead::new(recv, header.pkg_len().to_u64().unwrap() as usize);
728 let waiter = cmd_read.get_waiter();
729 let future = waiter
730 .create_result_future()
731 .map_err(into_cmd_err!(CmdErrorCode::Failed))?;
732 let version = header.version();
733 let seq = header.seq();
734 let cmd_code = header.cmd_code();
735 match cmd_handler
736 .handle(
737 local_id.clone(),
738 peer_id.clone(),
739 tunnel_id,
740 header,
741 CmdBody::from_reader(BufReader::new(cmd_read), body_len),
742 )
743 .await
744 {
745 Ok(Some(mut body)) => {
746 let mut write = resp_write.get().await;
747 let header = CmdHeader::<LEN, CMD>::new(
748 version,
749 true,
750 seq,
751 cmd_code,
752 LEN::from_u64(body.len()).unwrap(),
753 );
754 let buf = header
755 .to_vec()
756 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
757 if buf.len() > 255 {
758 return Err(cmd_err!(
759 CmdErrorCode::InvalidParam,
760 "header len too large"
761 ));
762 }
763 write
764 .write_u8(buf.len() as u8)
765 .await
766 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
767 write
768 .write_all(buf.as_slice())
769 .await
770 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
771 tokio::io::copy(&mut body, write.deref_mut().deref_mut())
772 .await
773 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
774 write
775 .flush()
776 .await
777 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
778 }
779 Err(e) => {
780 log::error!("handle cmd error: {:?}", e);
781 }
782 _ => {}
783 }
784 recv = future
785 .await
786 .map_err(into_cmd_err!(CmdErrorCode::Failed))??;
787 log::debug!(
788 "handle cmd {:?} from {} len {} tunnel {:?} complete",
789 cmd_code,
790 peer_id,
791 body_len,
792 tunnel_id
793 );
794 }
795 Ok(())
796 }
797 .await;
798 if ret.is_err() {
799 log::error!("recv cmd error: {:?}", ret.as_ref().err().unwrap());
800 }
801 ret
802 });
803 Ok(ClassifiedCmdSend::new(
804 tunnel_id,
805 classification,
806 None,
807 pool_tunnel_id,
808 pool_classification,
809 handle,
810 write,
811 self.resp_waiter.clone(),
812 remote_id,
813 tunnel_meta,
814 ))
815 }
816}
817
818pub struct DefaultClassifiedCmdClient<
819 C: WorkerClassification,
820 M: CmdTunnelMeta,
821 R: ClassifiedCmdTunnelRead<C, M>,
822 W: ClassifiedCmdTunnelWrite<C, M>,
823 F: ClassifiedCmdTunnelFactory<C, M, R, W>,
824 LEN: RawEncode
825 + for<'a> RawDecode<'a>
826 + Copy
827 + Send
828 + Sync
829 + 'static
830 + FromPrimitive
831 + ToPrimitive
832 + RawFixedBytes,
833 CMD: RawEncode
834 + for<'a> RawDecode<'a>
835 + Copy
836 + Send
837 + Sync
838 + 'static
839 + RawFixedBytes
840 + Eq
841 + Hash
842 + Debug,
843> {
844 tunnel_pool: ClassifiedWorkerPoolRef<
845 CmdClientTunnelClassification<C>,
846 ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
847 ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>,
848 >,
849 cmd_handler_map: Arc<CmdHandlerMap<LEN, CMD>>,
850}
851
852impl<
853 C: WorkerClassification,
854 M: CmdTunnelMeta,
855 R: ClassifiedCmdTunnelRead<C, M>,
856 W: ClassifiedCmdTunnelWrite<C, M>,
857 F: ClassifiedCmdTunnelFactory<C, M, R, W>,
858 LEN: RawEncode
859 + for<'a> RawDecode<'a>
860 + Copy
861 + Send
862 + Sync
863 + 'static
864 + FromPrimitive
865 + ToPrimitive
866 + RawFixedBytes,
867 CMD: RawEncode
868 + for<'a> RawDecode<'a>
869 + Copy
870 + Send
871 + Sync
872 + 'static
873 + RawFixedBytes
874 + Eq
875 + Hash
876 + Debug,
877> DefaultClassifiedCmdClient<C, M, R, W, F, LEN, CMD>
878{
879 pub fn new(factory: F, tunnel_count: u16) -> Arc<Self> {
880 let cmd_handler_map = Arc::new(CmdHandlerMap::new());
881 let resp_waiter = Arc::new(RespWaiter::new());
882 let handler_map = cmd_handler_map.clone();
883 let waiter = resp_waiter.clone();
884 Arc::new(Self {
885 tunnel_pool: ClassifiedWorkerPool::new(
886 tunnel_count,
887 ClassifiedCmdWriteFactory::<C, M, R, W, _, LEN, CMD>::new(
888 factory,
889 move |local_id: PeerId,
890 peer_id: PeerId,
891 tunnel_id: TunnelId,
892 header: CmdHeader<LEN, CMD>,
893 body_read: CmdBody| {
894 let handler_map = handler_map.clone();
895 let waiter = waiter.clone();
896 async move {
897 if header.is_resp() && header.seq().is_some() {
898 let resp_id = gen_resp_id(
899 tunnel_id,
900 header.cmd_code(),
901 header.seq().unwrap(),
902 );
903 let _ = waiter.set_result(resp_id, body_read);
904 Ok(None)
905 } else {
906 if let Some(handler) = handler_map.get(header.cmd_code()) {
907 handler
908 .handle(local_id, peer_id, tunnel_id, header, body_read)
909 .await
910 } else {
911 Ok(None)
912 }
913 }
914 }
915 },
916 resp_waiter.clone(),
917 ),
918 ),
919 cmd_handler_map,
920 })
921 }
922
923 async fn get_send(
924 &self,
925 ) -> CmdResult<
926 ClassifiedWorkerGuard<
927 CmdClientTunnelClassification<C>,
928 ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
929 ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>,
930 >,
931 > {
932 self.tunnel_pool
933 .get_worker()
934 .await
935 .map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
936 }
937
938 async fn get_send_of_tunnel_id(
939 &self,
940 tunnel_id: TunnelId,
941 ) -> CmdResult<
942 ClassifiedWorkerGuard<
943 CmdClientTunnelClassification<C>,
944 ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
945 ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>,
946 >,
947 > {
948 self.tunnel_pool
949 .get_classified_worker(CmdClientTunnelClassification {
950 tunnel_id: Some(tunnel_id),
951 classification: None,
952 })
953 .await
954 .map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
955 }
956
957 async fn get_classified_send(
958 &self,
959 classification: C,
960 ) -> CmdResult<
961 ClassifiedWorkerGuard<
962 CmdClientTunnelClassification<C>,
963 ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
964 ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>,
965 >,
966 > {
967 self.tunnel_pool
968 .get_classified_worker(CmdClientTunnelClassification {
969 tunnel_id: None,
970 classification: Some(classification),
971 })
972 .await
973 .map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
974 }
975}
976
977pub type ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD> = ClassifiedSendGuard<
978 CmdClientTunnelClassification<C>,
979 M,
980 ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
981 ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>,
982>;
983#[async_trait::async_trait]
984impl<
985 C: WorkerClassification,
986 M: CmdTunnelMeta,
987 R: ClassifiedCmdTunnelRead<C, M>,
988 W: ClassifiedCmdTunnelWrite<C, M>,
989 F: ClassifiedCmdTunnelFactory<C, M, R, W>,
990 LEN: RawEncode
991 + for<'a> RawDecode<'a>
992 + Copy
993 + Send
994 + Sync
995 + 'static
996 + FromPrimitive
997 + ToPrimitive
998 + RawFixedBytes,
999 CMD: RawEncode
1000 + for<'a> RawDecode<'a>
1001 + Copy
1002 + Send
1003 + Sync
1004 + 'static
1005 + RawFixedBytes
1006 + Eq
1007 + Hash
1008 + Debug,
1009>
1010 CmdClient<
1011 LEN,
1012 CMD,
1013 M,
1014 ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
1015 ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD>,
1016 > for DefaultClassifiedCmdClient<C, M, R, W, F, LEN, CMD>
1017{
1018 fn register_cmd_handler(&self, cmd: CMD, handler: impl CmdHandler<LEN, CMD>) {
1019 self.cmd_handler_map.insert(cmd, handler);
1020 }
1021
1022 async fn send(&self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
1023 let mut send = self.get_send().await?;
1024 send.send(cmd, version, body).await
1025 }
1026
1027 async fn send_with_resp(
1028 &self,
1029 cmd: CMD,
1030 version: u8,
1031 body: &[u8],
1032 timeout: Duration,
1033 ) -> CmdResult<CmdBody> {
1034 let mut send = self.get_send().await?;
1035 send.send_with_resp(cmd, version, body, timeout).await
1036 }
1037
1038 async fn send_parts(&self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
1039 let mut send = self.get_send().await?;
1040 send.send_parts(cmd, version, body).await
1041 }
1042
1043 async fn send_parts_with_resp(
1044 &self,
1045 cmd: CMD,
1046 version: u8,
1047 body: &[&[u8]],
1048 timeout: Duration,
1049 ) -> CmdResult<CmdBody> {
1050 let mut send = self.get_send().await?;
1051 send.send_parts_with_resp(cmd, version, body, timeout).await
1052 }
1053
1054 async fn send_cmd(&self, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
1055 let mut send = self.get_send().await?;
1056 send.send_cmd(cmd, version, body).await
1057 }
1058
1059 async fn send_cmd_with_resp(
1060 &self,
1061 cmd: CMD,
1062 version: u8,
1063 body: CmdBody,
1064 timeout: Duration,
1065 ) -> CmdResult<CmdBody> {
1066 let mut send = self.get_send().await?;
1067 send.send_cmd_with_resp(cmd, version, body, timeout).await
1068 }
1069
1070 async fn send_by_specify_tunnel(
1071 &self,
1072 tunnel_id: TunnelId,
1073 cmd: CMD,
1074 version: u8,
1075 body: &[u8],
1076 ) -> CmdResult<()> {
1077 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1078 send.send(cmd, version, body).await
1079 }
1080
1081 async fn send_by_specify_tunnel_with_resp(
1082 &self,
1083 tunnel_id: TunnelId,
1084 cmd: CMD,
1085 version: u8,
1086 body: &[u8],
1087 timeout: Duration,
1088 ) -> CmdResult<CmdBody> {
1089 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1090 send.send_with_resp(cmd, version, body, timeout).await
1091 }
1092
1093 async fn send_parts_by_specify_tunnel(
1094 &self,
1095 tunnel_id: TunnelId,
1096 cmd: CMD,
1097 version: u8,
1098 body: &[&[u8]],
1099 ) -> CmdResult<()> {
1100 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1101 send.send_parts(cmd, version, body).await
1102 }
1103
1104 async fn send_parts_by_specify_tunnel_with_resp(
1105 &self,
1106 tunnel_id: TunnelId,
1107 cmd: CMD,
1108 version: u8,
1109 body: &[&[u8]],
1110 timeout: Duration,
1111 ) -> CmdResult<CmdBody> {
1112 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1113 send.send_parts_with_resp(cmd, version, body, timeout).await
1114 }
1115
1116 async fn send_cmd_by_specify_tunnel(
1117 &self,
1118 tunnel_id: TunnelId,
1119 cmd: CMD,
1120 version: u8,
1121 body: CmdBody,
1122 ) -> CmdResult<()> {
1123 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1124 send.send_cmd(cmd, version, body).await
1125 }
1126
1127 async fn send_cmd_by_specify_tunnel_with_resp(
1128 &self,
1129 tunnel_id: TunnelId,
1130 cmd: CMD,
1131 version: u8,
1132 body: CmdBody,
1133 timeout: Duration,
1134 ) -> CmdResult<CmdBody> {
1135 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1136 send.send_cmd_with_resp(cmd, version, body, timeout).await
1137 }
1138
1139 async fn clear_all_tunnel(&self) {
1140 self.tunnel_pool.clear_all_worker().await;
1141 }
1142
1143 async fn get_send(
1144 &self,
1145 tunnel_id: TunnelId,
1146 ) -> CmdResult<ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD>> {
1147 Ok(ClassifiedSendGuard {
1148 worker_guard: self.get_send_of_tunnel_id(tunnel_id).await?,
1149 _p: std::marker::PhantomData,
1150 })
1151 }
1152}
1153
1154#[async_trait::async_trait]
1155impl<
1156 C: WorkerClassification,
1157 M: CmdTunnelMeta,
1158 R: ClassifiedCmdTunnelRead<C, M>,
1159 W: ClassifiedCmdTunnelWrite<C, M>,
1160 F: ClassifiedCmdTunnelFactory<C, M, R, W>,
1161 LEN: RawEncode
1162 + for<'a> RawDecode<'a>
1163 + Copy
1164 + Send
1165 + Sync
1166 + 'static
1167 + FromPrimitive
1168 + ToPrimitive
1169 + RawFixedBytes,
1170 CMD: RawEncode
1171 + for<'a> RawDecode<'a>
1172 + Copy
1173 + Send
1174 + Sync
1175 + 'static
1176 + RawFixedBytes
1177 + Eq
1178 + Hash
1179 + Debug,
1180>
1181 ClassifiedCmdClient<
1182 LEN,
1183 CMD,
1184 C,
1185 M,
1186 ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
1187 ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD>,
1188 > for DefaultClassifiedCmdClient<C, M, R, W, F, LEN, CMD>
1189{
1190 async fn send_by_classified_tunnel(
1191 &self,
1192 classification: C,
1193 cmd: CMD,
1194 version: u8,
1195 body: &[u8],
1196 ) -> CmdResult<()> {
1197 let mut send = self.get_classified_send(classification).await?;
1198 send.send(cmd, version, body).await
1199 }
1200
1201 async fn send_by_classified_tunnel_with_resp(
1202 &self,
1203 classification: C,
1204 cmd: CMD,
1205 version: u8,
1206 body: &[u8],
1207 timeout: Duration,
1208 ) -> CmdResult<CmdBody> {
1209 let mut send = self.get_classified_send(classification).await?;
1210 send.send_with_resp(cmd, version, body, timeout).await
1211 }
1212
1213 async fn send_parts_by_classified_tunnel(
1214 &self,
1215 classification: C,
1216 cmd: CMD,
1217 version: u8,
1218 body: &[&[u8]],
1219 ) -> CmdResult<()> {
1220 let mut send = self.get_classified_send(classification).await?;
1221 send.send_parts(cmd, version, body).await
1222 }
1223
1224 async fn send_parts_by_classified_tunnel_with_resp(
1225 &self,
1226 classification: C,
1227 cmd: CMD,
1228 version: u8,
1229 body: &[&[u8]],
1230 timeout: Duration,
1231 ) -> CmdResult<CmdBody> {
1232 let mut send = self.get_classified_send(classification).await?;
1233 send.send_parts_with_resp(cmd, version, body, timeout).await
1234 }
1235
1236 async fn send_cmd_by_classified_tunnel(
1237 &self,
1238 classification: C,
1239 cmd: CMD,
1240 version: u8,
1241 body: CmdBody,
1242 ) -> CmdResult<()> {
1243 let mut send = self.get_classified_send(classification).await?;
1244 send.send_cmd(cmd, version, body).await
1245 }
1246
1247 async fn send_cmd_by_classified_tunnel_with_resp(
1248 &self,
1249 classification: C,
1250 cmd: CMD,
1251 version: u8,
1252 body: CmdBody,
1253 timeout: Duration,
1254 ) -> CmdResult<CmdBody> {
1255 let mut send = self.get_classified_send(classification).await?;
1256 send.send_cmd_with_resp(cmd, version, body, timeout).await
1257 }
1258
1259 async fn find_tunnel_id_by_classified(&self, classification: C) -> CmdResult<TunnelId> {
1260 let send = self.get_classified_send(classification).await?;
1261 Ok(send.get_tunnel_id())
1262 }
1263
1264 async fn get_send_by_classified(
1265 &self,
1266 classification: C,
1267 ) -> CmdResult<
1268 ClassifiedSendGuard<
1269 CmdClientTunnelClassification<C>,
1270 M,
1271 ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
1272 ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>,
1273 >,
1274 > {
1275 Ok(ClassifiedSendGuard {
1276 worker_guard: self.get_classified_send(classification).await?,
1277 _p: std::marker::PhantomData,
1278 })
1279 }
1280}