1use std::collections::HashMap;
2use std::fmt::Debug;
3use std::hash::Hash;
4use std::sync::{Arc, Mutex};
5use bucky_raw_codec::{RawDecode, RawEncode, RawFixedBytes};
6use num::{FromPrimitive, ToPrimitive};
7use sfo_pool::{ClassifiedWorker, ClassifiedWorkerFactory, ClassifiedWorkerGuard, ClassifiedWorkerPool, ClassifiedWorkerPoolRef, PoolErrorCode, PoolResult, WorkerClassification};
8use sfo_split::Splittable;
9use crate::{into_pool_err, pool_err, ClassifiedCmdNode, CmdHandler, CmdHeader, CmdNode, PeerId, TunnelId, TunnelIdGenerator};
10use crate::client::{ClassifiedCmdSend, ClassifiedCmdTunnelRead, ClassifiedCmdTunnelWrite};
11use crate::cmd::CmdHandlerMap;
12use crate::errors::{into_cmd_err, CmdErrorCode, CmdResult};
13use crate::node::create_recv_handle;
14use crate::server::CmdTunnelListener;
15
16#[derive(Debug, Clone, Eq, Hash)]
17pub struct CmdNodeTunnelClassification<C: WorkerClassification> {
18 peer_id: Option<PeerId>,
19 tunnel_id: Option<TunnelId>,
20 classification: Option<C>,
21}
22
23impl<C: WorkerClassification> PartialEq for CmdNodeTunnelClassification<C> {
24 fn eq(&self, other: &Self) -> bool {
25 self.peer_id == other.peer_id && self.tunnel_id == other.tunnel_id && self.classification == other.classification
26 }
27}
28
29impl<C, R, W, LEN, CMD> ClassifiedWorker<CmdNodeTunnelClassification<C>> for ClassifiedCmdSend<C, R, W, LEN, CMD>
30where C: WorkerClassification,
31 R: ClassifiedCmdTunnelRead<C>,
32 W: ClassifiedCmdTunnelWrite<C>,
33 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
34 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug {
35 fn is_work(&self) -> bool {
36 self.is_work && !self.recv_handle.is_finished()
37 }
38
39 fn is_valid(&self, c: CmdNodeTunnelClassification<C>) -> bool {
40 if c.peer_id.is_some() && c.peer_id.as_ref().unwrap() != &self.write.get_remote_peer_id() {
41 return false;
42 }
43
44 if c.tunnel_id.is_some() {
45 self.tunnel_id == c.tunnel_id.unwrap()
46 } else {
47 if c.classification.is_some() {
48 self.classification == c.classification.unwrap()
49 } else {
50 true
51 }
52 }
53 }
54
55 fn classification(&self) -> CmdNodeTunnelClassification<C> {
56 CmdNodeTunnelClassification {
57 peer_id: Some(self.write.get_remote_peer_id()),
58 tunnel_id: Some(self.tunnel_id),
59 classification: Some(self.classification.clone()),
60 }
61 }
62}
63
64#[async_trait::async_trait]
65pub trait ClassifiedCmdNodeTunnelFactory<C: WorkerClassification, R: ClassifiedCmdTunnelRead<C>, W: ClassifiedCmdTunnelWrite<C>>: Send + Sync + 'static {
66 async fn create_tunnel(&self, classification: Option<CmdNodeTunnelClassification<C>>) -> CmdResult<Splittable<R, W>>;
67}
68
69struct CmdWriteFactoryImpl<C: WorkerClassification,
70 R: ClassifiedCmdTunnelRead<C>,
71 W: ClassifiedCmdTunnelWrite<C>,
72 F: ClassifiedCmdNodeTunnelFactory<C, R, W>,
73 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
74 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug,
75 LISTENER: CmdTunnelListener<R, W>> {
76 tunnel_listener: LISTENER,
77 tunnel_factory: F,
78 cmd_handler: Arc<dyn CmdHandler<LEN, CMD>>,
79 tunnel_id_generator: TunnelIdGenerator,
80 send_cache: Arc<Mutex<HashMap<PeerId, Vec<ClassifiedCmdSend<C, R, W, LEN, CMD>>>>>,
81}
82
83
84impl<C: WorkerClassification,
85 R: ClassifiedCmdTunnelRead<C>,
86 W: ClassifiedCmdTunnelWrite<C>,
87 F: ClassifiedCmdNodeTunnelFactory<C, R, W>,
88 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
89 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
90 LISTENER: CmdTunnelListener<R, W>> CmdWriteFactoryImpl<C, R, W, F, LEN, CMD, LISTENER> {
91 pub fn new(tunnel_factory: F,
92 tunnel_listener: LISTENER,
93 cmd_handler: impl CmdHandler<LEN, CMD>) -> Self {
94 Self {
95 tunnel_listener,
96 tunnel_factory,
97 cmd_handler: Arc::new(cmd_handler),
98 tunnel_id_generator: TunnelIdGenerator::new(),
99 send_cache: Arc::new(Mutex::new(Default::default())),
100 }
101 }
102
103
104 pub fn start(self: &Arc<Self>) {
105 let this = self.clone();
106 tokio::spawn(async move {
107 if let Err(e) = this.run().await {
108 log::error!("cmd server error: {:?}", e);
109 }
110 });
111 }
112
113 async fn run(self: &Arc<Self>) -> CmdResult<()> {
114 loop {
115 let tunnel = self.tunnel_listener.accept().await?;
116 let peer_id = tunnel.get_remote_peer_id();
117 let classification = tunnel.get_classification();
118 let tunnel_id = self.tunnel_id_generator.generate();
119 let this = self.clone();
120 tokio::spawn(async move {
121 let ret: CmdResult<()> = async move {
122 let this = this.clone();
123 let cmd_handler = this.cmd_handler.clone();
124 let (reader, writer) = tunnel.split();
125 let recv_handle = create_recv_handle::<R, W, LEN, CMD>(reader, tunnel_id, cmd_handler);
126 {
127 let mut send_cache = this.send_cache.lock().unwrap();
128 let send_list = send_cache.entry(peer_id).or_insert(Vec::new());
129 send_list.push(ClassifiedCmdSend::new(tunnel_id, classification, recv_handle, writer));
130 }
131 Ok(())
132 }.await;
133 if let Err(e) = ret {
134 log::error!("peer connection error: {:?}", e);
135 }
136 });
137 }
138 }
139}
140
141#[async_trait::async_trait]
142impl<C: WorkerClassification,
143 R: ClassifiedCmdTunnelRead<C>,
144 W: ClassifiedCmdTunnelWrite<C>,
145 F: ClassifiedCmdNodeTunnelFactory<C, R, W>,
146 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
147 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Debug,
148 LISTENER: CmdTunnelListener<R, W>> ClassifiedWorkerFactory<CmdNodeTunnelClassification<C>, ClassifiedCmdSend<C, R, W, LEN, CMD>> for CmdWriteFactoryImpl<C, R, W, F, LEN, CMD, LISTENER> {
149 async fn create(&self, c: Option<CmdNodeTunnelClassification<C>>) -> PoolResult<ClassifiedCmdSend<C, R, W, LEN, CMD>> {
150 if c.is_some() {
151 let classification = c.unwrap();
152 if classification.peer_id.is_some() {
153 let peer_id = classification.peer_id.clone().unwrap();
154 let tunnel_id = classification.tunnel_id;
155 if tunnel_id.is_some() {
156 let mut send_cache = self.send_cache.lock().unwrap();
157 if let Some(send_list) = send_cache.get_mut(&peer_id) {
158 let mut send_index = None;
159 for (index, send) in send_list.iter().enumerate() {
160 if send.get_tunnel_id() == tunnel_id.unwrap() {
161 send_index = Some(index);
162 break;
163 }
164 }
165 if let Some(send_index) = send_index {
166 let send = send_list.remove(send_index);
167 Ok(send)
168 } else {
169 Err(pool_err!(PoolErrorCode::Failed, "tunnel {:?} not found", tunnel_id.unwrap()))
170 }
171 } else {
172 Err(pool_err!(PoolErrorCode::Failed, "tunnel {:?} not found", tunnel_id.unwrap()))
173 }
174 } else {
175 {
176 let mut send_cache = self.send_cache.lock().unwrap();
177 if let Some(send_list) = send_cache.get_mut(&peer_id) {
178 if !send_list.is_empty() {
179 let send = send_list.pop().unwrap();
180 if send_list.is_empty() {
181 send_cache.remove(&peer_id);
182 }
183 return Ok(send);
184 }
185 }
186 }
187 let tunnel = self.tunnel_factory.create_tunnel(Some(classification)).await.map_err(into_pool_err!(PoolErrorCode::Failed))?;
188 let classification = tunnel.get_classification();
189 let tunnel_id = self.tunnel_id_generator.generate();
190 let (recv, write) = tunnel.split();
191 let cmd_handler = self.cmd_handler.clone();
192 let handle = create_recv_handle::<R, W, LEN, CMD>(recv, tunnel_id, cmd_handler);
193 Ok(ClassifiedCmdSend::new(tunnel_id, classification, handle, write))
194 }
195 } else {
196 if classification.tunnel_id.is_some() {
197 Err(pool_err!(PoolErrorCode::Failed, "must set peer id when set tunnel id"))
198 } else {
199 let tunnel = self.tunnel_factory.create_tunnel(Some(classification)).await.map_err(into_pool_err!(PoolErrorCode::Failed))?;
200 let classification = tunnel.get_classification();
201 let tunnel_id = self.tunnel_id_generator.generate();
202 let (recv, write) = tunnel.split();
203 let cmd_handler = self.cmd_handler.clone();
204 let handle = create_recv_handle::<R, W, LEN, CMD>(recv, tunnel_id, cmd_handler);
205 Ok(ClassifiedCmdSend::new(tunnel_id, classification, handle, write))
206 }
207 }
208
209 } else {
210 Err(pool_err!(PoolErrorCode::Failed, "peer id is none"))
211 }
212 }
213}
214
215struct CmdWriteFactory<C: WorkerClassification,
216 R: ClassifiedCmdTunnelRead<C>,
217 W: ClassifiedCmdTunnelWrite<C>,
218 F: ClassifiedCmdNodeTunnelFactory<C, R, W>,
219 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
220 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug,
221 LISTENER: CmdTunnelListener<R, W>> {
222 inner: Arc<CmdWriteFactoryImpl<C, R, W, F, LEN, CMD, LISTENER>>
223}
224
225
226impl<C: WorkerClassification,
227 R: ClassifiedCmdTunnelRead<C>,
228 W: ClassifiedCmdTunnelWrite<C>,
229 F: ClassifiedCmdNodeTunnelFactory<C, R, W>,
230 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes + RawFixedBytes,
231 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes + RawFixedBytes,
232 LISTENER: CmdTunnelListener<R, W>> CmdWriteFactory<C, R, W, F, LEN, CMD, LISTENER> {
233 pub fn new(tunnel_factory: F,
234 tunnel_listener: LISTENER,
235 cmd_handler: impl CmdHandler<LEN, CMD>) -> Self {
236 Self {
237 inner: Arc::new(CmdWriteFactoryImpl::new(tunnel_factory, tunnel_listener, cmd_handler)),
238 }
239 }
240
241 pub fn start(&self) {
242 self.inner.start();
243 }
244}
245
246#[async_trait::async_trait]
247impl<C: WorkerClassification,
248 R: ClassifiedCmdTunnelRead<C>,
249 W: ClassifiedCmdTunnelWrite<C>,
250 F: ClassifiedCmdNodeTunnelFactory<C, R, W>,
251 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
252 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Debug,
253 LISTENER: CmdTunnelListener<R, W>> ClassifiedWorkerFactory<CmdNodeTunnelClassification<C>, ClassifiedCmdSend<C, R, W, LEN, CMD>> for CmdWriteFactory<C, R, W, F, LEN, CMD, LISTENER> {
254 async fn create(&self, c: Option<CmdNodeTunnelClassification<C>>) -> PoolResult<ClassifiedCmdSend<C, R, W, LEN, CMD>> {
255 self.inner.create(c).await
256 }
257}
258
259pub struct DefaultClassifiedCmdNode<C: WorkerClassification,
260 R: ClassifiedCmdTunnelRead<C>,
261 W: ClassifiedCmdTunnelWrite<C>,
262 F: ClassifiedCmdNodeTunnelFactory<C, R, W>,
263 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync +'static + FromPrimitive + ToPrimitive + RawFixedBytes,
264 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync +'static + RawFixedBytes + Eq + Hash + Debug,
265 LISTENER: CmdTunnelListener<R, W>> {
266 tunnel_pool: ClassifiedWorkerPoolRef<CmdNodeTunnelClassification<C>, ClassifiedCmdSend<C, R, W, LEN, CMD>, CmdWriteFactory<C, R, W, F, LEN, CMD, LISTENER>>,
267 cmd_handler_map: Arc<CmdHandlerMap<LEN, CMD>>,
268}
269
270impl<C: WorkerClassification,
271 R: ClassifiedCmdTunnelRead<C>,
272 W: ClassifiedCmdTunnelWrite<C>,
273 F: ClassifiedCmdNodeTunnelFactory<C, R, W>,
274 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync +'static + FromPrimitive + ToPrimitive + RawFixedBytes,
275 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync +'static + RawFixedBytes + Eq + Hash + Debug,
276 LISTENER: CmdTunnelListener<R, W>> DefaultClassifiedCmdNode<C, R, W, F, LEN, CMD, LISTENER> {
277 pub fn new(listener: LISTENER, factory: F, tunnel_count: u16) -> Arc<Self> {
278 let cmd_handler_map = Arc::new(CmdHandlerMap::new());
279 let handler_map = cmd_handler_map.clone();
280 let write_factory = CmdWriteFactory::<C, R, W, _, LEN, CMD, LISTENER>::new(factory, listener, move |peer_id: PeerId, tunnel_id: TunnelId, header: CmdHeader<LEN, CMD>, body_read| {
281 let handler_map = handler_map.clone();
282 async move {
283 if let Some(handler) = handler_map.get(header.cmd_code()) {
284 handler.handle(peer_id, tunnel_id, header, body_read).await?;
285 }
286 Ok(())
287 }
288 });
289 write_factory.start();
290 Arc::new(Self {
291 tunnel_pool: ClassifiedWorkerPool::new(tunnel_count, write_factory),
292 cmd_handler_map,
293 })
294 }
295
296
297 async fn get_send(&self, peer_id: PeerId) -> CmdResult<ClassifiedWorkerGuard<CmdNodeTunnelClassification<C>, ClassifiedCmdSend<C, R, W, LEN, CMD>, CmdWriteFactory<C, R, W, F, LEN, CMD, LISTENER>>> {
298 self.tunnel_pool.get_classified_worker(CmdNodeTunnelClassification {
299 peer_id: Some(peer_id),
300 tunnel_id: None,
301 classification: None,
302 }).await.map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
303 }
304
305 async fn get_send_of_tunnel_id(&self, peer_id: PeerId, tunnel_id: TunnelId) -> CmdResult<ClassifiedWorkerGuard<CmdNodeTunnelClassification<C>, ClassifiedCmdSend<C, R, W, LEN, CMD>, CmdWriteFactory<C, R, W, F, LEN, CMD, LISTENER>>> {
306 self.tunnel_pool.get_classified_worker(CmdNodeTunnelClassification {
307 peer_id: Some(peer_id),
308 tunnel_id: Some(tunnel_id),
309 classification: None,
310 }).await.map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
311 }
312
313 async fn get_classified_send(&self, classification: C) -> CmdResult<ClassifiedWorkerGuard<CmdNodeTunnelClassification<C>, ClassifiedCmdSend<C, R, W, LEN, CMD>, CmdWriteFactory<C, R, W, F, LEN, CMD, LISTENER>>> {
314 self.tunnel_pool.get_classified_worker(CmdNodeTunnelClassification {
315 peer_id: None,
316 tunnel_id: None,
317 classification: Some(classification),
318 }).await.map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
319 }
320}
321
322#[async_trait::async_trait]
323impl<C: WorkerClassification,
324 R: ClassifiedCmdTunnelRead<C>,
325 W: ClassifiedCmdTunnelWrite<C>,
326 F: ClassifiedCmdNodeTunnelFactory<C, R, W>,
327 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync +'static + FromPrimitive + ToPrimitive + RawFixedBytes,
328 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync +'static + RawFixedBytes + Eq + Hash + Debug,
329 LISTENER: CmdTunnelListener<R, W>> CmdNode<LEN, CMD> for DefaultClassifiedCmdNode<C, R, W, F, LEN, CMD, LISTENER> {
330 fn register_cmd_handler(&self, cmd: CMD, handler: impl CmdHandler<LEN, CMD>) {
331 self.cmd_handler_map.insert(cmd, handler)
332 }
333
334 async fn send(&self, peer_id: &PeerId, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
335 let mut send = self.get_send(peer_id.clone()).await?;
336 send.send(cmd, version, body).await
337 }
338
339 async fn send2(&self, peer_id: &PeerId, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
340 let mut send = self.get_send(peer_id.clone()).await?;
341 send.send2(cmd, version, body).await
342 }
343
344 async fn send_by_specify_tunnel(&self, peer_id: &PeerId, tunnel_id: TunnelId, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
345 let mut send = self.get_send_of_tunnel_id(peer_id.clone(), tunnel_id).await?;
346 send.send(cmd, version, body).await
347 }
348
349 async fn send2_by_specify_tunnel(&self, peer_id: &PeerId, tunnel_id: TunnelId, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
350 let mut send = self.get_send_of_tunnel_id(peer_id.clone(), tunnel_id).await?;
351 send.send2(cmd, version, body).await
352 }
353
354 async fn clear_all_tunnel(&self) {
355 self.tunnel_pool.clear_all_worker().await;
356 }
357}
358
359#[async_trait::async_trait]
360impl<C: WorkerClassification,
361 R: ClassifiedCmdTunnelRead<C>,
362 W: ClassifiedCmdTunnelWrite<C>,
363 F: ClassifiedCmdNodeTunnelFactory<C, R, W>,
364 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync +'static + FromPrimitive + ToPrimitive + RawFixedBytes,
365 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync +'static + RawFixedBytes + Eq + Hash + Debug,
366 LISTENER: CmdTunnelListener<R, W>> ClassifiedCmdNode<LEN, CMD, C> for DefaultClassifiedCmdNode<C, R, W, F, LEN, CMD, LISTENER> {
367 async fn send_by_classified_tunnel(&self, classification: C, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
368 let mut send = self.get_classified_send(classification).await?;
369 send.send(cmd, version, body).await
370 }
371
372 async fn send2_by_classified_tunnel(&self, classification: C, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
373 let mut send = self.get_classified_send(classification).await?;
374 send.send2(cmd, version, body).await
375 }
376
377 async fn find_tunnel_id_by_classified(&self, classification: C) -> CmdResult<TunnelId> {
378 let send = self.get_classified_send(classification).await?;
379 Ok(send.get_tunnel_id())
380 }
381}