1use std::hash::Hash;
2use std::sync::{Arc, Mutex};
3use bucky_raw_codec::{RawConvertTo, RawDecode, RawEncode, RawFixedBytes, RawFrom};
4use num::{FromPrimitive, ToPrimitive};
5use sfo_pool::{into_pool_err, pool_err, ClassifiedWorker, ClassifiedWorkerFactory, ClassifiedWorkerGuard, ClassifiedWorkerPool, ClassifiedWorkerPoolRef, PoolErrorCode, PoolResult, WorkerClassification};
6use tokio::io::{AsyncReadExt, AsyncWriteExt};
7use tokio::spawn;
8use tokio::task::JoinHandle;
9use crate::{CmdTunnelRead, CmdTunnelWrite, TunnelId, TunnelIdGenerator};
10use crate::client::{ClassifiedCmdClient, CmdClient};
11use crate::cmd::{CmdBodyReadImpl, CmdHandler, CmdHandlerMap, CmdHeader};
12use crate::errors::{into_cmd_err, CmdErrorCode, CmdResult};
13use crate::peer_id::PeerId;
14use std::fmt::Debug;
15use sfo_split::{RHalf, Splittable, WHalf};
16
17pub trait ClassifiedCmdTunnelRead<C: WorkerClassification>: CmdTunnelRead + 'static + Send {
18 fn get_classification(&self) -> C;
19}
20
21pub trait ClassifiedCmdTunnelWrite<C: WorkerClassification>: CmdTunnelWrite + 'static + Send {
22 fn get_classification(&self) -> C;
23}
24
25pub type ClassifiedCmdTunnel<R, W> = Splittable<R, W>;
26pub type ClassifiedCmdTunnelRHalf<R, W> = RHalf<R, W>;
27pub type ClassifiedCmdTunnelWHalf<R, W> = WHalf<R, W>;
28
29#[derive(Debug, Clone, Copy, Eq, Hash)]
30pub struct CmdClientTunnelClassification<C: WorkerClassification> {
31 tunnel_id: Option<TunnelId>,
32 classification: Option<C>,
33}
34
35impl<C: WorkerClassification> PartialEq for CmdClientTunnelClassification<C> {
36 fn eq(&self, other: &Self) -> bool {
37 self.tunnel_id == other.tunnel_id && self.classification == other.classification
38 }
39}
40
41
42#[async_trait::async_trait]
43pub trait ClassifiedCmdTunnelFactory<C: WorkerClassification, R: ClassifiedCmdTunnelRead<C>, W: ClassifiedCmdTunnelWrite<C>>: Send + Sync + 'static {
44 async fn create_tunnel(&self, classification: Option<C>) -> CmdResult<Splittable<R, W>>;
45}
46
47pub struct ClassifiedCmdSend<C, R, W, LEN, CMD>
48where
49 C: WorkerClassification,
50 R: ClassifiedCmdTunnelRead<C>,
51 W: ClassifiedCmdTunnelWrite<C>,
52 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
53 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug,
54{
55 pub(crate) recv_handle: JoinHandle<CmdResult<()>>,
56 pub(crate) write: ClassifiedCmdTunnelWHalf<R, W>,
57 pub(crate) is_work: bool,
58 pub(crate) classification: C,
59 pub(crate) tunnel_id: TunnelId,
60 _p: std::marker::PhantomData<(LEN, CMD)>,
61
62}
63
64impl<C, R, W, LEN, CMD> ClassifiedCmdSend<C, R, W, LEN, CMD>
65where C: WorkerClassification,
66 R: ClassifiedCmdTunnelRead<C>,
67 W: ClassifiedCmdTunnelWrite<C>,
68 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
69 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug {
70 pub(crate) fn new(tunnel_id: TunnelId, classification: C, recv_handle: JoinHandle<CmdResult<()>>, write: ClassifiedCmdTunnelWHalf<R, W>) -> Self {
71 Self {
72 recv_handle,
73 write,
74 is_work: true,
75 classification,
76 tunnel_id,
77 _p: Default::default(),
78 }
79 }
80
81 pub fn get_tunnel_id(&self) -> TunnelId {
82 self.tunnel_id
83 }
84
85 pub fn set_disable(&mut self) {
86 self.is_work = false;
87 self.recv_handle.abort();
88 }
89
90 pub async fn send(&mut self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
91 log::trace!("client {:?} send cmd: {:?}, len: {}, data: {}", self.tunnel_id, cmd, body.len(), hex::encode(body));
92 let header = CmdHeader::<LEN, CMD>::new(version, cmd, LEN::from_u64(body.len() as u64).unwrap());
93 let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
94 let ret = self.send_inner(buf.as_slice(), body).await;
95 if let Err(e) = ret {
96 self.set_disable();
97 return Err(e);
98 }
99 Ok(())
100 }
101
102 pub async fn send2(&mut self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
103 let mut len = 0;
104 for b in body.iter() {
105 len += b.len();
106 log::trace!("client {:?} send2 cmd {:?} body: {}", self.tunnel_id, cmd, hex::encode(b));
107 }
108 log::trace!("client {:?} send2 cmd: {:?}, len {}", self.tunnel_id, cmd, len);
109 let header = CmdHeader::<LEN, CMD>::new(version, cmd, LEN::from_u64(len as u64).unwrap());
110 let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
111 let ret = self.send_inner2(buf.as_slice(), body).await;
112 if let Err(e) = ret {
113 self.set_disable();
114 return Err(e);
115 }
116 Ok(())
117 }
118
119 async fn send_inner(&mut self, header: &[u8], body: &[u8]) -> CmdResult<()> {
120 self.write.write_all(header).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
121 self.write.write_all(body).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
122 self.write.flush().await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
123 Ok(())
124 }
125
126 async fn send_inner2(&mut self, header: &[u8], body: &[&[u8]]) -> CmdResult<()> {
127 self.write.write_all(header).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
128 for b in body.iter() {
129 self.write.write_all(b).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
130 }
131 self.write.flush().await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
132 Ok(())
133 }
134}
135
136impl<C, R, W, LEN, CMD> Drop for ClassifiedCmdSend<C, R, W, LEN, CMD>
137where C: WorkerClassification,
138 R: ClassifiedCmdTunnelRead<C>,
139 W: ClassifiedCmdTunnelWrite<C>,
140 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
141 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug {
142 fn drop(&mut self) {
143 self.set_disable();
144 }
145}
146
147impl<C, R, W, LEN, CMD> ClassifiedWorker<CmdClientTunnelClassification<C>> for ClassifiedCmdSend<C, R, W, LEN, CMD>
148where C: WorkerClassification,
149 R: ClassifiedCmdTunnelRead<C>,
150 W: ClassifiedCmdTunnelWrite<C>,
151 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
152 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug {
153 fn is_work(&self) -> bool {
154 self.is_work && !self.recv_handle.is_finished()
155 }
156
157 fn is_valid(&self, c: CmdClientTunnelClassification<C>) -> bool {
158 if c.tunnel_id.is_some() {
159 self.tunnel_id == c.tunnel_id.unwrap()
160 } else {
161 if c.classification.is_some() {
162 self.classification == c.classification.unwrap()
163 } else {
164 true
165 }
166 }
167 }
168
169 fn classification(&self) -> CmdClientTunnelClassification<C> {
170 CmdClientTunnelClassification {
171 tunnel_id: Some(self.tunnel_id),
172 classification: Some(self.classification.clone()),
173 }
174 }
175}
176
177struct CmdWriteFactory<C: WorkerClassification,
178 R: ClassifiedCmdTunnelRead<C>,
179 W: ClassifiedCmdTunnelWrite<C>,
180 F: ClassifiedCmdTunnelFactory<C, R, W>,
181 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
182 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug> {
183 tunnel_factory: F,
184 cmd_handler: Arc<dyn CmdHandler<LEN, CMD>>,
185 tunnel_id_generator: TunnelIdGenerator,
186 _p: std::marker::PhantomData<Mutex<(C, R, W)>>,
187}
188
189impl<
190 C: WorkerClassification,
191 R: ClassifiedCmdTunnelRead<C>,
192 W: ClassifiedCmdTunnelWrite<C>,
193 F: ClassifiedCmdTunnelFactory<C, R, W>,
194 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
195 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug
196> CmdWriteFactory<C, R, W, F, LEN, CMD> {
197 pub fn new(tunnel_factory: F, cmd_handler: impl CmdHandler<LEN, CMD>) -> Self {
198 Self {
199 tunnel_factory,
200 cmd_handler: Arc::new(cmd_handler),
201 tunnel_id_generator: TunnelIdGenerator::new(),
202 _p: Default::default(),
203 }
204 }
205}
206
207#[async_trait::async_trait]
208impl<C: WorkerClassification,
209 R: ClassifiedCmdTunnelRead<C>,
210 W: ClassifiedCmdTunnelWrite<C>,
211 F: ClassifiedCmdTunnelFactory<C, R, W>,
212 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
213 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Debug> ClassifiedWorkerFactory<CmdClientTunnelClassification<C>, ClassifiedCmdSend<C, R, W, LEN, CMD>
214> for CmdWriteFactory<C, R, W, F, LEN, CMD> {
215 async fn create(&self, classification: Option<CmdClientTunnelClassification<C>>) -> PoolResult<ClassifiedCmdSend<C, R, W, LEN, CMD>> {
216 if classification.is_some() && classification.as_ref().unwrap().tunnel_id.is_some() {
217 return Err(pool_err!(PoolErrorCode::Failed, "tunnel {:?} not found", classification.as_ref().unwrap().tunnel_id.unwrap()));
218 }
219
220 let classification = if classification.is_some() && classification.as_ref().unwrap().classification.is_some() {
221 classification.unwrap().classification
222 } else {
223 None
224 };
225 let tunnel = self.tunnel_factory.create_tunnel(classification).await.map_err(into_pool_err!(PoolErrorCode::Failed))?;
226 let classification = tunnel.get_classification();
227 let peer_id = tunnel.get_remote_peer_id();
228 let tunnel_id = self.tunnel_id_generator.generate();
229 let (mut recv, write) = tunnel.split();
230 let cmd_handler = self.cmd_handler.clone();
231 let handle = spawn(async move {
232 let ret: CmdResult<()> = async move {
233 loop {
234 let mut header = vec![0u8; CmdHeader::<LEN, CMD>::raw_bytes().unwrap()];
235 let n = recv.read_exact(header.as_mut()).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
236 if n == 0 {
237 break;
238 }
239 let header = CmdHeader::<LEN, CMD>::clone_from_slice(header.as_slice()).map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
240 log::trace!("recv cmd {:?} from {} len {} tunnel {:?}", header.cmd_code(), peer_id, header.pkg_len().to_u64().unwrap(), tunnel_id);
241 let cmd_read = Box::new(CmdBodyReadImpl::new(recv, header.pkg_len().to_u64().unwrap() as usize));
242 let waiter = cmd_read.get_waiter();
243 let future = waiter.create_result_future();
244 if let Err(e) = cmd_handler.handle(peer_id.clone(), tunnel_id, header, cmd_read).await {
245 log::error!("handle cmd error: {:?}", e);
246 }
247 recv = future.await.map_err(into_cmd_err!(CmdErrorCode::Failed))??;
248 }
249 Ok(())
250 }.await;
251 ret
252 });
253 Ok(ClassifiedCmdSend::new(tunnel_id, classification, handle, write))
254 }
255}
256
257pub struct DefaultClassifiedCmdClient<C: WorkerClassification,
258 R: ClassifiedCmdTunnelRead<C>,
259 W: ClassifiedCmdTunnelWrite<C>,
260 F: ClassifiedCmdTunnelFactory<C, R, W>,
261 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
262 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Eq + Hash + Debug> {
263 tunnel_pool: ClassifiedWorkerPoolRef<CmdClientTunnelClassification<C>, ClassifiedCmdSend<C, R, W, LEN, CMD>, CmdWriteFactory<C, R, W, F, LEN, CMD>>,
264 cmd_handler_map: Arc<CmdHandlerMap<LEN, CMD>>,
265}
266
267impl<C: WorkerClassification,
268 R: ClassifiedCmdTunnelRead<C>,
269 W: ClassifiedCmdTunnelWrite<C>,
270 F: ClassifiedCmdTunnelFactory<C, R, W>,
271 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
272 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Eq + Hash + Debug> DefaultClassifiedCmdClient<C, R, W, F, LEN, CMD> {
273 pub fn new(factory: F, tunnel_count: u16) -> Arc<Self> {
274 let cmd_handler_map = Arc::new(CmdHandlerMap::new());
275 let handler_map = cmd_handler_map.clone();
276 Arc::new(Self {
277 tunnel_pool: ClassifiedWorkerPool::new(tunnel_count, CmdWriteFactory::<C, R, W, _, LEN, CMD>::new(factory, move |peer_id: PeerId, tunnel_id, header: CmdHeader<LEN, CMD>, body_read| {
278 let handler_map = handler_map.clone();
279 async move {
280 if let Some(handler) = handler_map.get(header.cmd_code()) {
281 handler.handle(peer_id, tunnel_id, header, body_read).await?;
282 }
283 Ok(())
284 }
285 })),
286 cmd_handler_map,
287 })
288 }
289
290 async fn get_send(&self) -> CmdResult<ClassifiedWorkerGuard<CmdClientTunnelClassification<C>, ClassifiedCmdSend<C, R, W, LEN, CMD>, CmdWriteFactory<C, R, W, F, LEN, CMD>>> {
291 self.tunnel_pool.get_worker().await.map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
292 }
293
294 async fn get_send_of_tunnel_id(&self, tunnel_id: TunnelId) -> CmdResult<ClassifiedWorkerGuard<CmdClientTunnelClassification<C>, ClassifiedCmdSend<C, R, W, LEN, CMD>, CmdWriteFactory<C, R, W, F, LEN, CMD>>> {
295 self.tunnel_pool.get_classified_worker(CmdClientTunnelClassification {
296 tunnel_id: Some(tunnel_id),
297 classification: None,
298 }).await.map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
299 }
300
301 async fn get_classified_send(&self, classification: C) -> CmdResult<ClassifiedWorkerGuard<CmdClientTunnelClassification<C>, ClassifiedCmdSend<C, R, W, LEN, CMD>, CmdWriteFactory<C, R, W, F, LEN, CMD>>> {
302 self.tunnel_pool.get_classified_worker(CmdClientTunnelClassification {
303 tunnel_id: None,
304 classification: Some(classification),
305 }).await.map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
306 }
307}
308
309#[async_trait::async_trait]
310impl<C: WorkerClassification,
311 R: ClassifiedCmdTunnelRead<C>,
312 W: ClassifiedCmdTunnelWrite<C>,
313 F: ClassifiedCmdTunnelFactory<C, R, W>,
314 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
315 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Eq + Hash + Debug> CmdClient<LEN, CMD> for DefaultClassifiedCmdClient<C, R, W, F, LEN, CMD> {
316 fn register_cmd_handler(&self, cmd: CMD, handler: impl CmdHandler<LEN, CMD>) {
317 self.cmd_handler_map.insert(cmd, handler);
318 }
319
320 async fn send(&self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
321 let mut send = self.get_send().await?;
322 send.send(cmd, version, body).await
323 }
324
325 async fn send2(&self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
326 let mut send = self.get_send().await?;
327 send.send2(cmd, version, body).await
328 }
329
330 async fn send_by_specify_tunnel(&self, tunnel_id: TunnelId, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
331 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
332 send.send(cmd, version, body).await
333 }
334
335 async fn send2_by_specify_tunnel(&self, tunnel_id: TunnelId, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
336 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
337 send.send2(cmd, version, body).await
338 }
339
340 async fn clear_all_tunnel(&self) {
341 self.tunnel_pool.clear_all_worker().await;
342 }
343}
344
345#[async_trait::async_trait]
346impl<C: WorkerClassification,
347 R: ClassifiedCmdTunnelRead<C>,
348 W: ClassifiedCmdTunnelWrite<C>,
349 F: ClassifiedCmdTunnelFactory<C, R, W>,
350 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
351 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Eq + Hash + Debug> ClassifiedCmdClient<LEN, CMD, C> for DefaultClassifiedCmdClient<C, R, W, F, LEN, CMD> {
352 async fn send_by_classified_tunnel(&self, classification: C, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
353 let mut send = self.get_classified_send(classification).await?;
354 send.send(cmd, version, body).await
355 }
356
357 async fn send2_by_classified_tunnel(&self, classification: C, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
358 let mut send = self.get_classified_send(classification).await?;
359 send.send2(cmd, version, body).await
360 }
361
362 async fn find_tunnel_id_by_classified(&self, classification: C) -> CmdResult<TunnelId> {
363 let send = self.get_classified_send(classification).await?;
364 Ok(send.get_tunnel_id())
365 }
366}