1use std::collections::HashMap;
2use std::fmt::Debug;
3use std::hash::Hash;
4use std::sync::{Arc, Mutex};
5use bucky_raw_codec::{RawDecode, RawEncode, RawFixedBytes};
6use num::{FromPrimitive, ToPrimitive};
7use sfo_pool::{ClassifiedWorker, ClassifiedWorkerFactory, ClassifiedWorkerGuard, ClassifiedWorkerPool, ClassifiedWorkerPoolRef, PoolErrorCode, PoolResult};
8use sfo_split::{Splittable};
9use crate::{into_pool_err, pool_err, CmdHandler, CmdHeader, CmdNode, CmdTunnelRead, CmdTunnelWrite, PeerId, TunnelId, TunnelIdGenerator};
10use crate::client::{CmdSend};
11use crate::cmd::{CmdHandlerMap};
12use crate::errors::{into_cmd_err, CmdErrorCode, CmdResult};
13use crate::node::create_recv_handle;
14use crate::server::{CmdTunnelListener};
15
16#[async_trait::async_trait]
17pub trait CmdNodeTunnelFactory<R: CmdTunnelRead, W: CmdTunnelWrite>: Send + Sync + 'static {
18 async fn create_tunnel(&self, remote_id: &PeerId) -> CmdResult<Splittable<R, W>>;
19}
20
21
22impl<R, W, LEN, CMD> ClassifiedWorker<(PeerId, Option<TunnelId>)> for CmdSend<R, W, LEN, CMD>
23where R: CmdTunnelRead,
24 W: CmdTunnelWrite,
25 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
26 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug{
27 fn is_work(&self) -> bool {
28 self.is_work && !self.recv_handle.is_finished()
29 }
30
31 fn is_valid(&self, c: (PeerId, Option<TunnelId>)) -> bool {
32 let (peer_id, tunnel_id) = c;
33 if tunnel_id.is_some() {
34 self.tunnel_id == tunnel_id.unwrap() && peer_id == self.write.get_remote_peer_id()
35 } else {
36 peer_id == self.write.get_remote_peer_id()
37 }
38 }
39
40 fn classification(&self) -> (PeerId, Option<TunnelId>) {
41 (self.write.get_remote_peer_id().clone(), Some(self.tunnel_id))
42 }
43}
44
45struct CmdWriteFactoryImpl<R: CmdTunnelRead,
46 W: CmdTunnelWrite,
47 F: CmdNodeTunnelFactory<R, W>,
48 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
49 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug,
50 LISTENER: CmdTunnelListener<R, W>> {
51 tunnel_listener: LISTENER,
52 tunnel_factory: F,
53 cmd_handler: Arc<dyn CmdHandler<LEN, CMD>>,
54 tunnel_id_generator: TunnelIdGenerator,
55 send_cache: Arc<Mutex<HashMap<PeerId, Vec<CmdSend<R, W, LEN, CMD>>>>>,
56}
57
58
59impl<R: CmdTunnelRead,
60 W: CmdTunnelWrite,
61 F: CmdNodeTunnelFactory<R, W>,
62 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
63 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
64 LISTENER: CmdTunnelListener<R, W>> CmdWriteFactoryImpl<R, W, F, LEN, CMD, LISTENER> {
65 pub fn new(tunnel_factory: F,
66 tunnel_listener: LISTENER,
67 cmd_handler: impl CmdHandler<LEN, CMD>) -> Self {
68 Self {
69 tunnel_listener,
70 tunnel_factory,
71 cmd_handler: Arc::new(cmd_handler),
72 tunnel_id_generator: TunnelIdGenerator::new(),
73 send_cache: Arc::new(Mutex::new(Default::default())),
74 }
75 }
76
77
78 pub fn start(self: &Arc<Self>) {
79 let this = self.clone();
80 tokio::spawn(async move {
81 if let Err(e) = this.run().await {
82 log::error!("cmd server error: {:?}", e);
83 }
84 });
85 }
86
87 async fn run(self: &Arc<Self>) -> CmdResult<()> {
88 loop {
89 let tunnel = self.tunnel_listener.accept().await?;
90 let peer_id = tunnel.get_remote_peer_id();
91 let tunnel_id = self.tunnel_id_generator.generate();
92 let this = self.clone();
93 tokio::spawn(async move {
94 let ret: CmdResult<()> = async move {
95 let this = this.clone();
96 let cmd_handler = this.cmd_handler.clone();
97 let (reader, writer) = tunnel.split();
98 let recv_handle = create_recv_handle::<R, W, LEN, CMD>(reader, tunnel_id, cmd_handler);
99 {
100 let mut send_cache = this.send_cache.lock().unwrap();
101 let send_list = send_cache.entry(peer_id).or_insert(Vec::new());
102 send_list.push(CmdSend::new(tunnel_id, recv_handle, writer));
103 }
104 Ok(())
105 }.await;
106 if let Err(e) = ret {
107 log::error!("peer connection error: {:?}", e);
108 }
109 });
110 }
111 }
112}
113
114#[async_trait::async_trait]
115impl<R: CmdTunnelRead,
116 W: CmdTunnelWrite,
117 F: CmdNodeTunnelFactory<R, W>,
118 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
119 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Debug,
120 LISTENER: CmdTunnelListener<R, W>> ClassifiedWorkerFactory<(PeerId, Option<TunnelId>), CmdSend<R, W, LEN, CMD>> for CmdWriteFactoryImpl<R, W, F, LEN, CMD, LISTENER> {
121 async fn create(&self, c: Option<(PeerId, Option<TunnelId>)>) -> PoolResult<CmdSend<R, W, LEN, CMD>> {
122 if c.is_some() {
123 let (peer_id, tunnel_id) = c.unwrap();
124 if tunnel_id.is_some() {
125 let mut send_cache = self.send_cache.lock().unwrap();
126 if let Some(send_list) = send_cache.get_mut(&peer_id) {
127 let mut send_index = None;
128 for (index, send) in send_list.iter().enumerate() {
129 if send.get_tunnel_id() == tunnel_id.unwrap() {
130 send_index = Some(index);
131 break;
132 }
133 }
134 if let Some(send_index) = send_index {
135 let send = send_list.remove(send_index);
136 Ok(send)
137 } else {
138 Err(pool_err!(PoolErrorCode::Failed, "tunnel {:?} not found", tunnel_id.unwrap()))
139 }
140 } else {
141 Err(pool_err!(PoolErrorCode::Failed, "tunnel {:?} not found", tunnel_id.unwrap()))
142 }
143 } else {
144 {
145 let mut send_cache = self.send_cache.lock().unwrap();
146 if let Some(send_list) = send_cache.get_mut(&peer_id) {
147 if !send_list.is_empty() {
148 let send = send_list.pop().unwrap();
149 if send_list.is_empty() {
150 send_cache.remove(&peer_id);
151 }
152 return Ok(send);
153 }
154 }
155 }
156 let tunnel = self.tunnel_factory.create_tunnel(&peer_id).await.map_err(into_pool_err!(PoolErrorCode::Failed))?;
157 let tunnel_id = self.tunnel_id_generator.generate();
158 let (recv, write) = tunnel.split();
159 let cmd_handler = self.cmd_handler.clone();
160 let handle = create_recv_handle::<R, W, LEN, CMD>(recv, tunnel_id, cmd_handler);
161 Ok(CmdSend::new(tunnel_id, handle, write))
162 }
163 } else {
164 Err(pool_err!(PoolErrorCode::Failed, "peer id is none"))
165 }
166 }
167}
168
169struct CmdWriteFactory<R: CmdTunnelRead,
170 W: CmdTunnelWrite,
171 F: CmdNodeTunnelFactory<R, W>,
172 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
173 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug,
174 LISTENER: CmdTunnelListener<R, W>> {
175 inner: Arc<CmdWriteFactoryImpl<R, W, F, LEN, CMD, LISTENER>>
176}
177
178
179impl<R: CmdTunnelRead,
180 W: CmdTunnelWrite,
181 F: CmdNodeTunnelFactory<R, W>,
182 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
183 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
184 LISTENER: CmdTunnelListener<R, W>> CmdWriteFactory<R, W, F, LEN, CMD, LISTENER> {
185 pub fn new(tunnel_factory: F,
186 tunnel_listener: LISTENER,
187 cmd_handler: impl CmdHandler<LEN, CMD>) -> Self {
188 Self {
189 inner: Arc::new(CmdWriteFactoryImpl::new(tunnel_factory, tunnel_listener, cmd_handler)),
190 }
191 }
192
193 pub fn start(&self) {
194 self.inner.start();
195 }
196}
197
198#[async_trait::async_trait]
199impl<R: CmdTunnelRead,
200 W: CmdTunnelWrite,
201 F: CmdNodeTunnelFactory<R, W>,
202 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
203 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Debug,
204 LISTENER: CmdTunnelListener<R, W>> ClassifiedWorkerFactory<(PeerId, Option<TunnelId>), CmdSend<R, W, LEN, CMD>> for CmdWriteFactory<R, W, F, LEN, CMD, LISTENER> {
205 async fn create(&self, c: Option<(PeerId, Option<TunnelId>)>) -> PoolResult<CmdSend<R, W, LEN, CMD>> {
206 self.inner.create(c).await
207 }
208}
209pub struct DefaultCmdNode<R: CmdTunnelRead,
210 W: CmdTunnelWrite,
211 F: CmdNodeTunnelFactory<R, W>,
212 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
213 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Eq + Hash + Debug,
214 LISTENER: CmdTunnelListener<R, W>> {
215 tunnel_pool: ClassifiedWorkerPoolRef<(PeerId, Option<TunnelId>), CmdSend<R, W, LEN, CMD>, CmdWriteFactory<R, W, F, LEN, CMD, LISTENER>>,
216 cmd_handler_map: Arc<CmdHandlerMap<LEN, CMD>>,
217}
218
219impl<R: CmdTunnelRead,
220 W: CmdTunnelWrite,
221 F: CmdNodeTunnelFactory<R, W>,
222 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync +'static + FromPrimitive + ToPrimitive + RawFixedBytes,
223 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync +'static + RawFixedBytes + Eq + Hash + Debug,
224 LISTENER: CmdTunnelListener<R, W>> DefaultCmdNode<R, W, F, LEN, CMD, LISTENER> {
225 pub fn new(listener: LISTENER, factory: F, tunnel_count: u16) -> Arc<Self> {
226 let cmd_handler_map = Arc::new(CmdHandlerMap::new());
227 let handler_map = cmd_handler_map.clone();
228 let write_factory = CmdWriteFactory::<R, W, _, LEN, CMD, LISTENER>::new(factory, listener, move |peer_id: PeerId, tunnel_id: TunnelId, header: CmdHeader<LEN, CMD>, body_read| {
229 let handler_map = handler_map.clone();
230 async move {
231 if let Some(handler) = handler_map.get(header.cmd_code()) {
232 handler.handle(peer_id, tunnel_id, header, body_read).await?;
233 }
234 Ok(())
235 }
236 });
237 write_factory.start();
238 Arc::new(Self {
239 tunnel_pool: ClassifiedWorkerPool::new(tunnel_count, write_factory),
240 cmd_handler_map,
241 })
242 }
243
244 async fn get_send(&self, peer_id: PeerId) -> CmdResult<ClassifiedWorkerGuard<(PeerId, Option<TunnelId>), CmdSend<R, W, LEN, CMD>, CmdWriteFactory<R, W, F, LEN, CMD, LISTENER>>> {
245 self.tunnel_pool.get_classified_worker((peer_id, None)).await.map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
246 }
247
248 async fn get_send_of_tunnel_id(&self, peer_id: PeerId, tunnel_id: TunnelId) -> CmdResult<ClassifiedWorkerGuard<(PeerId, Option<TunnelId>), CmdSend<R, W, LEN, CMD>, CmdWriteFactory<R, W, F, LEN, CMD, LISTENER>>> {
249 self.tunnel_pool.get_classified_worker((peer_id, Some(tunnel_id))).await.map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
250 }
251
252}
253
254#[async_trait::async_trait]
255impl<R: CmdTunnelRead,
256 W: CmdTunnelWrite,
257 F: CmdNodeTunnelFactory<R, W>,
258 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + RawFixedBytes + Sync + Send + 'static + FromPrimitive + ToPrimitive,
259 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + RawFixedBytes + Sync + Send + 'static + Eq + Hash + Debug,
260 LISTENER: CmdTunnelListener<R, W>> CmdNode<LEN, CMD> for DefaultCmdNode<R, W, F, LEN, CMD, LISTENER> {
261 fn register_cmd_handler(&self, cmd: CMD, handler: impl CmdHandler<LEN, CMD>) {
262 self.cmd_handler_map.insert(cmd, handler);
263 }
264
265 async fn send(&self, peer_id: &PeerId, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
266 let mut send = self.get_send(peer_id.clone()).await?;
267 send.send(cmd, version, body).await
268 }
269
270 async fn send2(&self, peer_id: &PeerId, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
271 let mut send = self.get_send(peer_id.clone()).await?;
272 send.send2(cmd, version, body).await
273 }
274
275 async fn send_by_specify_tunnel(&self, peer_id: &PeerId, tunnel_id: TunnelId, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
276 let mut send = self.get_send_of_tunnel_id(peer_id.clone(), tunnel_id).await?;
277 send.send(cmd, version, body).await
278 }
279
280 async fn send2_by_specify_tunnel(&self, peer_id: &PeerId, tunnel_id: TunnelId, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
281 let mut send = self.get_send_of_tunnel_id(peer_id.clone(), tunnel_id).await?;
282 send.send2(cmd, version, body).await
283 }
284
285 async fn clear_all_tunnel(&self) {
286 self.tunnel_pool.clear_all_worker().await
287 }
288}