sfo_cmd_server/server/
server.rs

1use std::fmt::Debug;
2use std::hash::Hash;
3use std::sync::{Arc, Mutex};
4use bucky_raw_codec::{RawConvertTo, RawDecode, RawEncode, RawFixedBytes, RawFrom};
5use num::{FromPrimitive, ToPrimitive};
6use sfo_split::Splittable;
7use tokio::io::{AsyncReadExt, AsyncWriteExt};
8use crate::cmd::{CmdBodyReadImpl, CmdHandler, CmdHandlerMap, CmdHeader};
9use crate::{CmdTunnelRead, CmdTunnelWrite, TunnelId};
10use crate::errors::{cmd_err, into_cmd_err, CmdErrorCode, CmdResult};
11use crate::peer_connection::PeerConnection;
12use crate::peer_id::PeerId;
13use crate::server::CmdServer;
14use super::peer_manager::{PeerManager, PeerManagerRef};
15
16#[async_trait::async_trait]
17pub trait CmdTunnelListener<R: CmdTunnelRead, W: CmdTunnelWrite>: Send + Sync + 'static {
18    async fn accept(&self) -> CmdResult<Splittable<R, W>>;
19}
20
21#[async_trait::async_trait]
22pub trait CmdServerEventListener: Send + Sync + 'static {
23    async fn on_peer_connected(&self, peer_id: &PeerId) -> CmdResult<()>;
24    async fn on_peer_disconnected(&self, peer_id: &PeerId) -> CmdResult<()>;
25}
26
27#[derive(Clone)]
28struct CmdServerEventListenerEmit {
29    listeners: Arc<Mutex<Vec<Arc<dyn CmdServerEventListener>>>>,
30}
31
32impl CmdServerEventListenerEmit {
33    pub fn new() -> Self {
34        Self {
35            listeners: Arc::new(Mutex::new(Vec::new())),
36        }
37    }
38
39    pub fn attach_event_listener(&self, event_listener: Arc<dyn CmdServerEventListener>) {
40        self.listeners.lock().unwrap().push(event_listener);
41    }
42}
43
44#[async_trait::async_trait]
45impl CmdServerEventListener for CmdServerEventListenerEmit {
46    async fn on_peer_connected(&self, peer_id: &PeerId) -> CmdResult<()> {
47        let listeners = {
48            self.listeners.lock().unwrap().clone()
49        };
50        for listener in listeners.iter() {
51            if let Err(e) = listener.on_peer_connected(peer_id).await {
52                log::error!("on_peer_connected error: {:?}", e);
53            }
54        }
55        Ok(())
56    }
57
58    async fn on_peer_disconnected(&self, peer_id: &PeerId) -> CmdResult<()> {
59        let listeners = {
60            self.listeners.lock().unwrap().clone()
61        };
62        for listener in listeners.iter() {
63            if let Err(e) = listener.on_peer_disconnected(peer_id).await {
64                log::error!("on_peer_disconnected error: {:?}", e);
65            }
66        }
67        Ok(())
68    }
69}
70
71pub struct DefaultCmdServer<R: CmdTunnelRead, W: CmdTunnelWrite, LEN, CMD, LISTENER> {
72    tunnel_listener: LISTENER,
73    cmd_handler_map: Arc<CmdHandlerMap<LEN, CMD>>,
74    peer_manager: PeerManagerRef<R, W>,
75    event_emit: CmdServerEventListenerEmit,
76    _l: Mutex<std::marker::PhantomData<(R, W, LEN, CMD)>>,
77}
78
79impl<R: CmdTunnelRead,
80    W: CmdTunnelWrite,
81    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + RawFixedBytes + Sync + Send + 'static + FromPrimitive + ToPrimitive,
82    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + RawFixedBytes + Sync + Send + 'static + Eq + Hash,
83    LISTENER: CmdTunnelListener<R, W>> DefaultCmdServer<R, W, LEN, CMD, LISTENER> {
84    pub fn new(tunnel_listener: LISTENER) -> Arc<Self> {
85        let event_emit = CmdServerEventListenerEmit::new();
86        Arc::new(Self {
87            tunnel_listener,
88            cmd_handler_map: Arc::new(CmdHandlerMap::new()),
89            peer_manager: PeerManager::new(Arc::new(event_emit.clone())),
90            event_emit,
91            _l: Default::default(),
92        })
93    }
94
95    pub fn attach_event_listener(&self, event_listener: Arc<dyn CmdServerEventListener>) {
96        self.event_emit.attach_event_listener(event_listener);
97    }
98
99    pub async fn get_peer_tunnels(&self, peer_id: &PeerId) -> Vec<Arc<tokio::sync::Mutex<PeerConnection<R, W>>>> {
100        let connections = self.peer_manager.find_connections(peer_id);
101        connections
102    }
103
104    pub fn start(self: &Arc<Self>) {
105        let this = self.clone();
106        tokio::spawn(async move {
107            if let Err(e) = this.run().await {
108                log::error!("cmd server error: {:?}", e);
109            }
110        });
111    }
112
113    async fn run(self: &Arc<Self>) -> CmdResult<()> {
114        loop {
115            let tunnel = self.tunnel_listener.accept().await?;
116            let peer_id = tunnel.get_remote_peer_id();
117            let tunnel_id = self.peer_manager.generate_conn_id();
118            let this = self.clone();
119            tokio::spawn(async move {
120                let remote_id = peer_id.clone();
121                let ret: CmdResult<()> = async move {
122                    let this = this.clone();
123                    let cmd_handler_map = this.cmd_handler_map.clone();
124                    let (mut reader, writer) = tunnel.split();
125                    let recv_handle = tokio::spawn(async move {
126                        let ret: CmdResult<()> = async move {
127                            loop {
128                                let mut header = vec![0u8; CmdHeader::<LEN, CMD>::raw_bytes().unwrap()];
129                                let n = reader.read_exact(&mut header).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
130                                if n == 0 {
131                                    break;
132                                }
133                                let header = CmdHeader::<LEN, CMD>::clone_from_slice(header.as_slice()).map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
134                                let cmd_read = Box::new(CmdBodyReadImpl::new(reader, header.pkg_len().to_u64().unwrap() as usize));
135                                let waiter = cmd_read.get_waiter();
136                                let future = waiter.create_result_future();
137                                {
138                                    let body_read = cmd_read;
139                                    if let Some(handler) = cmd_handler_map.get(header.cmd_code()) {
140                                        if let Err(e) = handler.handle(remote_id.clone(), tunnel_id, header, body_read).await {
141                                            log::error!("handle cmd error: {:?}", e);
142                                        }
143                                    }
144                                };
145                                reader = future.await.map_err(into_cmd_err!(CmdErrorCode::Failed))??;
146                                // }
147                            }
148                            Ok(())
149                        }.await;
150                        ret
151                    });
152
153                    let peer_conn = PeerConnection {
154                        conn_id: tunnel_id,
155                        peer_id: peer_id.clone(),
156                        send: writer,
157                        handle: Some(recv_handle),
158                    };
159                    this.peer_manager.add_peer_connection(peer_conn).await;
160                    Ok(())
161                }.await;
162                if let Err(e) = ret {
163                    log::error!("peer connection error: {:?}", e);
164                }
165            });
166        }
167    }
168}
169
170#[async_trait::async_trait]
171impl<R: CmdTunnelRead,
172    W: CmdTunnelWrite,
173    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + RawFixedBytes + Sync + Send + 'static + FromPrimitive + ToPrimitive,
174    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + RawFixedBytes + Sync + Send + 'static + Eq + Hash + Debug,
175    LISTENER: CmdTunnelListener<R, W>> CmdServer<LEN, CMD> for DefaultCmdServer<R, W, LEN, CMD, LISTENER> {
176    fn register_cmd_handler(&self, cmd: CMD, handler: impl CmdHandler<LEN, CMD>) {
177        self.cmd_handler_map.insert(cmd, handler);
178    }
179
180    async fn send(&self, peer_id: &PeerId, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
181        let connections = self.peer_manager.find_connections(peer_id);
182        for conn in connections {
183            let ret: CmdResult<()> = async move {
184                let mut conn = conn.lock().await;
185                log::trace!("send peer_id: {}, tunnel_id {:?}, cmd: {:?}, len: {} data: {}", peer_id, conn.conn_id, cmd, body.len(), hex::encode(body));
186                let header = CmdHeader::<LEN, CMD>::new(version, cmd, LEN::from_u64(body.len() as u64).unwrap());
187                let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
188                conn.send.write_all(buf.as_slice()).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
189                conn.send.write_all(body).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
190                conn.send.flush().await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
191                Ok(())
192            }.await;
193            if ret.is_ok() {
194                break;
195            }
196        }
197        Ok(())
198    }
199
200    async fn send2(&self, peer_id: &PeerId, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
201        let connections = self.peer_manager.find_connections(peer_id);
202        for conn in connections {
203            let ret: CmdResult<()> = async move {
204                let mut conn = conn.lock().await;
205                let mut len = 0;
206                for b in body.iter() {
207                    len += b.len();
208                    log::trace!("send2 peer_id: {}, tunnel_id: {:?}, cmd: {:?} body: {}", peer_id, conn.conn_id, cmd, hex::encode(b));
209                }
210                log::trace!("send2 peer_id: {}, tunnel_id: {:?}, cmd: {:?} len: {}", peer_id, conn.conn_id, cmd, len);
211                let header = CmdHeader::<LEN, CMD>::new(version, cmd, LEN::from_u64(len as u64).unwrap());
212                let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
213                conn.send.write_all(buf.as_slice()).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
214                for b in body.iter() {
215                    conn.send.write_all(b).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
216                }
217                conn.send.flush().await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
218                Ok(())
219            }.await;
220            if ret.is_ok() {
221                break;
222            }
223        }
224        Ok(())
225    }
226
227    async fn send_by_specify_tunnel(&self, peer_id: &PeerId, tunnel_id: TunnelId, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
228        let conn = self.peer_manager.find_connection(tunnel_id);
229        if conn.is_none() {
230            return Err(cmd_err!(CmdErrorCode::PeerConnectionNotFound, "tunnel_id: {:?}", tunnel_id));
231        }
232        let conn = conn.unwrap();
233        let mut conn = conn.lock().await;
234        assert_eq!(tunnel_id, conn.conn_id);
235        log::trace!("send_by_specify_tunnel peer_id: {}, tunnel_id: {:?}, cmd: {:?}, len: {} data: {}", peer_id, conn.conn_id, cmd, body.len(), hex::encode(body));
236        let header = CmdHeader::<LEN, CMD>::new(version, cmd, LEN::from_u64(body.len() as u64).unwrap());
237        let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
238        conn.send.write_all(buf.as_slice()).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
239        conn.send.write_all(body).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
240        conn.send.flush().await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
241        Ok(())
242    }
243
244    async fn send2_by_specify_tunnel(&self, peer_id: &PeerId, tunnel_id: TunnelId, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
245        let conn = self.peer_manager.find_connection(tunnel_id);
246        if conn.is_none() {
247            return Err(cmd_err!(CmdErrorCode::PeerConnectionNotFound, "tunnel_id: {:?}", tunnel_id));
248        }
249        let conn = conn.unwrap();
250        let mut conn = conn.lock().await;
251        assert_eq!(tunnel_id, conn.conn_id);
252        let mut len = 0;
253        for b in body.iter() {
254            len += b.len();
255            log::trace!("send2_by_specify_tunnel peer_id: {}, tunnel_id: {:?}, cmd: {:?} body: {}", peer_id, conn.conn_id, cmd, hex::encode(b));
256        }
257        log::trace!("send2_by_specify_tunnel peer_id: {}, tunnel_id: {:?}, cmd: {:?} len: {}", peer_id, conn.conn_id, cmd, len);
258        let header = CmdHeader::<LEN, CMD>::new(version, cmd, LEN::from_u64(len as u64).unwrap());
259        let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
260        conn.send.write_all(buf.as_slice()).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
261        for b in body.iter() {
262            conn.send.write_all(b).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
263        }
264        conn.send.flush().await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
265        Ok(())
266    }
267
268    async fn send_by_all_tunnels(&self, peer_id: &PeerId, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
269        let connections = self.peer_manager.find_connections(peer_id);
270        for conn in connections {
271            let _ret: CmdResult<()> = async move {
272                let mut conn = conn.lock().await;
273                let header = CmdHeader::<LEN, CMD>::new(version, cmd, LEN::from_u64(body.len() as u64).unwrap());
274                let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
275                conn.send.write_all(buf.as_slice()).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
276                conn.send.write_all(body).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
277                conn.send.flush().await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
278                Ok(())
279            }.await;
280        }
281        Ok(())
282    }
283
284    async fn send2_by_all_tunnels(&self, peer_id: &PeerId, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
285        let connections = self.peer_manager.find_connections(peer_id);
286        let mut len = 0;
287        for b in body.iter() {
288            len += b.len();
289        }
290        for conn in connections {
291            let _ret: CmdResult<()> = async move {
292                let mut conn = conn.lock().await;
293                let header = CmdHeader::<LEN, CMD>::new(version, cmd, LEN::from_u64(len as u64).unwrap());
294                let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
295                conn.send.write_all(buf.as_slice()).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
296                for b in body.iter() {
297                    conn.send.write_all(b).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
298                }
299                conn.send.flush().await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
300                Ok(())
301            }.await;
302        }
303        Ok(())
304    }
305}