Skip to main content

sfo_cmd_server/node/
classified_node.rs

1use crate::client::{
2    ClassifiedCmdSend, ClassifiedCmdTunnelRead, ClassifiedCmdTunnelWrite, RespWaiter,
3    RespWaiterRef, TrackedSendGuard, TunnelReserveResult, TunnelRuntimeRegistry, gen_resp_id,
4};
5use crate::cmd::CmdHandlerMap;
6use crate::errors::{CmdErrorCode, CmdResult, into_cmd_err};
7use crate::node::create_recv_handle;
8use crate::server::CmdTunnelListener;
9use crate::{
10    ClassifiedCmdNode, CmdBody, CmdHandler, CmdHeader, CmdNode, CmdTunnelMeta, PeerId, TunnelId,
11    TunnelIdGenerator, into_pool_err, pool_err,
12};
13use async_named_locker::ObjectHolder;
14use bucky_raw_codec::{RawDecode, RawEncode, RawFixedBytes};
15use num::{FromPrimitive, ToPrimitive};
16use sfo_pool::{
17    ClassifiedWorker, ClassifiedWorkerFactory, ClassifiedWorkerGuard, ClassifiedWorkerPool,
18    ClassifiedWorkerPoolRef, PoolErrorCode, PoolResult, WorkerClassification,
19};
20use sfo_split::Splittable;
21use std::collections::HashMap;
22use std::fmt::Debug;
23use std::hash::Hash;
24use std::sync::{Arc, Mutex};
25use std::time::Duration;
26use tokio::task::yield_now;
27
28#[derive(Debug, Clone, Eq, Hash)]
29pub struct CmdNodeTunnelClassification<C: WorkerClassification> {
30    pub peer_id: Option<PeerId>,
31    pub tunnel_id: Option<TunnelId>,
32    pub classification: Option<C>,
33}
34
35impl<C: WorkerClassification> PartialEq for CmdNodeTunnelClassification<C> {
36    fn eq(&self, other: &Self) -> bool {
37        self.peer_id == other.peer_id
38            && self.tunnel_id == other.tunnel_id
39            && self.classification == other.classification
40    }
41}
42
43impl<C, M, R, W, LEN, CMD> ClassifiedWorker<CmdNodeTunnelClassification<C>>
44    for ClassifiedCmdSend<C, M, R, W, LEN, CMD>
45where
46    C: WorkerClassification,
47    M: CmdTunnelMeta,
48    R: ClassifiedCmdTunnelRead<C, M>,
49    W: ClassifiedCmdTunnelWrite<C, M>,
50    LEN: RawEncode
51        + for<'a> RawDecode<'a>
52        + Copy
53        + Send
54        + Sync
55        + 'static
56        + FromPrimitive
57        + ToPrimitive,
58    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
59{
60    fn is_work(&self) -> bool {
61        self.is_work && !self.recv_handle.is_finished()
62    }
63
64    fn is_valid(&self, c: CmdNodeTunnelClassification<C>) -> bool {
65        if c.peer_id.is_some() && c.peer_id.as_ref().unwrap() != &self.remote_id {
66            return false;
67        }
68
69        if c.tunnel_id.is_some() {
70            self.tunnel_id == c.tunnel_id.unwrap()
71        } else {
72            if c.classification.is_some() {
73                self.classification == c.classification.unwrap()
74            } else {
75                true
76            }
77        }
78    }
79
80    fn classification(&self) -> CmdNodeTunnelClassification<C> {
81        CmdNodeTunnelClassification {
82            peer_id: self.pool_peer_id.clone(),
83            tunnel_id: self.pool_tunnel_id,
84            classification: self.pool_classification.clone(),
85        }
86    }
87}
88
89#[async_trait::async_trait]
90pub trait ClassifiedCmdNodeTunnelFactory<
91    C: WorkerClassification,
92    M: CmdTunnelMeta,
93    R: ClassifiedCmdTunnelRead<C, M>,
94    W: ClassifiedCmdTunnelWrite<C, M>,
95>: Send + Sync + 'static
96{
97    async fn create_tunnel(
98        &self,
99        classification: Option<CmdNodeTunnelClassification<C>>,
100    ) -> CmdResult<Splittable<R, W>>;
101}
102
103struct CmdWriteFactoryImpl<
104    C: WorkerClassification,
105    M: CmdTunnelMeta,
106    R: ClassifiedCmdTunnelRead<C, M>,
107    W: ClassifiedCmdTunnelWrite<C, M>,
108    F: ClassifiedCmdNodeTunnelFactory<C, M, R, W>,
109    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
110    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
111    LISTENER: CmdTunnelListener<M, R, W>,
112> {
113    tunnel_listener: LISTENER,
114    tunnel_factory: F,
115    cmd_handler: Arc<dyn CmdHandler<LEN, CMD>>,
116    tunnel_id_generator: TunnelIdGenerator,
117    resp_waiter: RespWaiterRef,
118    send_cache: Arc<Mutex<HashMap<PeerId, Vec<ClassifiedCmdSend<C, M, R, W, LEN, CMD>>>>>,
119}
120
121impl<
122    C: WorkerClassification,
123    M: CmdTunnelMeta,
124    R: ClassifiedCmdTunnelRead<C, M>,
125    W: ClassifiedCmdTunnelWrite<C, M>,
126    F: ClassifiedCmdNodeTunnelFactory<C, M, R, W>,
127    LEN: RawEncode
128        + for<'a> RawDecode<'a>
129        + Copy
130        + Send
131        + Sync
132        + 'static
133        + FromPrimitive
134        + ToPrimitive
135        + RawFixedBytes,
136    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
137    LISTENER: CmdTunnelListener<M, R, W>,
138> CmdWriteFactoryImpl<C, M, R, W, F, LEN, CMD, LISTENER>
139{
140    pub fn new(
141        tunnel_factory: F,
142        tunnel_listener: LISTENER,
143        cmd_handler: impl CmdHandler<LEN, CMD>,
144        resp_waiter: RespWaiterRef,
145    ) -> Self {
146        Self {
147            tunnel_listener,
148            tunnel_factory,
149            cmd_handler: Arc::new(cmd_handler),
150            tunnel_id_generator: TunnelIdGenerator::new(),
151            resp_waiter,
152            send_cache: Arc::new(Mutex::new(Default::default())),
153        }
154    }
155
156    pub fn start(self: &Arc<Self>) {
157        let this = self.clone();
158        tokio::spawn(async move {
159            if let Err(e) = this.run().await {
160                log::error!("cmd server error: {:?}", e);
161            }
162        });
163    }
164
165    async fn run(self: &Arc<Self>) -> CmdResult<()> {
166        loop {
167            let tunnel = self.tunnel_listener.accept().await?;
168            let peer_id = tunnel.get_remote_peer_id();
169            let classification = tunnel.get_classification();
170            let tunnel_id = self.tunnel_id_generator.generate();
171            let this = self.clone();
172            let resp_waiter = self.resp_waiter.clone();
173            tokio::spawn(async move {
174                let ret: CmdResult<()> = async move {
175                    let this = this.clone();
176                    let cmd_handler = this.cmd_handler.clone();
177                    let (reader, writer) = tunnel.split();
178                    let tunnel_meta = reader.get_tunnel_meta();
179                    let remote_id = writer.get_remote_peer_id();
180                    let writer = ObjectHolder::new(writer);
181                    let recv_handle = create_recv_handle::<M, R, W, LEN, CMD>(
182                        reader,
183                        writer.clone(),
184                        tunnel_id,
185                        cmd_handler,
186                    );
187                    {
188                        let mut send_cache = this.send_cache.lock().unwrap();
189                        let send_list = send_cache.entry(peer_id).or_insert(Vec::new());
190                        send_list.push(ClassifiedCmdSend::new(
191                            tunnel_id,
192                            classification.clone(),
193                            Some(remote_id.clone()),
194                            Some(tunnel_id),
195                            Some(classification),
196                            recv_handle,
197                            writer,
198                            resp_waiter,
199                            remote_id,
200                            tunnel_meta,
201                        ));
202                    }
203                    Ok(())
204                }
205                .await;
206                if let Err(e) = ret {
207                    log::error!("peer connection error: {:?}", e);
208                }
209            });
210        }
211    }
212}
213
214#[async_trait::async_trait]
215impl<
216    C: WorkerClassification,
217    M: CmdTunnelMeta,
218    R: ClassifiedCmdTunnelRead<C, M>,
219    W: ClassifiedCmdTunnelWrite<C, M>,
220    F: ClassifiedCmdNodeTunnelFactory<C, M, R, W>,
221    LEN: RawEncode
222        + for<'a> RawDecode<'a>
223        + Copy
224        + Send
225        + Sync
226        + 'static
227        + FromPrimitive
228        + ToPrimitive
229        + RawFixedBytes,
230    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Debug,
231    LISTENER: CmdTunnelListener<M, R, W>,
232> ClassifiedWorkerFactory<CmdNodeTunnelClassification<C>, ClassifiedCmdSend<C, M, R, W, LEN, CMD>>
233    for CmdWriteFactoryImpl<C, M, R, W, F, LEN, CMD, LISTENER>
234{
235    async fn create(
236        &self,
237        c: Option<CmdNodeTunnelClassification<C>>,
238    ) -> PoolResult<ClassifiedCmdSend<C, M, R, W, LEN, CMD>> {
239        if c.is_some() {
240            let classification = c.unwrap();
241            if classification.peer_id.is_some() {
242                let peer_id = classification.peer_id.clone().unwrap();
243                let tunnel_id = classification.tunnel_id;
244                if tunnel_id.is_some() {
245                    let mut send_cache = self.send_cache.lock().unwrap();
246                    if let Some(send_list) = send_cache.get_mut(&peer_id) {
247                        let mut send_index = None;
248                        for (index, send) in send_list.iter().enumerate() {
249                            if send.get_tunnel_id() == tunnel_id.unwrap() {
250                                send_index = Some(index);
251                                break;
252                            }
253                        }
254                        if let Some(send_index) = send_index {
255                            let mut send = send_list.remove(send_index);
256                            send.set_pool_key(
257                                Some(peer_id.clone()),
258                                tunnel_id,
259                                classification.classification.clone(),
260                            );
261                            Ok(send)
262                        } else {
263                            Err(pool_err!(
264                                PoolErrorCode::Failed,
265                                "tunnel {:?} not found",
266                                tunnel_id.unwrap()
267                            ))
268                        }
269                    } else {
270                        Err(pool_err!(
271                            PoolErrorCode::Failed,
272                            "tunnel {:?} not found",
273                            tunnel_id.unwrap()
274                        ))
275                    }
276                } else {
277                    {
278                        let mut send_cache = self.send_cache.lock().unwrap();
279                        if let Some(send_list) = send_cache.get_mut(&peer_id) {
280                            if !send_list.is_empty() {
281                                let mut send = send_list.pop().unwrap();
282                                send.set_pool_key(
283                                    Some(peer_id.clone()),
284                                    tunnel_id,
285                                    classification.classification.clone(),
286                                );
287                                if send_list.is_empty() {
288                                    send_cache.remove(&peer_id);
289                                }
290                                return Ok(send);
291                            }
292                        }
293                    }
294                    let tunnel = self
295                        .tunnel_factory
296                        .create_tunnel(Some(classification.clone()))
297                        .await
298                        .map_err(into_pool_err!(PoolErrorCode::Failed))?;
299                    let actual_classification = tunnel.get_classification();
300                    let pool_peer_id = classification.peer_id.clone();
301                    let pool_tunnel_id = classification.tunnel_id;
302                    let pool_classification = classification.classification.clone();
303                    let tunnel_id = self.tunnel_id_generator.generate();
304                    let (recv, write) = tunnel.split();
305                    let remote_id = write.get_remote_peer_id();
306                    let tunnel_meta = recv.get_tunnel_meta();
307                    let write = ObjectHolder::new(write);
308                    let cmd_handler = self.cmd_handler.clone();
309                    let handle = create_recv_handle::<M, R, W, LEN, CMD>(
310                        recv,
311                        write.clone(),
312                        tunnel_id,
313                        cmd_handler,
314                    );
315                    Ok(ClassifiedCmdSend::new(
316                        tunnel_id,
317                        actual_classification,
318                        pool_peer_id,
319                        pool_tunnel_id,
320                        pool_classification,
321                        handle,
322                        write,
323                        self.resp_waiter.clone(),
324                        remote_id,
325                        tunnel_meta,
326                    ))
327                }
328            } else {
329                if classification.tunnel_id.is_some() {
330                    Err(pool_err!(
331                        PoolErrorCode::Failed,
332                        "must set peer id when set tunnel id"
333                    ))
334                } else {
335                    let tunnel = self
336                        .tunnel_factory
337                        .create_tunnel(Some(classification.clone()))
338                        .await
339                        .map_err(into_pool_err!(PoolErrorCode::Failed))?;
340                    let actual_classification = tunnel.get_classification();
341                    let pool_peer_id = classification.peer_id.clone();
342                    let pool_tunnel_id = classification.tunnel_id;
343                    let pool_classification = classification.classification.clone();
344                    let tunnel_id = self.tunnel_id_generator.generate();
345                    let (recv, write) = tunnel.split();
346                    let remote_id = write.get_remote_peer_id();
347                    let tunnel_meta = write.get_tunnel_meta();
348                    let write = ObjectHolder::new(write);
349                    let cmd_handler = self.cmd_handler.clone();
350                    let handle = create_recv_handle::<M, R, W, LEN, CMD>(
351                        recv,
352                        write.clone(),
353                        tunnel_id,
354                        cmd_handler,
355                    );
356                    Ok(ClassifiedCmdSend::new(
357                        tunnel_id,
358                        actual_classification,
359                        pool_peer_id,
360                        pool_tunnel_id,
361                        pool_classification,
362                        handle,
363                        write,
364                        self.resp_waiter.clone(),
365                        remote_id,
366                        tunnel_meta,
367                    ))
368                }
369            }
370        } else {
371            Err(pool_err!(PoolErrorCode::Failed, "peer id is none"))
372        }
373    }
374}
375
376pub struct ClassifiedCmdNodeWriteFactory<
377    C: WorkerClassification,
378    M: CmdTunnelMeta,
379    R: ClassifiedCmdTunnelRead<C, M>,
380    W: ClassifiedCmdTunnelWrite<C, M>,
381    F: ClassifiedCmdNodeTunnelFactory<C, M, R, W>,
382    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
383    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
384    LISTENER: CmdTunnelListener<M, R, W>,
385> {
386    inner: Arc<CmdWriteFactoryImpl<C, M, R, W, F, LEN, CMD, LISTENER>>,
387}
388
389impl<
390    C: WorkerClassification,
391    M: CmdTunnelMeta,
392    R: ClassifiedCmdTunnelRead<C, M>,
393    W: ClassifiedCmdTunnelWrite<C, M>,
394    F: ClassifiedCmdNodeTunnelFactory<C, M, R, W>,
395    LEN: RawEncode
396        + for<'a> RawDecode<'a>
397        + Copy
398        + Send
399        + Sync
400        + 'static
401        + FromPrimitive
402        + ToPrimitive
403        + RawFixedBytes
404        + RawFixedBytes,
405    CMD: RawEncode
406        + for<'a> RawDecode<'a>
407        + Copy
408        + Send
409        + Sync
410        + 'static
411        + Debug
412        + RawFixedBytes
413        + RawFixedBytes,
414    LISTENER: CmdTunnelListener<M, R, W>,
415> ClassifiedCmdNodeWriteFactory<C, M, R, W, F, LEN, CMD, LISTENER>
416{
417    pub(crate) fn new(
418        tunnel_factory: F,
419        tunnel_listener: LISTENER,
420        cmd_handler: impl CmdHandler<LEN, CMD>,
421        resp_waiter: RespWaiterRef,
422    ) -> Self {
423        Self {
424            inner: Arc::new(CmdWriteFactoryImpl::new(
425                tunnel_factory,
426                tunnel_listener,
427                cmd_handler,
428                resp_waiter,
429            )),
430        }
431    }
432
433    pub fn start(&self) {
434        self.inner.start();
435    }
436}
437
438#[async_trait::async_trait]
439impl<
440    C: WorkerClassification,
441    M: CmdTunnelMeta,
442    R: ClassifiedCmdTunnelRead<C, M>,
443    W: ClassifiedCmdTunnelWrite<C, M>,
444    F: ClassifiedCmdNodeTunnelFactory<C, M, R, W>,
445    LEN: RawEncode
446        + for<'a> RawDecode<'a>
447        + Copy
448        + Send
449        + Sync
450        + 'static
451        + FromPrimitive
452        + ToPrimitive
453        + RawFixedBytes,
454    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Debug,
455    LISTENER: CmdTunnelListener<M, R, W>,
456> ClassifiedWorkerFactory<CmdNodeTunnelClassification<C>, ClassifiedCmdSend<C, M, R, W, LEN, CMD>>
457    for ClassifiedCmdNodeWriteFactory<C, M, R, W, F, LEN, CMD, LISTENER>
458{
459    async fn create(
460        &self,
461        c: Option<CmdNodeTunnelClassification<C>>,
462    ) -> PoolResult<ClassifiedCmdSend<C, M, R, W, LEN, CMD>> {
463        self.inner.create(c).await
464    }
465}
466
467pub struct DefaultClassifiedCmdNode<
468    C: WorkerClassification,
469    M: CmdTunnelMeta,
470    R: ClassifiedCmdTunnelRead<C, M>,
471    W: ClassifiedCmdTunnelWrite<C, M>,
472    F: ClassifiedCmdNodeTunnelFactory<C, M, R, W>,
473    LEN: RawEncode
474        + for<'a> RawDecode<'a>
475        + Copy
476        + Send
477        + Sync
478        + 'static
479        + FromPrimitive
480        + ToPrimitive
481        + RawFixedBytes,
482    CMD: RawEncode
483        + for<'a> RawDecode<'a>
484        + Copy
485        + Send
486        + Sync
487        + 'static
488        + RawFixedBytes
489        + Eq
490        + Hash
491        + Debug,
492    LISTENER: CmdTunnelListener<M, R, W>,
493> {
494    tunnel_pool: ClassifiedWorkerPoolRef<
495        CmdNodeTunnelClassification<C>,
496        ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
497        ClassifiedCmdNodeWriteFactory<C, M, R, W, F, LEN, CMD, LISTENER>,
498    >,
499    runtime_tunnels: Arc<TunnelRuntimeRegistry>,
500    cmd_handler_map: Arc<CmdHandlerMap<LEN, CMD>>,
501}
502
503impl<
504    C: WorkerClassification,
505    M: CmdTunnelMeta,
506    R: ClassifiedCmdTunnelRead<C, M>,
507    W: ClassifiedCmdTunnelWrite<C, M>,
508    F: ClassifiedCmdNodeTunnelFactory<C, M, R, W>,
509    LEN: RawEncode
510        + for<'a> RawDecode<'a>
511        + Copy
512        + Send
513        + Sync
514        + 'static
515        + FromPrimitive
516        + ToPrimitive
517        + RawFixedBytes,
518    CMD: RawEncode
519        + for<'a> RawDecode<'a>
520        + Copy
521        + Send
522        + Sync
523        + 'static
524        + RawFixedBytes
525        + Eq
526        + Hash
527        + Debug,
528    LISTENER: CmdTunnelListener<M, R, W>,
529> DefaultClassifiedCmdNode<C, M, R, W, F, LEN, CMD, LISTENER>
530{
531    fn tunnel_not_found(tunnel_id: TunnelId) -> sfo_result::Error<CmdErrorCode> {
532        crate::errors::cmd_err!(CmdErrorCode::Failed, "tunnel {:?} not found", tunnel_id)
533    }
534
535    pub fn new(listener: LISTENER, factory: F, tunnel_count: u16) -> Arc<Self> {
536        let cmd_handler_map = Arc::new(CmdHandlerMap::new());
537        let handler_map = cmd_handler_map.clone();
538        let resp_waiter = Arc::new(RespWaiter::new());
539        let waiter = resp_waiter.clone();
540        let write_factory = ClassifiedCmdNodeWriteFactory::<C, M, R, W, _, LEN, CMD, LISTENER>::new(
541            factory,
542            listener,
543            move |local_id: PeerId,
544                  peer_id: PeerId,
545                  tunnel_id: TunnelId,
546                  header: CmdHeader<LEN, CMD>,
547                  body_read: CmdBody| {
548                let handler_map = handler_map.clone();
549                let waiter = waiter.clone();
550                async move {
551                    if header.is_resp() && header.seq().is_some() {
552                        let resp_id =
553                            gen_resp_id(tunnel_id, header.cmd_code(), header.seq().unwrap());
554                        let _ = waiter.set_result(resp_id, body_read);
555                        Ok(None)
556                    } else {
557                        if let Some(handler) = handler_map.get(header.cmd_code()) {
558                            handler
559                                .handle(local_id, peer_id, tunnel_id, header, body_read)
560                                .await
561                        } else {
562                            Ok(None)
563                        }
564                    }
565                }
566            },
567            resp_waiter.clone(),
568        );
569        write_factory.start();
570        Arc::new(Self {
571            tunnel_pool: ClassifiedWorkerPool::new(tunnel_count, write_factory),
572            runtime_tunnels: TunnelRuntimeRegistry::new(),
573            cmd_handler_map,
574        })
575    }
576
577    fn tracked_send_guard(
578        &self,
579        worker_guard: ClassifiedWorkerGuard<
580            CmdNodeTunnelClassification<C>,
581            ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
582            ClassifiedCmdNodeWriteFactory<C, M, R, W, F, LEN, CMD, LISTENER>,
583        >,
584    ) -> ClassifiedCmdNodeSendGuard<C, M, R, W, F, LEN, CMD, LISTENER> {
585        let tunnel_id = worker_guard.get_tunnel_id();
586        TrackedSendGuard::new(worker_guard, self.runtime_tunnels.clone(), tunnel_id)
587    }
588
589    async fn reserve_tunnel(&self, tunnel_id: TunnelId) -> CmdResult<()> {
590        loop {
591            match self.runtime_tunnels.reserve_existing(tunnel_id) {
592                TunnelReserveResult::Acquired => return Ok(()),
593                TunnelReserveResult::Wait(waiter) => waiter.notified().await,
594                TunnelReserveResult::Missing => return Err(Self::tunnel_not_found(tunnel_id)),
595            }
596        }
597    }
598
599    async fn get_send(
600        &self,
601        peer_id: PeerId,
602    ) -> CmdResult<ClassifiedCmdNodeSendGuard<C, M, R, W, F, LEN, CMD, LISTENER>> {
603        loop {
604            let worker_guard = self
605                .tunnel_pool
606                .get_classified_worker(CmdNodeTunnelClassification {
607                    peer_id: Some(peer_id.clone()),
608                    tunnel_id: None,
609                    classification: None,
610                })
611                .await
612                .map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))?;
613            let tunnel_id = worker_guard.get_tunnel_id();
614            if self.runtime_tunnels.mark_borrowed(tunnel_id) {
615                return Ok(self.tracked_send_guard(worker_guard));
616            }
617            drop(worker_guard);
618            yield_now().await;
619        }
620    }
621
622    async fn get_send_of_tunnel_id(
623        &self,
624        peer_id: PeerId,
625        tunnel_id: TunnelId,
626    ) -> CmdResult<ClassifiedCmdNodeSendGuard<C, M, R, W, F, LEN, CMD, LISTENER>> {
627        self.reserve_tunnel(tunnel_id).await?;
628        match self
629            .tunnel_pool
630            .get_classified_worker(CmdNodeTunnelClassification {
631                peer_id: Some(peer_id),
632                tunnel_id: Some(tunnel_id),
633                classification: None,
634            })
635            .await
636        {
637            Ok(worker_guard) => Ok(self.tracked_send_guard(worker_guard)),
638            Err(_) => {
639                self.runtime_tunnels.remove(tunnel_id);
640                Err(Self::tunnel_not_found(tunnel_id))
641            }
642        }
643    }
644
645    async fn get_classified_send(
646        &self,
647        classification: C,
648    ) -> CmdResult<ClassifiedCmdNodeSendGuard<C, M, R, W, F, LEN, CMD, LISTENER>> {
649        loop {
650            let worker_guard = self
651                .tunnel_pool
652                .get_classified_worker(CmdNodeTunnelClassification {
653                    peer_id: None,
654                    tunnel_id: None,
655                    classification: Some(classification.clone()),
656                })
657                .await
658                .map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))?;
659            let tunnel_id = worker_guard.get_tunnel_id();
660            if self.runtime_tunnels.mark_borrowed(tunnel_id) {
661                return Ok(self.tracked_send_guard(worker_guard));
662            }
663            drop(worker_guard);
664            yield_now().await;
665        }
666    }
667
668    async fn get_peer_classified_send(
669        &self,
670        peer_id: PeerId,
671        classification: C,
672    ) -> CmdResult<ClassifiedCmdNodeSendGuard<C, M, R, W, F, LEN, CMD, LISTENER>> {
673        loop {
674            let worker_guard = self
675                .tunnel_pool
676                .get_classified_worker(CmdNodeTunnelClassification {
677                    peer_id: Some(peer_id.clone()),
678                    tunnel_id: None,
679                    classification: Some(classification.clone()),
680                })
681                .await
682                .map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))?;
683            let tunnel_id = worker_guard.get_tunnel_id();
684            if self.runtime_tunnels.mark_borrowed(tunnel_id) {
685                return Ok(self.tracked_send_guard(worker_guard));
686            }
687            drop(worker_guard);
688            yield_now().await;
689        }
690    }
691}
692
693pub type ClassifiedCmdNodeSendGuard<C, M, R, W, F, LEN, CMD, LISTENER> = TrackedSendGuard<
694    CmdNodeTunnelClassification<C>,
695    M,
696    ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
697    ClassifiedCmdNodeWriteFactory<C, M, R, W, F, LEN, CMD, LISTENER>,
698>;
699#[async_trait::async_trait]
700impl<
701    C: WorkerClassification,
702    M: CmdTunnelMeta,
703    R: ClassifiedCmdTunnelRead<C, M>,
704    W: ClassifiedCmdTunnelWrite<C, M>,
705    F: ClassifiedCmdNodeTunnelFactory<C, M, R, W>,
706    LEN: RawEncode
707        + for<'a> RawDecode<'a>
708        + Copy
709        + Send
710        + Sync
711        + 'static
712        + FromPrimitive
713        + ToPrimitive
714        + RawFixedBytes,
715    CMD: RawEncode
716        + for<'a> RawDecode<'a>
717        + Copy
718        + Send
719        + Sync
720        + 'static
721        + RawFixedBytes
722        + Eq
723        + Hash
724        + Debug,
725    LISTENER: CmdTunnelListener<M, R, W>,
726>
727    CmdNode<
728        LEN,
729        CMD,
730        M,
731        ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
732        ClassifiedCmdNodeSendGuard<C, M, R, W, F, LEN, CMD, LISTENER>,
733    > for DefaultClassifiedCmdNode<C, M, R, W, F, LEN, CMD, LISTENER>
734{
735    fn register_cmd_handler(&self, cmd: CMD, handler: impl CmdHandler<LEN, CMD>) {
736        self.cmd_handler_map.insert(cmd, handler)
737    }
738
739    async fn send(&self, peer_id: &PeerId, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
740        let mut send = self.get_send(peer_id.clone()).await?;
741        send.send(cmd, version, body).await
742    }
743
744    async fn send_with_resp(
745        &self,
746        peer_id: &PeerId,
747        cmd: CMD,
748        version: u8,
749        body: &[u8],
750        timeout: Duration,
751    ) -> CmdResult<CmdBody> {
752        let mut send = self.get_send(peer_id.clone()).await?;
753        send.send_with_resp(cmd, version, body, timeout).await
754    }
755
756    async fn send_parts(
757        &self,
758        peer_id: &PeerId,
759        cmd: CMD,
760        version: u8,
761        body: &[&[u8]],
762    ) -> CmdResult<()> {
763        let mut send = self.get_send(peer_id.clone()).await?;
764        send.send_parts(cmd, version, body).await
765    }
766
767    async fn send_parts_with_resp(
768        &self,
769        peer_id: &PeerId,
770        cmd: CMD,
771        version: u8,
772        body: &[&[u8]],
773        timeout: Duration,
774    ) -> CmdResult<CmdBody> {
775        let mut send = self.get_send(peer_id.clone()).await?;
776        send.send_parts_with_resp(cmd, version, body, timeout).await
777    }
778
779    async fn send_cmd(
780        &self,
781        peer_id: &PeerId,
782        cmd: CMD,
783        version: u8,
784        body: CmdBody,
785    ) -> CmdResult<()> {
786        let mut send = self.get_send(peer_id.clone()).await?;
787        send.send_cmd(cmd, version, body).await
788    }
789
790    async fn send_cmd_with_resp(
791        &self,
792        peer_id: &PeerId,
793        cmd: CMD,
794        version: u8,
795        body: CmdBody,
796        timeout: Duration,
797    ) -> CmdResult<CmdBody> {
798        let mut send = self.get_send(peer_id.clone()).await?;
799        send.send_cmd_with_resp(cmd, version, body, timeout).await
800    }
801
802    async fn send_by_specify_tunnel(
803        &self,
804        peer_id: &PeerId,
805        tunnel_id: TunnelId,
806        cmd: CMD,
807        version: u8,
808        body: &[u8],
809    ) -> CmdResult<()> {
810        let mut send = self
811            .get_send_of_tunnel_id(peer_id.clone(), tunnel_id)
812            .await?;
813        send.send(cmd, version, body).await
814    }
815
816    async fn send_by_specify_tunnel_with_resp(
817        &self,
818        peer_id: &PeerId,
819        tunnel_id: TunnelId,
820        cmd: CMD,
821        version: u8,
822        body: &[u8],
823        timeout: Duration,
824    ) -> CmdResult<CmdBody> {
825        let mut send = self
826            .get_send_of_tunnel_id(peer_id.clone(), tunnel_id)
827            .await?;
828        send.send_with_resp(cmd, version, body, timeout).await
829    }
830
831    async fn send_parts_by_specify_tunnel(
832        &self,
833        peer_id: &PeerId,
834        tunnel_id: TunnelId,
835        cmd: CMD,
836        version: u8,
837        body: &[&[u8]],
838    ) -> CmdResult<()> {
839        let mut send = self
840            .get_send_of_tunnel_id(peer_id.clone(), tunnel_id)
841            .await?;
842        send.send_parts(cmd, version, body).await
843    }
844
845    async fn send_parts_by_specify_tunnel_with_resp(
846        &self,
847        peer_id: &PeerId,
848        tunnel_id: TunnelId,
849        cmd: CMD,
850        version: u8,
851        body: &[&[u8]],
852        timeout: Duration,
853    ) -> CmdResult<CmdBody> {
854        let mut send = self
855            .get_send_of_tunnel_id(peer_id.clone(), tunnel_id)
856            .await?;
857        send.send_parts_with_resp(cmd, version, body, timeout).await
858    }
859
860    async fn send_cmd_by_specify_tunnel(
861        &self,
862        peer_id: &PeerId,
863        tunnel_id: TunnelId,
864        cmd: CMD,
865        version: u8,
866        body: CmdBody,
867    ) -> CmdResult<()> {
868        let mut send = self
869            .get_send_of_tunnel_id(peer_id.clone(), tunnel_id)
870            .await?;
871        send.send_cmd(cmd, version, body).await
872    }
873
874    async fn send_cmd_by_specify_tunnel_with_resp(
875        &self,
876        peer_id: &PeerId,
877        tunnel_id: TunnelId,
878        cmd: CMD,
879        version: u8,
880        body: CmdBody,
881        timeout: Duration,
882    ) -> CmdResult<CmdBody> {
883        let mut send = self
884            .get_send_of_tunnel_id(peer_id.clone(), tunnel_id)
885            .await?;
886        send.send_cmd_with_resp(cmd, version, body, timeout).await
887    }
888
889    async fn clear_all_tunnel(&self) {
890        self.tunnel_pool.clear_all_worker().await;
891        self.runtime_tunnels.clear();
892    }
893
894    async fn get_send(
895        &self,
896        peer_id: &PeerId,
897        tunnel_id: TunnelId,
898    ) -> CmdResult<ClassifiedCmdNodeSendGuard<C, M, R, W, F, LEN, CMD, LISTENER>> {
899        self.get_send_of_tunnel_id(peer_id.clone(), tunnel_id).await
900    }
901}
902
903#[async_trait::async_trait]
904impl<
905    C: WorkerClassification,
906    M: CmdTunnelMeta,
907    R: ClassifiedCmdTunnelRead<C, M>,
908    W: ClassifiedCmdTunnelWrite<C, M>,
909    F: ClassifiedCmdNodeTunnelFactory<C, M, R, W>,
910    LEN: RawEncode
911        + for<'a> RawDecode<'a>
912        + Copy
913        + Send
914        + Sync
915        + 'static
916        + FromPrimitive
917        + ToPrimitive
918        + RawFixedBytes,
919    CMD: RawEncode
920        + for<'a> RawDecode<'a>
921        + Copy
922        + Send
923        + Sync
924        + 'static
925        + RawFixedBytes
926        + Eq
927        + Hash
928        + Debug,
929    LISTENER: CmdTunnelListener<M, R, W>,
930>
931    ClassifiedCmdNode<
932        LEN,
933        CMD,
934        C,
935        M,
936        ClassifiedCmdSend<C, M, R, W, LEN, CMD>,
937        ClassifiedCmdNodeSendGuard<C, M, R, W, F, LEN, CMD, LISTENER>,
938    > for DefaultClassifiedCmdNode<C, M, R, W, F, LEN, CMD, LISTENER>
939{
940    async fn send_by_classified_tunnel(
941        &self,
942        classification: C,
943        cmd: CMD,
944        version: u8,
945        body: &[u8],
946    ) -> CmdResult<()> {
947        let mut send = self.get_classified_send(classification).await?;
948        send.send(cmd, version, body).await
949    }
950
951    async fn send_by_classified_tunnel_with_resp(
952        &self,
953        classification: C,
954        cmd: CMD,
955        version: u8,
956        body: &[u8],
957        timeout: Duration,
958    ) -> CmdResult<CmdBody> {
959        let mut send = self.get_classified_send(classification).await?;
960        send.send_with_resp(cmd, version, body, timeout).await
961    }
962
963    async fn send_parts_by_classified_tunnel(
964        &self,
965        classification: C,
966        cmd: CMD,
967        version: u8,
968        body: &[&[u8]],
969    ) -> CmdResult<()> {
970        let mut send = self.get_classified_send(classification).await?;
971        send.send_parts(cmd, version, body).await
972    }
973
974    async fn send_parts_by_classified_tunnel_with_resp(
975        &self,
976        classification: C,
977        cmd: CMD,
978        version: u8,
979        body: &[&[u8]],
980        timeout: Duration,
981    ) -> CmdResult<CmdBody> {
982        let mut send = self.get_classified_send(classification).await?;
983        send.send_parts_with_resp(cmd, version, body, timeout).await
984    }
985
986    async fn send_cmd_by_classified_tunnel(
987        &self,
988        classification: C,
989        cmd: CMD,
990        version: u8,
991        body: CmdBody,
992    ) -> CmdResult<()> {
993        let mut send = self.get_classified_send(classification).await?;
994        send.send_cmd(cmd, version, body).await
995    }
996
997    async fn send_cmd_by_classified_tunnel_with_resp(
998        &self,
999        classification: C,
1000        cmd: CMD,
1001        version: u8,
1002        body: CmdBody,
1003        timeout: Duration,
1004    ) -> CmdResult<CmdBody> {
1005        let mut send = self.get_classified_send(classification).await?;
1006        send.send_cmd_with_resp(cmd, version, body, timeout).await
1007    }
1008
1009    async fn send_by_peer_classified_tunnel(
1010        &self,
1011        peer_id: &PeerId,
1012        classification: C,
1013        cmd: CMD,
1014        version: u8,
1015        body: &[u8],
1016    ) -> CmdResult<()> {
1017        let mut send = self
1018            .get_peer_classified_send(peer_id.clone(), classification)
1019            .await?;
1020        send.send(cmd, version, body).await
1021    }
1022
1023    async fn send_by_peer_classified_tunnel_with_resp(
1024        &self,
1025        peer_id: &PeerId,
1026        classification: C,
1027        cmd: CMD,
1028        version: u8,
1029        body: &[u8],
1030        timeout: Duration,
1031    ) -> CmdResult<CmdBody> {
1032        let mut send = self
1033            .get_peer_classified_send(peer_id.clone(), classification)
1034            .await?;
1035        send.send_with_resp(cmd, version, body, timeout).await
1036    }
1037
1038    async fn send_parts_by_peer_classified_tunnel(
1039        &self,
1040        peer_id: &PeerId,
1041        classification: C,
1042        cmd: CMD,
1043        version: u8,
1044        body: &[&[u8]],
1045    ) -> CmdResult<()> {
1046        let mut send = self
1047            .get_peer_classified_send(peer_id.clone(), classification)
1048            .await?;
1049        send.send_parts(cmd, version, body).await
1050    }
1051
1052    async fn send_parts_by_peer_classified_tunnel_with_resp(
1053        &self,
1054        peer_id: &PeerId,
1055        classification: C,
1056        cmd: CMD,
1057        version: u8,
1058        body: &[&[u8]],
1059        timeout: Duration,
1060    ) -> CmdResult<CmdBody> {
1061        let mut send = self
1062            .get_peer_classified_send(peer_id.clone(), classification)
1063            .await?;
1064        send.send_parts_with_resp(cmd, version, body, timeout).await
1065    }
1066
1067    async fn send_cmd_by_peer_classified_tunnel(
1068        &self,
1069        peer_id: &PeerId,
1070        classification: C,
1071        cmd: CMD,
1072        version: u8,
1073        body: CmdBody,
1074    ) -> CmdResult<()> {
1075        let mut send = self
1076            .get_peer_classified_send(peer_id.clone(), classification)
1077            .await?;
1078        send.send_cmd(cmd, version, body).await
1079    }
1080
1081    async fn send_cmd_by_peer_classified_tunnel_with_resp(
1082        &self,
1083        peer_id: &PeerId,
1084        classification: C,
1085        cmd: CMD,
1086        version: u8,
1087        body: CmdBody,
1088        timeout: Duration,
1089    ) -> CmdResult<CmdBody> {
1090        let mut send = self
1091            .get_peer_classified_send(peer_id.clone(), classification)
1092            .await?;
1093        send.send_cmd_with_resp(cmd, version, body, timeout).await
1094    }
1095
1096    async fn find_tunnel_id_by_classified(&self, classification: C) -> CmdResult<TunnelId> {
1097        let send = self.get_classified_send(classification).await?;
1098        Ok(send.get_tunnel_id())
1099    }
1100
1101    async fn find_tunnel_id_by_peer_classified(
1102        &self,
1103        peer_id: &PeerId,
1104        classification: C,
1105    ) -> CmdResult<TunnelId> {
1106        let send = self
1107            .get_peer_classified_send(peer_id.clone(), classification)
1108            .await?;
1109        Ok(send.get_tunnel_id())
1110    }
1111
1112    async fn get_send_by_classified(
1113        &self,
1114        classification: C,
1115    ) -> CmdResult<ClassifiedCmdNodeSendGuard<C, M, R, W, F, LEN, CMD, LISTENER>> {
1116        self.get_classified_send(classification).await
1117    }
1118
1119    async fn get_send_by_peer_classified(
1120        &self,
1121        peer_id: &PeerId,
1122        classification: C,
1123    ) -> CmdResult<ClassifiedCmdNodeSendGuard<C, M, R, W, F, LEN, CMD, LISTENER>> {
1124        self.get_peer_classified_send(peer_id.clone(), classification)
1125            .await
1126    }
1127}