1use std::collections::HashMap;
2use std::fmt::Debug;
3use std::hash::Hash;
4use std::sync::{Arc, Mutex};
5use std::time::Duration;
6use async_named_locker::ObjectHolder;
7use bucky_raw_codec::{RawDecode, RawEncode, RawFixedBytes};
8use num::{FromPrimitive, ToPrimitive};
9use sfo_pool::{ClassifiedWorker, ClassifiedWorkerFactory, ClassifiedWorkerGuard, ClassifiedWorkerPool, ClassifiedWorkerPoolRef, PoolErrorCode, PoolResult};
10use sfo_split::{Splittable};
11use crate::{into_pool_err, pool_err, CmdBody, CmdHandler, CmdHeader, CmdNode, CmdTunnelMeta, CmdTunnelRead, CmdTunnelWrite, PeerId, TunnelId, TunnelIdGenerator};
12use crate::client::{gen_resp_id, ClassifiedSendGuard, CommonCmdSend, RespWaiter, RespWaiterRef};
13use crate::cmd::{CmdHandlerMap};
14use crate::errors::{into_cmd_err, CmdErrorCode, CmdResult};
15use crate::node::create_recv_handle;
16use crate::server::{CmdTunnelListener};
17
18#[async_trait::async_trait]
19pub trait CmdNodeTunnelFactory<M: CmdTunnelMeta, R: CmdTunnelRead<M>, W: CmdTunnelWrite<M>>: Send + Sync + 'static {
20 async fn create_tunnel(&self, remote_id: &PeerId) -> CmdResult<Splittable<R, W>>;
21}
22
23
24impl<M, R, W, LEN, CMD> ClassifiedWorker<(PeerId, Option<TunnelId>)> for CommonCmdSend<M, R, W, LEN, CMD>
25where M: CmdTunnelMeta,
26 R: CmdTunnelRead<M>,
27 W: CmdTunnelWrite<M>,
28 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
29 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes {
30 fn is_work(&self) -> bool {
31 self.is_work && !self.recv_handle.is_finished()
32 }
33
34 fn is_valid(&self, c: (PeerId, Option<TunnelId>)) -> bool {
35 let (peer_id, tunnel_id) = c;
36 if tunnel_id.is_some() {
37 self.tunnel_id == tunnel_id.unwrap() && peer_id == self.remote_id
38 } else {
39 peer_id == self.remote_id
40 }
41 }
42
43 fn classification(&self) -> (PeerId, Option<TunnelId>) {
44 (self.remote_id.clone(), Some(self.tunnel_id))
45 }
46}
47
48struct CmdWriteFactoryImpl<M: CmdTunnelMeta,
49 R: CmdTunnelRead<M>,
50 W: CmdTunnelWrite<M>,
51 F: CmdNodeTunnelFactory<M, R, W>,
52 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
53 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
54 LISTENER: CmdTunnelListener<M, R, W>> {
55 tunnel_listener: LISTENER,
56 tunnel_factory: F,
57 cmd_handler: Arc<dyn CmdHandler<LEN, CMD>>,
58 tunnel_id_generator: TunnelIdGenerator,
59 resp_waiter: RespWaiterRef,
60 send_cache: Arc<Mutex<HashMap<PeerId, Vec<CommonCmdSend<M, R, W, LEN, CMD>>>>>,
61}
62
63
64impl<M: CmdTunnelMeta,
65 R: CmdTunnelRead<M>,
66 W: CmdTunnelWrite<M>,
67 F: CmdNodeTunnelFactory<M, R, W>,
68 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
69 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
70 LISTENER: CmdTunnelListener<M, R, W>> CmdWriteFactoryImpl<M, R, W, F, LEN, CMD, LISTENER> {
71 pub fn new(tunnel_factory: F,
72 tunnel_listener: LISTENER,
73 cmd_handler: impl CmdHandler<LEN, CMD>,
74 resp_waiter: RespWaiterRef) -> Self {
75 Self {
76 tunnel_listener,
77 tunnel_factory,
78 cmd_handler: Arc::new(cmd_handler),
79 tunnel_id_generator: TunnelIdGenerator::new(),
80 resp_waiter,
81 send_cache: Arc::new(Mutex::new(Default::default())),
82 }
83 }
84
85
86 pub fn start(self: &Arc<Self>) {
87 let this = self.clone();
88 tokio::spawn(async move {
89 if let Err(e) = this.run().await {
90 log::error!("cmd server error: {:?}", e);
91 }
92 });
93 }
94
95 async fn run(self: &Arc<Self>) -> CmdResult<()> {
96 loop {
97 let tunnel = self.tunnel_listener.accept().await?;
98 let peer_id = tunnel.get_remote_peer_id();
99 let tunnel_id = self.tunnel_id_generator.generate();
100 let resp_waiter = self.resp_waiter.clone();
101 let this = self.clone();
102 tokio::spawn(async move {
103 let ret: CmdResult<()> = async move {
104 let this = this.clone();
105 let cmd_handler = this.cmd_handler.clone();
106 let (reader, writer) = tunnel.split();
107 let remote_id = reader.get_remote_peer_id();
108 let tunnel_meta = reader.get_tunnel_meta();
109 let writer = ObjectHolder::new(writer);
110 let recv_handle = create_recv_handle::<M, R, W, LEN, CMD>(reader, writer.clone(), tunnel_id, cmd_handler);
111 {
112 let mut send_cache = this.send_cache.lock().unwrap();
113 let send_list = send_cache.entry(peer_id).or_insert(Vec::new());
114 send_list.push(CommonCmdSend::new(tunnel_id, recv_handle, writer, resp_waiter, remote_id, tunnel_meta));
115 }
116 Ok(())
117 }.await;
118 if let Err(e) = ret {
119 log::error!("peer connection error: {:?}", e);
120 }
121 });
122 }
123 }
124}
125
126#[async_trait::async_trait]
127impl<M: CmdTunnelMeta,
128 R: CmdTunnelRead<M>,
129 W: CmdTunnelWrite<M>,
130 F: CmdNodeTunnelFactory<M, R, W>,
131 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
132 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Debug,
133 LISTENER: CmdTunnelListener<M, R, W>> ClassifiedWorkerFactory<(PeerId, Option<TunnelId>), CommonCmdSend<M, R, W, LEN, CMD>> for CmdWriteFactoryImpl<M, R, W, F, LEN, CMD, LISTENER> {
134 async fn create(&self, c: Option<(PeerId, Option<TunnelId>)>) -> PoolResult<CommonCmdSend<M, R, W, LEN, CMD>> {
135 if c.is_some() {
136 let (peer_id, tunnel_id) = c.unwrap();
137 if tunnel_id.is_some() {
138 let mut send_cache = self.send_cache.lock().unwrap();
139 if let Some(send_list) = send_cache.get_mut(&peer_id) {
140 let mut send_index = None;
141 for (index, send) in send_list.iter().enumerate() {
142 if send.get_tunnel_id() == tunnel_id.unwrap() {
143 send_index = Some(index);
144 break;
145 }
146 }
147 if let Some(send_index) = send_index {
148 let send = send_list.remove(send_index);
149 Ok(send)
150 } else {
151 Err(pool_err!(PoolErrorCode::Failed, "tunnel {:?} not found", tunnel_id.unwrap()))
152 }
153 } else {
154 Err(pool_err!(PoolErrorCode::Failed, "tunnel {:?} not found", tunnel_id.unwrap()))
155 }
156 } else {
157 {
158 let mut send_cache = self.send_cache.lock().unwrap();
159 if let Some(send_list) = send_cache.get_mut(&peer_id) {
160 if !send_list.is_empty() {
161 let send = send_list.pop().unwrap();
162 if send_list.is_empty() {
163 send_cache.remove(&peer_id);
164 }
165 return Ok(send);
166 }
167 }
168 }
169 let tunnel = self.tunnel_factory.create_tunnel(&peer_id).await.map_err(into_pool_err!(PoolErrorCode::Failed))?;
170 let tunnel_id = self.tunnel_id_generator.generate();
171 let (recv, write) = tunnel.split();
172 let remote_id = recv.get_remote_peer_id();
173 let tunnel_meta = recv.get_tunnel_meta();
174 let write = ObjectHolder::new(write);
175 let cmd_handler = self.cmd_handler.clone();
176 let handle = create_recv_handle::<M, R, W, LEN, CMD>(recv, write.clone(), tunnel_id, cmd_handler);
177 Ok(CommonCmdSend::new(tunnel_id, handle, write, self.resp_waiter.clone(), remote_id, tunnel_meta))
178 }
179 } else {
180 Err(pool_err!(PoolErrorCode::Failed, "peer id is none"))
181 }
182 }
183}
184
185pub struct CmdNodeWriteFactory<M: CmdTunnelMeta,
186 R: CmdTunnelRead<M>,
187 W: CmdTunnelWrite<M>,
188 F: CmdNodeTunnelFactory<M, R, W>,
189 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
190 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
191 LISTENER: CmdTunnelListener<M, R, W>> {
192 inner: Arc<CmdWriteFactoryImpl<M, R, W, F, LEN, CMD, LISTENER>>
193}
194
195
196impl<M: CmdTunnelMeta,
197 R: CmdTunnelRead<M>,
198 W: CmdTunnelWrite<M>,
199 F: CmdNodeTunnelFactory<M, R, W>,
200 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
201 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
202 LISTENER: CmdTunnelListener<M, R, W>> CmdNodeWriteFactory<M, R, W, F, LEN, CMD, LISTENER> {
203 pub(crate) fn new(tunnel_factory: F,
204 tunnel_listener: LISTENER,
205 cmd_handler: impl CmdHandler<LEN, CMD>,
206 resp_waiter: RespWaiterRef) -> Self {
207 Self {
208 inner: Arc::new(CmdWriteFactoryImpl::new(tunnel_factory, tunnel_listener, cmd_handler, resp_waiter)),
209 }
210 }
211
212 pub fn start(&self) {
213 self.inner.start();
214 }
215}
216
217#[async_trait::async_trait]
218impl<M: CmdTunnelMeta,
219 R: CmdTunnelRead<M>,
220 W: CmdTunnelWrite<M>,
221 F: CmdNodeTunnelFactory<M, R, W>,
222 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
223 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Debug,
224 LISTENER: CmdTunnelListener<M, R, W>> ClassifiedWorkerFactory<(PeerId, Option<TunnelId>), CommonCmdSend<M, R, W, LEN, CMD>> for CmdNodeWriteFactory<M, R, W, F, LEN, CMD, LISTENER> {
225 async fn create(&self, c: Option<(PeerId, Option<TunnelId>)>) -> PoolResult<CommonCmdSend<M, R, W, LEN, CMD>> {
226 self.inner.create(c).await
227 }
228}
229pub struct DefaultCmdNode<M: CmdTunnelMeta,
230 R: CmdTunnelRead<M>,
231 W: CmdTunnelWrite<M>,
232 F: CmdNodeTunnelFactory<M, R, W>,
233 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
234 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Eq + Hash + Debug,
235 LISTENER: CmdTunnelListener<M, R, W>> {
236 tunnel_pool: ClassifiedWorkerPoolRef<(PeerId, Option<TunnelId>), CommonCmdSend<M, R, W, LEN, CMD>, CmdNodeWriteFactory<M, R, W, F, LEN, CMD, LISTENER>>,
237 cmd_handler_map: Arc<CmdHandlerMap<LEN, CMD>>,
238}
239
240impl<M: CmdTunnelMeta,
241 R: CmdTunnelRead<M>,
242 W: CmdTunnelWrite<M>,
243 F: CmdNodeTunnelFactory<M, R, W>,
244 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync +'static + FromPrimitive + ToPrimitive + RawFixedBytes,
245 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync +'static + RawFixedBytes + Eq + Hash + Debug,
246 LISTENER: CmdTunnelListener<M, R, W>> DefaultCmdNode<M, R, W, F, LEN, CMD, LISTENER> {
247 pub fn new(listener: LISTENER, factory: F, tunnel_count: u16) -> Arc<Self> {
248 let cmd_handler_map = Arc::new(CmdHandlerMap::new());
249 let handler_map = cmd_handler_map.clone();
250 let resp_waiter = Arc::new(RespWaiter::new());
251 let waiter = resp_waiter.clone();
252 let write_factory = CmdNodeWriteFactory::<M, R, W, _, LEN, CMD, LISTENER>::new(factory, listener, move |peer_id: PeerId, tunnel_id: TunnelId, header: CmdHeader<LEN, CMD>, body_read: CmdBody| {
253 let handler_map = handler_map.clone();
254 let waiter = waiter.clone();
255 async move {
256 if header.is_resp() && header.seq().is_some() {
257 let resp_id = gen_resp_id(header.cmd_code(), header.seq().unwrap());
258 let _ = waiter.set_result(resp_id, body_read);
259 Ok(None)
260 } else {
261 if let Some(handler) = handler_map.get(header.cmd_code()) {
262 handler.handle(peer_id, tunnel_id, header, body_read).await
263 } else {
264 Ok(None)
265 }
266 }
267 }
268 }, resp_waiter.clone());
269 write_factory.start();
270 Arc::new(Self {
271 tunnel_pool: ClassifiedWorkerPool::new(tunnel_count, write_factory),
272 cmd_handler_map,
273 })
274 }
275
276 async fn get_send(&self, peer_id: PeerId) -> CmdResult<ClassifiedWorkerGuard<(PeerId, Option<TunnelId>), CommonCmdSend<M, R, W, LEN, CMD>, CmdNodeWriteFactory<M, R, W, F, LEN, CMD, LISTENER>>> {
277 self.tunnel_pool.get_classified_worker((peer_id, None)).await.map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
278 }
279
280 async fn get_send_of_tunnel_id(&self, peer_id: PeerId, tunnel_id: TunnelId) -> CmdResult<ClassifiedWorkerGuard<(PeerId, Option<TunnelId>), CommonCmdSend<M, R, W, LEN, CMD>, CmdNodeWriteFactory<M, R, W, F, LEN, CMD, LISTENER>>> {
281 self.tunnel_pool.get_classified_worker((peer_id, Some(tunnel_id))).await.map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
282 }
283
284}
285
286pub type CmdNodeSendGuard<M, R, W, F, LEN, CMD, LISTENER> = ClassifiedSendGuard<(PeerId, Option<TunnelId>), M, CommonCmdSend<M, R, W, LEN, CMD>, CmdNodeWriteFactory<M, R, W, F, LEN, CMD, LISTENER>>;
287#[async_trait::async_trait]
288impl<M: CmdTunnelMeta,
289 R: CmdTunnelRead<M>,
290 W: CmdTunnelWrite<M>,
291 F: CmdNodeTunnelFactory<M, R, W>,
292 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + RawFixedBytes + Sync + Send + 'static + FromPrimitive + ToPrimitive,
293 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + RawFixedBytes + Sync + Send + 'static + Eq + Hash + Debug,
294 LISTENER: CmdTunnelListener<M, R, W>> CmdNode<LEN, CMD, M, CommonCmdSend<M, R, W, LEN, CMD>, CmdNodeSendGuard<M, R, W, F, LEN, CMD, LISTENER>> for DefaultCmdNode<M, R, W, F, LEN, CMD, LISTENER> {
295 fn register_cmd_handler(&self, cmd: CMD, handler: impl CmdHandler<LEN, CMD>) {
296 self.cmd_handler_map.insert(cmd, handler);
297 }
298
299 async fn send(&self, peer_id: &PeerId, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
300 let mut send = self.get_send(peer_id.clone()).await?;
301 send.send(cmd, version, body).await
302 }
303
304 async fn send_with_resp(&self, peer_id: &PeerId, cmd: CMD, version: u8, body: &[u8], timeout: Duration) -> CmdResult<CmdBody> {
305 let mut send = self.get_send(peer_id.clone()).await?;
306 send.send_with_resp(cmd, version, body, timeout).await
307 }
308
309 async fn send2(&self, peer_id: &PeerId, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
310 let mut send = self.get_send(peer_id.clone()).await?;
311 send.send2(cmd, version, body).await
312 }
313
314 async fn send2_with_resp(&self, peer_id: &PeerId, cmd: CMD, version: u8, body: &[&[u8]], timeout: Duration) -> CmdResult<CmdBody> {
315 let mut send = self.get_send(peer_id.clone()).await?;
316 send.send2_with_resp(cmd, version, body, timeout).await
317 }
318
319 async fn send_cmd(&self, peer_id: &PeerId, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
320 let mut send = self.get_send(peer_id.clone()).await?;
321 send.send_cmd(cmd, version, body).await
322 }
323
324 async fn send_cmd_with_resp(&self, peer_id: &PeerId, cmd: CMD, version: u8, body: CmdBody, timeout: Duration) -> CmdResult<CmdBody> {
325 let mut send = self.get_send(peer_id.clone()).await?;
326 send.send_cmd_with_resp(cmd, version, body, timeout).await
327 }
328
329 async fn send_by_specify_tunnel(&self, peer_id: &PeerId, tunnel_id: TunnelId, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
330 let mut send = self.get_send_of_tunnel_id(peer_id.clone(), tunnel_id).await?;
331 send.send(cmd, version, body).await
332 }
333
334 async fn send_by_specify_tunnel_with_resp(&self, peer_id: &PeerId, tunnel_id: TunnelId, cmd: CMD, version: u8, body: &[u8], timeout: Duration) -> CmdResult<CmdBody> {
335 let mut send = self.get_send_of_tunnel_id(peer_id.clone(), tunnel_id).await?;
336 send.send_with_resp(cmd, version, body, timeout).await
337 }
338
339 async fn send2_by_specify_tunnel(&self, peer_id: &PeerId, tunnel_id: TunnelId, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
340 let mut send = self.get_send_of_tunnel_id(peer_id.clone(), tunnel_id).await?;
341 send.send2(cmd, version, body).await
342 }
343
344 async fn send2_by_specify_tunnel_with_resp(&self, peer_id: &PeerId, tunnel_id: TunnelId, cmd: CMD, version: u8, body: &[&[u8]], timeout: Duration) -> CmdResult<CmdBody> {
345 let mut send = self.get_send_of_tunnel_id(peer_id.clone(), tunnel_id).await?;
346 send.send2_with_resp(cmd, version, body, timeout).await
347 }
348
349 async fn send_cmd_by_specify_tunnel(&self, peer_id: &PeerId, tunnel_id: TunnelId, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
350 let mut send = self.get_send_of_tunnel_id(peer_id.clone(), tunnel_id).await?;
351 send.send_cmd(cmd, version, body).await
352 }
353
354 async fn send_cmd_by_specify_tunnel_with_resp(&self, peer_id: &PeerId, tunnel_id: TunnelId, cmd: CMD, version: u8, body: CmdBody, timeout: Duration) -> CmdResult<CmdBody> {
355 let mut send = self.get_send_of_tunnel_id(peer_id.clone(), tunnel_id).await?;
356 send.send_cmd_with_resp(cmd, version, body, timeout).await
357 }
358
359 async fn clear_all_tunnel(&self) {
360 self.tunnel_pool.clear_all_worker().await
361 }
362
363 async fn get_send(&self, peer_id: &PeerId, tunnel_id: TunnelId) -> CmdResult<CmdNodeSendGuard<M, R, W, F, LEN, CMD, LISTENER>> {
364 Ok(ClassifiedSendGuard {
365 worker_guard: self.get_send_of_tunnel_id(peer_id.clone(), tunnel_id).await?,
366 _p: Default::default(),
367 })
368 }
369}