sfo_cmd_server/node/
node.rs

1use std::collections::HashMap;
2use std::fmt::Debug;
3use std::hash::Hash;
4use std::sync::{Arc, Mutex};
5use bucky_raw_codec::{RawDecode, RawEncode, RawFixedBytes};
6use num::{FromPrimitive, ToPrimitive};
7use sfo_pool::{ClassifiedWorker, ClassifiedWorkerFactory, ClassifiedWorkerGuard, ClassifiedWorkerPool, ClassifiedWorkerPoolRef, PoolErrorCode, PoolResult};
8use sfo_split::{Splittable};
9use crate::{into_pool_err, pool_err, CmdHandler, CmdHeader, CmdNode, CmdTunnelRead, CmdTunnelWrite, PeerId, TunnelId, TunnelIdGenerator};
10use crate::client::{CmdSend};
11use crate::cmd::{CmdHandlerMap};
12use crate::errors::{into_cmd_err, CmdErrorCode, CmdResult};
13use crate::node::create_recv_handle;
14use crate::server::{CmdTunnelListener};
15
16#[async_trait::async_trait]
17pub trait CmdNodeTunnelFactory<R: CmdTunnelRead, W: CmdTunnelWrite>: Send + Sync + 'static {
18    async fn create_tunnel(&self, remote_id: &PeerId) -> CmdResult<Splittable<R, W>>;
19}
20
21
22impl<R, W, LEN, CMD> ClassifiedWorker<(PeerId, Option<TunnelId>)> for CmdSend<R, W, LEN, CMD>
23where R: CmdTunnelRead,
24      W: CmdTunnelWrite,
25      LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
26      CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug{
27    fn is_work(&self) -> bool {
28        self.is_work && !self.recv_handle.is_finished()
29    }
30
31    fn is_valid(&self, c: (PeerId, Option<TunnelId>)) -> bool {
32        let (peer_id, tunnel_id) = c;
33        if tunnel_id.is_some() {
34            self.tunnel_id == tunnel_id.unwrap() && peer_id == self.write.get_remote_peer_id()
35        } else {
36            peer_id == self.write.get_remote_peer_id()
37        }
38    }
39
40    fn classification(&self) -> (PeerId, Option<TunnelId>) {
41        (self.write.get_remote_peer_id().clone(), Some(self.tunnel_id))
42    }
43}
44
45struct CmdWriteFactoryImpl<R: CmdTunnelRead,
46    W: CmdTunnelWrite,
47    F: CmdNodeTunnelFactory<R, W>,
48    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
49    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug,
50    LISTENER: CmdTunnelListener<R, W>> {
51    tunnel_listener: LISTENER,
52    tunnel_factory: F,
53    cmd_handler: Arc<dyn CmdHandler<LEN, CMD>>,
54    tunnel_id_generator: TunnelIdGenerator,
55    send_cache: Arc<Mutex<HashMap<PeerId, Vec<CmdSend<R, W, LEN, CMD>>>>>,
56}
57
58
59impl<R: CmdTunnelRead,
60    W: CmdTunnelWrite,
61    F: CmdNodeTunnelFactory<R, W>,
62    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
63    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
64    LISTENER: CmdTunnelListener<R, W>> CmdWriteFactoryImpl<R, W, F, LEN, CMD, LISTENER> {
65    pub fn new(tunnel_factory: F,
66               tunnel_listener: LISTENER,
67               cmd_handler: impl CmdHandler<LEN, CMD>) -> Self {
68        Self {
69            tunnel_listener,
70            tunnel_factory,
71            cmd_handler: Arc::new(cmd_handler),
72            tunnel_id_generator: TunnelIdGenerator::new(),
73            send_cache: Arc::new(Mutex::new(Default::default())),
74        }
75    }
76
77
78    pub fn start(self: &Arc<Self>) {
79        let this = self.clone();
80        tokio::spawn(async move {
81            if let Err(e) = this.run().await {
82                log::error!("cmd server error: {:?}", e);
83            }
84        });
85    }
86
87    async fn run(self: &Arc<Self>) -> CmdResult<()> {
88        loop {
89            let tunnel = self.tunnel_listener.accept().await?;
90            let peer_id = tunnel.get_remote_peer_id();
91            let tunnel_id = self.tunnel_id_generator.generate();
92            let this = self.clone();
93            tokio::spawn(async move {
94                let ret: CmdResult<()> = async move {
95                    let this = this.clone();
96                    let cmd_handler = this.cmd_handler.clone();
97                    let (reader, writer) = tunnel.split();
98                    let recv_handle = create_recv_handle::<R, W, LEN, CMD>(reader, tunnel_id, cmd_handler);
99                    {
100                        let mut send_cache = this.send_cache.lock().unwrap();
101                        let send_list = send_cache.entry(peer_id).or_insert(Vec::new());
102                        send_list.push(CmdSend::new(tunnel_id, recv_handle, writer));
103                    }
104                    Ok(())
105                }.await;
106                if let Err(e) = ret {
107                    log::error!("peer connection error: {:?}", e);
108                }
109            });
110        }
111    }
112}
113
114#[async_trait::async_trait]
115impl<R: CmdTunnelRead,
116    W: CmdTunnelWrite,
117    F: CmdNodeTunnelFactory<R, W>,
118    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
119    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Debug,
120    LISTENER: CmdTunnelListener<R, W>> ClassifiedWorkerFactory<(PeerId, Option<TunnelId>), CmdSend<R, W, LEN, CMD>> for CmdWriteFactoryImpl<R, W, F, LEN, CMD, LISTENER> {
121    async fn create(&self, c: Option<(PeerId, Option<TunnelId>)>) -> PoolResult<CmdSend<R, W, LEN, CMD>> {
122        if c.is_some() {
123            let (peer_id, tunnel_id) = c.unwrap();
124            if tunnel_id.is_some() {
125                let mut send_cache = self.send_cache.lock().unwrap();
126                if let Some(send_list) = send_cache.get_mut(&peer_id) {
127                    let mut send_index = None;
128                    for (index, send) in send_list.iter().enumerate() {
129                        if send.get_tunnel_id() == tunnel_id.unwrap() {
130                            send_index = Some(index);
131                            break;
132                        }
133                    }
134                    if let Some(send_index) = send_index {
135                        let send = send_list.remove(send_index);
136                        Ok(send)
137                    } else {
138                        Err(pool_err!(PoolErrorCode::Failed, "tunnel {:?} not found", tunnel_id.unwrap()))
139                    }
140                } else {
141                    Err(pool_err!(PoolErrorCode::Failed, "tunnel {:?} not found", tunnel_id.unwrap()))
142                }
143            } else {
144                {
145                    let mut send_cache = self.send_cache.lock().unwrap();
146                    if let Some(send_list) = send_cache.get_mut(&peer_id) {
147                        if !send_list.is_empty() {
148                            let send = send_list.pop().unwrap();
149                            if send_list.is_empty() {
150                                send_cache.remove(&peer_id);
151                            }
152                            return Ok(send);
153                        }
154                    }
155                }
156                let tunnel = self.tunnel_factory.create_tunnel(&peer_id).await.map_err(into_pool_err!(PoolErrorCode::Failed))?;
157                let tunnel_id = self.tunnel_id_generator.generate();
158                let (recv, write) = tunnel.split();
159                let cmd_handler = self.cmd_handler.clone();
160                let handle = create_recv_handle::<R, W, LEN, CMD>(recv, tunnel_id, cmd_handler);
161                Ok(CmdSend::new(tunnel_id, handle, write))
162            }
163        } else {
164            Err(pool_err!(PoolErrorCode::Failed, "peer id is none"))
165        }
166    }
167}
168
169struct CmdWriteFactory<R: CmdTunnelRead,
170    W: CmdTunnelWrite,
171    F: CmdNodeTunnelFactory<R, W>,
172    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
173    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug,
174    LISTENER: CmdTunnelListener<R, W>> {
175    inner: Arc<CmdWriteFactoryImpl<R, W, F, LEN, CMD, LISTENER>>
176}
177
178
179impl<R: CmdTunnelRead,
180    W: CmdTunnelWrite,
181    F: CmdNodeTunnelFactory<R, W>,
182    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
183    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
184    LISTENER: CmdTunnelListener<R, W>> CmdWriteFactory<R, W, F, LEN, CMD, LISTENER> {
185    pub fn new(tunnel_factory: F,
186               tunnel_listener: LISTENER,
187               cmd_handler: impl CmdHandler<LEN, CMD>) -> Self {
188        Self {
189            inner: Arc::new(CmdWriteFactoryImpl::new(tunnel_factory, tunnel_listener, cmd_handler)),
190        }
191    }
192
193    pub fn start(&self) {
194        self.inner.start();
195    }
196}
197
198#[async_trait::async_trait]
199impl<R: CmdTunnelRead,
200    W: CmdTunnelWrite,
201    F: CmdNodeTunnelFactory<R, W>,
202    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
203    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Debug,
204    LISTENER: CmdTunnelListener<R, W>> ClassifiedWorkerFactory<(PeerId, Option<TunnelId>), CmdSend<R, W, LEN, CMD>> for CmdWriteFactory<R, W, F, LEN, CMD, LISTENER> {
205    async fn create(&self, c: Option<(PeerId, Option<TunnelId>)>) -> PoolResult<CmdSend<R, W, LEN, CMD>> {
206        self.inner.create(c).await
207    }
208}
209pub struct DefaultCmdNode<R: CmdTunnelRead,
210    W: CmdTunnelWrite,
211    F: CmdNodeTunnelFactory<R, W>,
212    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
213    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Eq + Hash + Debug,
214    LISTENER: CmdTunnelListener<R, W>> {
215    tunnel_pool: ClassifiedWorkerPoolRef<(PeerId, Option<TunnelId>), CmdSend<R, W, LEN, CMD>, CmdWriteFactory<R, W, F, LEN, CMD, LISTENER>>,
216    cmd_handler_map: Arc<CmdHandlerMap<LEN, CMD>>,
217}
218
219impl<R: CmdTunnelRead,
220    W: CmdTunnelWrite,
221    F: CmdNodeTunnelFactory<R, W>,
222    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync +'static + FromPrimitive + ToPrimitive + RawFixedBytes,
223    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync +'static + RawFixedBytes + Eq + Hash + Debug,
224    LISTENER: CmdTunnelListener<R, W>> DefaultCmdNode<R, W, F, LEN, CMD, LISTENER> {
225    pub fn new(listener: LISTENER, factory: F, tunnel_count: u16) -> Arc<Self> {
226        let cmd_handler_map = Arc::new(CmdHandlerMap::new());
227        let handler_map = cmd_handler_map.clone();
228        let write_factory = CmdWriteFactory::<R, W, _, LEN, CMD, LISTENER>::new(factory, listener, move |peer_id: PeerId, tunnel_id: TunnelId, header: CmdHeader<LEN, CMD>, body_read| {
229            let handler_map = handler_map.clone();
230            async move {
231                if let Some(handler) = handler_map.get(header.cmd_code()) {
232                    handler.handle(peer_id, tunnel_id, header, body_read).await?;
233                }
234                Ok(())
235            }
236        });
237        write_factory.start();
238        Arc::new(Self {
239            tunnel_pool: ClassifiedWorkerPool::new(tunnel_count, write_factory),
240            cmd_handler_map,
241        })
242    }
243
244    async fn get_send(&self, peer_id: PeerId) -> CmdResult<ClassifiedWorkerGuard<(PeerId, Option<TunnelId>), CmdSend<R, W, LEN, CMD>, CmdWriteFactory<R, W, F, LEN, CMD, LISTENER>>> {
245        self.tunnel_pool.get_classified_worker((peer_id, None)).await.map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
246    }
247
248    async fn get_send_of_tunnel_id(&self, peer_id: PeerId, tunnel_id: TunnelId) -> CmdResult<ClassifiedWorkerGuard<(PeerId, Option<TunnelId>), CmdSend<R, W, LEN, CMD>, CmdWriteFactory<R, W, F, LEN, CMD, LISTENER>>> {
249        self.tunnel_pool.get_classified_worker((peer_id, Some(tunnel_id))).await.map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
250    }
251
252}
253
254#[async_trait::async_trait]
255impl<R: CmdTunnelRead,
256    W: CmdTunnelWrite,
257    F: CmdNodeTunnelFactory<R, W>,
258    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + RawFixedBytes + Sync + Send + 'static + FromPrimitive + ToPrimitive,
259    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + RawFixedBytes + Sync + Send + 'static + Eq + Hash + Debug,
260    LISTENER: CmdTunnelListener<R, W>> CmdNode<LEN, CMD> for DefaultCmdNode<R, W, F, LEN, CMD, LISTENER> {
261    fn register_cmd_handler(&self, cmd: CMD, handler: impl CmdHandler<LEN, CMD>) {
262        self.cmd_handler_map.insert(cmd, handler);
263    }
264
265    async fn send(&self, peer_id: &PeerId, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
266        let mut send = self.get_send(peer_id.clone()).await?;
267        send.send(cmd, version, body).await
268    }
269
270    async fn send2(&self, peer_id: &PeerId, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
271        let mut send = self.get_send(peer_id.clone()).await?;
272        send.send2(cmd, version, body).await
273    }
274
275    async fn send_by_specify_tunnel(&self, peer_id: &PeerId, tunnel_id: TunnelId, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
276        let mut send = self.get_send_of_tunnel_id(peer_id.clone(), tunnel_id).await?;
277        send.send(cmd, version, body).await
278    }
279
280    async fn send2_by_specify_tunnel(&self, peer_id: &PeerId, tunnel_id: TunnelId, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
281        let mut send = self.get_send_of_tunnel_id(peer_id.clone(), tunnel_id).await?;
282        send.send2(cmd, version, body).await
283    }
284
285    async fn clear_all_tunnel(&self) {
286        self.tunnel_pool.clear_all_worker().await
287    }
288}