sfo_cmd_server/node/
node.rs

1use std::collections::HashMap;
2use std::fmt::Debug;
3use std::hash::Hash;
4use std::sync::{Arc, Mutex};
5use std::time::Duration;
6use async_named_locker::ObjectHolder;
7use bucky_raw_codec::{RawDecode, RawEncode, RawFixedBytes};
8use num::{FromPrimitive, ToPrimitive};
9use sfo_pool::{ClassifiedWorker, ClassifiedWorkerFactory, ClassifiedWorkerGuard, ClassifiedWorkerPool, ClassifiedWorkerPoolRef, PoolErrorCode, PoolResult};
10use sfo_split::{Splittable};
11use crate::{into_pool_err, pool_err, CmdBody, CmdHandler, CmdHeader, CmdNode, CmdTunnelMeta, CmdTunnelRead, CmdTunnelWrite, PeerId, TunnelId, TunnelIdGenerator};
12use crate::client::{gen_resp_id, ClassifiedSendGuard, CommonCmdSend, RespWaiter, RespWaiterRef};
13use crate::cmd::{CmdHandlerMap};
14use crate::errors::{into_cmd_err, CmdErrorCode, CmdResult};
15use crate::node::create_recv_handle;
16use crate::server::{CmdTunnelListener};
17
18#[async_trait::async_trait]
19pub trait CmdNodeTunnelFactory<M: CmdTunnelMeta, R: CmdTunnelRead<M>, W: CmdTunnelWrite<M>>: Send + Sync + 'static {
20    async fn create_tunnel(&self, remote_id: &PeerId) -> CmdResult<Splittable<R, W>>;
21}
22
23
24impl<M, R, W, LEN, CMD> ClassifiedWorker<(PeerId, Option<TunnelId>)> for CommonCmdSend<M, R, W, LEN, CMD>
25where M: CmdTunnelMeta,
26      R: CmdTunnelRead<M>,
27      W: CmdTunnelWrite<M>,
28      LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
29      CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes {
30    fn is_work(&self) -> bool {
31        self.is_work && !self.recv_handle.is_finished()
32    }
33
34    fn is_valid(&self, c: (PeerId, Option<TunnelId>)) -> bool {
35        let (peer_id, tunnel_id) = c;
36        if tunnel_id.is_some() {
37            self.tunnel_id == tunnel_id.unwrap() && peer_id == self.remote_id
38        } else {
39            peer_id == self.remote_id
40        }
41    }
42
43    fn classification(&self) -> (PeerId, Option<TunnelId>) {
44        (self.remote_id.clone(), Some(self.tunnel_id))
45    }
46}
47
48struct CmdWriteFactoryImpl<M: CmdTunnelMeta,
49    R: CmdTunnelRead<M>,
50    W: CmdTunnelWrite<M>,
51    F: CmdNodeTunnelFactory<M, R, W>,
52    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
53    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
54    LISTENER: CmdTunnelListener<M, R, W>> {
55    tunnel_listener: LISTENER,
56    tunnel_factory: F,
57    cmd_handler: Arc<dyn CmdHandler<LEN, CMD>>,
58    tunnel_id_generator: TunnelIdGenerator,
59    resp_waiter: RespWaiterRef,
60    send_cache: Arc<Mutex<HashMap<PeerId, Vec<CommonCmdSend<M, R, W, LEN, CMD>>>>>,
61}
62
63
64impl<M: CmdTunnelMeta,
65    R: CmdTunnelRead<M>,
66    W: CmdTunnelWrite<M>,
67    F: CmdNodeTunnelFactory<M, R, W>,
68    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
69    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
70    LISTENER: CmdTunnelListener<M, R, W>> CmdWriteFactoryImpl<M, R, W, F, LEN, CMD, LISTENER> {
71    pub fn new(tunnel_factory: F,
72               tunnel_listener: LISTENER,
73               cmd_handler: impl CmdHandler<LEN, CMD>,
74               resp_waiter: RespWaiterRef) -> Self {
75        Self {
76            tunnel_listener,
77            tunnel_factory,
78            cmd_handler: Arc::new(cmd_handler),
79            tunnel_id_generator: TunnelIdGenerator::new(),
80            resp_waiter,
81            send_cache: Arc::new(Mutex::new(Default::default())),
82        }
83    }
84
85
86    pub fn start(self: &Arc<Self>) {
87        let this = self.clone();
88        tokio::spawn(async move {
89            if let Err(e) = this.run().await {
90                log::error!("cmd server error: {:?}", e);
91            }
92        });
93    }
94
95    async fn run(self: &Arc<Self>) -> CmdResult<()> {
96        loop {
97            let tunnel = self.tunnel_listener.accept().await?;
98            let peer_id = tunnel.get_remote_peer_id();
99            let tunnel_id = self.tunnel_id_generator.generate();
100            let resp_waiter = self.resp_waiter.clone();
101            let this = self.clone();
102            tokio::spawn(async move {
103                let ret: CmdResult<()> = async move {
104                    let this = this.clone();
105                    let cmd_handler = this.cmd_handler.clone();
106                    let (reader, writer) = tunnel.split();
107                    let remote_id = reader.get_remote_peer_id();
108                    let tunnel_meta = reader.get_tunnel_meta();
109                    let writer = ObjectHolder::new(writer);
110                    let recv_handle = create_recv_handle::<M, R, W, LEN, CMD>(reader, writer.clone(), tunnel_id, cmd_handler);
111                    {
112                        let mut send_cache = this.send_cache.lock().unwrap();
113                        let send_list = send_cache.entry(peer_id).or_insert(Vec::new());
114                        send_list.push(CommonCmdSend::new(tunnel_id, recv_handle, writer, resp_waiter, remote_id, tunnel_meta));
115                    }
116                    Ok(())
117                }.await;
118                if let Err(e) = ret {
119                    log::error!("peer connection error: {:?}", e);
120                }
121            });
122        }
123    }
124}
125
126#[async_trait::async_trait]
127impl<M: CmdTunnelMeta,
128    R: CmdTunnelRead<M>,
129    W: CmdTunnelWrite<M>,
130    F: CmdNodeTunnelFactory<M, R, W>,
131    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
132    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Debug,
133    LISTENER: CmdTunnelListener<M, R, W>> ClassifiedWorkerFactory<(PeerId, Option<TunnelId>), CommonCmdSend<M, R, W, LEN, CMD>> for CmdWriteFactoryImpl<M, R, W, F, LEN, CMD, LISTENER> {
134    async fn create(&self, c: Option<(PeerId, Option<TunnelId>)>) -> PoolResult<CommonCmdSend<M, R, W, LEN, CMD>> {
135        if c.is_some() {
136            let (peer_id, tunnel_id) = c.unwrap();
137            if tunnel_id.is_some() {
138                let mut send_cache = self.send_cache.lock().unwrap();
139                if let Some(send_list) = send_cache.get_mut(&peer_id) {
140                    let mut send_index = None;
141                    for (index, send) in send_list.iter().enumerate() {
142                        if send.get_tunnel_id() == tunnel_id.unwrap() {
143                            send_index = Some(index);
144                            break;
145                        }
146                    }
147                    if let Some(send_index) = send_index {
148                        let send = send_list.remove(send_index);
149                        Ok(send)
150                    } else {
151                        Err(pool_err!(PoolErrorCode::Failed, "tunnel {:?} not found", tunnel_id.unwrap()))
152                    }
153                } else {
154                    Err(pool_err!(PoolErrorCode::Failed, "tunnel {:?} not found", tunnel_id.unwrap()))
155                }
156            } else {
157                {
158                    let mut send_cache = self.send_cache.lock().unwrap();
159                    if let Some(send_list) = send_cache.get_mut(&peer_id) {
160                        if !send_list.is_empty() {
161                            let send = send_list.pop().unwrap();
162                            if send_list.is_empty() {
163                                send_cache.remove(&peer_id);
164                            }
165                            return Ok(send);
166                        }
167                    }
168                }
169                let tunnel = self.tunnel_factory.create_tunnel(&peer_id).await.map_err(into_pool_err!(PoolErrorCode::Failed))?;
170                let tunnel_id = self.tunnel_id_generator.generate();
171                let (recv, write) = tunnel.split();
172                let remote_id = recv.get_remote_peer_id();
173                let tunnel_meta = recv.get_tunnel_meta();
174                let write = ObjectHolder::new(write);
175                let cmd_handler = self.cmd_handler.clone();
176                let handle = create_recv_handle::<M, R, W, LEN, CMD>(recv, write.clone(), tunnel_id, cmd_handler);
177                Ok(CommonCmdSend::new(tunnel_id, handle, write, self.resp_waiter.clone(), remote_id, tunnel_meta))
178            }
179        } else {
180            Err(pool_err!(PoolErrorCode::Failed, "peer id is none"))
181        }
182    }
183}
184
185pub struct CmdNodeWriteFactory<M: CmdTunnelMeta,
186    R: CmdTunnelRead<M>,
187    W: CmdTunnelWrite<M>,
188    F: CmdNodeTunnelFactory<M, R, W>,
189    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
190    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
191    LISTENER: CmdTunnelListener<M, R, W>> {
192    inner: Arc<CmdWriteFactoryImpl<M, R, W, F, LEN, CMD, LISTENER>>
193}
194
195
196impl<M: CmdTunnelMeta,
197    R: CmdTunnelRead<M>,
198    W: CmdTunnelWrite<M>,
199    F: CmdNodeTunnelFactory<M, R, W>,
200    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
201    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
202    LISTENER: CmdTunnelListener<M, R, W>> CmdNodeWriteFactory<M, R, W, F, LEN, CMD, LISTENER> {
203    pub(crate) fn new(tunnel_factory: F,
204               tunnel_listener: LISTENER,
205               cmd_handler: impl CmdHandler<LEN, CMD>,
206                      resp_waiter: RespWaiterRef) -> Self {
207        Self {
208            inner: Arc::new(CmdWriteFactoryImpl::new(tunnel_factory, tunnel_listener, cmd_handler, resp_waiter)),
209        }
210    }
211
212    pub fn start(&self) {
213        self.inner.start();
214    }
215}
216
217#[async_trait::async_trait]
218impl<M: CmdTunnelMeta,
219    R: CmdTunnelRead<M>,
220    W: CmdTunnelWrite<M>,
221    F: CmdNodeTunnelFactory<M, R, W>,
222    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
223    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Debug,
224    LISTENER: CmdTunnelListener<M, R, W>> ClassifiedWorkerFactory<(PeerId, Option<TunnelId>), CommonCmdSend<M, R, W, LEN, CMD>> for CmdNodeWriteFactory<M, R, W, F, LEN, CMD, LISTENER> {
225    async fn create(&self, c: Option<(PeerId, Option<TunnelId>)>) -> PoolResult<CommonCmdSend<M, R, W, LEN, CMD>> {
226        self.inner.create(c).await
227    }
228}
229pub struct DefaultCmdNode<M: CmdTunnelMeta,
230    R: CmdTunnelRead<M>,
231    W: CmdTunnelWrite<M>,
232    F: CmdNodeTunnelFactory<M, R, W>,
233    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
234    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Eq + Hash + Debug,
235    LISTENER: CmdTunnelListener<M, R, W>> {
236    tunnel_pool: ClassifiedWorkerPoolRef<(PeerId, Option<TunnelId>), CommonCmdSend<M, R, W, LEN, CMD>, CmdNodeWriteFactory<M, R, W, F, LEN, CMD, LISTENER>>,
237    cmd_handler_map: Arc<CmdHandlerMap<LEN, CMD>>,
238}
239
240impl<M: CmdTunnelMeta,
241    R: CmdTunnelRead<M>,
242    W: CmdTunnelWrite<M>,
243    F: CmdNodeTunnelFactory<M, R, W>,
244    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync +'static + FromPrimitive + ToPrimitive + RawFixedBytes,
245    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync +'static + RawFixedBytes + Eq + Hash + Debug,
246    LISTENER: CmdTunnelListener<M, R, W>> DefaultCmdNode<M, R, W, F, LEN, CMD, LISTENER> {
247    pub fn new(listener: LISTENER, factory: F, tunnel_count: u16) -> Arc<Self> {
248        let cmd_handler_map = Arc::new(CmdHandlerMap::new());
249        let handler_map = cmd_handler_map.clone();
250        let resp_waiter = Arc::new(RespWaiter::new());
251        let waiter = resp_waiter.clone();
252        let write_factory = CmdNodeWriteFactory::<M, R, W, _, LEN, CMD, LISTENER>::new(factory, listener, move |peer_id: PeerId, tunnel_id: TunnelId, header: CmdHeader<LEN, CMD>, body_read: CmdBody| {
253            let handler_map = handler_map.clone();
254            let waiter = waiter.clone();
255            async move {
256                if header.is_resp() && header.seq().is_some() {
257                    let resp_id = gen_resp_id(header.cmd_code(), header.seq().unwrap());
258                    let _ = waiter.set_result(resp_id, body_read);
259                    Ok(None)
260                } else {
261                    if let Some(handler) = handler_map.get(header.cmd_code()) {
262                        handler.handle(peer_id, tunnel_id, header, body_read).await
263                    } else {
264                        Ok(None)
265                    }
266                }
267            }
268        }, resp_waiter.clone());
269        write_factory.start();
270        Arc::new(Self {
271            tunnel_pool: ClassifiedWorkerPool::new(tunnel_count, write_factory),
272            cmd_handler_map,
273        })
274    }
275
276    async fn get_send(&self, peer_id: PeerId) -> CmdResult<ClassifiedWorkerGuard<(PeerId, Option<TunnelId>), CommonCmdSend<M, R, W, LEN, CMD>, CmdNodeWriteFactory<M, R, W, F, LEN, CMD, LISTENER>>> {
277        self.tunnel_pool.get_classified_worker((peer_id, None)).await.map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
278    }
279
280    async fn get_send_of_tunnel_id(&self, peer_id: PeerId, tunnel_id: TunnelId) -> CmdResult<ClassifiedWorkerGuard<(PeerId, Option<TunnelId>), CommonCmdSend<M, R, W, LEN, CMD>, CmdNodeWriteFactory<M, R, W, F, LEN, CMD, LISTENER>>> {
281        self.tunnel_pool.get_classified_worker((peer_id, Some(tunnel_id))).await.map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
282    }
283
284}
285
286pub type CmdNodeSendGuard<M, R, W, F, LEN, CMD, LISTENER> = ClassifiedSendGuard<(PeerId, Option<TunnelId>), M, CommonCmdSend<M, R, W, LEN, CMD>, CmdNodeWriteFactory<M, R, W, F, LEN, CMD, LISTENER>>;
287#[async_trait::async_trait]
288impl<M: CmdTunnelMeta,
289    R: CmdTunnelRead<M>,
290    W: CmdTunnelWrite<M>,
291    F: CmdNodeTunnelFactory<M, R, W>,
292    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + RawFixedBytes + Sync + Send + 'static + FromPrimitive + ToPrimitive,
293    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + RawFixedBytes + Sync + Send + 'static + Eq + Hash + Debug,
294    LISTENER: CmdTunnelListener<M, R, W>> CmdNode<LEN, CMD, M, CommonCmdSend<M, R, W, LEN, CMD>, CmdNodeSendGuard<M, R, W, F, LEN, CMD, LISTENER>> for DefaultCmdNode<M, R, W, F, LEN, CMD, LISTENER> {
295    fn register_cmd_handler(&self, cmd: CMD, handler: impl CmdHandler<LEN, CMD>) {
296        self.cmd_handler_map.insert(cmd, handler);
297    }
298
299    async fn send(&self, peer_id: &PeerId, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
300        let mut send = self.get_send(peer_id.clone()).await?;
301        send.send(cmd, version, body).await
302    }
303
304    async fn send_with_resp(&self, peer_id: &PeerId, cmd: CMD, version: u8, body: &[u8], timeout: Duration) -> CmdResult<CmdBody> {
305        let mut send = self.get_send(peer_id.clone()).await?;
306        send.send_with_resp(cmd, version, body, timeout).await
307    }
308
309    async fn send2(&self, peer_id: &PeerId, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
310        let mut send = self.get_send(peer_id.clone()).await?;
311        send.send2(cmd, version, body).await
312    }
313
314    async fn send2_with_resp(&self, peer_id: &PeerId, cmd: CMD, version: u8, body: &[&[u8]], timeout: Duration) -> CmdResult<CmdBody> {
315        let mut send = self.get_send(peer_id.clone()).await?;
316        send.send2_with_resp(cmd, version, body, timeout).await
317    }
318
319    async fn send_cmd(&self, peer_id: &PeerId, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
320        let mut send = self.get_send(peer_id.clone()).await?;
321        send.send_cmd(cmd, version, body).await
322    }
323
324    async fn send_cmd_with_resp(&self, peer_id: &PeerId, cmd: CMD, version: u8, body: CmdBody, timeout: Duration) -> CmdResult<CmdBody> {
325        let mut send = self.get_send(peer_id.clone()).await?;
326        send.send_cmd_with_resp(cmd, version, body, timeout).await
327    }
328
329    async fn send_by_specify_tunnel(&self, peer_id: &PeerId, tunnel_id: TunnelId, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
330        let mut send = self.get_send_of_tunnel_id(peer_id.clone(), tunnel_id).await?;
331        send.send(cmd, version, body).await
332    }
333
334    async fn send_by_specify_tunnel_with_resp(&self, peer_id: &PeerId, tunnel_id: TunnelId, cmd: CMD, version: u8, body: &[u8], timeout: Duration) -> CmdResult<CmdBody> {
335        let mut send = self.get_send_of_tunnel_id(peer_id.clone(), tunnel_id).await?;
336        send.send_with_resp(cmd, version, body, timeout).await
337    }
338
339    async fn send2_by_specify_tunnel(&self, peer_id: &PeerId, tunnel_id: TunnelId, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
340        let mut send = self.get_send_of_tunnel_id(peer_id.clone(), tunnel_id).await?;
341        send.send2(cmd, version, body).await
342    }
343
344    async fn send2_by_specify_tunnel_with_resp(&self, peer_id: &PeerId, tunnel_id: TunnelId, cmd: CMD, version: u8, body: &[&[u8]], timeout: Duration) -> CmdResult<CmdBody> {
345        let mut send = self.get_send_of_tunnel_id(peer_id.clone(), tunnel_id).await?;
346        send.send2_with_resp(cmd, version, body, timeout).await
347    }
348
349    async fn send_cmd_by_specify_tunnel(&self, peer_id: &PeerId, tunnel_id: TunnelId, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
350        let mut send = self.get_send_of_tunnel_id(peer_id.clone(), tunnel_id).await?;
351        send.send_cmd(cmd, version, body).await
352    }
353
354    async fn send_cmd_by_specify_tunnel_with_resp(&self, peer_id: &PeerId, tunnel_id: TunnelId, cmd: CMD, version: u8, body: CmdBody, timeout: Duration) -> CmdResult<CmdBody> {
355        let mut send = self.get_send_of_tunnel_id(peer_id.clone(), tunnel_id).await?;
356        send.send_cmd_with_resp(cmd, version, body, timeout).await
357    }
358
359    async fn clear_all_tunnel(&self) {
360        self.tunnel_pool.clear_all_worker().await
361    }
362
363    async fn get_send(&self, peer_id: &PeerId, tunnel_id: TunnelId) -> CmdResult<CmdNodeSendGuard<M, R, W, F, LEN, CMD, LISTENER>> {
364        Ok(ClassifiedSendGuard {
365            worker_guard: self.get_send_of_tunnel_id(peer_id.clone(), tunnel_id).await?,
366            _p: Default::default(),
367        })
368    }
369}