1use std::hash::Hash;
2use std::sync::{Arc, Mutex};
3use bucky_raw_codec::{RawConvertTo, RawDecode, RawEncode, RawFixedBytes, RawFrom};
4use num::{FromPrimitive, ToPrimitive};
5use sfo_pool::{into_pool_err, pool_err, ClassifiedWorker, ClassifiedWorkerFactory, ClassifiedWorkerGuard, ClassifiedWorkerPool, ClassifiedWorkerPoolRef, PoolErrorCode, PoolResult, WorkerClassification};
6use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
7use tokio::spawn;
8use tokio::task::JoinHandle;
9use crate::{CmdBody, CmdTunnelMeta, CmdTunnelRead, CmdTunnelWrite, TunnelId, TunnelIdGenerator};
10use crate::client::{gen_resp_id, gen_seq, ClassifiedCmdClient, ClassifiedSendGuard, CmdClient, CmdSend, RespWaiter, RespWaiterRef};
11use crate::cmd::{CmdBodyRead, CmdHandler, CmdHandlerMap, CmdHeader};
12use crate::errors::{cmd_err, into_cmd_err, CmdErrorCode, CmdResult};
13use crate::peer_id::PeerId;
14use std::fmt::Debug;
15use std::ops::{DerefMut};
16use std::time::Duration;
17use async_named_locker::ObjectHolder;
18use sfo_split::{RHalf, Splittable, WHalf};
19
20pub trait ClassifiedCmdTunnelRead<C: WorkerClassification, M: CmdTunnelMeta>: CmdTunnelRead<M> + 'static + Send {
21 fn get_classification(&self) -> C;
22}
23
24pub trait ClassifiedCmdTunnelWrite<C: WorkerClassification, M: CmdTunnelMeta>: CmdTunnelWrite<M> + 'static + Send {
25 fn get_classification(&self) -> C;
26}
27
28pub type ClassifiedCmdTunnel<R, W> = Splittable<R, W>;
29pub type ClassifiedCmdTunnelRHalf<R, W> = RHalf<R, W>;
30pub type ClassifiedCmdTunnelWHalf<R, W> = WHalf<R, W>;
31
32#[derive(Debug, Clone, Copy, Eq, Hash)]
33pub struct CmdClientTunnelClassification<C: WorkerClassification> {
34 tunnel_id: Option<TunnelId>,
35 classification: Option<C>,
36}
37
38impl<C: WorkerClassification> PartialEq for CmdClientTunnelClassification<C> {
39 fn eq(&self, other: &Self) -> bool {
40 self.tunnel_id == other.tunnel_id && self.classification == other.classification
41 }
42}
43
44
45#[async_trait::async_trait]
46pub trait ClassifiedCmdTunnelFactory<C: WorkerClassification, M: CmdTunnelMeta, R: ClassifiedCmdTunnelRead<C, M>, W: ClassifiedCmdTunnelWrite<C, M>>: Send + Sync + 'static {
47 async fn create_tunnel(&self, classification: Option<C>) -> CmdResult<Splittable<R, W>>;
48}
49
50pub struct ClassifiedCmdSend<C, M, R, W, LEN, CMD>
51where
52 C: WorkerClassification,
53 M: CmdTunnelMeta,
54 R: ClassifiedCmdTunnelRead<C, M>,
55 W: ClassifiedCmdTunnelWrite<C, M>,
56 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
57 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
58{
59 pub(crate) recv_handle: JoinHandle<CmdResult<()>>,
60 pub(crate) write: ObjectHolder<ClassifiedCmdTunnelWHalf<R, W>>,
61 pub(crate) is_work: bool,
62 pub(crate) classification: C,
63 pub(crate) tunnel_id: TunnelId,
64 pub(crate) resp_waiter: RespWaiterRef,
65 pub(crate) remote_id: PeerId,
66 pub(crate) tunnel_meta: Option<Arc<M>>,
67 _p: std::marker::PhantomData<(LEN, CMD)>,
68
69}
70
71impl<C, M, R, W, LEN, CMD> ClassifiedCmdSend<C, M, R, W, LEN, CMD>
85where C: WorkerClassification,
86 M: CmdTunnelMeta,
87 R: ClassifiedCmdTunnelRead<C, M>,
88 W: ClassifiedCmdTunnelWrite<C, M>,
89 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
90 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes {
91 pub(crate) fn new(tunnel_id: TunnelId,
92 classification: C,
93 recv_handle: JoinHandle<CmdResult<()>>,
94 write: ObjectHolder<ClassifiedCmdTunnelWHalf<R, W>>,
95 resp_waiter: RespWaiterRef,
96 remote_id: PeerId,
97 tunnel_meta: Option<Arc<M>>) -> Self {
98 Self {
99 recv_handle,
100 write,
101 is_work: true,
102 classification,
103 tunnel_id,
104 resp_waiter,
105 remote_id,
106 tunnel_meta,
107 _p: Default::default(),
108 }
109 }
110
111 pub fn get_tunnel_id(&self) -> TunnelId {
112 self.tunnel_id
113 }
114
115 pub fn set_disable(&mut self) {
116 self.is_work = false;
117 self.recv_handle.abort();
118 }
119
120 pub async fn send(&mut self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
121 log::trace!("client {:?} send cmd: {:?}, len: {}, data: {}", self.tunnel_id, cmd, body.len(), hex::encode(body));
122 let header = CmdHeader::<LEN, CMD>::new(version, false, None, cmd, LEN::from_u64(body.len() as u64).unwrap());
123 let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
124 let ret = self.send_inner(buf.as_slice(), body).await;
125 if let Err(e) = ret {
126 self.set_disable();
127 return Err(e);
128 }
129 Ok(())
130 }
131
132 pub async fn send_with_resp(&mut self, cmd: CMD, version: u8, body: &[u8], timeout: Duration) -> CmdResult<CmdBody> {
133 if let Some(id) = tokio::task::try_id() {
134 if id == self.recv_handle.id() {
135 return Err(cmd_err!(CmdErrorCode::Failed, "can't send with resp in recv task"));
136 }
137 }
138 log::trace!("client {:?} send cmd: {:?}, len: {}, data: {}", self.tunnel_id, cmd, body.len(), hex::encode(body));
139 let seq = gen_seq();
140 let header = CmdHeader::<LEN, CMD>::new(version, false, Some(seq), cmd, LEN::from_u64(body.len() as u64).unwrap());
141 let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
142 let resp_id = gen_resp_id(cmd, seq);
143 let waiter = self.resp_waiter.clone();
144 let resp_waiter = waiter.create_timeout_result_future(resp_id, timeout)
145 .map_err(into_cmd_err!(CmdErrorCode::Failed, "create timeout result future error"))?;
146 let ret = self.send_inner(buf.as_slice(), body).await;
147 if let Err(e) = ret {
148 self.set_disable();
149 return Err(e);
150 }
151 let resp = resp_waiter.await.map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
152 Ok(resp)
153 }
154
155 pub async fn send2(&mut self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
156 let mut len = 0;
157 for b in body.iter() {
158 len += b.len();
159 log::trace!("client {:?} send2 cmd {:?} body: {}", self.tunnel_id, cmd, hex::encode(b));
160 }
161 log::trace!("client {:?} send2 cmd: {:?}, len {}", self.tunnel_id, cmd, len);
162 let header = CmdHeader::<LEN, CMD>::new(version, false, None, cmd, LEN::from_u64(len as u64).unwrap());
163 let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
164 let ret = self.send_inner2(buf.as_slice(), body).await;
165 if let Err(e) = ret {
166 self.set_disable();
167 return Err(e);
168 }
169 Ok(())
170 }
171
172 pub async fn send2_with_resp(&mut self, cmd: CMD, version: u8, body: &[&[u8]], timeout: Duration) -> CmdResult<CmdBody> {
173 if let Some(id) = tokio::task::try_id() {
174 if id == self.recv_handle.id() {
175 return Err(cmd_err!(CmdErrorCode::Failed, "can't send with resp in recv task"));
176 }
177 }
178 let mut len = 0;
179 for b in body.iter() {
180 len += b.len();
181 log::trace!("client {:?} send2 cmd {:?} body: {}", self.tunnel_id, cmd, hex::encode(b));
182 }
183 log::trace!("client {:?} send2 cmd: {:?}, len {}", self.tunnel_id, cmd, len);
184 let seq = gen_seq();
185 let header = CmdHeader::<LEN, CMD>::new(version, false, Some(seq), cmd, LEN::from_u64(len as u64).unwrap());
186 let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
187 let resp_id = gen_resp_id(cmd, seq);
188 let waiter = self.resp_waiter.clone();
189 let resp_waiter = waiter.create_timeout_result_future(resp_id, timeout)
190 .map_err(into_cmd_err!(CmdErrorCode::Failed, "create timeout result future error"))?;
191 let ret = self.send_inner2(buf.as_slice(), body).await;
192 if let Err(e) = ret {
193 self.set_disable();
194 return Err(e);
195 }
196 let resp = resp_waiter.await.map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
197 Ok(resp)
198 }
199
200 pub async fn send_cmd(&mut self, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
201 log::trace!("client {:?} send cmd: {:?}, len: {}", self.tunnel_id, cmd, body.len());
202 let header = CmdHeader::<LEN, CMD>::new(version, false, None, cmd, LEN::from_u64(body.len()).unwrap());
203 let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
204 let ret = self.send_inner_cmd(buf.as_slice(), body).await;
205 if let Err(e) = ret {
206 self.set_disable();
207 return Err(e);
208 }
209 Ok(())
210 }
211
212 pub async fn send_cmd_with_resp(&mut self, cmd: CMD, version: u8, body: CmdBody, timeout: Duration) -> CmdResult<CmdBody> {
213 if let Some(id) = tokio::task::try_id() {
214 if id == self.recv_handle.id() {
215 return Err(cmd_err!(CmdErrorCode::Failed, "can't send with resp in recv task"));
216 }
217 }
218 log::trace!("client {:?} send cmd: {:?}, len: {}", self.tunnel_id, cmd, body.len());
219 let seq = gen_seq();
220 let header = CmdHeader::<LEN, CMD>::new(version, false, Some(seq), cmd, LEN::from_u64(body.len()).unwrap());
221 let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
222 let resp_id = gen_resp_id(cmd, seq);
223 let waiter = self.resp_waiter.clone();
224 let resp_waiter = waiter.create_timeout_result_future(resp_id, timeout)
225 .map_err(into_cmd_err!(CmdErrorCode::Failed, "create timeout result future error"))?;
226 let ret = self.send_inner_cmd(buf.as_slice(), body).await;
227 if let Err(e) = ret {
228 self.set_disable();
229 return Err(e);
230 }
231 let resp = resp_waiter.await.map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
232 Ok(resp)
233 }
234
235 async fn send_inner(&mut self, header: &[u8], body: &[u8]) -> CmdResult<()> {
236 let mut write = self.write.get().await;
237 if header.len() > 255 {
238 return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
239 }
240 write.write_u8(header.len() as u8).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
241 write.write_all(header).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
242 write.write_all(body).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
243 write.flush().await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
244 Ok(())
245 }
246
247 async fn send_inner2(&mut self, header: &[u8], body: &[&[u8]]) -> CmdResult<()> {
248 let mut write = self.write.get().await;
249 if header.len() > 255 {
250 return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
251 }
252 write.write_u8(header.len() as u8).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
253 write.write_all(header).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
254 for b in body.iter() {
255 write.write_all(b).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
256 }
257 write.flush().await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
258 Ok(())
259 }
260
261 async fn send_inner_cmd(&mut self, header: &[u8], mut body: CmdBody) -> CmdResult<()> {
262 let mut write = self.write.get().await;
263 if header.len() > 255 {
264 return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
265 }
266 write.write_u8(header.len() as u8).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
267 write.write_all(header).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
268 tokio::io::copy(&mut body, write.deref_mut().deref_mut()).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
269 write.flush().await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
270 Ok(())
271 }
272}
273
274impl<C, M, R, W, LEN, CMD> Drop for ClassifiedCmdSend<C, M, R, W, LEN, CMD>
275where C: WorkerClassification,
276 M: CmdTunnelMeta,
277 R: ClassifiedCmdTunnelRead<C, M>,
278 W: ClassifiedCmdTunnelWrite<C, M>,
279 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
280 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes {
281 fn drop(&mut self) {
282 self.set_disable();
283 }
284}
285
286impl<C, M, R, W, LEN, CMD> CmdSend<M> for ClassifiedCmdSend<C, M, R, W, LEN, CMD>
287where C: WorkerClassification,
288 M: CmdTunnelMeta,
289 R: ClassifiedCmdTunnelRead<C, M>,
290 W: ClassifiedCmdTunnelWrite<C, M>,
291 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
292 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes {
293 fn get_tunnel_meta(&self) -> Option<Arc<M>> {
294 self.tunnel_meta.clone()
295 }
296
297 fn get_remote_peer_id(&self) -> PeerId {
298 self.remote_id.clone()
299 }
300}
301
302impl<C, M, R, W, LEN, CMD> ClassifiedWorker<CmdClientTunnelClassification<C>> for ClassifiedCmdSend<C, M, R, W, LEN, CMD>
303where C: WorkerClassification,
304 M: CmdTunnelMeta,
305 R: ClassifiedCmdTunnelRead<C, M>,
306 W: ClassifiedCmdTunnelWrite<C, M>,
307 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
308 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes {
309 fn is_work(&self) -> bool {
310 self.is_work && !self.recv_handle.is_finished()
311 }
312
313 fn is_valid(&self, c: CmdClientTunnelClassification<C>) -> bool {
314 if c.tunnel_id.is_some() {
315 self.tunnel_id == c.tunnel_id.unwrap()
316 } else {
317 if c.classification.is_some() {
318 self.classification == c.classification.unwrap()
319 } else {
320 true
321 }
322 }
323 }
324
325 fn classification(&self) -> CmdClientTunnelClassification<C> {
326 CmdClientTunnelClassification {
327 tunnel_id: Some(self.tunnel_id),
328 classification: Some(self.classification.clone()),
329 }
330 }
331}
332
333pub struct ClassifiedCmdWriteFactory<C: WorkerClassification,
334 M: CmdTunnelMeta,
335 R: ClassifiedCmdTunnelRead<C, M>,
336 W: ClassifiedCmdTunnelWrite<C, M>,
337 F: ClassifiedCmdTunnelFactory<C, M, R, W>,
338 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
339 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes> {
340 tunnel_factory: F,
341 cmd_handler: Arc<dyn CmdHandler<LEN, CMD>>,
342 resp_waiter: RespWaiterRef,
343 tunnel_id_generator: TunnelIdGenerator,
344 _p: std::marker::PhantomData<Mutex<(C, M, R, W)>>,
345}
346
347impl<
348 C: WorkerClassification,
349 M: CmdTunnelMeta,
350 R: ClassifiedCmdTunnelRead<C, M>,
351 W: ClassifiedCmdTunnelWrite<C, M>,
352 F: ClassifiedCmdTunnelFactory<C, M, R, W>,
353 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
354 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes
355> ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD> {
356 pub(crate) fn new(tunnel_factory: F,
357 cmd_handler: impl CmdHandler<LEN, CMD>,
358 resp_waiter: RespWaiterRef,) -> Self {
359 Self {
360 tunnel_factory,
361 cmd_handler: Arc::new(cmd_handler),
362 resp_waiter,
363 tunnel_id_generator: TunnelIdGenerator::new(),
364 _p: Default::default(),
365 }
366 }
367}
368
369#[async_trait::async_trait]
370impl<C: WorkerClassification,
371 M: CmdTunnelMeta,
372 R: ClassifiedCmdTunnelRead<C, M>,
373 W: ClassifiedCmdTunnelWrite<C, M>,
374 F: ClassifiedCmdTunnelFactory<C, M, R, W>,
375 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
376 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Debug> ClassifiedWorkerFactory<CmdClientTunnelClassification<C>, ClassifiedCmdSend<C, M, R, W, LEN, CMD>
377> for ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD> {
378 async fn create(&self, classification: Option<CmdClientTunnelClassification<C>>) -> PoolResult<ClassifiedCmdSend<C, M, R, W, LEN, CMD>> {
379 if classification.is_some() && classification.as_ref().unwrap().tunnel_id.is_some() {
380 return Err(pool_err!(PoolErrorCode::Failed, "tunnel {:?} not found", classification.as_ref().unwrap().tunnel_id.unwrap()));
381 }
382
383 let classification = if classification.is_some() && classification.as_ref().unwrap().classification.is_some() {
384 classification.unwrap().classification
385 } else {
386 None
387 };
388 let tunnel = self.tunnel_factory.create_tunnel(classification).await.map_err(into_pool_err!(PoolErrorCode::Failed))?;
389 let classification = tunnel.get_classification();
390 let peer_id = tunnel.get_remote_peer_id();
391 let tunnel_id = self.tunnel_id_generator.generate();
392 let (mut recv, write) = tunnel.split();
393 let remote_id = peer_id.clone();
394 let tunnel_meta = recv.get_tunnel_meta();
395 let write = ObjectHolder::new(write);
396 let resp_write = write.clone();
397 let cmd_handler = self.cmd_handler.clone();
398 let handle = spawn(async move {
399 let ret: CmdResult<()> = async move {
400 loop {
401 let header_len = recv.read_u8().await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
402 let mut header = vec![0u8; header_len as usize];
403 let n = recv.read_exact(header.as_mut()).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
404 if n == 0 {
405 break;
406 }
407 let header = CmdHeader::<LEN, CMD>::clone_from_slice(header.as_slice()).map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
408 log::trace!("recv cmd {:?} from {} len {} tunnel {:?}", header.cmd_code(), peer_id, header.pkg_len().to_u64().unwrap(), tunnel_id);
409 let body_len = header.pkg_len().to_u64().unwrap();
410 let cmd_read = CmdBodyRead::new(recv, header.pkg_len().to_u64().unwrap() as usize);
411 let waiter = cmd_read.get_waiter();
412 let future = waiter.create_result_future().map_err(into_cmd_err!(CmdErrorCode::Failed))?;
413 let version = header.version();
414 let seq = header.seq();
415 let cmd_code = header.cmd_code();
416 match cmd_handler.handle(peer_id.clone(), tunnel_id, header, CmdBody::from_reader(BufReader::new(cmd_read), body_len)).await {
417 Ok(Some(mut body)) => {
418 let mut write = resp_write.get().await;
419 let header = CmdHeader::<LEN, CMD>::new(version, true, seq, cmd_code, LEN::from_u64(body.len()).unwrap());
420 let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
421 if buf.len() > 255 {
422 return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
423 }
424 write.write_u8(buf.len() as u8).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
425 write.write_all(buf.as_slice()).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
426 tokio::io::copy(&mut body, write.deref_mut().deref_mut()).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
427 write.flush().await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
428 }
429 Err(e) => {
430 log::error!("handle cmd error: {:?}", e);
431 }
432 _ => {}
433 }
434 recv = future.await.map_err(into_cmd_err!(CmdErrorCode::Failed))??;
435 }
436 Ok(())
437 }.await;
438 ret
439 });
440 Ok(ClassifiedCmdSend::new(tunnel_id, classification, handle, write, self.resp_waiter.clone(), remote_id, tunnel_meta))
441 }
442}
443
444pub struct DefaultClassifiedCmdClient<C: WorkerClassification,
445 M: CmdTunnelMeta,
446 R: ClassifiedCmdTunnelRead<C, M>,
447 W: ClassifiedCmdTunnelWrite<C, M>,
448 F: ClassifiedCmdTunnelFactory<C, M, R, W>,
449 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
450 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Eq + Hash + Debug> {
451 tunnel_pool: ClassifiedWorkerPoolRef<CmdClientTunnelClassification<C>, ClassifiedCmdSend<C, M, R, W, LEN, CMD>, ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>>,
452 cmd_handler_map: Arc<CmdHandlerMap<LEN, CMD>>,
453}
454
455impl<C: WorkerClassification,
456 M: CmdTunnelMeta,
457 R: ClassifiedCmdTunnelRead<C, M>,
458 W: ClassifiedCmdTunnelWrite<C, M>,
459 F: ClassifiedCmdTunnelFactory<C, M, R, W>,
460 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
461 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Eq + Hash + Debug> DefaultClassifiedCmdClient<C, M, R, W, F, LEN, CMD> {
462 pub fn new(factory: F, tunnel_count: u16) -> Arc<Self> {
463 let cmd_handler_map = Arc::new(CmdHandlerMap::new());
464 let resp_waiter = Arc::new(RespWaiter::new());
465 let handler_map = cmd_handler_map.clone();
466 let waiter = resp_waiter.clone();
467 Arc::new(Self {
468 tunnel_pool: ClassifiedWorkerPool::new(tunnel_count, ClassifiedCmdWriteFactory::<C, M, R, W, _, LEN, CMD>::new(factory, move |peer_id: PeerId, tunnel_id: TunnelId, header: CmdHeader<LEN, CMD>, body_read: CmdBody| {
469 let handler_map = handler_map.clone();
470 let waiter = waiter.clone();
471 async move {
472 if header.is_resp() && header.seq().is_some() {
473 let resp_id = gen_resp_id(header.cmd_code(), header.seq().unwrap());
474 let _ = waiter.set_result(resp_id, body_read);
475 Ok(None)
476 } else {
477 if let Some(handler) = handler_map.get(header.cmd_code()) {
478 handler.handle(peer_id, tunnel_id, header, body_read).await
479 } else {
480 Ok(None)
481 }
482 }
483 }
484 }, resp_waiter.clone())),
485 cmd_handler_map,
486 })
487 }
488
489 async fn get_send(&self) -> CmdResult<ClassifiedWorkerGuard<CmdClientTunnelClassification<C>, ClassifiedCmdSend<C, M, R, W, LEN, CMD>, ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>>> {
490 self.tunnel_pool.get_worker().await.map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
491 }
492
493 async fn get_send_of_tunnel_id(&self, tunnel_id: TunnelId) -> CmdResult<ClassifiedWorkerGuard<CmdClientTunnelClassification<C>, ClassifiedCmdSend<C, M, R, W, LEN, CMD>, ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>>> {
494 self.tunnel_pool.get_classified_worker(CmdClientTunnelClassification {
495 tunnel_id: Some(tunnel_id),
496 classification: None,
497 }).await.map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
498 }
499
500 async fn get_classified_send(&self, classification: C) -> CmdResult<ClassifiedWorkerGuard<CmdClientTunnelClassification<C>, ClassifiedCmdSend<C, M, R, W, LEN, CMD>, ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>>> {
501 self.tunnel_pool.get_classified_worker(CmdClientTunnelClassification {
502 tunnel_id: None,
503 classification: Some(classification),
504 }).await.map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
505 }
506}
507
508pub type ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD> = ClassifiedSendGuard<CmdClientTunnelClassification<C>, M, ClassifiedCmdSend<C, M, R, W, LEN, CMD>, ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>>;
509#[async_trait::async_trait]
510impl<C: WorkerClassification,
511 M: CmdTunnelMeta,
512 R: ClassifiedCmdTunnelRead<C, M>,
513 W: ClassifiedCmdTunnelWrite<C, M>,
514 F: ClassifiedCmdTunnelFactory<C, M, R, W>,
515 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
516 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Eq + Hash + Debug,
517> CmdClient<LEN, CMD, M, ClassifiedCmdSend<C, M, R, W, LEN, CMD>, ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD>> for DefaultClassifiedCmdClient<C, M, R, W, F, LEN, CMD> {
518 fn register_cmd_handler(&self, cmd: CMD, handler: impl CmdHandler<LEN, CMD>) {
519 self.cmd_handler_map.insert(cmd, handler);
520 }
521
522 async fn send(&self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
523 let mut send = self.get_send().await?;
524 send.send(cmd, version, body).await
525 }
526
527 async fn send_with_resp(&self, cmd: CMD, version: u8, body: &[u8], timeout: Duration) -> CmdResult<CmdBody> {
528 let mut send = self.get_send().await?;
529 send.send_with_resp(cmd, version, body, timeout).await
530 }
531
532 async fn send2(&self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
533 let mut send = self.get_send().await?;
534 send.send2(cmd, version, body).await
535 }
536
537 async fn send2_with_resp(&self, cmd: CMD, version: u8, body: &[&[u8]], timeout: Duration) -> CmdResult<CmdBody> {
538 let mut send = self.get_send().await?;
539 send.send2_with_resp(cmd, version, body, timeout).await
540 }
541
542 async fn send_cmd(&self, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
543 let mut send = self.get_send().await?;
544 send.send_cmd(cmd, version, body).await
545 }
546
547 async fn send_cmd_with_resp(&self, cmd: CMD, version: u8, body: CmdBody, timeout: Duration) -> CmdResult<CmdBody> {
548 let mut send = self.get_send().await?;
549 send.send_cmd_with_resp(cmd, version, body, timeout).await
550 }
551
552 async fn send_by_specify_tunnel(&self, tunnel_id: TunnelId, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
553 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
554 send.send(cmd, version, body).await
555 }
556
557 async fn send_by_specify_tunnel_with_resp(&self, tunnel_id: TunnelId, cmd: CMD, version: u8, body: &[u8], timeout: Duration) -> CmdResult<CmdBody> {
558 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
559 send.send_with_resp(cmd, version, body, timeout).await
560 }
561
562 async fn send2_by_specify_tunnel(&self, tunnel_id: TunnelId, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
563 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
564 send.send2(cmd, version, body).await
565 }
566
567 async fn send2_by_specify_tunnel_with_resp(&self, tunnel_id: TunnelId, cmd: CMD, version: u8, body: &[&[u8]], timeout: Duration) -> CmdResult<CmdBody> {
568 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
569 send.send2_with_resp(cmd, version, body, timeout).await
570 }
571
572 async fn send_cmd_by_specify_tunnel(&self, tunnel_id: TunnelId, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
573 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
574 send.send_cmd(cmd, version, body).await
575 }
576
577 async fn send_cmd_by_specify_tunnel_with_resp(&self, tunnel_id: TunnelId, cmd: CMD, version: u8, body: CmdBody, timeout: Duration) -> CmdResult<CmdBody> {
578 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
579 send.send_cmd_with_resp(cmd, version, body, timeout).await
580 }
581
582 async fn clear_all_tunnel(&self) {
583 self.tunnel_pool.clear_all_worker().await;
584 }
585
586 async fn get_send(&self, tunnel_id: TunnelId) -> CmdResult<ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD>> {
587 Ok(ClassifiedSendGuard {
588 worker_guard: self.get_send_of_tunnel_id(tunnel_id).await?,
589 _p: std::marker::PhantomData,
590 })
591 }
592}
593
594#[async_trait::async_trait]
595impl<C: WorkerClassification,
596 M: CmdTunnelMeta,
597 R: ClassifiedCmdTunnelRead<C, M>,
598 W: ClassifiedCmdTunnelWrite<C, M>,
599 F: ClassifiedCmdTunnelFactory<C, M, R, W>,
600 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
601 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Eq + Hash + Debug
602> ClassifiedCmdClient<LEN, CMD, C, M, ClassifiedCmdSend<C, M, R, W, LEN, CMD>, ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD>> for DefaultClassifiedCmdClient<C, M, R, W, F, LEN, CMD> {
603 async fn send_by_classified_tunnel(&self, classification: C, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
604 let mut send = self.get_classified_send(classification).await?;
605 send.send(cmd, version, body).await
606 }
607
608 async fn send_by_classified_tunnel_with_resp(&self, classification: C, cmd: CMD, version: u8, body: &[u8], timeout: Duration) -> CmdResult<CmdBody> {
609 let mut send = self.get_classified_send(classification).await?;
610 send.send_with_resp(cmd, version, body, timeout).await
611 }
612
613 async fn send2_by_classified_tunnel(&self, classification: C, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
614 let mut send = self.get_classified_send(classification).await?;
615 send.send2(cmd, version, body).await
616 }
617
618 async fn send2_by_classified_tunnel_with_resp(&self, classification: C, cmd: CMD, version: u8, body: &[&[u8]], timeout: Duration) -> CmdResult<CmdBody> {
619 let mut send = self.get_classified_send(classification).await?;
620 send.send2_with_resp(cmd, version, body, timeout).await
621 }
622
623 async fn send_cmd_by_classified_tunnel(&self, classification: C, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
624 let mut send = self.get_classified_send(classification).await?;
625 send.send_cmd(cmd, version, body).await
626 }
627
628 async fn send_cmd_by_classified_tunnel_with_resp(&self, classification: C, cmd: CMD, version: u8, body: CmdBody, timeout: Duration) -> CmdResult<CmdBody> {
629 let mut send = self.get_classified_send(classification).await?;
630 send.send_cmd_with_resp(cmd, version, body, timeout).await
631 }
632
633 async fn find_tunnel_id_by_classified(&self, classification: C) -> CmdResult<TunnelId> {
634 let send = self.get_classified_send(classification).await?;
635 Ok(send.get_tunnel_id())
636 }
637
638 async fn get_send_by_classified(&self, classification: C) -> CmdResult<ClassifiedSendGuard<CmdClientTunnelClassification<C>, M, ClassifiedCmdSend<C, M, R, W, LEN, CMD>, ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>>> {
639 Ok(ClassifiedSendGuard {
640 worker_guard: self.get_classified_send(classification).await?,
641 _p: std::marker::PhantomData,
642 })
643 }
644}