sfo_cmd_server/node/
classified_node.rs

1use std::collections::HashMap;
2use std::fmt::Debug;
3use std::hash::Hash;
4use std::sync::{Arc, Mutex};
5use bucky_raw_codec::{RawDecode, RawEncode, RawFixedBytes};
6use num::{FromPrimitive, ToPrimitive};
7use sfo_pool::{ClassifiedWorker, ClassifiedWorkerFactory, ClassifiedWorkerGuard, ClassifiedWorkerPool, ClassifiedWorkerPoolRef, PoolErrorCode, PoolResult, WorkerClassification};
8use sfo_split::Splittable;
9use crate::{into_pool_err, pool_err, ClassifiedCmdNode, CmdHandler, CmdHeader, CmdNode, PeerId, TunnelId, TunnelIdGenerator};
10use crate::client::{ClassifiedCmdSend, ClassifiedCmdTunnelRead, ClassifiedCmdTunnelWrite};
11use crate::cmd::CmdHandlerMap;
12use crate::errors::{into_cmd_err, CmdErrorCode, CmdResult};
13use crate::node::create_recv_handle;
14use crate::server::CmdTunnelListener;
15
16#[derive(Debug, Clone, Eq, Hash)]
17pub struct CmdNodeTunnelClassification<C: WorkerClassification> {
18    peer_id: Option<PeerId>,
19    tunnel_id: Option<TunnelId>,
20    classification: Option<C>,
21}
22
23impl<C: WorkerClassification> PartialEq for CmdNodeTunnelClassification<C> {
24    fn eq(&self, other: &Self) -> bool {
25        self.peer_id == other.peer_id && self.tunnel_id == other.tunnel_id && self.classification == other.classification
26    }
27}
28
29impl<C, R, W, LEN, CMD> ClassifiedWorker<CmdNodeTunnelClassification<C>> for ClassifiedCmdSend<C, R, W, LEN, CMD>
30where C: WorkerClassification,
31      R: ClassifiedCmdTunnelRead<C>,
32      W: ClassifiedCmdTunnelWrite<C>,
33      LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
34      CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug {
35    fn is_work(&self) -> bool {
36        self.is_work && !self.recv_handle.is_finished()
37    }
38
39    fn is_valid(&self, c: CmdNodeTunnelClassification<C>) -> bool {
40        if c.peer_id.is_some() && c.peer_id.as_ref().unwrap() != &self.write.get_remote_peer_id() {
41            return false;
42        }
43
44        if c.tunnel_id.is_some() {
45            self.tunnel_id == c.tunnel_id.unwrap()
46        } else {
47            if c.classification.is_some() {
48                self.classification == c.classification.unwrap()
49            } else {
50                true
51            }
52        }
53    }
54
55    fn classification(&self) -> CmdNodeTunnelClassification<C> {
56        CmdNodeTunnelClassification {
57            peer_id: Some(self.write.get_remote_peer_id()),
58            tunnel_id: Some(self.tunnel_id),
59            classification: Some(self.classification.clone()),
60        }
61    }
62}
63
64#[async_trait::async_trait]
65pub trait ClassifiedCmdNodeTunnelFactory<C: WorkerClassification, R: ClassifiedCmdTunnelRead<C>, W: ClassifiedCmdTunnelWrite<C>>: Send + Sync + 'static {
66    async fn create_tunnel(&self, classification: Option<CmdNodeTunnelClassification<C>>) -> CmdResult<Splittable<R, W>>;
67}
68
69struct CmdWriteFactoryImpl<C: WorkerClassification,
70    R: ClassifiedCmdTunnelRead<C>,
71    W: ClassifiedCmdTunnelWrite<C>,
72    F: ClassifiedCmdNodeTunnelFactory<C, R, W>,
73    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
74    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug,
75    LISTENER: CmdTunnelListener<R, W>> {
76    tunnel_listener: LISTENER,
77    tunnel_factory: F,
78    cmd_handler: Arc<dyn CmdHandler<LEN, CMD>>,
79    tunnel_id_generator: TunnelIdGenerator,
80    send_cache: Arc<Mutex<HashMap<PeerId, Vec<ClassifiedCmdSend<C, R, W, LEN, CMD>>>>>,
81}
82
83
84impl<C: WorkerClassification,
85    R: ClassifiedCmdTunnelRead<C>,
86    W: ClassifiedCmdTunnelWrite<C>,
87    F: ClassifiedCmdNodeTunnelFactory<C, R, W>,
88    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
89    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
90    LISTENER: CmdTunnelListener<R, W>> CmdWriteFactoryImpl<C, R, W, F, LEN, CMD, LISTENER> {
91    pub fn new(tunnel_factory: F,
92               tunnel_listener: LISTENER,
93               cmd_handler: impl CmdHandler<LEN, CMD>) -> Self {
94        Self {
95            tunnel_listener,
96            tunnel_factory,
97            cmd_handler: Arc::new(cmd_handler),
98            tunnel_id_generator: TunnelIdGenerator::new(),
99            send_cache: Arc::new(Mutex::new(Default::default())),
100        }
101    }
102
103
104    pub fn start(self: &Arc<Self>) {
105        let this = self.clone();
106        tokio::spawn(async move {
107            if let Err(e) = this.run().await {
108                log::error!("cmd server error: {:?}", e);
109            }
110        });
111    }
112
113    async fn run(self: &Arc<Self>) -> CmdResult<()> {
114        loop {
115            let tunnel = self.tunnel_listener.accept().await?;
116            let peer_id = tunnel.get_remote_peer_id();
117            let classification = tunnel.get_classification();
118            let tunnel_id = self.tunnel_id_generator.generate();
119            let this = self.clone();
120            tokio::spawn(async move {
121                let ret: CmdResult<()> = async move {
122                    let this = this.clone();
123                    let cmd_handler = this.cmd_handler.clone();
124                    let (reader, writer) = tunnel.split();
125                    let recv_handle = create_recv_handle::<R, W, LEN, CMD>(reader, tunnel_id, cmd_handler);
126                    {
127                        let mut send_cache = this.send_cache.lock().unwrap();
128                        let send_list = send_cache.entry(peer_id).or_insert(Vec::new());
129                        send_list.push(ClassifiedCmdSend::new(tunnel_id, classification, recv_handle, writer));
130                    }
131                    Ok(())
132                }.await;
133                if let Err(e) = ret {
134                    log::error!("peer connection error: {:?}", e);
135                }
136            });
137        }
138    }
139}
140
141#[async_trait::async_trait]
142impl<C: WorkerClassification,
143    R: ClassifiedCmdTunnelRead<C>,
144    W: ClassifiedCmdTunnelWrite<C>,
145    F: ClassifiedCmdNodeTunnelFactory<C, R, W>,
146    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
147    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Debug,
148    LISTENER: CmdTunnelListener<R, W>> ClassifiedWorkerFactory<CmdNodeTunnelClassification<C>, ClassifiedCmdSend<C, R, W, LEN, CMD>> for CmdWriteFactoryImpl<C, R, W, F, LEN, CMD, LISTENER> {
149    async fn create(&self, c: Option<CmdNodeTunnelClassification<C>>) -> PoolResult<ClassifiedCmdSend<C, R, W, LEN, CMD>> {
150        if c.is_some() {
151            let classification = c.unwrap();
152            if classification.peer_id.is_some() {
153                let peer_id = classification.peer_id.clone().unwrap();
154                let tunnel_id = classification.tunnel_id;
155                if tunnel_id.is_some() {
156                    let mut send_cache = self.send_cache.lock().unwrap();
157                    if let Some(send_list) = send_cache.get_mut(&peer_id) {
158                        let mut send_index = None;
159                        for (index, send) in send_list.iter().enumerate() {
160                            if send.get_tunnel_id() == tunnel_id.unwrap() {
161                                send_index = Some(index);
162                                break;
163                            }
164                        }
165                        if let Some(send_index) = send_index {
166                            let send = send_list.remove(send_index);
167                            Ok(send)
168                        } else {
169                            Err(pool_err!(PoolErrorCode::Failed, "tunnel {:?} not found", tunnel_id.unwrap()))
170                        }
171                    } else {
172                        Err(pool_err!(PoolErrorCode::Failed, "tunnel {:?} not found", tunnel_id.unwrap()))
173                    }
174                } else {
175                    {
176                        let mut send_cache = self.send_cache.lock().unwrap();
177                        if let Some(send_list) = send_cache.get_mut(&peer_id) {
178                            if !send_list.is_empty() {
179                                let send = send_list.pop().unwrap();
180                                if send_list.is_empty() {
181                                    send_cache.remove(&peer_id);
182                                }
183                                return Ok(send);
184                            }
185                        }
186                    }
187                    let tunnel = self.tunnel_factory.create_tunnel(Some(classification)).await.map_err(into_pool_err!(PoolErrorCode::Failed))?;
188                    let classification = tunnel.get_classification();
189                    let tunnel_id = self.tunnel_id_generator.generate();
190                    let (recv, write) = tunnel.split();
191                    let cmd_handler = self.cmd_handler.clone();
192                    let handle = create_recv_handle::<R, W, LEN, CMD>(recv, tunnel_id, cmd_handler);
193                    Ok(ClassifiedCmdSend::new(tunnel_id, classification, handle, write))
194                }
195            } else {
196                if classification.tunnel_id.is_some() {
197                    Err(pool_err!(PoolErrorCode::Failed, "must set peer id when set tunnel id"))
198                } else {
199                    let tunnel = self.tunnel_factory.create_tunnel(Some(classification)).await.map_err(into_pool_err!(PoolErrorCode::Failed))?;
200                    let classification = tunnel.get_classification();
201                    let tunnel_id = self.tunnel_id_generator.generate();
202                    let (recv, write) = tunnel.split();
203                    let cmd_handler = self.cmd_handler.clone();
204                    let handle = create_recv_handle::<R, W, LEN, CMD>(recv, tunnel_id, cmd_handler);
205                    Ok(ClassifiedCmdSend::new(tunnel_id, classification, handle, write))
206                }
207            }
208
209        } else {
210            Err(pool_err!(PoolErrorCode::Failed, "peer id is none"))
211        }
212    }
213}
214
215struct CmdWriteFactory<C: WorkerClassification,
216    R: ClassifiedCmdTunnelRead<C>,
217    W: ClassifiedCmdTunnelWrite<C>,
218    F: ClassifiedCmdNodeTunnelFactory<C, R, W>,
219    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
220    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug,
221    LISTENER: CmdTunnelListener<R, W>> {
222    inner: Arc<CmdWriteFactoryImpl<C, R, W, F, LEN, CMD, LISTENER>>
223}
224
225
226impl<C: WorkerClassification,
227    R: ClassifiedCmdTunnelRead<C>,
228    W: ClassifiedCmdTunnelWrite<C>,
229    F: ClassifiedCmdNodeTunnelFactory<C, R, W>,
230    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes + RawFixedBytes,
231    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes + RawFixedBytes,
232    LISTENER: CmdTunnelListener<R, W>> CmdWriteFactory<C, R, W, F, LEN, CMD, LISTENER> {
233    pub fn new(tunnel_factory: F,
234               tunnel_listener: LISTENER,
235               cmd_handler: impl CmdHandler<LEN, CMD>) -> Self {
236        Self {
237            inner: Arc::new(CmdWriteFactoryImpl::new(tunnel_factory, tunnel_listener, cmd_handler)),
238        }
239    }
240
241    pub fn start(&self) {
242        self.inner.start();
243    }
244}
245
246#[async_trait::async_trait]
247impl<C: WorkerClassification,
248    R: ClassifiedCmdTunnelRead<C>,
249    W: ClassifiedCmdTunnelWrite<C>,
250    F: ClassifiedCmdNodeTunnelFactory<C, R, W>,
251    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
252    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Debug,
253    LISTENER: CmdTunnelListener<R, W>> ClassifiedWorkerFactory<CmdNodeTunnelClassification<C>, ClassifiedCmdSend<C, R, W, LEN, CMD>> for CmdWriteFactory<C, R, W, F, LEN, CMD, LISTENER> {
254    async fn create(&self, c: Option<CmdNodeTunnelClassification<C>>) -> PoolResult<ClassifiedCmdSend<C, R, W, LEN, CMD>> {
255        self.inner.create(c).await
256    }
257}
258
259pub struct DefaultClassifiedCmdNode<C: WorkerClassification,
260    R: ClassifiedCmdTunnelRead<C>,
261    W: ClassifiedCmdTunnelWrite<C>,
262    F: ClassifiedCmdNodeTunnelFactory<C, R, W>,
263    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync +'static + FromPrimitive + ToPrimitive + RawFixedBytes,
264    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync +'static + RawFixedBytes + Eq + Hash + Debug,
265    LISTENER: CmdTunnelListener<R, W>> {
266    tunnel_pool: ClassifiedWorkerPoolRef<CmdNodeTunnelClassification<C>, ClassifiedCmdSend<C, R, W, LEN, CMD>, CmdWriteFactory<C, R, W, F, LEN, CMD, LISTENER>>,
267    cmd_handler_map: Arc<CmdHandlerMap<LEN, CMD>>,
268}
269
270impl<C: WorkerClassification,
271    R: ClassifiedCmdTunnelRead<C>,
272    W: ClassifiedCmdTunnelWrite<C>,
273    F: ClassifiedCmdNodeTunnelFactory<C, R, W>,
274    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync +'static + FromPrimitive + ToPrimitive + RawFixedBytes,
275    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync +'static + RawFixedBytes + Eq + Hash + Debug,
276    LISTENER: CmdTunnelListener<R, W>> DefaultClassifiedCmdNode<C, R, W, F, LEN, CMD, LISTENER> {
277    pub fn new(listener: LISTENER, factory: F, tunnel_count: u16) -> Arc<Self> {
278        let cmd_handler_map = Arc::new(CmdHandlerMap::new());
279        let handler_map = cmd_handler_map.clone();
280        let write_factory = CmdWriteFactory::<C, R, W, _, LEN, CMD, LISTENER>::new(factory, listener, move |peer_id: PeerId, tunnel_id: TunnelId, header: CmdHeader<LEN, CMD>, body_read| {
281            let handler_map = handler_map.clone();
282            async move {
283                if let Some(handler) = handler_map.get(header.cmd_code()) {
284                    handler.handle(peer_id, tunnel_id, header, body_read).await?;
285                }
286                Ok(())
287            }
288        });
289        write_factory.start();
290        Arc::new(Self {
291            tunnel_pool: ClassifiedWorkerPool::new(tunnel_count, write_factory),
292            cmd_handler_map,
293        })
294    }
295
296
297    async fn get_send(&self, peer_id: PeerId) -> CmdResult<ClassifiedWorkerGuard<CmdNodeTunnelClassification<C>, ClassifiedCmdSend<C, R, W, LEN, CMD>, CmdWriteFactory<C, R, W, F, LEN, CMD, LISTENER>>> {
298        self.tunnel_pool.get_classified_worker(CmdNodeTunnelClassification {
299            peer_id: Some(peer_id),
300            tunnel_id: None,
301            classification: None,
302        }).await.map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
303    }
304
305    async fn get_send_of_tunnel_id(&self, peer_id: PeerId, tunnel_id: TunnelId) -> CmdResult<ClassifiedWorkerGuard<CmdNodeTunnelClassification<C>, ClassifiedCmdSend<C, R, W, LEN, CMD>, CmdWriteFactory<C, R, W, F, LEN, CMD, LISTENER>>> {
306        self.tunnel_pool.get_classified_worker(CmdNodeTunnelClassification {
307            peer_id: Some(peer_id),
308            tunnel_id: Some(tunnel_id),
309            classification: None,
310        }).await.map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
311    }
312
313    async fn get_classified_send(&self, classification: C) -> CmdResult<ClassifiedWorkerGuard<CmdNodeTunnelClassification<C>, ClassifiedCmdSend<C, R, W, LEN, CMD>, CmdWriteFactory<C, R, W, F, LEN, CMD, LISTENER>>> {
314        self.tunnel_pool.get_classified_worker(CmdNodeTunnelClassification {
315            peer_id: None,
316            tunnel_id: None,
317            classification: Some(classification),
318        }).await.map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
319    }
320}
321
322#[async_trait::async_trait]
323impl<C: WorkerClassification,
324    R: ClassifiedCmdTunnelRead<C>,
325    W: ClassifiedCmdTunnelWrite<C>,
326    F: ClassifiedCmdNodeTunnelFactory<C, R, W>,
327    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync +'static + FromPrimitive + ToPrimitive + RawFixedBytes,
328    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync +'static + RawFixedBytes + Eq + Hash + Debug,
329    LISTENER: CmdTunnelListener<R, W>> CmdNode<LEN, CMD> for DefaultClassifiedCmdNode<C, R, W, F, LEN, CMD, LISTENER> {
330    fn register_cmd_handler(&self, cmd: CMD, handler: impl CmdHandler<LEN, CMD>) {
331        self.cmd_handler_map.insert(cmd, handler)
332    }
333
334    async fn send(&self, peer_id: &PeerId, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
335        let mut send = self.get_send(peer_id.clone()).await?;
336        send.send(cmd, version, body).await
337    }
338
339    async fn send2(&self, peer_id: &PeerId, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
340        let mut send = self.get_send(peer_id.clone()).await?;
341        send.send2(cmd, version, body).await
342    }
343
344    async fn send_by_specify_tunnel(&self, peer_id: &PeerId, tunnel_id: TunnelId, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
345        let mut send = self.get_send_of_tunnel_id(peer_id.clone(), tunnel_id).await?;
346        send.send(cmd, version, body).await
347    }
348
349    async fn send2_by_specify_tunnel(&self, peer_id: &PeerId, tunnel_id: TunnelId, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
350        let mut send = self.get_send_of_tunnel_id(peer_id.clone(), tunnel_id).await?;
351        send.send2(cmd, version, body).await
352    }
353
354    async fn clear_all_tunnel(&self) {
355        self.tunnel_pool.clear_all_worker().await;
356    }
357}
358
359#[async_trait::async_trait]
360impl<C: WorkerClassification,
361    R: ClassifiedCmdTunnelRead<C>,
362    W: ClassifiedCmdTunnelWrite<C>,
363    F: ClassifiedCmdNodeTunnelFactory<C, R, W>,
364    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync +'static + FromPrimitive + ToPrimitive + RawFixedBytes,
365    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync +'static + RawFixedBytes + Eq + Hash + Debug,
366    LISTENER: CmdTunnelListener<R, W>> ClassifiedCmdNode<LEN, CMD, C> for DefaultClassifiedCmdNode<C, R, W, F, LEN, CMD, LISTENER> {
367    async fn send_by_classified_tunnel(&self, classification: C, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
368        let mut send = self.get_classified_send(classification).await?;
369        send.send(cmd, version, body).await
370    }
371
372    async fn send2_by_classified_tunnel(&self, classification: C, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
373        let mut send = self.get_classified_send(classification).await?;
374        send.send2(cmd, version, body).await
375    }
376
377    async fn find_tunnel_id_by_classified(&self, classification: C) -> CmdResult<TunnelId> {
378        let send = self.get_classified_send(classification).await?;
379        Ok(send.get_tunnel_id())
380    }
381}