1use crate::client::{
2 ClassifiedCmdClient, CmdClient, CmdSend, RespWaiter, RespWaiterRef, TrackedSendGuard,
3 TunnelReserveResult, TunnelRuntimeRegistry, gen_resp_id, gen_seq,
4};
5use crate::cmd::{CmdBodyRead, CmdHandler, CmdHandlerMap, CmdHeader};
6use crate::errors::{CmdErrorCode, CmdResult, cmd_err, into_cmd_err};
7use crate::peer_id::PeerId;
8use crate::{CmdBody, CmdTunnelMeta, CmdTunnelRead, CmdTunnelWrite, TunnelId, TunnelIdGenerator};
9use async_named_locker::ObjectHolder;
10use bucky_raw_codec::{RawConvertTo, RawDecode, RawEncode, RawFixedBytes, RawFrom};
11use num::{FromPrimitive, ToPrimitive};
12use sfo_pool::{
13 ClassifiedWorker, ClassifiedWorkerFactory, ClassifiedWorkerGuard, ClassifiedWorkerPool,
14 ClassifiedWorkerPoolRef, PoolErrorCode, PoolResult, WorkerClassification, into_pool_err,
15 pool_err,
16};
17use sfo_split::{RHalf, Splittable, WHalf};
18use std::fmt::Debug;
19use std::hash::Hash;
20use std::ops::DerefMut;
21use std::sync::{Arc, Mutex};
22use std::time::Duration;
23use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
24use tokio::spawn;
25use tokio::task::JoinHandle;
26use tokio::task::yield_now;
27
28pub trait ClassifiedCmdTunnelRead<C: WorkerClassification, M: CmdTunnelMeta>:
29 CmdTunnelRead<M> + 'static + Send
30{
31 fn get_classification(&self) -> C;
32}
33
34pub trait ClassifiedCmdTunnelWrite<C: WorkerClassification, M: CmdTunnelMeta>:
35 CmdTunnelWrite<M> + 'static + Send
36{
37 fn get_classification(&self) -> C;
38}
39
40pub type ClassifiedCmdTunnel<R, W> = Splittable<R, W>;
41pub type ClassifiedCmdTunnelRHalf<R, W> = RHalf<R, W>;
42pub type ClassifiedCmdTunnelWHalf<R, W> = WHalf<R, W>;
43
44#[derive(Debug, Clone, Copy, Eq, Hash)]
45pub struct CmdClientTunnelClassification<C: WorkerClassification> {
46 tunnel_id: Option<TunnelId>,
47 classification: Option<C>,
48}
49
50impl<C: WorkerClassification> PartialEq for CmdClientTunnelClassification<C> {
51 fn eq(&self, other: &Self) -> bool {
52 self.tunnel_id == other.tunnel_id && self.classification == other.classification
53 }
54}
55
56#[async_trait::async_trait]
57pub trait ClassifiedCmdTunnelFactory<
58 C: WorkerClassification,
59 M: CmdTunnelMeta,
60 R: ClassifiedCmdTunnelRead<C, M>,
61 W: ClassifiedCmdTunnelWrite<C, M>,
62>: Send + Sync + 'static
63{
64 async fn create_tunnel(&self, classification: Option<C>) -> CmdResult<Splittable<R, W>>;
65}
66
67pub struct ClassifiedCmdSend<C, M, R, W, LEN, CMD>
68where
69 C: WorkerClassification,
70 M: CmdTunnelMeta,
71 R: ClassifiedCmdTunnelRead<C, M>,
72 W: ClassifiedCmdTunnelWrite<C, M>,
73 LEN: RawEncode
74 + for<'a> RawDecode<'a>
75 + Copy
76 + Send
77 + Sync
78 + 'static
79 + FromPrimitive
80 + ToPrimitive,
81 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
82{
83 pub(crate) recv_handle: JoinHandle<CmdResult<()>>,
84 pub(crate) write: ObjectHolder<ClassifiedCmdTunnelWHalf<R, W>>,
85 pub(crate) is_work: bool,
86 pub(crate) classification: C,
87 pub(crate) pool_tunnel_id: Option<TunnelId>,
88 pub(crate) pool_classification: Option<C>,
89 pub(crate) tunnel_id: TunnelId,
90 pub(crate) resp_waiter: RespWaiterRef,
91 pub(crate) remote_id: PeerId,
92 pub(crate) pool_peer_id: Option<PeerId>,
93 pub(crate) tunnel_meta: Option<Arc<M>>,
94 _p: std::marker::PhantomData<(LEN, CMD)>,
95}
96
97impl<C, M, R, W, LEN, CMD> ClassifiedCmdSend<C, M, R, W, LEN, CMD>
111where
112 C: WorkerClassification,
113 M: CmdTunnelMeta,
114 R: ClassifiedCmdTunnelRead<C, M>,
115 W: ClassifiedCmdTunnelWrite<C, M>,
116 LEN: RawEncode
117 + for<'a> RawDecode<'a>
118 + Copy
119 + Send
120 + Sync
121 + 'static
122 + FromPrimitive
123 + ToPrimitive,
124 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
125{
126 pub(crate) fn new(
127 tunnel_id: TunnelId,
128 classification: C,
129 pool_peer_id: Option<PeerId>,
130 pool_tunnel_id: Option<TunnelId>,
131 pool_classification: Option<C>,
132 recv_handle: JoinHandle<CmdResult<()>>,
133 write: ObjectHolder<ClassifiedCmdTunnelWHalf<R, W>>,
134 resp_waiter: RespWaiterRef,
135 remote_id: PeerId,
136 tunnel_meta: Option<Arc<M>>,
137 ) -> Self {
138 Self {
139 recv_handle,
140 write,
141 is_work: true,
142 classification,
143 pool_tunnel_id,
144 pool_classification,
145 tunnel_id,
146 resp_waiter,
147 remote_id,
148 pool_peer_id,
149 tunnel_meta,
150 _p: Default::default(),
151 }
152 }
153
154 pub fn get_tunnel_id(&self) -> TunnelId {
155 self.tunnel_id
156 }
157
158 pub fn set_disable(&mut self) {
159 self.is_work = false;
160 self.recv_handle.abort();
161 }
162
163 pub(crate) fn set_pool_key(
164 &mut self,
165 peer_id: Option<PeerId>,
166 tunnel_id: Option<TunnelId>,
167 classification: Option<C>,
168 ) {
169 self.pool_peer_id = peer_id;
170 self.pool_tunnel_id = tunnel_id;
171 self.pool_classification = classification;
172 }
173
174 pub async fn send(&mut self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
175 log::trace!(
176 "client {:?} send cmd: {:?}, len: {}, data: {}",
177 self.tunnel_id,
178 cmd,
179 body.len(),
180 hex::encode(body)
181 );
182 let header = CmdHeader::<LEN, CMD>::new(
183 version,
184 false,
185 None,
186 cmd,
187 LEN::from_u64(body.len() as u64).unwrap(),
188 );
189 let buf = header
190 .to_vec()
191 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
192 let ret = self.send_inner(buf.as_slice(), body).await;
193 if let Err(e) = ret {
194 self.set_disable();
195 return Err(e);
196 }
197 Ok(())
198 }
199
200 pub async fn send_with_resp(
201 &mut self,
202 cmd: CMD,
203 version: u8,
204 body: &[u8],
205 timeout: Duration,
206 ) -> CmdResult<CmdBody> {
207 if let Some(id) = tokio::task::try_id() {
208 if id == self.recv_handle.id() {
209 return Err(cmd_err!(
210 CmdErrorCode::Failed,
211 "can't send with resp in recv task"
212 ));
213 }
214 }
215 log::trace!(
216 "client {:?} send cmd: {:?}, len: {}, data: {}",
217 self.tunnel_id,
218 cmd,
219 body.len(),
220 hex::encode(body)
221 );
222 let seq = gen_seq();
223 let header = CmdHeader::<LEN, CMD>::new(
224 version,
225 false,
226 Some(seq),
227 cmd,
228 LEN::from_u64(body.len() as u64).unwrap(),
229 );
230 let buf = header
231 .to_vec()
232 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
233 let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
234 let waiter = self.resp_waiter.clone();
235 let resp_waiter = waiter
236 .create_timeout_result_future(resp_id, timeout)
237 .map_err(into_cmd_err!(
238 CmdErrorCode::Failed,
239 "create timeout result future error"
240 ))?;
241 let ret = self.send_inner(buf.as_slice(), body).await;
242 if let Err(e) = ret {
243 self.set_disable();
244 return Err(e);
245 }
246 let resp = resp_waiter
247 .await
248 .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
249 Ok(resp)
250 }
251
252 pub async fn send_parts(&mut self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
253 let mut len = 0;
254 for b in body.iter() {
255 len += b.len();
256 log::trace!(
257 "client {:?} send2 cmd {:?} body: {}",
258 self.tunnel_id,
259 cmd,
260 hex::encode(b)
261 );
262 }
263 log::trace!(
264 "client {:?} send2 cmd: {:?}, len {}",
265 self.tunnel_id,
266 cmd,
267 len
268 );
269 let header = CmdHeader::<LEN, CMD>::new(
270 version,
271 false,
272 None,
273 cmd,
274 LEN::from_u64(len as u64).unwrap(),
275 );
276 let buf = header
277 .to_vec()
278 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
279 let ret = self.send_inner2(buf.as_slice(), body).await;
280 if let Err(e) = ret {
281 self.set_disable();
282 return Err(e);
283 }
284 Ok(())
285 }
286
287 pub async fn send_parts_with_resp(
288 &mut self,
289 cmd: CMD,
290 version: u8,
291 body: &[&[u8]],
292 timeout: Duration,
293 ) -> CmdResult<CmdBody> {
294 if let Some(id) = tokio::task::try_id() {
295 if id == self.recv_handle.id() {
296 return Err(cmd_err!(
297 CmdErrorCode::Failed,
298 "can't send with resp in recv task"
299 ));
300 }
301 }
302 let mut len = 0;
303 for b in body.iter() {
304 len += b.len();
305 log::trace!(
306 "client {:?} send2 cmd {:?} body: {}",
307 self.tunnel_id,
308 cmd,
309 hex::encode(b)
310 );
311 }
312 log::trace!(
313 "client {:?} send2 cmd: {:?}, len {}",
314 self.tunnel_id,
315 cmd,
316 len
317 );
318 let seq = gen_seq();
319 let header = CmdHeader::<LEN, CMD>::new(
320 version,
321 false,
322 Some(seq),
323 cmd,
324 LEN::from_u64(len as u64).unwrap(),
325 );
326 let buf = header
327 .to_vec()
328 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
329 let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
330 let waiter = self.resp_waiter.clone();
331 let resp_waiter = waiter
332 .create_timeout_result_future(resp_id, timeout)
333 .map_err(into_cmd_err!(
334 CmdErrorCode::Failed,
335 "create timeout result future error"
336 ))?;
337 let ret = self.send_inner2(buf.as_slice(), body).await;
338 if let Err(e) = ret {
339 self.set_disable();
340 return Err(e);
341 }
342 let resp = resp_waiter
343 .await
344 .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
345 Ok(resp)
346 }
347
348 #[allow(deprecated)]
349 #[deprecated(note = "use send_parts instead")]
350 pub async fn send2(&mut self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
351 self.send_parts(cmd, version, body).await
352 }
353
354 #[allow(deprecated)]
355 #[deprecated(note = "use send_parts_with_resp instead")]
356 pub async fn send2_with_resp(
357 &mut self,
358 cmd: CMD,
359 version: u8,
360 body: &[&[u8]],
361 timeout: Duration,
362 ) -> CmdResult<CmdBody> {
363 self.send_parts_with_resp(cmd, version, body, timeout).await
364 }
365
366 pub async fn send_cmd(&mut self, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
367 log::trace!(
368 "client {:?} send cmd: {:?}, len: {}",
369 self.tunnel_id,
370 cmd,
371 body.len()
372 );
373 let header = CmdHeader::<LEN, CMD>::new(
374 version,
375 false,
376 None,
377 cmd,
378 LEN::from_u64(body.len()).unwrap(),
379 );
380 let buf = header
381 .to_vec()
382 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
383 let ret = self.send_inner_cmd(buf.as_slice(), body).await;
384 if let Err(e) = ret {
385 self.set_disable();
386 return Err(e);
387 }
388 Ok(())
389 }
390
391 pub async fn send_cmd_with_resp(
392 &mut self,
393 cmd: CMD,
394 version: u8,
395 body: CmdBody,
396 timeout: Duration,
397 ) -> CmdResult<CmdBody> {
398 if let Some(id) = tokio::task::try_id() {
399 if id == self.recv_handle.id() {
400 return Err(cmd_err!(
401 CmdErrorCode::Failed,
402 "can't send with resp in recv task"
403 ));
404 }
405 }
406 log::trace!(
407 "client {:?} send cmd: {:?}, len: {}",
408 self.tunnel_id,
409 cmd,
410 body.len()
411 );
412 let seq = gen_seq();
413 let header = CmdHeader::<LEN, CMD>::new(
414 version,
415 false,
416 Some(seq),
417 cmd,
418 LEN::from_u64(body.len()).unwrap(),
419 );
420 let buf = header
421 .to_vec()
422 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
423 let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
424 let waiter = self.resp_waiter.clone();
425 let resp_waiter = waiter
426 .create_timeout_result_future(resp_id, timeout)
427 .map_err(into_cmd_err!(
428 CmdErrorCode::Failed,
429 "create timeout result future error"
430 ))?;
431 let ret = self.send_inner_cmd(buf.as_slice(), body).await;
432 if let Err(e) = ret {
433 self.set_disable();
434 return Err(e);
435 }
436 let resp = resp_waiter
437 .await
438 .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
439 Ok(resp)
440 }
441
442 async fn send_inner(&mut self, header: &[u8], body: &[u8]) -> CmdResult<()> {
443 let mut write = self.write.get().await;
444 if header.len() > 255 {
445 return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
446 }
447 write
448 .write_u8(header.len() as u8)
449 .await
450 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
451 write
452 .write_all(header)
453 .await
454 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
455 write
456 .write_all(body)
457 .await
458 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
459 write
460 .flush()
461 .await
462 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
463 Ok(())
464 }
465
466 async fn send_inner2(&mut self, header: &[u8], body: &[&[u8]]) -> CmdResult<()> {
467 let mut write = self.write.get().await;
468 if header.len() > 255 {
469 return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
470 }
471 write
472 .write_u8(header.len() as u8)
473 .await
474 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
475 write
476 .write_all(header)
477 .await
478 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
479 for b in body.iter() {
480 write
481 .write_all(b)
482 .await
483 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
484 }
485 write
486 .flush()
487 .await
488 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
489 Ok(())
490 }
491
492 async fn send_inner_cmd(&mut self, header: &[u8], mut body: CmdBody) -> CmdResult<()> {
493 let mut write = self.write.get().await;
494 if header.len() > 255 {
495 return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
496 }
497 write
498 .write_u8(header.len() as u8)
499 .await
500 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
501 write
502 .write_all(header)
503 .await
504 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
505 tokio::io::copy(&mut body, write.deref_mut().deref_mut())
506 .await
507 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
508 write
509 .flush()
510 .await
511 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
512 Ok(())
513 }
514}
515
516impl<C, M, R, W, LEN, CMD> Drop for ClassifiedCmdSend<C, M, R, W, LEN, CMD>
517where
518 C: WorkerClassification,
519 M: CmdTunnelMeta,
520 R: ClassifiedCmdTunnelRead<C, M>,
521 W: ClassifiedCmdTunnelWrite<C, M>,
522 LEN: RawEncode
523 + for<'a> RawDecode<'a>
524 + Copy
525 + Send
526 + Sync
527 + 'static
528 + FromPrimitive
529 + ToPrimitive,
530 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
531{
532 fn drop(&mut self) {
533 self.set_disable();
534 }
535}
536
537impl<C, M, R, W, LEN, CMD> CmdSend<M> for ClassifiedCmdSend<C, M, R, W, LEN, CMD>
538where
539 C: WorkerClassification,
540 M: CmdTunnelMeta,
541 R: ClassifiedCmdTunnelRead<C, M>,
542 W: ClassifiedCmdTunnelWrite<C, M>,
543 LEN: RawEncode
544 + for<'a> RawDecode<'a>
545 + Copy
546 + Send
547 + Sync
548 + 'static
549 + FromPrimitive
550 + ToPrimitive,
551 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
552{
553 fn get_tunnel_meta(&self) -> Option<Arc<M>> {
554 self.tunnel_meta.clone()
555 }
556
557 fn get_remote_peer_id(&self) -> PeerId {
558 self.remote_id.clone()
559 }
560}
561
562impl<C, M, R, W, LEN, CMD> ClassifiedWorker<CmdClientTunnelClassification<C>>
563 for ClassifiedCmdSend<C, M, R, W, LEN, CMD>
564where
565 C: WorkerClassification,
566 M: CmdTunnelMeta,
567 R: ClassifiedCmdTunnelRead<C, M>,
568 W: ClassifiedCmdTunnelWrite<C, M>,
569 LEN: RawEncode
570 + for<'a> RawDecode<'a>
571 + Copy
572 + Send
573 + Sync
574 + 'static
575 + FromPrimitive
576 + ToPrimitive,
577 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
578{
579 fn is_work(&self) -> bool {
580 self.is_work && !self.recv_handle.is_finished()
581 }
582
583 fn is_valid(&self, c: CmdClientTunnelClassification<C>) -> bool {
584 if c.tunnel_id.is_some() {
585 self.tunnel_id == c.tunnel_id.unwrap()
586 } else {
587 if c.classification.is_some() {
588 self.classification == c.classification.unwrap()
589 } else {
590 true
591 }
592 }
593 }
594
595 fn classification(&self) -> CmdClientTunnelClassification<C> {
596 CmdClientTunnelClassification {
597 tunnel_id: self.pool_tunnel_id,
598 classification: self.pool_classification.clone(),
599 }
600 }
601}
602
603pub struct ClassifiedCmdWriteFactory<
604 C: WorkerClassification,
605 M: CmdTunnelMeta,
606 R: ClassifiedCmdTunnelRead<C, M>,
607 W: ClassifiedCmdTunnelWrite<C, M>,
608 F: ClassifiedCmdTunnelFactory<C, M, R, W>,
609 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
610 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
611> {
612 tunnel_factory: F,
613 cmd_handler: Arc<dyn CmdHandler<LEN, CMD>>,
614 resp_waiter: RespWaiterRef,
615 tunnel_id_generator: TunnelIdGenerator,
616 _p: std::marker::PhantomData<Mutex<(C, M, R, W)>>,
617}
618
619impl<
620 C: WorkerClassification,
621 M: CmdTunnelMeta,
622 R: ClassifiedCmdTunnelRead<C, M>,
623 W: ClassifiedCmdTunnelWrite<C, M>,
624 F: ClassifiedCmdTunnelFactory<C, M, R, W>,
625 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
626 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
627> ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>
628{
629 pub(crate) fn new(
630 tunnel_factory: F,
631 cmd_handler: impl CmdHandler<LEN, CMD>,
632 resp_waiter: RespWaiterRef,
633 ) -> Self {
634 Self {
635 tunnel_factory,
636 cmd_handler: Arc::new(cmd_handler),
637 resp_waiter,
638 tunnel_id_generator: TunnelIdGenerator::new(),
639 _p: Default::default(),
640 }
641 }
642}
643
644#[async_trait::async_trait]
645impl<
646 C: WorkerClassification,
647 M: CmdTunnelMeta,
648 R: ClassifiedCmdTunnelRead<C, M>,
649 W: ClassifiedCmdTunnelWrite<C, M>,
650 F: ClassifiedCmdTunnelFactory<C, M, R, W>,
651 LEN: RawEncode
652 + for<'a> RawDecode<'a>
653 + Copy
654 + Send
655 + Sync
656 + 'static
657 + FromPrimitive
658 + ToPrimitive
659 + RawFixedBytes,
660 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Debug,
661> ClassifiedWorkerFactory<CmdClientTunnelClassification<C>, ClassifiedCmdSend<C, M, R, W, LEN, CMD>>
662 for ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>
663{
664 async fn create(
665 &self,
666 classification: Option<CmdClientTunnelClassification<C>>,
667 ) -> PoolResult<ClassifiedCmdSend<C, M, R, W, LEN, CMD>> {
668 if classification.is_some() && classification.as_ref().unwrap().tunnel_id.is_some() {
669 return Err(pool_err!(
670 PoolErrorCode::Failed,
671 "tunnel {:?} not found",
672 classification.as_ref().unwrap().tunnel_id.unwrap()
673 ));
674 }
675
676 let requested_classification = classification;
677 let classification = requested_classification
678 .as_ref()
679 .and_then(|key| key.classification.clone());
680 let tunnel = self
681 .tunnel_factory
682 .create_tunnel(classification)
683 .await
684 .map_err(into_pool_err!(PoolErrorCode::Failed))?;
685 let classification = tunnel.get_classification();
686 let pool_tunnel_id = requested_classification
687 .as_ref()
688 .and_then(|key| key.tunnel_id);
689 let pool_classification = requested_classification
690 .as_ref()
691 .and_then(|key| key.classification.clone())
692 .or(Some(classification.clone()));
693 let peer_id = tunnel.get_remote_peer_id();
694 let tunnel_id = self.tunnel_id_generator.generate();
695 let (mut recv, write) = tunnel.split();
696 let local_id = recv.get_local_peer_id();
697 let remote_id = peer_id.clone();
698 let tunnel_meta = recv.get_tunnel_meta();
699 let write = ObjectHolder::new(write);
700 let resp_write = write.clone();
701 let cmd_handler = self.cmd_handler.clone();
702 let handle = spawn(async move {
703 let ret: CmdResult<()> = async move {
704 loop {
705 let header_len = recv
706 .read_u8()
707 .await
708 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
709 let mut header = vec![0u8; header_len as usize];
710 let n = recv
711 .read_exact(header.as_mut())
712 .await
713 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
714 if n == 0 {
715 break;
716 }
717 let header = CmdHeader::<LEN, CMD>::clone_from_slice(header.as_slice())
718 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
719 log::trace!(
720 "recv cmd {:?} from {} len {} tunnel {:?}",
721 header.cmd_code(),
722 peer_id,
723 header.pkg_len().to_u64().unwrap(),
724 tunnel_id
725 );
726 let body_len = header.pkg_len().to_u64().unwrap();
727 let cmd_read =
728 CmdBodyRead::new(recv, header.pkg_len().to_u64().unwrap() as usize);
729 let waiter = cmd_read.get_waiter();
730 let future = waiter
731 .create_result_future()
732 .map_err(into_cmd_err!(CmdErrorCode::Failed))?;
733 let version = header.version();
734 let seq = header.seq();
735 let cmd_code = header.cmd_code();
736 match cmd_handler
737 .handle(
738 local_id.clone(),
739 peer_id.clone(),
740 tunnel_id,
741 header,
742 CmdBody::from_reader(BufReader::new(cmd_read), body_len),
743 )
744 .await
745 {
746 Ok(Some(mut body)) => {
747 let mut write = resp_write.get().await;
748 let header = CmdHeader::<LEN, CMD>::new(
749 version,
750 true,
751 seq,
752 cmd_code,
753 LEN::from_u64(body.len()).unwrap(),
754 );
755 let buf = header
756 .to_vec()
757 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
758 if buf.len() > 255 {
759 return Err(cmd_err!(
760 CmdErrorCode::InvalidParam,
761 "header len too large"
762 ));
763 }
764 write
765 .write_u8(buf.len() as u8)
766 .await
767 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
768 write
769 .write_all(buf.as_slice())
770 .await
771 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
772 tokio::io::copy(&mut body, write.deref_mut().deref_mut())
773 .await
774 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
775 write
776 .flush()
777 .await
778 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
779 }
780 Err(e) => {
781 log::error!("handle cmd error: {:?}", e);
782 }
783 _ => {}
784 }
785 recv = future
786 .await
787 .map_err(into_cmd_err!(CmdErrorCode::Failed))??;
788 log::debug!(
789 "handle cmd {:?} from {} len {} tunnel {:?} complete",
790 cmd_code,
791 peer_id,
792 body_len,
793 tunnel_id
794 );
795 }
796 Ok(())
797 }
798 .await;
799 if ret.is_err() {
800 log::error!("recv cmd error: {:?}", ret.as_ref().err().unwrap());
801 }
802 ret
803 });
804 Ok(ClassifiedCmdSend::new(
805 tunnel_id,
806 classification,
807 None,
808 pool_tunnel_id,
809 pool_classification,
810 handle,
811 write,
812 self.resp_waiter.clone(),
813 remote_id,
814 tunnel_meta,
815 ))
816 }
817}
818
819pub struct DefaultClassifiedCmdClient<
820 C: WorkerClassification,
821 M: CmdTunnelMeta,
822 R: ClassifiedCmdTunnelRead<C, M>,
823 W: ClassifiedCmdTunnelWrite<C, M>,
824 F: ClassifiedCmdTunnelFactory<C, M, R, W>,
825 LEN: RawEncode
826 + for<'a> RawDecode<'a>
827 + Copy
828 + Send
829 + Sync
830 + 'static
831 + FromPrimitive
832 + ToPrimitive
833 + RawFixedBytes,
834 CMD: RawEncode
835 + for<'a> RawDecode<'a>
836 + Copy
837 + Send
838 + Sync
839 + 'static
840 + RawFixedBytes
841 + Eq
842 + Hash
843 + Debug,
844> {
845 tunnel_pool: ClassifiedWorkerPoolRef<
846 CmdClientTunnelClassification<C>,
847 ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
848 ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>,
849 >,
850 runtime_tunnels: Arc<TunnelRuntimeRegistry>,
851 cmd_handler_map: Arc<CmdHandlerMap<LEN, CMD>>,
852}
853
854impl<
855 C: WorkerClassification,
856 M: CmdTunnelMeta,
857 R: ClassifiedCmdTunnelRead<C, M>,
858 W: ClassifiedCmdTunnelWrite<C, M>,
859 F: ClassifiedCmdTunnelFactory<C, M, R, W>,
860 LEN: RawEncode
861 + for<'a> RawDecode<'a>
862 + Copy
863 + Send
864 + Sync
865 + 'static
866 + FromPrimitive
867 + ToPrimitive
868 + RawFixedBytes,
869 CMD: RawEncode
870 + for<'a> RawDecode<'a>
871 + Copy
872 + Send
873 + Sync
874 + 'static
875 + RawFixedBytes
876 + Eq
877 + Hash
878 + Debug,
879> DefaultClassifiedCmdClient<C, M, R, W, F, LEN, CMD>
880{
881 fn tunnel_not_found(tunnel_id: TunnelId) -> sfo_result::Error<CmdErrorCode> {
882 cmd_err!(CmdErrorCode::Failed, "tunnel {:?} not found", tunnel_id)
883 }
884
885 pub fn new(factory: F, tunnel_count: u16) -> Arc<Self> {
886 let cmd_handler_map = Arc::new(CmdHandlerMap::new());
887 let resp_waiter = Arc::new(RespWaiter::new());
888 let handler_map = cmd_handler_map.clone();
889 let waiter = resp_waiter.clone();
890 Arc::new(Self {
891 tunnel_pool: ClassifiedWorkerPool::new(
892 tunnel_count,
893 ClassifiedCmdWriteFactory::<C, M, R, W, _, LEN, CMD>::new(
894 factory,
895 move |local_id: PeerId,
896 peer_id: PeerId,
897 tunnel_id: TunnelId,
898 header: CmdHeader<LEN, CMD>,
899 body_read: CmdBody| {
900 let handler_map = handler_map.clone();
901 let waiter = waiter.clone();
902 async move {
903 if header.is_resp() && header.seq().is_some() {
904 let resp_id = gen_resp_id(
905 tunnel_id,
906 header.cmd_code(),
907 header.seq().unwrap(),
908 );
909 let _ = waiter.set_result(resp_id, body_read);
910 Ok(None)
911 } else {
912 if let Some(handler) = handler_map.get(header.cmd_code()) {
913 handler
914 .handle(local_id, peer_id, tunnel_id, header, body_read)
915 .await
916 } else {
917 Ok(None)
918 }
919 }
920 }
921 },
922 resp_waiter.clone(),
923 ),
924 ),
925 runtime_tunnels: TunnelRuntimeRegistry::new(),
926 cmd_handler_map,
927 })
928 }
929
930 fn tracked_send_guard(
931 &self,
932 worker_guard: ClassifiedWorkerGuard<
933 CmdClientTunnelClassification<C>,
934 ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
935 ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>,
936 >,
937 ) -> ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD> {
938 let tunnel_id = worker_guard.get_tunnel_id();
939 TrackedSendGuard::new(worker_guard, self.runtime_tunnels.clone(), tunnel_id)
940 }
941
942 async fn reserve_tunnel(&self, tunnel_id: TunnelId) -> CmdResult<()> {
943 loop {
944 match self.runtime_tunnels.reserve_existing(tunnel_id) {
945 TunnelReserveResult::Acquired => return Ok(()),
946 TunnelReserveResult::Wait(waiter) => waiter.notified().await,
947 TunnelReserveResult::Missing => return Err(Self::tunnel_not_found(tunnel_id)),
948 }
949 }
950 }
951
952 async fn get_send(&self) -> CmdResult<ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD>> {
953 loop {
954 let worker_guard = self
955 .tunnel_pool
956 .get_worker()
957 .await
958 .map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))?;
959 let tunnel_id = worker_guard.get_tunnel_id();
960 if self.runtime_tunnels.mark_borrowed(tunnel_id) {
961 return Ok(self.tracked_send_guard(worker_guard));
962 }
963 drop(worker_guard);
964 yield_now().await;
965 }
966 }
967
968 async fn get_send_of_tunnel_id(
969 &self,
970 tunnel_id: TunnelId,
971 ) -> CmdResult<ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD>> {
972 self.reserve_tunnel(tunnel_id).await?;
973 match self
974 .tunnel_pool
975 .get_classified_worker(CmdClientTunnelClassification {
976 tunnel_id: Some(tunnel_id),
977 classification: None,
978 })
979 .await
980 {
981 Ok(worker_guard) => Ok(self.tracked_send_guard(worker_guard)),
982 Err(_) => {
983 self.runtime_tunnels.remove(tunnel_id);
984 Err(Self::tunnel_not_found(tunnel_id))
985 }
986 }
987 }
988
989 async fn get_classified_send(
990 &self,
991 classification: C,
992 ) -> CmdResult<ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD>> {
993 loop {
994 let worker_guard = self
995 .tunnel_pool
996 .get_classified_worker(CmdClientTunnelClassification {
997 tunnel_id: None,
998 classification: Some(classification.clone()),
999 })
1000 .await
1001 .map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))?;
1002 let tunnel_id = worker_guard.get_tunnel_id();
1003 if self.runtime_tunnels.mark_borrowed(tunnel_id) {
1004 return Ok(self.tracked_send_guard(worker_guard));
1005 }
1006 drop(worker_guard);
1007 yield_now().await;
1008 }
1009 }
1010}
1011
1012pub type ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD> = TrackedSendGuard<
1013 CmdClientTunnelClassification<C>,
1014 M,
1015 ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
1016 ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>,
1017>;
1018#[async_trait::async_trait]
1019impl<
1020 C: WorkerClassification,
1021 M: CmdTunnelMeta,
1022 R: ClassifiedCmdTunnelRead<C, M>,
1023 W: ClassifiedCmdTunnelWrite<C, M>,
1024 F: ClassifiedCmdTunnelFactory<C, M, R, W>,
1025 LEN: RawEncode
1026 + for<'a> RawDecode<'a>
1027 + Copy
1028 + Send
1029 + Sync
1030 + 'static
1031 + FromPrimitive
1032 + ToPrimitive
1033 + RawFixedBytes,
1034 CMD: RawEncode
1035 + for<'a> RawDecode<'a>
1036 + Copy
1037 + Send
1038 + Sync
1039 + 'static
1040 + RawFixedBytes
1041 + Eq
1042 + Hash
1043 + Debug,
1044>
1045 CmdClient<
1046 LEN,
1047 CMD,
1048 M,
1049 ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
1050 ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD>,
1051 > for DefaultClassifiedCmdClient<C, M, R, W, F, LEN, CMD>
1052{
1053 fn register_cmd_handler(&self, cmd: CMD, handler: impl CmdHandler<LEN, CMD>) {
1054 self.cmd_handler_map.insert(cmd, handler);
1055 }
1056
1057 async fn send(&self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
1058 let mut send = self.get_send().await?;
1059 send.send(cmd, version, body).await
1060 }
1061
1062 async fn send_with_resp(
1063 &self,
1064 cmd: CMD,
1065 version: u8,
1066 body: &[u8],
1067 timeout: Duration,
1068 ) -> CmdResult<CmdBody> {
1069 let mut send = self.get_send().await?;
1070 send.send_with_resp(cmd, version, body, timeout).await
1071 }
1072
1073 async fn send_parts(&self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
1074 let mut send = self.get_send().await?;
1075 send.send_parts(cmd, version, body).await
1076 }
1077
1078 async fn send_parts_with_resp(
1079 &self,
1080 cmd: CMD,
1081 version: u8,
1082 body: &[&[u8]],
1083 timeout: Duration,
1084 ) -> CmdResult<CmdBody> {
1085 let mut send = self.get_send().await?;
1086 send.send_parts_with_resp(cmd, version, body, timeout).await
1087 }
1088
1089 async fn send_cmd(&self, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
1090 let mut send = self.get_send().await?;
1091 send.send_cmd(cmd, version, body).await
1092 }
1093
1094 async fn send_cmd_with_resp(
1095 &self,
1096 cmd: CMD,
1097 version: u8,
1098 body: CmdBody,
1099 timeout: Duration,
1100 ) -> CmdResult<CmdBody> {
1101 let mut send = self.get_send().await?;
1102 send.send_cmd_with_resp(cmd, version, body, timeout).await
1103 }
1104
1105 async fn send_by_specify_tunnel(
1106 &self,
1107 tunnel_id: TunnelId,
1108 cmd: CMD,
1109 version: u8,
1110 body: &[u8],
1111 ) -> CmdResult<()> {
1112 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1113 send.send(cmd, version, body).await
1114 }
1115
1116 async fn send_by_specify_tunnel_with_resp(
1117 &self,
1118 tunnel_id: TunnelId,
1119 cmd: CMD,
1120 version: u8,
1121 body: &[u8],
1122 timeout: Duration,
1123 ) -> CmdResult<CmdBody> {
1124 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1125 send.send_with_resp(cmd, version, body, timeout).await
1126 }
1127
1128 async fn send_parts_by_specify_tunnel(
1129 &self,
1130 tunnel_id: TunnelId,
1131 cmd: CMD,
1132 version: u8,
1133 body: &[&[u8]],
1134 ) -> CmdResult<()> {
1135 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1136 send.send_parts(cmd, version, body).await
1137 }
1138
1139 async fn send_parts_by_specify_tunnel_with_resp(
1140 &self,
1141 tunnel_id: TunnelId,
1142 cmd: CMD,
1143 version: u8,
1144 body: &[&[u8]],
1145 timeout: Duration,
1146 ) -> CmdResult<CmdBody> {
1147 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1148 send.send_parts_with_resp(cmd, version, body, timeout).await
1149 }
1150
1151 async fn send_cmd_by_specify_tunnel(
1152 &self,
1153 tunnel_id: TunnelId,
1154 cmd: CMD,
1155 version: u8,
1156 body: CmdBody,
1157 ) -> CmdResult<()> {
1158 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1159 send.send_cmd(cmd, version, body).await
1160 }
1161
1162 async fn send_cmd_by_specify_tunnel_with_resp(
1163 &self,
1164 tunnel_id: TunnelId,
1165 cmd: CMD,
1166 version: u8,
1167 body: CmdBody,
1168 timeout: Duration,
1169 ) -> CmdResult<CmdBody> {
1170 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1171 send.send_cmd_with_resp(cmd, version, body, timeout).await
1172 }
1173
1174 async fn clear_all_tunnel(&self) {
1175 self.tunnel_pool.clear_all_worker().await;
1176 self.runtime_tunnels.clear();
1177 }
1178
1179 async fn get_send(
1180 &self,
1181 tunnel_id: TunnelId,
1182 ) -> CmdResult<ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD>> {
1183 self.get_send_of_tunnel_id(tunnel_id).await
1184 }
1185}
1186
1187#[async_trait::async_trait]
1188impl<
1189 C: WorkerClassification,
1190 M: CmdTunnelMeta,
1191 R: ClassifiedCmdTunnelRead<C, M>,
1192 W: ClassifiedCmdTunnelWrite<C, M>,
1193 F: ClassifiedCmdTunnelFactory<C, M, R, W>,
1194 LEN: RawEncode
1195 + for<'a> RawDecode<'a>
1196 + Copy
1197 + Send
1198 + Sync
1199 + 'static
1200 + FromPrimitive
1201 + ToPrimitive
1202 + RawFixedBytes,
1203 CMD: RawEncode
1204 + for<'a> RawDecode<'a>
1205 + Copy
1206 + Send
1207 + Sync
1208 + 'static
1209 + RawFixedBytes
1210 + Eq
1211 + Hash
1212 + Debug,
1213>
1214 ClassifiedCmdClient<
1215 LEN,
1216 CMD,
1217 C,
1218 M,
1219 ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
1220 ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD>,
1221 > for DefaultClassifiedCmdClient<C, M, R, W, F, LEN, CMD>
1222{
1223 async fn send_by_classified_tunnel(
1224 &self,
1225 classification: C,
1226 cmd: CMD,
1227 version: u8,
1228 body: &[u8],
1229 ) -> CmdResult<()> {
1230 let mut send = self.get_classified_send(classification).await?;
1231 send.send(cmd, version, body).await
1232 }
1233
1234 async fn send_by_classified_tunnel_with_resp(
1235 &self,
1236 classification: C,
1237 cmd: CMD,
1238 version: u8,
1239 body: &[u8],
1240 timeout: Duration,
1241 ) -> CmdResult<CmdBody> {
1242 let mut send = self.get_classified_send(classification).await?;
1243 send.send_with_resp(cmd, version, body, timeout).await
1244 }
1245
1246 async fn send_parts_by_classified_tunnel(
1247 &self,
1248 classification: C,
1249 cmd: CMD,
1250 version: u8,
1251 body: &[&[u8]],
1252 ) -> CmdResult<()> {
1253 let mut send = self.get_classified_send(classification).await?;
1254 send.send_parts(cmd, version, body).await
1255 }
1256
1257 async fn send_parts_by_classified_tunnel_with_resp(
1258 &self,
1259 classification: C,
1260 cmd: CMD,
1261 version: u8,
1262 body: &[&[u8]],
1263 timeout: Duration,
1264 ) -> CmdResult<CmdBody> {
1265 let mut send = self.get_classified_send(classification).await?;
1266 send.send_parts_with_resp(cmd, version, body, timeout).await
1267 }
1268
1269 async fn send_cmd_by_classified_tunnel(
1270 &self,
1271 classification: C,
1272 cmd: CMD,
1273 version: u8,
1274 body: CmdBody,
1275 ) -> CmdResult<()> {
1276 let mut send = self.get_classified_send(classification).await?;
1277 send.send_cmd(cmd, version, body).await
1278 }
1279
1280 async fn send_cmd_by_classified_tunnel_with_resp(
1281 &self,
1282 classification: C,
1283 cmd: CMD,
1284 version: u8,
1285 body: CmdBody,
1286 timeout: Duration,
1287 ) -> CmdResult<CmdBody> {
1288 let mut send = self.get_classified_send(classification).await?;
1289 send.send_cmd_with_resp(cmd, version, body, timeout).await
1290 }
1291
1292 async fn find_tunnel_id_by_classified(&self, classification: C) -> CmdResult<TunnelId> {
1293 let send = self.get_classified_send(classification).await?;
1294 Ok(send.get_tunnel_id())
1295 }
1296
1297 async fn get_send_by_classified(
1298 &self,
1299 classification: C,
1300 ) -> CmdResult<ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD>> {
1301 self.get_classified_send(classification).await
1302 }
1303}