1use crate::client::{
2 CmdClient, CmdSend, RespWaiter, RespWaiterRef, SendGuard, TrackedSendGuard,
3 TunnelReserveResult, TunnelRuntimeRegistry, gen_resp_id, gen_seq,
4};
5use crate::cmd::{CmdBodyRead, CmdHandler, CmdHandlerMap, CmdHeader};
6use crate::errors::{CmdErrorCode, CmdResult, cmd_err, into_cmd_err};
7use crate::peer_id::PeerId;
8use crate::{CmdBody, CmdTunnelMeta, CmdTunnelRead, CmdTunnelWrite, TunnelId, TunnelIdGenerator};
9use async_named_locker::ObjectHolder;
10use bucky_raw_codec::{RawConvertTo, RawDecode, RawEncode, RawFixedBytes, RawFrom};
11use num::{FromPrimitive, ToPrimitive};
12use sfo_pool::{
13 ClassifiedWorker, ClassifiedWorkerFactory, ClassifiedWorkerGuard, ClassifiedWorkerPool,
14 ClassifiedWorkerPoolRef, PoolErrorCode, PoolResult, WorkerClassification, into_pool_err,
15 pool_err,
16};
17use sfo_split::{Splittable, WHalf};
18use std::fmt::Debug;
19use std::hash::Hash;
20use std::marker::PhantomData;
21use std::ops::{Deref, DerefMut};
22use std::sync::{Arc, Mutex};
23use std::time::Duration;
24use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
25use tokio::spawn;
26use tokio::task::JoinHandle;
27use tokio::task::yield_now;
28
29#[async_trait::async_trait]
30pub trait CmdTunnelFactory<M: CmdTunnelMeta, R: CmdTunnelRead<M>, W: CmdTunnelWrite<M>>:
31 Send + Sync + 'static
32{
33 async fn create_tunnel(&self) -> CmdResult<Splittable<R, W>>;
34}
35
36pub struct CommonCmdSend<M: CmdTunnelMeta, R: CmdTunnelRead<M>, W: CmdTunnelWrite<M>, LEN, CMD>
37where
38 LEN: RawEncode
39 + for<'a> RawDecode<'a>
40 + Copy
41 + Send
42 + Sync
43 + 'static
44 + FromPrimitive
45 + ToPrimitive,
46 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
47{
48 pub(crate) recv_handle: JoinHandle<CmdResult<()>>,
49 pub(crate) write: ObjectHolder<WHalf<R, W>>,
50 pub(crate) is_work: bool,
51 pub(crate) tunnel_id: TunnelId,
52 pub(crate) remote_id: PeerId,
53 pub(crate) resp_waiter: RespWaiterRef,
54 pub(crate) tunnel_meta: Option<Arc<M>>,
55 _p: std::marker::PhantomData<(LEN, CMD)>,
56}
57
58impl<M, R, W, LEN, CMD> CommonCmdSend<M, R, W, LEN, CMD>
71where
72 M: CmdTunnelMeta,
73 R: CmdTunnelRead<M>,
74 W: CmdTunnelWrite<M>,
75 LEN: RawEncode
76 + for<'a> RawDecode<'a>
77 + Copy
78 + Send
79 + Sync
80 + 'static
81 + FromPrimitive
82 + ToPrimitive,
83 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
84{
85 pub fn new(
86 tunnel_id: TunnelId,
87 recv_handle: JoinHandle<CmdResult<()>>,
88 write: ObjectHolder<WHalf<R, W>>,
89 resp_waiter: RespWaiterRef,
90 remote_id: PeerId,
91 tunnel_meta: Option<Arc<M>>,
92 ) -> Self {
93 Self {
94 recv_handle,
95 write,
96 is_work: true,
97 tunnel_id,
98 remote_id,
99 resp_waiter,
100 tunnel_meta,
101 _p: Default::default(),
102 }
103 }
104
105 pub fn get_tunnel_id(&self) -> TunnelId {
106 self.tunnel_id
107 }
108
109 pub fn set_disable(&mut self) {
110 self.is_work = false;
111 self.recv_handle.abort();
112 }
113
114 pub async fn send(&mut self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
115 log::trace!(
116 "client {:?} send cmd: {:?}, len: {} data:{}",
117 self.tunnel_id,
118 cmd,
119 body.len(),
120 hex::encode(body)
121 );
122 let header = CmdHeader::<LEN, CMD>::new(
123 version,
124 false,
125 None,
126 cmd,
127 LEN::from_u64(body.len() as u64).unwrap(),
128 );
129 let buf = header
130 .to_vec()
131 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
132 let ret = self.send_inner(buf.as_slice(), body).await;
133 if let Err(e) = ret {
134 self.set_disable();
135 return Err(e);
136 }
137 Ok(())
138 }
139
140 pub async fn send_with_resp(
141 &mut self,
142 cmd: CMD,
143 version: u8,
144 body: &[u8],
145 timeout: Duration,
146 ) -> CmdResult<CmdBody> {
147 if let Some(id) = tokio::task::try_id() {
148 if id == self.recv_handle.id() {
149 return Err(cmd_err!(
150 CmdErrorCode::Failed,
151 "can't send with resp in recv task"
152 ));
153 }
154 }
155 log::trace!(
156 "client {:?} send cmd: {:?}, len: {}, data: {}",
157 self.tunnel_id,
158 cmd,
159 body.len(),
160 hex::encode(body)
161 );
162 let seq = gen_seq();
163 let header = CmdHeader::<LEN, CMD>::new(
164 version,
165 false,
166 Some(seq),
167 cmd,
168 LEN::from_u64(body.len() as u64).unwrap(),
169 );
170 let buf = header
171 .to_vec()
172 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
173 let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
174 let waiter = self.resp_waiter.clone();
175 let resp_waiter = waiter
176 .create_timeout_result_future(resp_id, timeout)
177 .map_err(into_cmd_err!(
178 CmdErrorCode::Failed,
179 "create timeout result future error"
180 ))?;
181 let ret = self.send_inner(buf.as_slice(), body).await;
182 if let Err(e) = ret {
183 self.set_disable();
184 return Err(e);
185 }
186 let resp = resp_waiter
187 .await
188 .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
189 Ok(resp)
190 }
191
192 pub async fn send_parts(&mut self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
193 let mut len = 0;
194 for b in body.iter() {
195 len += b.len();
196 log::trace!(
197 "client {:?} send2 cmd: {:?}, data {}",
198 self.tunnel_id,
199 cmd,
200 hex::encode(b)
201 );
202 }
203 log::trace!(
204 "client {:?} send2 cmd: {:?}, len {}",
205 self.tunnel_id,
206 cmd,
207 len
208 );
209 let header = CmdHeader::<LEN, CMD>::new(
210 version,
211 false,
212 None,
213 cmd,
214 LEN::from_u64(len as u64).unwrap(),
215 );
216 let buf = header
217 .to_vec()
218 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
219 let ret = self.send_inner2(buf.as_slice(), body).await;
220 if let Err(e) = ret {
221 self.set_disable();
222 return Err(e);
223 }
224 Ok(())
225 }
226
227 pub async fn send_parts_with_resp(
228 &mut self,
229 cmd: CMD,
230 version: u8,
231 body: &[&[u8]],
232 timeout: Duration,
233 ) -> CmdResult<CmdBody> {
234 if let Some(id) = tokio::task::try_id() {
235 if id == self.recv_handle.id() {
236 return Err(cmd_err!(
237 CmdErrorCode::Failed,
238 "can't send with resp in recv task"
239 ));
240 }
241 }
242 let mut len = 0;
243 for b in body.iter() {
244 len += b.len();
245 log::trace!(
246 "client {:?} send2 cmd {:?} body: {}",
247 self.tunnel_id,
248 cmd,
249 hex::encode(b)
250 );
251 }
252 log::trace!(
253 "client {:?} send2 cmd: {:?}, len {}",
254 self.tunnel_id,
255 cmd,
256 len
257 );
258 let seq = gen_seq();
259 let header = CmdHeader::<LEN, CMD>::new(
260 version,
261 false,
262 Some(seq),
263 cmd,
264 LEN::from_u64(len as u64).unwrap(),
265 );
266 let buf = header
267 .to_vec()
268 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
269 let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
270 let waiter = self.resp_waiter.clone();
271 let resp_waiter = waiter
272 .create_timeout_result_future(resp_id, timeout)
273 .map_err(into_cmd_err!(
274 CmdErrorCode::Failed,
275 "create timeout result future error"
276 ))?;
277 let ret = self.send_inner2(buf.as_slice(), body).await;
278 if let Err(e) = ret {
279 self.set_disable();
280 return Err(e);
281 }
282 let resp = resp_waiter
283 .await
284 .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
285 Ok(resp)
286 }
287
288 #[allow(deprecated)]
289 #[deprecated(note = "use send_parts instead")]
290 pub async fn send2(&mut self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
291 self.send_parts(cmd, version, body).await
292 }
293
294 #[allow(deprecated)]
295 #[deprecated(note = "use send_parts_with_resp instead")]
296 pub async fn send2_with_resp(
297 &mut self,
298 cmd: CMD,
299 version: u8,
300 body: &[&[u8]],
301 timeout: Duration,
302 ) -> CmdResult<CmdBody> {
303 self.send_parts_with_resp(cmd, version, body, timeout).await
304 }
305
306 pub async fn send_cmd(&mut self, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
307 log::trace!(
308 "client {:?} send cmd: {:?}, len: {}",
309 self.tunnel_id,
310 cmd,
311 body.len()
312 );
313 let header = CmdHeader::<LEN, CMD>::new(
314 version,
315 false,
316 None,
317 cmd,
318 LEN::from_u64(body.len()).unwrap(),
319 );
320 let buf = header
321 .to_vec()
322 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
323 let ret = self.send_inner_cmd(buf.as_slice(), body).await;
324 if let Err(e) = ret {
325 self.set_disable();
326 return Err(e);
327 }
328 Ok(())
329 }
330
331 pub async fn send_cmd_with_resp(
332 &mut self,
333 cmd: CMD,
334 version: u8,
335 body: CmdBody,
336 timeout: Duration,
337 ) -> CmdResult<CmdBody> {
338 if let Some(id) = tokio::task::try_id() {
339 if id == self.recv_handle.id() {
340 return Err(cmd_err!(
341 CmdErrorCode::Failed,
342 "can't send with resp in recv task"
343 ));
344 }
345 }
346 log::trace!(
347 "client {:?} send cmd: {:?}, len: {}",
348 self.tunnel_id,
349 cmd,
350 body.len()
351 );
352 let seq = gen_seq();
353 let header = CmdHeader::<LEN, CMD>::new(
354 version,
355 false,
356 Some(seq),
357 cmd,
358 LEN::from_u64(body.len()).unwrap(),
359 );
360 let buf = header
361 .to_vec()
362 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
363 let resp_id = gen_resp_id(self.tunnel_id, cmd, seq);
364 let waiter = self.resp_waiter.clone();
365 let resp_waiter = waiter
366 .create_timeout_result_future(resp_id, timeout)
367 .map_err(into_cmd_err!(
368 CmdErrorCode::Failed,
369 "create timeout result future error"
370 ))?;
371 let ret = self.send_inner_cmd(buf.as_slice(), body).await;
372 if let Err(e) = ret {
373 self.set_disable();
374 return Err(e);
375 }
376 let resp = resp_waiter
377 .await
378 .map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
379 Ok(resp)
380 }
381
382 async fn send_inner(&mut self, header: &[u8], body: &[u8]) -> CmdResult<()> {
383 let mut write = self.write.get().await;
384 if header.len() > 255 {
385 return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too long"));
386 }
387 write
388 .write_u8(header.len() as u8)
389 .await
390 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
391 write
392 .write_all(header)
393 .await
394 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
395 write
396 .write_all(body)
397 .await
398 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
399 write
400 .flush()
401 .await
402 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
403 Ok(())
404 }
405
406 async fn send_inner2(&mut self, header: &[u8], body: &[&[u8]]) -> CmdResult<()> {
407 let mut write = self.write.get().await;
408 if header.len() > 255 {
409 return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too long"));
410 }
411 write
412 .write_u8(header.len() as u8)
413 .await
414 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
415 write
416 .write_all(header)
417 .await
418 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
419 for b in body.iter() {
420 write
421 .write_all(b)
422 .await
423 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
424 }
425 write
426 .flush()
427 .await
428 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
429 Ok(())
430 }
431
432 async fn send_inner_cmd(&mut self, header: &[u8], mut body: CmdBody) -> CmdResult<()> {
433 let mut write = self.write.get().await;
434 if header.len() > 255 {
435 return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
436 }
437 write
438 .write_u8(header.len() as u8)
439 .await
440 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
441 write
442 .write_all(header)
443 .await
444 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
445 tokio::io::copy(&mut body, write.deref_mut().deref_mut())
446 .await
447 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
448 write
449 .flush()
450 .await
451 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
452 Ok(())
453 }
454}
455
456impl<M, R, W, LEN, CMD> Drop for CommonCmdSend<M, R, W, LEN, CMD>
457where
458 M: CmdTunnelMeta,
459 R: CmdTunnelRead<M>,
460 W: CmdTunnelWrite<M>,
461 LEN: RawEncode
462 + for<'a> RawDecode<'a>
463 + Copy
464 + Send
465 + Sync
466 + 'static
467 + FromPrimitive
468 + ToPrimitive,
469 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
470{
471 fn drop(&mut self) {
472 self.set_disable();
473 }
474}
475
476impl<M, R, W, LEN, CMD> CmdSend<M> for CommonCmdSend<M, R, W, LEN, CMD>
477where
478 M: CmdTunnelMeta,
479 R: CmdTunnelRead<M>,
480 W: CmdTunnelWrite<M>,
481 LEN: RawEncode
482 + for<'a> RawDecode<'a>
483 + Copy
484 + Send
485 + Sync
486 + 'static
487 + FromPrimitive
488 + ToPrimitive,
489 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
490{
491 fn get_tunnel_meta(&self) -> Option<Arc<M>> {
492 self.tunnel_meta.clone()
493 }
494
495 fn get_remote_peer_id(&self) -> PeerId {
496 self.remote_id.clone()
497 }
498}
499
500impl<M, R, W, LEN, CMD> ClassifiedWorker<TunnelId> for CommonCmdSend<M, R, W, LEN, CMD>
501where
502 M: CmdTunnelMeta,
503 R: CmdTunnelRead<M>,
504 W: CmdTunnelWrite<M>,
505 LEN: RawEncode
506 + for<'a> RawDecode<'a>
507 + Copy
508 + Send
509 + Sync
510 + 'static
511 + FromPrimitive
512 + ToPrimitive,
513 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
514{
515 fn is_work(&self) -> bool {
516 self.is_work && !self.recv_handle.is_finished()
517 }
518
519 fn is_valid(&self, c: TunnelId) -> bool {
520 self.tunnel_id == c
521 }
522
523 fn classification(&self) -> TunnelId {
524 self.tunnel_id
525 }
526}
527
528pub struct ClassifiedSendGuard<
529 C: WorkerClassification,
530 M: CmdTunnelMeta,
531 CW: ClassifiedWorker<C> + CmdSend<M>,
532 F: ClassifiedWorkerFactory<C, CW>,
533> {
534 pub(crate) worker_guard: ClassifiedWorkerGuard<C, CW, F>,
535 pub(crate) _p: PhantomData<M>,
536}
537
538impl<
539 C: WorkerClassification,
540 M: CmdTunnelMeta,
541 CW: ClassifiedWorker<C> + CmdSend<M>,
542 F: ClassifiedWorkerFactory<C, CW>,
543> Deref for ClassifiedSendGuard<C, M, CW, F>
544{
545 type Target = CW;
546
547 fn deref(&self) -> &Self::Target {
548 &self.worker_guard.deref()
549 }
550}
551
552impl<
553 C: WorkerClassification,
554 M: CmdTunnelMeta,
555 CW: ClassifiedWorker<C> + CmdSend<M>,
556 F: ClassifiedWorkerFactory<C, CW>,
557> SendGuard<M, CW> for ClassifiedSendGuard<C, M, CW, F>
558{
559}
560
561pub struct CmdWriteFactory<
562 M: CmdTunnelMeta,
563 R: CmdTunnelRead<M>,
564 W: CmdTunnelWrite<M>,
565 F: CmdTunnelFactory<M, R, W>,
566 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
567 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug,
568> {
569 tunnel_factory: F,
570 cmd_handler: Arc<dyn CmdHandler<LEN, CMD>>,
571 resp_waiter: RespWaiterRef,
572 tunnel_id_generator: TunnelIdGenerator,
573 p: std::marker::PhantomData<Mutex<(R, W, M)>>,
574}
575
576impl<
577 M: CmdTunnelMeta,
578 R: CmdTunnelRead<M>,
579 W: CmdTunnelWrite<M>,
580 F: CmdTunnelFactory<M, R, W>,
581 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
582 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug,
583> CmdWriteFactory<M, R, W, F, LEN, CMD>
584{
585 pub(crate) fn new(
586 tunnel_factory: F,
587 cmd_handler: impl CmdHandler<LEN, CMD>,
588 resp_waiter: RespWaiterRef,
589 ) -> Self {
590 Self {
591 tunnel_factory,
592 cmd_handler: Arc::new(cmd_handler),
593 resp_waiter,
594 tunnel_id_generator: TunnelIdGenerator::new(),
595 p: Default::default(),
596 }
597 }
598}
599
600#[async_trait::async_trait]
601impl<
602 M: CmdTunnelMeta,
603 R: CmdTunnelRead<M>,
604 W: CmdTunnelWrite<M>,
605 F: CmdTunnelFactory<M, R, W>,
606 LEN: RawEncode
607 + for<'a> RawDecode<'a>
608 + Copy
609 + Send
610 + Sync
611 + 'static
612 + FromPrimitive
613 + ToPrimitive
614 + RawFixedBytes,
615 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Debug,
616> ClassifiedWorkerFactory<TunnelId, CommonCmdSend<M, R, W, LEN, CMD>>
617 for CmdWriteFactory<M, R, W, F, LEN, CMD>
618{
619 async fn create(&self, c: Option<TunnelId>) -> PoolResult<CommonCmdSend<M, R, W, LEN, CMD>> {
620 if c.is_some() {
621 return Err(pool_err!(
622 PoolErrorCode::Failed,
623 "tunnel {:?} not found",
624 c.unwrap()
625 ));
626 }
627 let tunnel = self
628 .tunnel_factory
629 .create_tunnel()
630 .await
631 .map_err(into_pool_err!(PoolErrorCode::Failed))?;
632 let peer_id = tunnel.get_remote_peer_id();
633 let tunnel_id = self.tunnel_id_generator.generate();
634 let (mut recv, write) = tunnel.split();
635 let local_id = recv.get_local_peer_id();
636 let remote_id = write.get_remote_peer_id();
637 let meta = write.get_tunnel_meta();
638 let write = ObjectHolder::new(write);
639 let resp_write = write.clone();
640 let cmd_handler = self.cmd_handler.clone();
641 let handle = spawn(async move {
642 let ret: CmdResult<()> = async move {
643 loop {
644 let header_len = recv
645 .read_u8()
646 .await
647 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
648 let mut header = vec![0u8; header_len as usize];
649 let n = recv
650 .read_exact(header.as_mut())
651 .await
652 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
653 if n == 0 {
654 break;
655 }
656 let header = CmdHeader::<LEN, CMD>::clone_from_slice(header.as_slice())
657 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
658 log::trace!(
659 "recv cmd {:?} from {} len {}",
660 header.cmd_code(),
661 peer_id.to_base58(),
662 header.pkg_len().to_u64().unwrap()
663 );
664 let body_len = header.pkg_len().to_u64().unwrap();
665 let cmd_read =
666 CmdBodyRead::new(recv, header.pkg_len().to_u64().unwrap() as usize);
667 let waiter = cmd_read.get_waiter();
668 let future = waiter
669 .create_result_future()
670 .map_err(into_cmd_err!(CmdErrorCode::Failed))?;
671 let version = header.version();
672 let seq = header.seq();
673 let cmd_code = header.cmd_code();
674 match cmd_handler
675 .handle(
676 local_id.clone(),
677 peer_id.clone(),
678 tunnel_id,
679 header,
680 CmdBody::from_reader(BufReader::new(cmd_read), body_len),
681 )
682 .await
683 {
684 Ok(Some(mut body)) => {
685 let mut write = resp_write.get().await;
686 let header = CmdHeader::<LEN, CMD>::new(
687 version,
688 true,
689 seq,
690 cmd_code,
691 LEN::from_u64(body.len()).unwrap(),
692 );
693 let buf = header
694 .to_vec()
695 .map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
696 if buf.len() > 255 {
697 return Err(cmd_err!(
698 CmdErrorCode::InvalidParam,
699 "header len too long"
700 ));
701 }
702 write
703 .write_u8(buf.len() as u8)
704 .await
705 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
706 write
707 .write_all(buf.as_slice())
708 .await
709 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
710 tokio::io::copy(&mut body, write.deref_mut().deref_mut())
711 .await
712 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
713 write
714 .flush()
715 .await
716 .map_err(into_cmd_err!(CmdErrorCode::IoError))?;
717 }
718 Ok(None) => {}
719 Err(e) => {
720 log::error!("handle cmd error: {:?}", e);
721 }
722 }
723 recv = future
724 .await
725 .map_err(into_cmd_err!(CmdErrorCode::Failed))??;
726 }
727 Ok(())
728 }
729 .await;
730 ret
731 });
732 Ok(CommonCmdSend::new(
733 tunnel_id,
734 handle,
735 write,
736 self.resp_waiter.clone(),
737 remote_id,
738 meta,
739 ))
740 }
741}
742
743pub struct DefaultCmdClient<
744 M: CmdTunnelMeta,
745 R: CmdTunnelRead<M>,
746 W: CmdTunnelWrite<M>,
747 F: CmdTunnelFactory<M, R, W>,
748 LEN: RawEncode
749 + for<'a> RawDecode<'a>
750 + Copy
751 + Send
752 + Sync
753 + 'static
754 + FromPrimitive
755 + ToPrimitive
756 + RawFixedBytes,
757 CMD: RawEncode
758 + for<'a> RawDecode<'a>
759 + Copy
760 + Send
761 + Sync
762 + 'static
763 + RawFixedBytes
764 + Eq
765 + Hash
766 + Debug,
767> {
768 tunnel_pool: ClassifiedWorkerPoolRef<
769 TunnelId,
770 CommonCmdSend<M, R, W, LEN, CMD>,
771 CmdWriteFactory<M, R, W, F, LEN, CMD>,
772 >,
773 runtime_tunnels: Arc<TunnelRuntimeRegistry>,
774 cmd_handler_map: Arc<CmdHandlerMap<LEN, CMD>>,
775}
776
777impl<
778 M: CmdTunnelMeta,
779 R: CmdTunnelRead<M>,
780 W: CmdTunnelWrite<M>,
781 F: CmdTunnelFactory<M, R, W>,
782 LEN: RawEncode
783 + for<'a> RawDecode<'a>
784 + Copy
785 + Send
786 + Sync
787 + 'static
788 + FromPrimitive
789 + ToPrimitive
790 + RawFixedBytes,
791 CMD: RawEncode
792 + for<'a> RawDecode<'a>
793 + Copy
794 + Send
795 + Sync
796 + 'static
797 + RawFixedBytes
798 + Eq
799 + Hash
800 + Debug,
801> DefaultCmdClient<M, R, W, F, LEN, CMD>
802{
803 fn tunnel_not_found(tunnel_id: TunnelId) -> sfo_result::Error<CmdErrorCode> {
804 cmd_err!(CmdErrorCode::Failed, "tunnel {:?} not found", tunnel_id)
805 }
806
807 pub fn new(factory: F, tunnel_count: u16) -> Arc<Self> {
808 let cmd_handler_map = Arc::new(CmdHandlerMap::new());
809 let handler_map = cmd_handler_map.clone();
810 let resp_waiter = Arc::new(RespWaiter::new());
811 let waiter = resp_waiter.clone();
812 Arc::new(Self {
813 tunnel_pool: ClassifiedWorkerPool::new(
814 tunnel_count,
815 CmdWriteFactory::<M, R, W, _, LEN, CMD>::new(
816 factory,
817 move |local_id: PeerId,
818 peer_id: PeerId,
819 tunnel_id: TunnelId,
820 header: CmdHeader<LEN, CMD>,
821 body_read: CmdBody| {
822 let handler_map = handler_map.clone();
823 let waiter = waiter.clone();
824 async move {
825 if header.is_resp() && header.seq().is_some() {
826 let resp_id = gen_resp_id(
827 tunnel_id,
828 header.cmd_code(),
829 header.seq().unwrap(),
830 );
831 let _ = waiter.set_result(resp_id, body_read);
832 Ok(None)
833 } else {
834 if let Some(handler) = handler_map.get(header.cmd_code()) {
835 handler
836 .handle(local_id, peer_id, tunnel_id, header, body_read)
837 .await
838 } else {
839 Ok(None)
840 }
841 }
842 }
843 },
844 resp_waiter.clone(),
845 ),
846 ),
847 runtime_tunnels: TunnelRuntimeRegistry::new(),
848 cmd_handler_map,
849 })
850 }
851
852 fn tracked_send_guard(
853 &self,
854 worker_guard: ClassifiedWorkerGuard<
855 TunnelId,
856 CommonCmdSend<M, R, W, LEN, CMD>,
857 CmdWriteFactory<M, R, W, F, LEN, CMD>,
858 >,
859 ) -> CmdClientSendGuard<M, R, W, F, LEN, CMD> {
860 let tunnel_id = worker_guard.get_tunnel_id();
861 TrackedSendGuard::new(worker_guard, self.runtime_tunnels.clone(), tunnel_id)
862 }
863
864 async fn reserve_tunnel(&self, tunnel_id: TunnelId) -> CmdResult<()> {
865 loop {
866 match self.runtime_tunnels.reserve_existing(tunnel_id) {
867 TunnelReserveResult::Acquired => return Ok(()),
868 TunnelReserveResult::Wait(waiter) => waiter.notified().await,
869 TunnelReserveResult::Missing => return Err(Self::tunnel_not_found(tunnel_id)),
870 }
871 }
872 }
873
874 async fn get_send(&self) -> CmdResult<CmdClientSendGuard<M, R, W, F, LEN, CMD>> {
875 loop {
876 let worker_guard = self
877 .tunnel_pool
878 .get_worker()
879 .await
880 .map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))?;
881 let tunnel_id = worker_guard.get_tunnel_id();
882 if self.runtime_tunnels.mark_borrowed(tunnel_id) {
883 return Ok(self.tracked_send_guard(worker_guard));
884 }
885 drop(worker_guard);
886 yield_now().await;
887 }
888 }
889
890 async fn get_send_of_tunnel_id(
891 &self,
892 tunnel_id: TunnelId,
893 ) -> CmdResult<CmdClientSendGuard<M, R, W, F, LEN, CMD>> {
894 self.reserve_tunnel(tunnel_id).await?;
895 match self.tunnel_pool.get_classified_worker(tunnel_id).await {
896 Ok(worker_guard) => Ok(self.tracked_send_guard(worker_guard)),
897 Err(_) => {
898 self.runtime_tunnels.remove(tunnel_id);
899 Err(Self::tunnel_not_found(tunnel_id))
900 }
901 }
902 }
903}
904
905pub type CmdClientSendGuard<M, R, W, F, LEN, CMD> = TrackedSendGuard<
906 TunnelId,
907 M,
908 CommonCmdSend<M, R, W, LEN, CMD>,
909 CmdWriteFactory<M, R, W, F, LEN, CMD>,
910>;
911#[async_trait::async_trait]
912impl<
913 M: CmdTunnelMeta,
914 R: CmdTunnelRead<M>,
915 W: CmdTunnelWrite<M>,
916 F: CmdTunnelFactory<M, R, W>,
917 LEN: RawEncode
918 + for<'a> RawDecode<'a>
919 + Copy
920 + Send
921 + Sync
922 + 'static
923 + FromPrimitive
924 + ToPrimitive
925 + RawFixedBytes,
926 CMD: RawEncode
927 + for<'a> RawDecode<'a>
928 + Copy
929 + Send
930 + Sync
931 + 'static
932 + RawFixedBytes
933 + Eq
934 + Hash
935 + Debug,
936> CmdClient<LEN, CMD, M, CommonCmdSend<M, R, W, LEN, CMD>, CmdClientSendGuard<M, R, W, F, LEN, CMD>>
937 for DefaultCmdClient<M, R, W, F, LEN, CMD>
938{
939 fn register_cmd_handler(&self, cmd: CMD, handler: impl CmdHandler<LEN, CMD>) {
940 self.cmd_handler_map.insert(cmd, handler);
941 }
942
943 async fn send(&self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
944 let mut send = self.get_send().await?;
945 send.send(cmd, version, body).await
946 }
947
948 async fn send_with_resp(
949 &self,
950 cmd: CMD,
951 version: u8,
952 body: &[u8],
953 timeout: Duration,
954 ) -> CmdResult<CmdBody> {
955 let mut send = self.get_send().await?;
956 send.send_with_resp(cmd, version, body, timeout).await
957 }
958
959 async fn send_parts(&self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
960 let mut send = self.get_send().await?;
961 send.send_parts(cmd, version, body).await
962 }
963
964 async fn send_parts_with_resp(
965 &self,
966 cmd: CMD,
967 version: u8,
968 body: &[&[u8]],
969 timeout: Duration,
970 ) -> CmdResult<CmdBody> {
971 let mut send = self.get_send().await?;
972 send.send_parts_with_resp(cmd, version, body, timeout).await
973 }
974
975 async fn send_cmd(&self, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
976 let mut send = self.get_send().await?;
977 send.send_cmd(cmd, version, body).await
978 }
979
980 async fn send_cmd_with_resp(
981 &self,
982 cmd: CMD,
983 version: u8,
984 body: CmdBody,
985 timeout: Duration,
986 ) -> CmdResult<CmdBody> {
987 let mut send = self.get_send().await?;
988 send.send_cmd_with_resp(cmd, version, body, timeout).await
989 }
990
991 async fn send_by_specify_tunnel(
992 &self,
993 tunnel_id: TunnelId,
994 cmd: CMD,
995 version: u8,
996 body: &[u8],
997 ) -> CmdResult<()> {
998 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
999 send.send(cmd, version, body).await
1000 }
1001
1002 async fn send_by_specify_tunnel_with_resp(
1003 &self,
1004 tunnel_id: TunnelId,
1005 cmd: CMD,
1006 version: u8,
1007 body: &[u8],
1008 timeout: Duration,
1009 ) -> CmdResult<CmdBody> {
1010 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1011 send.send_with_resp(cmd, version, body, timeout).await
1012 }
1013
1014 async fn send_parts_by_specify_tunnel(
1015 &self,
1016 tunnel_id: TunnelId,
1017 cmd: CMD,
1018 version: u8,
1019 body: &[&[u8]],
1020 ) -> CmdResult<()> {
1021 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1022 send.send_parts(cmd, version, body).await
1023 }
1024
1025 async fn send_parts_by_specify_tunnel_with_resp(
1026 &self,
1027 tunnel_id: TunnelId,
1028 cmd: CMD,
1029 version: u8,
1030 body: &[&[u8]],
1031 timeout: Duration,
1032 ) -> CmdResult<CmdBody> {
1033 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1034 send.send_parts_with_resp(cmd, version, body, timeout).await
1035 }
1036
1037 async fn send_cmd_by_specify_tunnel(
1038 &self,
1039 tunnel_id: TunnelId,
1040 cmd: CMD,
1041 version: u8,
1042 body: CmdBody,
1043 ) -> CmdResult<()> {
1044 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1045 send.send_cmd(cmd, version, body).await
1046 }
1047
1048 async fn send_cmd_by_specify_tunnel_with_resp(
1049 &self,
1050 tunnel_id: TunnelId,
1051 cmd: CMD,
1052 version: u8,
1053 body: CmdBody,
1054 timeout: Duration,
1055 ) -> CmdResult<CmdBody> {
1056 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
1057 send.send_cmd_with_resp(cmd, version, body, timeout).await
1058 }
1059
1060 async fn clear_all_tunnel(&self) {
1061 self.tunnel_pool.clear_all_worker().await;
1062 self.runtime_tunnels.clear();
1063 }
1064
1065 async fn get_send(
1066 &self,
1067 tunnel_id: TunnelId,
1068 ) -> CmdResult<CmdClientSendGuard<M, R, W, F, LEN, CMD>> {
1069 self.get_send_of_tunnel_id(tunnel_id).await
1070 }
1071}