1use std::fmt::Debug;
2use std::hash::Hash;
3use std::sync::{Arc, Mutex};
4use bucky_raw_codec::{RawConvertTo, RawDecode, RawEncode, RawFixedBytes, RawFrom};
5use num::{FromPrimitive, ToPrimitive};
6use sfo_split::Splittable;
7use tokio::io::{AsyncReadExt, AsyncWriteExt};
8use crate::cmd::{CmdBodyReadImpl, CmdHandler, CmdHandlerMap, CmdHeader};
9use crate::{CmdTunnelRead, CmdTunnelWrite, TunnelId};
10use crate::errors::{cmd_err, into_cmd_err, CmdErrorCode, CmdResult};
11use crate::peer_connection::PeerConnection;
12use crate::peer_id::PeerId;
13use crate::server::CmdServer;
14use super::peer_manager::{PeerManager, PeerManagerRef};
15
16#[async_trait::async_trait]
17pub trait CmdTunnelListener<R: CmdTunnelRead, W: CmdTunnelWrite>: Send + Sync + 'static {
18 async fn accept(&self) -> CmdResult<Splittable<R, W>>;
19}
20
21#[async_trait::async_trait]
22pub trait CmdServerEventListener: Send + Sync + 'static {
23 async fn on_peer_connected(&self, peer_id: &PeerId) -> CmdResult<()>;
24 async fn on_peer_disconnected(&self, peer_id: &PeerId) -> CmdResult<()>;
25}
26
27#[derive(Clone)]
28struct CmdServerEventListenerEmit {
29 listeners: Arc<Mutex<Vec<Arc<dyn CmdServerEventListener>>>>,
30}
31
32impl CmdServerEventListenerEmit {
33 pub fn new() -> Self {
34 Self {
35 listeners: Arc::new(Mutex::new(Vec::new())),
36 }
37 }
38
39 pub fn attach_event_listener(&self, event_listener: Arc<dyn CmdServerEventListener>) {
40 self.listeners.lock().unwrap().push(event_listener);
41 }
42}
43
44#[async_trait::async_trait]
45impl CmdServerEventListener for CmdServerEventListenerEmit {
46 async fn on_peer_connected(&self, peer_id: &PeerId) -> CmdResult<()> {
47 let listeners = {
48 self.listeners.lock().unwrap().clone()
49 };
50 for listener in listeners.iter() {
51 if let Err(e) = listener.on_peer_connected(peer_id).await {
52 log::error!("on_peer_connected error: {:?}", e);
53 }
54 }
55 Ok(())
56 }
57
58 async fn on_peer_disconnected(&self, peer_id: &PeerId) -> CmdResult<()> {
59 let listeners = {
60 self.listeners.lock().unwrap().clone()
61 };
62 for listener in listeners.iter() {
63 if let Err(e) = listener.on_peer_disconnected(peer_id).await {
64 log::error!("on_peer_disconnected error: {:?}", e);
65 }
66 }
67 Ok(())
68 }
69}
70
71pub struct DefaultCmdServer<R: CmdTunnelRead, W: CmdTunnelWrite, LEN, CMD, LISTENER> {
72 tunnel_listener: LISTENER,
73 cmd_handler_map: Arc<CmdHandlerMap<LEN, CMD>>,
74 peer_manager: PeerManagerRef<R, W>,
75 event_emit: CmdServerEventListenerEmit,
76 _l: Mutex<std::marker::PhantomData<(R, W, LEN, CMD)>>,
77}
78
79impl<R: CmdTunnelRead,
80 W: CmdTunnelWrite,
81 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + RawFixedBytes + Sync + Send + 'static + FromPrimitive + ToPrimitive,
82 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + RawFixedBytes + Sync + Send + 'static + Eq + Hash,
83 LISTENER: CmdTunnelListener<R, W>> DefaultCmdServer<R, W, LEN, CMD, LISTENER> {
84 pub fn new(tunnel_listener: LISTENER) -> Arc<Self> {
85 let event_emit = CmdServerEventListenerEmit::new();
86 Arc::new(Self {
87 tunnel_listener,
88 cmd_handler_map: Arc::new(CmdHandlerMap::new()),
89 peer_manager: PeerManager::new(Arc::new(event_emit.clone())),
90 event_emit,
91 _l: Default::default(),
92 })
93 }
94
95 pub fn attach_event_listener(&self, event_listener: Arc<dyn CmdServerEventListener>) {
96 self.event_emit.attach_event_listener(event_listener);
97 }
98
99 pub async fn get_peer_tunnels(&self, peer_id: &PeerId) -> Vec<Arc<tokio::sync::Mutex<PeerConnection<R, W>>>> {
100 let connections = self.peer_manager.find_connections(peer_id);
101 connections
102 }
103
104 pub fn start(self: &Arc<Self>) {
105 let this = self.clone();
106 tokio::spawn(async move {
107 if let Err(e) = this.run().await {
108 log::error!("cmd server error: {:?}", e);
109 }
110 });
111 }
112
113 async fn run(self: &Arc<Self>) -> CmdResult<()> {
114 loop {
115 let tunnel = self.tunnel_listener.accept().await?;
116 let peer_id = tunnel.get_remote_peer_id();
117 let tunnel_id = self.peer_manager.generate_conn_id();
118 let this = self.clone();
119 tokio::spawn(async move {
120 let remote_id = peer_id.clone();
121 let ret: CmdResult<()> = async move {
122 let this = this.clone();
123 let cmd_handler_map = this.cmd_handler_map.clone();
124 let (mut reader, writer) = tunnel.split();
125 let recv_handle = tokio::spawn(async move {
126 let ret: CmdResult<()> = async move {
127 loop {
128 let mut header = vec![0u8; CmdHeader::<LEN, CMD>::raw_bytes().unwrap()];
129 let n = reader.read_exact(&mut header).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
130 if n == 0 {
131 break;
132 }
133 let header = CmdHeader::<LEN, CMD>::clone_from_slice(header.as_slice()).map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
134 let cmd_read = Box::new(CmdBodyReadImpl::new(reader, header.pkg_len().to_u64().unwrap() as usize));
135 let waiter = cmd_read.get_waiter();
136 let future = waiter.create_result_future();
137 {
138 let body_read = cmd_read;
139 if let Some(handler) = cmd_handler_map.get(header.cmd_code()) {
140 if let Err(e) = handler.handle(remote_id.clone(), tunnel_id, header, body_read).await {
141 log::error!("handle cmd error: {:?}", e);
142 }
143 }
144 };
145 reader = future.await.map_err(into_cmd_err!(CmdErrorCode::Failed))??;
146 }
148 Ok(())
149 }.await;
150 ret
151 });
152
153 let peer_conn = PeerConnection {
154 conn_id: tunnel_id,
155 peer_id: peer_id.clone(),
156 send: writer,
157 handle: Some(recv_handle),
158 };
159 this.peer_manager.add_peer_connection(peer_conn).await;
160 Ok(())
161 }.await;
162 if let Err(e) = ret {
163 log::error!("peer connection error: {:?}", e);
164 }
165 });
166 }
167 }
168}
169
170#[async_trait::async_trait]
171impl<R: CmdTunnelRead,
172 W: CmdTunnelWrite,
173 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + RawFixedBytes + Sync + Send + 'static + FromPrimitive + ToPrimitive,
174 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + RawFixedBytes + Sync + Send + 'static + Eq + Hash + Debug,
175 LISTENER: CmdTunnelListener<R, W>> CmdServer<LEN, CMD> for DefaultCmdServer<R, W, LEN, CMD, LISTENER> {
176 fn register_cmd_handler(&self, cmd: CMD, handler: impl CmdHandler<LEN, CMD>) {
177 self.cmd_handler_map.insert(cmd, handler);
178 }
179
180 async fn send(&self, peer_id: &PeerId, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
181 let connections = self.peer_manager.find_connections(peer_id);
182 for conn in connections {
183 let ret: CmdResult<()> = async move {
184 let mut conn = conn.lock().await;
185 log::trace!("send peer_id: {}, tunnel_id {:?}, cmd: {:?}, len: {} data: {}", peer_id, conn.conn_id, cmd, body.len(), hex::encode(body));
186 let header = CmdHeader::<LEN, CMD>::new(version, cmd, LEN::from_u64(body.len() as u64).unwrap());
187 let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
188 conn.send.write_all(buf.as_slice()).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
189 conn.send.write_all(body).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
190 conn.send.flush().await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
191 Ok(())
192 }.await;
193 if ret.is_ok() {
194 break;
195 }
196 }
197 Ok(())
198 }
199
200 async fn send2(&self, peer_id: &PeerId, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
201 let connections = self.peer_manager.find_connections(peer_id);
202 for conn in connections {
203 let ret: CmdResult<()> = async move {
204 let mut conn = conn.lock().await;
205 let mut len = 0;
206 for b in body.iter() {
207 len += b.len();
208 log::trace!("send2 peer_id: {}, tunnel_id: {:?}, cmd: {:?} body: {}", peer_id, conn.conn_id, cmd, hex::encode(b));
209 }
210 log::trace!("send2 peer_id: {}, tunnel_id: {:?}, cmd: {:?} len: {}", peer_id, conn.conn_id, cmd, len);
211 let header = CmdHeader::<LEN, CMD>::new(version, cmd, LEN::from_u64(len as u64).unwrap());
212 let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
213 conn.send.write_all(buf.as_slice()).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
214 for b in body.iter() {
215 conn.send.write_all(b).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
216 }
217 conn.send.flush().await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
218 Ok(())
219 }.await;
220 if ret.is_ok() {
221 break;
222 }
223 }
224 Ok(())
225 }
226
227 async fn send_by_specify_tunnel(&self, peer_id: &PeerId, tunnel_id: TunnelId, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
228 let conn = self.peer_manager.find_connection(tunnel_id);
229 if conn.is_none() {
230 return Err(cmd_err!(CmdErrorCode::PeerConnectionNotFound, "tunnel_id: {:?}", tunnel_id));
231 }
232 let conn = conn.unwrap();
233 let mut conn = conn.lock().await;
234 assert_eq!(tunnel_id, conn.conn_id);
235 log::trace!("send_by_specify_tunnel peer_id: {}, tunnel_id: {:?}, cmd: {:?}, len: {} data: {}", peer_id, conn.conn_id, cmd, body.len(), hex::encode(body));
236 let header = CmdHeader::<LEN, CMD>::new(version, cmd, LEN::from_u64(body.len() as u64).unwrap());
237 let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
238 conn.send.write_all(buf.as_slice()).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
239 conn.send.write_all(body).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
240 conn.send.flush().await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
241 Ok(())
242 }
243
244 async fn send2_by_specify_tunnel(&self, peer_id: &PeerId, tunnel_id: TunnelId, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
245 let conn = self.peer_manager.find_connection(tunnel_id);
246 if conn.is_none() {
247 return Err(cmd_err!(CmdErrorCode::PeerConnectionNotFound, "tunnel_id: {:?}", tunnel_id));
248 }
249 let conn = conn.unwrap();
250 let mut conn = conn.lock().await;
251 assert_eq!(tunnel_id, conn.conn_id);
252 let mut len = 0;
253 for b in body.iter() {
254 len += b.len();
255 log::trace!("send2_by_specify_tunnel peer_id: {}, tunnel_id: {:?}, cmd: {:?} body: {}", peer_id, conn.conn_id, cmd, hex::encode(b));
256 }
257 log::trace!("send2_by_specify_tunnel peer_id: {}, tunnel_id: {:?}, cmd: {:?} len: {}", peer_id, conn.conn_id, cmd, len);
258 let header = CmdHeader::<LEN, CMD>::new(version, cmd, LEN::from_u64(len as u64).unwrap());
259 let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
260 conn.send.write_all(buf.as_slice()).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
261 for b in body.iter() {
262 conn.send.write_all(b).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
263 }
264 conn.send.flush().await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
265 Ok(())
266 }
267
268 async fn send_by_all_tunnels(&self, peer_id: &PeerId, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
269 let connections = self.peer_manager.find_connections(peer_id);
270 for conn in connections {
271 let _ret: CmdResult<()> = async move {
272 let mut conn = conn.lock().await;
273 let header = CmdHeader::<LEN, CMD>::new(version, cmd, LEN::from_u64(body.len() as u64).unwrap());
274 let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
275 conn.send.write_all(buf.as_slice()).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
276 conn.send.write_all(body).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
277 conn.send.flush().await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
278 Ok(())
279 }.await;
280 }
281 Ok(())
282 }
283
284 async fn send2_by_all_tunnels(&self, peer_id: &PeerId, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
285 let connections = self.peer_manager.find_connections(peer_id);
286 let mut len = 0;
287 for b in body.iter() {
288 len += b.len();
289 }
290 for conn in connections {
291 let _ret: CmdResult<()> = async move {
292 let mut conn = conn.lock().await;
293 let header = CmdHeader::<LEN, CMD>::new(version, cmd, LEN::from_u64(len as u64).unwrap());
294 let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
295 conn.send.write_all(buf.as_slice()).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
296 for b in body.iter() {
297 conn.send.write_all(b).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
298 }
299 conn.send.flush().await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
300 Ok(())
301 }.await;
302 }
303 Ok(())
304 }
305}