1use std::fmt::Debug;
2use std::hash::Hash;
3use std::sync::{Arc, Mutex};
4use bucky_raw_codec::{RawConvertTo, RawDecode, RawEncode, RawFixedBytes, RawFrom};
5use num::{FromPrimitive, ToPrimitive};
6use sfo_pool::{into_pool_err, pool_err, ClassifiedWorker, ClassifiedWorkerFactory, ClassifiedWorkerGuard, ClassifiedWorkerPool, ClassifiedWorkerPoolRef, PoolErrorCode, PoolResult};
7use sfo_split::{Splittable, WHalf};
8use tokio::io::{AsyncReadExt, AsyncWriteExt};
9use tokio::spawn;
10use tokio::task::JoinHandle;
11use crate::{CmdTunnelRead, CmdTunnelWrite, TunnelId, TunnelIdGenerator};
12use crate::client::CmdClient;
13use crate::cmd::{CmdBodyReadImpl, CmdHandler, CmdHandlerMap, CmdHeader};
14use crate::errors::{into_cmd_err, CmdErrorCode, CmdResult};
15use crate::peer_id::PeerId;
16
17#[async_trait::async_trait]
18pub trait CmdTunnelFactory<R: CmdTunnelRead, W: CmdTunnelWrite>: Send + Sync + 'static {
19 async fn create_tunnel(&self) -> CmdResult<Splittable<R, W>>;
20}
21
22pub struct CmdSend<R: CmdTunnelRead, W: CmdTunnelWrite, LEN, CMD>
23where LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
24 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug {
25 pub(crate) recv_handle: JoinHandle<CmdResult<()>>,
26 pub(crate) write: WHalf<R, W>,
27 pub(crate) is_work: bool,
28 pub(crate) tunnel_id: TunnelId,
29 _p: std::marker::PhantomData<(LEN, CMD)>,
30
31}
32
33impl<R, W, LEN, CMD> CmdSend<R, W, LEN, CMD>
34where R: CmdTunnelRead,
35 W: CmdTunnelWrite,
36 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
37 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug {
38 pub fn new(tunnel_id: TunnelId, recv_handle: JoinHandle<CmdResult<()>>, write: WHalf<R, W>) -> Self {
39 Self {
40 recv_handle,
41 write,
42 is_work: true,
43 tunnel_id,
44 _p: Default::default(),
45 }
46 }
47
48 pub fn get_tunnel_id(&self) -> TunnelId {
49 self.tunnel_id
50 }
51
52 pub fn set_disable(&mut self) {
53 self.is_work = false;
54 self.recv_handle.abort();
55 }
56
57 pub async fn send(&mut self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
58 log::trace!("client {:?} send cmd: {:?}, len: {} data:{}", self.tunnel_id, cmd, body.len(), hex::encode(body));
59 let header = CmdHeader::<LEN, CMD>::new(version, cmd, LEN::from_u64(body.len() as u64).unwrap());
60 let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
61 let ret = self.send_inner(buf.as_slice(), body).await;
62 if let Err(e) = ret {
63 self.set_disable();
64 return Err(e);
65 }
66 Ok(())
67 }
68
69 pub async fn send2(&mut self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
70 let mut len = 0;
71 for b in body.iter() {
72 len += b.len();
73 log::trace!("client {:?} send2 cmd: {:?}, data {}", self.tunnel_id, cmd, hex::encode(b));
74 }
75 log::trace!("client {:?} send2 cmd: {:?}, len {}", self.tunnel_id, cmd, len);
76 let header = CmdHeader::<LEN, CMD>::new(version, cmd, LEN::from_u64(len as u64).unwrap());
77 let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
78 let ret = self.send_inner2(buf.as_slice(), body).await;
79 if let Err(e) = ret {
80 self.set_disable();
81 return Err(e);
82 }
83 Ok(())
84 }
85
86 async fn send_inner(&mut self, header: &[u8], body: &[u8]) -> CmdResult<()> {
87 self.write.write_all(header).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
88 self.write.write_all(body).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
89 self.write.flush().await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
90 Ok(())
91 }
92
93 async fn send_inner2(&mut self, header: &[u8], body: &[&[u8]]) -> CmdResult<()> {
94 self.write.write_all(header).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
95 for b in body.iter() {
96 self.write.write_all(b).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
97 }
98 self.write.flush().await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
99 Ok(())
100 }
101}
102
103impl<R, W, LEN, CMD> Drop for CmdSend<R, W, LEN, CMD>
104where R: CmdTunnelRead,
105 W: CmdTunnelWrite,
106 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
107 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug {
108 fn drop(&mut self) {
109 self.set_disable();
110 }
111}
112
113impl<R, W, LEN, CMD> ClassifiedWorker<TunnelId> for CmdSend<R, W, LEN, CMD>
114where R: CmdTunnelRead,
115 W: CmdTunnelWrite,
116 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
117 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug{
118 fn is_work(&self) -> bool {
119 self.is_work && !self.recv_handle.is_finished()
120 }
121
122 fn is_valid(&self, c: TunnelId) -> bool {
123 self.tunnel_id == c
124 }
125
126 fn classification(&self) -> TunnelId {
127 self.tunnel_id
128 }
129}
130
131
132struct CmdWriteFactory<R: CmdTunnelRead,
133 W: CmdTunnelWrite,
134 F: CmdTunnelFactory<R, W>,
135 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
136 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug> {
137 tunnel_factory: F,
138 cmd_handler: Arc<dyn CmdHandler<LEN, CMD>>,
139 tunnel_id_generator: TunnelIdGenerator,
140 p: std::marker::PhantomData<Mutex<(R, W)>>,
141}
142
143impl<R: CmdTunnelRead,
144 W: CmdTunnelWrite,
145 F: CmdTunnelFactory<R, W>,
146 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
147 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug> CmdWriteFactory<R, W, F, LEN, CMD> {
148 pub fn new(tunnel_factory: F, cmd_handler: impl CmdHandler<LEN, CMD>) -> Self {
149 Self {
150 tunnel_factory,
151 cmd_handler: Arc::new(cmd_handler),
152 tunnel_id_generator: TunnelIdGenerator::new(),
153 p: Default::default(),
154 }
155 }
156}
157
158#[async_trait::async_trait]
159impl<R: CmdTunnelRead,
160 W: CmdTunnelWrite,
161 F: CmdTunnelFactory<R, W>,
162 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
163 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Debug> ClassifiedWorkerFactory<TunnelId, CmdSend<R, W, LEN, CMD>> for CmdWriteFactory<R, W, F, LEN, CMD> {
164 async fn create(&self, c: Option<TunnelId>) -> PoolResult<CmdSend<R, W, LEN, CMD>> {
165 if c.is_some() {
166 return Err(pool_err!(PoolErrorCode::Failed, "tunnel {:?} not found", c.unwrap()));
167 }
168 let tunnel = self.tunnel_factory.create_tunnel().await.map_err(into_pool_err!(PoolErrorCode::Failed))?;
169 let peer_id = tunnel.get_remote_peer_id();
170 let tunnel_id = self.tunnel_id_generator.generate();
171 let (mut recv, write) = tunnel.split();
172 let cmd_handler = self.cmd_handler.clone();
173 let handle = spawn(async move {
174 let ret: CmdResult<()> = async move {
175 loop {
176 let mut header = vec![0u8; CmdHeader::<LEN, CMD>::raw_bytes().unwrap()];
177 let n = recv.read_exact(header.as_mut()).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
178 if n == 0 {
179 break;
180 }
181 let header = CmdHeader::<LEN, CMD>::clone_from_slice(header.as_slice()).map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
182 log::trace!("recv cmd {:?} from {} len {}", header.cmd_code(), peer_id.to_base58(), header.pkg_len().to_u64().unwrap());
183 let cmd_read = Box::new(CmdBodyReadImpl::new(recv, header.pkg_len().to_u64().unwrap() as usize));
184 let waiter = cmd_read.get_waiter();
185 let future = waiter.create_result_future();
186 if let Err(e) = cmd_handler.handle(peer_id.clone(), tunnel_id, header, cmd_read).await {
187 log::error!("handle cmd error: {:?}", e);
188 }
189 recv = future.await.map_err(into_cmd_err!(CmdErrorCode::Failed))??;
190 }
191 Ok(())
192 }.await;
193 ret
194 });
195 Ok(CmdSend::new(tunnel_id, handle, write))
196 }
197}
198
199pub struct DefaultCmdClient<R: CmdTunnelRead,
200 W: CmdTunnelWrite,
201 F: CmdTunnelFactory<R, W>,
202 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
203 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Eq + Hash + Debug> {
204 tunnel_pool: ClassifiedWorkerPoolRef<TunnelId, CmdSend<R, W, LEN, CMD>, CmdWriteFactory<R, W, F, LEN, CMD>>,
205 cmd_handler_map: Arc<CmdHandlerMap<LEN, CMD>>,
206}
207
208impl<R: CmdTunnelRead,
209 W: CmdTunnelWrite,
210 F: CmdTunnelFactory<R, W>,
211 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
212 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Eq + Hash + Debug> DefaultCmdClient<R, W, F, LEN, CMD> {
213 pub fn new(factory: F, tunnel_count: u16) -> Arc<Self> {
214 let cmd_handler_map = Arc::new(CmdHandlerMap::new());
215 let handler_map = cmd_handler_map.clone();
216 Arc::new(Self {
217 tunnel_pool: ClassifiedWorkerPool::new(tunnel_count, CmdWriteFactory::<R, W, _, LEN, CMD>::new(factory, move |peer_id: PeerId, tunnel_id: TunnelId, header: CmdHeader<LEN, CMD>, body_read| {
218 let handler_map = handler_map.clone();
219 async move {
220 if let Some(handler) = handler_map.get(header.cmd_code()) {
221 handler.handle(peer_id, tunnel_id, header, body_read).await?;
222 }
223 Ok(())
224 }
225 })),
226 cmd_handler_map,
227 })
228 }
229
230 async fn get_send(&self) -> CmdResult<ClassifiedWorkerGuard<TunnelId, CmdSend<R, W, LEN, CMD>, CmdWriteFactory<R, W, F, LEN, CMD>>> {
231 self.tunnel_pool.get_worker().await.map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
232 }
233
234 async fn get_send_of_tunnel_id(&self, tunnel_id: TunnelId) -> CmdResult<ClassifiedWorkerGuard<TunnelId, CmdSend<R, W, LEN, CMD>, CmdWriteFactory<R, W, F, LEN, CMD>>> {
235 self.tunnel_pool.get_classified_worker(tunnel_id).await.map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
236 }
237}
238
239#[async_trait::async_trait]
240impl<R: CmdTunnelRead,
241 W: CmdTunnelWrite,
242 F: CmdTunnelFactory<R, W>,
243 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
244 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Eq + Hash + Debug> CmdClient<LEN, CMD> for DefaultCmdClient<R, W, F, LEN, CMD> {
245 fn register_cmd_handler(&self, cmd: CMD, handler: impl CmdHandler<LEN, CMD>) {
246 self.cmd_handler_map.insert(cmd, handler);
247 }
248
249 async fn send(&self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
250 let mut send = self.get_send().await?;
251 send.send(cmd, version, body).await
252 }
253
254 async fn send2(&self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
255 let mut send = self.get_send().await?;
256 send.send2(cmd, version, body).await
257 }
258
259 async fn send_by_specify_tunnel(&self, tunnel_id: TunnelId, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
260 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
261 send.send(cmd, version, body).await
262 }
263
264 async fn send2_by_specify_tunnel(&self, tunnel_id: TunnelId, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
265 let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
266 send.send2(cmd, version, body).await
267 }
268
269 async fn clear_all_tunnel(&self) {
270 self.tunnel_pool.clear_all_worker().await;
271 }
272}