1use std::hash::Hash;
2use std::sync::{Arc, Mutex};
3use bucky_raw_codec::{RawConvertTo, RawDecode, RawEncode, RawFixedBytes, RawFrom};
4use num::{FromPrimitive, ToPrimitive};
5use sfo_pool::{into_pool_err, pool_err, ClassifiedWorker, ClassifiedWorkerFactory, ClassifiedWorkerGuard, ClassifiedWorkerPool, ClassifiedWorkerPoolRef, PoolErrorCode, PoolResult, WorkerClassification};
6use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
7use tokio::spawn;
8use tokio::task::JoinHandle;
9use crate::{CmdBody, CmdTunnelMeta, CmdTunnelRead, CmdTunnelWrite, TunnelId, TunnelIdGenerator};
10use crate::client::{gen_resp_id, gen_seq, ClassifiedCmdClient, ClassifiedSendGuard, CmdClient, CmdSend, RespWaiter, RespWaiterRef};
11use crate::cmd::{CmdBodyRead, CmdHandler, CmdHandlerMap, CmdHeader};
12use crate::errors::{cmd_err, into_cmd_err, CmdErrorCode, CmdResult};
13use crate::peer_id::PeerId;
14use std::fmt::Debug;
15use std::ops::{DerefMut};
16use std::time::Duration;
17use async_named_locker::ObjectHolder;
18use sfo_split::{RHalf, Splittable, WHalf};
19
20pub trait ClassifiedCmdTunnelRead<C: WorkerClassification, M: CmdTunnelMeta>: CmdTunnelRead<M> + 'static + Send {
21 fn get_classification(&self) -> C;
22}
23
24pub trait ClassifiedCmdTunnelWrite<C: WorkerClassification, M: CmdTunnelMeta>: CmdTunnelWrite<M> + 'static + Send {
25 fn get_classification(&self) -> C;
26}
27
28pub type ClassifiedCmdTunnel<R, W> = Splittable<R, W>;
29pub type ClassifiedCmdTunnelRHalf<R, W> = RHalf<R, W>;
30pub type ClassifiedCmdTunnelWHalf<R, W> = WHalf<R, W>;
31
32#[derive(Debug, Clone, Copy, Eq, Hash)]
33pub struct CmdClientTunnelClassification<C: WorkerClassification> {
34 tunnel_id: Option<TunnelId>,
35 classification: Option<C>,
36}
37
38impl<C: WorkerClassification> PartialEq for CmdClientTunnelClassification<C> {
39 fn eq(&self, other: &Self) -> bool {
40 self.tunnel_id == other.tunnel_id && self.classification == other.classification
41 }
42}
43
44
45#[async_trait::async_trait]
46pub trait ClassifiedCmdTunnelFactory<C: WorkerClassification, M: CmdTunnelMeta, R: ClassifiedCmdTunnelRead<C, M>, W: ClassifiedCmdTunnelWrite<C, M>>: Send + Sync + 'static {
47 async fn create_tunnel(&self, classification: Option<C>) -> CmdResult<Splittable<R, W>>;
48}
49
50pub struct ClassifiedCmdSend<C, M, R, W, LEN, CMD>
51where
52 C: WorkerClassification,
53 M: CmdTunnelMeta,
54 R: ClassifiedCmdTunnelRead<C, M>,
55 W: ClassifiedCmdTunnelWrite<C, M>,
56 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
57 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes,
58{
59 pub(crate) recv_handle: JoinHandle<CmdResult<()>>,
60 pub(crate) write: ObjectHolder<ClassifiedCmdTunnelWHalf<R, W>>,
61 pub(crate) is_work: bool,
62 pub(crate) classification: C,
63 pub(crate) tunnel_id: TunnelId,
64 pub(crate) resp_waiter: RespWaiterRef,
65 pub(crate) remote_id: PeerId,
66 pub(crate) tunnel_meta: Option<Arc<M>>,
67 _p: std::marker::PhantomData<(LEN, CMD)>,
68
69}
70
71impl<C, M, R, W, LEN, CMD> ClassifiedCmdSend<C, M, R, W, LEN, CMD>
85where C: WorkerClassification,
86 M: CmdTunnelMeta,
87 R: ClassifiedCmdTunnelRead<C, M>,
88 W: ClassifiedCmdTunnelWrite<C, M>,
89 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
90 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes {
91 pub(crate) fn new(tunnel_id: TunnelId,
92 classification: C,
93 recv_handle: JoinHandle<CmdResult<()>>,
94 write: ObjectHolder<ClassifiedCmdTunnelWHalf<R, W>>,
95 resp_waiter: RespWaiterRef,
96 remote_id: PeerId,
97 tunnel_meta: Option<Arc<M>>) -> Self {
98 Self {
99 recv_handle,
100 write,
101 is_work: true,
102 classification,
103 tunnel_id,
104 resp_waiter,
105 remote_id,
106 tunnel_meta,
107 _p: Default::default(),
108 }
109 }
110
111 pub fn get_tunnel_id(&self) -> TunnelId {
112 self.tunnel_id
113 }
114
115 pub fn set_disable(&mut self) {
116 self.is_work = false;
117 self.recv_handle.abort();
118 }
119
120 pub async fn send(&mut self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
121 log::trace!("client {:?} send cmd: {:?}, len: {}, data: {}", self.tunnel_id, cmd, body.len(), hex::encode(body));
122 let header = CmdHeader::<LEN, CMD>::new(version, false, None, cmd, LEN::from_u64(body.len() as u64).unwrap());
123 let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
124 let ret = self.send_inner(buf.as_slice(), body).await;
125 if let Err(e) = ret {
126 self.set_disable();
127 return Err(e);
128 }
129 Ok(())
130 }
131
132 pub async fn send_with_resp(&mut self, cmd: CMD, version: u8, body: &[u8], timeout: Duration) -> CmdResult<CmdBody> {
133 if let Some(id) = tokio::task::try_id() {
134 if id == self.recv_handle.id() {
135 return Err(cmd_err!(CmdErrorCode::Failed, "can't send with resp in recv task"));
136 }
137 }
138 log::trace!("client {:?} send cmd: {:?}, len: {}, data: {}", self.tunnel_id, cmd, body.len(), hex::encode(body));
139 let seq = gen_seq();
140 let header = CmdHeader::<LEN, CMD>::new(version, false, Some(seq), cmd, LEN::from_u64(body.len() as u64).unwrap());
141 let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
142 let resp_id = gen_resp_id(cmd, seq);
143 let waiter = self.resp_waiter.clone();
144 let resp_waiter = waiter.create_timeout_result_future(resp_id, timeout)
145 .map_err(into_cmd_err!(CmdErrorCode::Failed, "create timeout result future error"))?;
146 let ret = self.send_inner(buf.as_slice(), body).await;
147 if let Err(e) = ret {
148 self.set_disable();
149 return Err(e);
150 }
151 let resp = resp_waiter.await.map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
152 Ok(resp)
153 }
154
155 pub async fn send2(&mut self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
156 let mut len = 0;
157 for b in body.iter() {
158 len += b.len();
159 log::trace!("client {:?} send2 cmd {:?} body: {}", self.tunnel_id, cmd, hex::encode(b));
160 }
161 log::trace!("client {:?} send2 cmd: {:?}, len {}", self.tunnel_id, cmd, len);
162 let header = CmdHeader::<LEN, CMD>::new(version, false, None, cmd, LEN::from_u64(len as u64).unwrap());
163 let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
164 let ret = self.send_inner2(buf.as_slice(), body).await;
165 if let Err(e) = ret {
166 self.set_disable();
167 return Err(e);
168 }
169 Ok(())
170 }
171
172 pub async fn send2_with_resp(&mut self, cmd: CMD, version: u8, body: &[&[u8]], timeout: Duration) -> CmdResult<CmdBody> {
173 if let Some(id) = tokio::task::try_id() {
174 if id == self.recv_handle.id() {
175 return Err(cmd_err!(CmdErrorCode::Failed, "can't send with resp in recv task"));
176 }
177 }
178 let mut len = 0;
179 for b in body.iter() {
180 len += b.len();
181 log::trace!("client {:?} send2 cmd {:?} body: {}", self.tunnel_id, cmd, hex::encode(b));
182 }
183 log::trace!("client {:?} send2 cmd: {:?}, len {}", self.tunnel_id, cmd, len);
184 let seq = gen_seq();
185 let header = CmdHeader::<LEN, CMD>::new(version, false, Some(seq), cmd, LEN::from_u64(len as u64).unwrap());
186 let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
187 let resp_id = gen_resp_id(cmd, seq);
188 let waiter = self.resp_waiter.clone();
189 let resp_waiter = waiter.create_timeout_result_future(resp_id, timeout)
190 .map_err(into_cmd_err!(CmdErrorCode::Failed, "create timeout result future error"))?;
191 let ret = self.send_inner2(buf.as_slice(), body).await;
192 if let Err(e) = ret {
193 self.set_disable();
194 return Err(e);
195 }
196 let resp = resp_waiter.await.map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
197 Ok(resp)
198 }
199
200 pub async fn send_cmd(&mut self, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
201 log::trace!("client {:?} send cmd: {:?}, len: {}", self.tunnel_id, cmd, body.len());
202 let header = CmdHeader::<LEN, CMD>::new(version, false, None, cmd, LEN::from_u64(body.len()).unwrap());
203 let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
204 let ret = self.send_inner_cmd(buf.as_slice(), body).await;
205 if let Err(e) = ret {
206 self.set_disable();
207 return Err(e);
208 }
209 Ok(())
210 }
211
212 pub async fn send_cmd_with_resp(&mut self, cmd: CMD, version: u8, body: CmdBody, timeout: Duration) -> CmdResult<CmdBody> {
213 if let Some(id) = tokio::task::try_id() {
214 if id == self.recv_handle.id() {
215 return Err(cmd_err!(CmdErrorCode::Failed, "can't send with resp in recv task"));
216 }
217 }
218 log::trace!("client {:?} send cmd: {:?}, len: {}", self.tunnel_id, cmd, body.len());
219 let seq = gen_seq();
220 let header = CmdHeader::<LEN, CMD>::new(version, false, Some(seq), cmd, LEN::from_u64(body.len()).unwrap());
221 let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
222 let resp_id = gen_resp_id(cmd, seq);
223 let waiter = self.resp_waiter.clone();
224 let resp_waiter = waiter.create_timeout_result_future(resp_id, timeout)
225 .map_err(into_cmd_err!(CmdErrorCode::Failed, "create timeout result future error"))?;
226 let ret = self.send_inner_cmd(buf.as_slice(), body).await;
227 if let Err(e) = ret {
228 self.set_disable();
229 return Err(e);
230 }
231 let resp = resp_waiter.await.map_err(into_cmd_err!(CmdErrorCode::Timeout, "recv resp error"))?;
232 Ok(resp)
233 }
234
235 async fn send_inner(&mut self, header: &[u8], body: &[u8]) -> CmdResult<()> {
236 let mut write = self.write.get().await;
237 if header.len() > 255 {
238 return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
239 }
240 write.write_u8(header.len() as u8).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
241 write.write_all(header).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
242 write.write_all(body).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
243 write.flush().await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
244 Ok(())
245 }
246
247 async fn send_inner2(&mut self, header: &[u8], body: &[&[u8]]) -> CmdResult<()> {
248 let mut write = self.write.get().await;
249 if header.len() > 255 {
250 return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
251 }
252 write.write_u8(header.len() as u8).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
253 write.write_all(header).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
254 for b in body.iter() {
255 write.write_all(b).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
256 }
257 write.flush().await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
258 Ok(())
259 }
260
261 async fn send_inner_cmd(&mut self, header: &[u8], mut body: CmdBody) -> CmdResult<()> {
262 let mut write = self.write.get().await;
263 if header.len() > 255 {
264 return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
265 }
266 write.write_u8(header.len() as u8).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
267 write.write_all(header).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
268 tokio::io::copy(&mut body, write.deref_mut().deref_mut()).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
269 write.flush().await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
270 Ok(())
271 }
272}
273
274impl<C, M, R, W, LEN, CMD> Drop for ClassifiedCmdSend<C, M, R, W, LEN, CMD>
275where C: WorkerClassification,
276 M: CmdTunnelMeta,
277 R: ClassifiedCmdTunnelRead<C, M>,
278 W: ClassifiedCmdTunnelWrite<C, M>,
279 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
280 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes {
281 fn drop(&mut self) {
282 self.set_disable();
283 }
284}
285
286impl<C, M, R, W, LEN, CMD> CmdSend<M> for ClassifiedCmdSend<C, M, R, W, LEN, CMD>
287where C: WorkerClassification,
288 M: CmdTunnelMeta,
289 R: ClassifiedCmdTunnelRead<C, M>,
290 W: ClassifiedCmdTunnelWrite<C, M>,
291 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
292 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes {
293 fn get_tunnel_meta(&self) -> Option<Arc<M>> {
294 self.tunnel_meta.clone()
295 }
296
297 fn get_remote_peer_id(&self) -> PeerId {
298 self.remote_id.clone()
299 }
300}
301
302impl<C, M, R, W, LEN, CMD> ClassifiedWorker<CmdClientTunnelClassification<C>> for ClassifiedCmdSend<C, M, R, W, LEN, CMD>
303where C: WorkerClassification,
304 M: CmdTunnelMeta,
305 R: ClassifiedCmdTunnelRead<C, M>,
306 W: ClassifiedCmdTunnelWrite<C, M>,
307 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
308 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes {
309 fn is_work(&self) -> bool {
310 self.is_work && !self.recv_handle.is_finished()
311 }
312
313 fn is_valid(&self, c: CmdClientTunnelClassification<C>) -> bool {
314 if c.tunnel_id.is_some() {
315 self.tunnel_id == c.tunnel_id.unwrap()
316 } else {
317 if c.classification.is_some() {
318 self.classification == c.classification.unwrap()
319 } else {
320 true
321 }
322 }
323 }
324
325 fn classification(&self) -> CmdClientTunnelClassification<C> {
326 CmdClientTunnelClassification {
327 tunnel_id: Some(self.tunnel_id),
328 classification: Some(self.classification.clone()),
329 }
330 }
331}
332
333pub struct ClassifiedCmdWriteFactory<C: WorkerClassification,
334 M: CmdTunnelMeta,
335 R: ClassifiedCmdTunnelRead<C, M>,
336 W: ClassifiedCmdTunnelWrite<C, M>,
337 F: ClassifiedCmdTunnelFactory<C, M, R, W>,
338 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
339 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes> {
340 tunnel_factory: F,
341 cmd_handler: Arc<dyn CmdHandler<LEN, CMD>>,
342 resp_waiter: RespWaiterRef,
343 tunnel_id_generator: TunnelIdGenerator,
344 _p: std::marker::PhantomData<Mutex<(C, M, R, W)>>,
345}
346
347impl<
348 C: WorkerClassification,
349 M: CmdTunnelMeta,
350 R: ClassifiedCmdTunnelRead<C, M>,
351 W: ClassifiedCmdTunnelWrite<C, M>,
352 F: ClassifiedCmdTunnelFactory<C, M, R, W>,
353 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
354 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug + RawFixedBytes
355> ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD> {
356 pub(crate) fn new(tunnel_factory: F,
357 cmd_handler: impl CmdHandler<LEN, CMD>,
358 resp_waiter: RespWaiterRef,) -> Self {
359 Self {
360 tunnel_factory,
361 cmd_handler: Arc::new(cmd_handler),
362 resp_waiter,
363 tunnel_id_generator: TunnelIdGenerator::new(),
364 _p: Default::default(),
365 }
366 }
367}
368
369#[async_trait::async_trait]
370impl<C: WorkerClassification,
371 M: CmdTunnelMeta,
372 R: ClassifiedCmdTunnelRead<C, M>,
373 W: ClassifiedCmdTunnelWrite<C, M>,
374 F: ClassifiedCmdTunnelFactory<C, M, R, W>,
375 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
376 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Debug> ClassifiedWorkerFactory<CmdClientTunnelClassification<C>, ClassifiedCmdSend<C, M, R, W, LEN, CMD>
377> for ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD> {
378 async fn create(&self, classification: Option<CmdClientTunnelClassification<C>>) -> PoolResult<ClassifiedCmdSend<C, M, R, W, LEN, CMD>> {
379 if classification.is_some() && classification.as_ref().unwrap().tunnel_id.is_some() {
380 return Err(pool_err!(PoolErrorCode::Failed, "tunnel {:?} not found", classification.as_ref().unwrap().tunnel_id.unwrap()));
381 }
382
383 let classification = if classification.is_some() && classification.as_ref().unwrap().classification.is_some() {
384 classification.unwrap().classification
385 } else {
386 None
387 };
388 let tunnel = self.tunnel_factory.create_tunnel(classification).await.map_err(into_pool_err!(PoolErrorCode::Failed))?;
389 let classification = tunnel.get_classification();
390 let peer_id = tunnel.get_remote_peer_id();
391 let tunnel_id = self.tunnel_id_generator.generate();
392 let (mut recv, write) = tunnel.split();
393 let remote_id = peer_id.clone();
394 let tunnel_meta = recv.get_tunnel_meta();
395 let write = ObjectHolder::new(write);
396 let resp_write = write.clone();
397 let cmd_handler = self.cmd_handler.clone();
398 let handle = spawn(async move {
399 let ret: CmdResult<()> = async move {
400 loop {
401 let header_len = recv.read_u8().await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
402 let mut header = vec![0u8; header_len as usize];
403 let n = recv.read_exact(header.as_mut()).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
404 if n == 0 {
405 break;
406 }
407 let header = CmdHeader::<LEN, CMD>::clone_from_slice(header.as_slice()).map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
408 log::trace!("recv cmd {:?} from {} len {} tunnel {:?}", header.cmd_code(), peer_id, header.pkg_len().to_u64().unwrap(), tunnel_id);
409 let body_len = header.pkg_len().to_u64().unwrap();
410 let cmd_read = CmdBodyRead::new(recv, header.pkg_len().to_u64().unwrap() as usize);
411 let waiter = cmd_read.get_waiter();
412 let future = waiter.create_result_future().map_err(into_cmd_err!(CmdErrorCode::Failed))?;
413 let version = header.version();
414 let seq = header.seq();
415 let cmd_code = header.cmd_code();
416 match cmd_handler.handle(peer_id.clone(), tunnel_id, header, CmdBody::from_reader(BufReader::new(cmd_read), body_len)).await {
417 Ok(Some(mut body)) => {
418 let mut write = resp_write.get().await;
419 let header = CmdHeader::<LEN, CMD>::new(version, true, seq, cmd_code, LEN::from_u64(body.len()).unwrap());
420 let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
421 if buf.len() > 255 {
422 return Err(cmd_err!(CmdErrorCode::InvalidParam, "header len too large"));
423 }
424 write.write_u8(buf.len() as u8).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
425 write.write_all(buf.as_slice()).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
426 tokio::io::copy(&mut body, write.deref_mut().deref_mut()).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
427 write.flush().await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
428 }
429 Err(e) => {
430 log::error!("handle cmd error: {:?}", e);
431 }
432 _ => {}
433 }
434 recv = future.await.map_err(into_cmd_err!(CmdErrorCode::Failed))??;
435 log::debug!("handle cmd {:?} from {} len {} tunnel {:?} complete", cmd_code, peer_id, body_len, tunnel_id);
436 }
437 Ok(())
438 }.await;
439 if ret.is_err() {
440 log::error!("recv cmd error: {:?}", ret.as_ref().err().unwrap());
441 }
442 ret
443 });
444 Ok(ClassifiedCmdSend::new(tunnel_id, classification, handle, write, self.resp_waiter.clone(), remote_id, tunnel_meta))
445 }
446}
447
448pub struct DefaultClassifiedCmdClient<C: WorkerClassification,
449 M: CmdTunnelMeta,
450 R: ClassifiedCmdTunnelRead<C, M>,
451 W: ClassifiedCmdTunnelWrite<C, M>,
452 F: ClassifiedCmdTunnelFactory<C, M, R, W>,
453 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
454 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Eq + Hash + Debug> {
455 tunnel_pool: ClassifiedWorkerPoolRef<CmdClientTunnelClassification<C>, ClassifiedCmdSend<C, M, R, W, LEN, CMD>, ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>>,
456 cmd_handler_map: Arc<CmdHandlerMap<LEN, CMD>>,
457}
458
459impl<C: WorkerClassification,
460 M: CmdTunnelMeta,
461 R: ClassifiedCmdTunnelRead<C, M>,
462 W: ClassifiedCmdTunnelWrite<C, M>,
463 F: ClassifiedCmdTunnelFactory<C, M, R, W>,
464 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
465 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Eq + Hash + Debug> DefaultClassifiedCmdClient<C, M, R, W, F, LEN, CMD> {
466 pub fn new(factory: F, tunnel_count: u16) -> Arc<Self> {
467 let cmd_handler_map = Arc::new(CmdHandlerMap::new());
468 let resp_waiter = Arc::new(RespWaiter::new());
469 let handler_map = cmd_handler_map.clone();
470 let waiter = resp_waiter.clone();
471 Arc::new(Self {
472 tunnel_pool: ClassifiedWorkerPool::new(tunnel_count, ClassifiedCmdWriteFactory::<C, M, R, W, _, LEN, CMD>::new(factory, move |peer_id: PeerId, tunnel_id: TunnelId, header: CmdHeader<LEN, CMD>, body_read: CmdBody| {
473 let handler_map = handler_map.clone();
474 let waiter = waiter.clone();
475 async move {
476 if header.is_resp() && header.seq().is_some() {
477 let resp_id = gen_resp_id(header.cmd_code(), header.seq().unwrap());
478 let _ = waiter.set_result(resp_id, body_read);
479 Ok(None)
480 } else {
481 if let Some(handler) = handler_map.get(header.cmd_code()) {
482 handler.handle(peer_id, tunnel_id, header, body_read).await
483 } else {
484 Ok(None)
485 }
486 }
487 }
488 }, resp_waiter.clone())),
489 cmd_handler_map,
490 })
491 }
492
493 async fn get_send(&self) -> CmdResult<ClassifiedWorkerGuard<CmdClientTunnelClassification<C>, ClassifiedCmdSend<C, M, R, W, LEN, CMD>, ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>>> {
494 self.tunnel_pool.get_worker().await.map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
495 }
496
497 async fn get_send_of_tunnel_id(&self, tunnel_id: TunnelId) -> CmdResult<ClassifiedWorkerGuard<CmdClientTunnelClassification<C>, ClassifiedCmdSend<C, M, R, W, LEN, CMD>, ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>>> {
498 self.tunnel_pool.get_classified_worker(CmdClientTunnelClassification {
499 tunnel_id: Some(tunnel_id),
500 classification: None,
501 }).await.map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
502 }
503
504 async fn get_classified_send(&self, classification: C) -> CmdResult<ClassifiedWorkerGuard<CmdClientTunnelClassification<C>, ClassifiedCmdSend<C, M, R, W, LEN, CMD>, ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>>> {
505 self.tunnel_pool.get_classified_worker(CmdClientTunnelClassification {
506 tunnel_id: None,
507 classification: Some(classification),
508 }).await.map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
509 }
510}
511
512pub type ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD> = ClassifiedSendGuard<CmdClientTunnelClassification<C>, M, ClassifiedCmdSend<C, M, R, W, LEN, CMD>, ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>>;
513#[async_trait::async_trait]
514impl<C: WorkerClassification,
515 M: CmdTunnelMeta,
516 R: ClassifiedCmdTunnelRead<C, M>,
517 W: ClassifiedCmdTunnelWrite<C, M>,
518 F: ClassifiedCmdTunnelFactory<C, M, R, W>,
519 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
520 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Eq + Hash + Debug,
521> CmdClient<LEN, CMD, M, ClassifiedCmdSend<C, M, R, W, LEN, CMD>, ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD>> for DefaultClassifiedCmdClient<C, M, R, W, F, LEN, CMD> {
522 fn register_cmd_handler(&self, cmd: CMD, handler: impl CmdHandler<LEN, CMD>) {
523 self.cmd_handler_map.insert(cmd, handler);
524 }
525
526 async fn send(&self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
527 let mut send = self.get_send().await?;
528 send.send(cmd, version, body).await
529 }
530
531 async fn send_with_resp(&self, cmd: CMD, version: u8, body: &[u8], timeout: Duration) -> CmdResult<CmdBody> {
532 let mut send = self.get_send().await?;
533 send.send_with_resp(cmd, version, body, timeout).await
534 }
535
536 async fn send2(&self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
537 let mut send = self.get_send().await?;
538 send.send2(cmd, version, body).await
539 }
540
541 async fn send2_with_resp(&self, cmd: CMD, version: u8, body: &[&[u8]], timeout: Duration) -> CmdResult<CmdBody> {
542 let mut send = self.get_send().await?;
543 send.send2_with_resp(cmd, version, body, timeout).await
544 }
545
546 async fn send_cmd(&self, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
547 let mut send = self.get_send().await?;
548 send.send_cmd(cmd, version, body).await
549 }
550
551 async fn send_cmd_with_resp(&self, cmd: CMD, version: u8, body: CmdBody, timeout: Duration) -> CmdResult<CmdBody> {
552 let mut send = self.get_send().await?;
553 send.send_cmd_with_resp(cmd, version, body, timeout).await
554 }
555
556 async fn send_by_specify_tunnel(&self, tunnel_id: TunnelId, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
557 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
558 send.send(cmd, version, body).await
559 }
560
561 async fn send_by_specify_tunnel_with_resp(&self, tunnel_id: TunnelId, cmd: CMD, version: u8, body: &[u8], timeout: Duration) -> CmdResult<CmdBody> {
562 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
563 send.send_with_resp(cmd, version, body, timeout).await
564 }
565
566 async fn send2_by_specify_tunnel(&self, tunnel_id: TunnelId, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
567 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
568 send.send2(cmd, version, body).await
569 }
570
571 async fn send2_by_specify_tunnel_with_resp(&self, tunnel_id: TunnelId, cmd: CMD, version: u8, body: &[&[u8]], timeout: Duration) -> CmdResult<CmdBody> {
572 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
573 send.send2_with_resp(cmd, version, body, timeout).await
574 }
575
576 async fn send_cmd_by_specify_tunnel(&self, tunnel_id: TunnelId, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
577 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
578 send.send_cmd(cmd, version, body).await
579 }
580
581 async fn send_cmd_by_specify_tunnel_with_resp(&self, tunnel_id: TunnelId, cmd: CMD, version: u8, body: CmdBody, timeout: Duration) -> CmdResult<CmdBody> {
582 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
583 send.send_cmd_with_resp(cmd, version, body, timeout).await
584 }
585
586 async fn clear_all_tunnel(&self) {
587 self.tunnel_pool.clear_all_worker().await;
588 }
589
590 async fn get_send(&self, tunnel_id: TunnelId) -> CmdResult<ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD>> {
591 Ok(ClassifiedSendGuard {
592 worker_guard: self.get_send_of_tunnel_id(tunnel_id).await?,
593 _p: std::marker::PhantomData,
594 })
595 }
596}
597
598#[async_trait::async_trait]
599impl<C: WorkerClassification,
600 M: CmdTunnelMeta,
601 R: ClassifiedCmdTunnelRead<C, M>,
602 W: ClassifiedCmdTunnelWrite<C, M>,
603 F: ClassifiedCmdTunnelFactory<C, M, R, W>,
604 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
605 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Eq + Hash + Debug
606> ClassifiedCmdClient<LEN, CMD, C, M, ClassifiedCmdSend<C, M, R, W, LEN, CMD>, ClassifiedClientSendGuard<C, M, R, W, F, LEN, CMD>> for DefaultClassifiedCmdClient<C, M, R, W, F, LEN, CMD> {
607 async fn send_by_classified_tunnel(&self, classification: C, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
608 let mut send = self.get_classified_send(classification).await?;
609 send.send(cmd, version, body).await
610 }
611
612 async fn send_by_classified_tunnel_with_resp(&self, classification: C, cmd: CMD, version: u8, body: &[u8], timeout: Duration) -> CmdResult<CmdBody> {
613 let mut send = self.get_classified_send(classification).await?;
614 send.send_with_resp(cmd, version, body, timeout).await
615 }
616
617 async fn send2_by_classified_tunnel(&self, classification: C, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
618 let mut send = self.get_classified_send(classification).await?;
619 send.send2(cmd, version, body).await
620 }
621
622 async fn send2_by_classified_tunnel_with_resp(&self, classification: C, cmd: CMD, version: u8, body: &[&[u8]], timeout: Duration) -> CmdResult<CmdBody> {
623 let mut send = self.get_classified_send(classification).await?;
624 send.send2_with_resp(cmd, version, body, timeout).await
625 }
626
627 async fn send_cmd_by_classified_tunnel(&self, classification: C, cmd: CMD, version: u8, body: CmdBody) -> CmdResult<()> {
628 let mut send = self.get_classified_send(classification).await?;
629 send.send_cmd(cmd, version, body).await
630 }
631
632 async fn send_cmd_by_classified_tunnel_with_resp(&self, classification: C, cmd: CMD, version: u8, body: CmdBody, timeout: Duration) -> CmdResult<CmdBody> {
633 let mut send = self.get_classified_send(classification).await?;
634 send.send_cmd_with_resp(cmd, version, body, timeout).await
635 }
636
637 async fn find_tunnel_id_by_classified(&self, classification: C) -> CmdResult<TunnelId> {
638 let send = self.get_classified_send(classification).await?;
639 Ok(send.get_tunnel_id())
640 }
641
642 async fn get_send_by_classified(&self, classification: C) -> CmdResult<ClassifiedSendGuard<CmdClientTunnelClassification<C>, M, ClassifiedCmdSend<C, M, R, W, LEN, CMD>, ClassifiedCmdWriteFactory<C, M, R, W, F, LEN, CMD>>> {
643 Ok(ClassifiedSendGuard {
644 worker_guard: self.get_classified_send(classification).await?,
645 _p: std::marker::PhantomData,
646 })
647 }
648}