1use std::collections::HashMap;
2use std::fmt::Debug;
3use std::hash::Hash;
4use std::{fmt, io};
5use std::ops::DerefMut;
6use std::pin::Pin;
7use std::sync::{Arc, Mutex};
8use std::task::{Context, Poll};
9use bucky_raw_codec::{RawDecode, RawEncode};
10use callback_result::{SingleCallbackWaiter};
11use futures_lite::ready;
12use num::{FromPrimitive, ToPrimitive};
13use sfo_split::RHalf;
14use tokio::io::{AsyncBufRead, AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
15use crate::errors::{cmd_err, into_cmd_err, CmdErrorCode, CmdResult};
16use crate::peer_id::PeerId;
17use crate::{TunnelId};
18
19#[derive(RawEncode, RawDecode)]
20pub struct CmdHeader<LEN, CMD> {
21 pkg_len: LEN,
22 version: u8,
23 cmd_code: CMD,
24 is_resp: bool,
25 seq: Option<u32>,
26}
27
28impl<LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
29 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static> CmdHeader<LEN, CMD> {
30 pub fn new(version: u8, is_resp: bool, seq: Option<u32>, cmd_code: CMD, pkg_len: LEN) -> Self {
31 Self {
32 pkg_len,
33 version,
34 seq,
35 cmd_code,
36 is_resp,
37 }
38 }
39
40 pub fn pkg_len(&self) -> LEN {
41 self.pkg_len
42 }
43
44 pub fn version(&self) -> u8 {
45 self.version
46 }
47
48 pub fn seq(&self) -> Option<u32> {
49 self.seq
50 }
51
52 pub fn is_resp(&self) -> bool {
53 self.is_resp
54 }
55
56 pub fn cmd_code(&self) -> CMD {
57 self.cmd_code
58 }
59
60 pub fn set_pkg_len(&mut self, pkg_len: LEN) {
61 self.pkg_len = pkg_len;
62 }
63}
64
65#[async_trait::async_trait]
66pub trait CmdBodyReadAll: tokio::io::AsyncRead + Send + 'static {
67 async fn read_all(&mut self) -> CmdResult<Vec<u8>>;
68}
69
70pub(crate) struct CmdBodyRead<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin> {
71 recv: Option<RHalf<R, W>>,
72 len: usize,
73 offset: usize,
74 waiter: Arc<SingleCallbackWaiter<CmdResult<RHalf<R, W>>>>,
75}
76
77impl<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin> CmdBodyRead<R, W> {
78 pub fn new(recv: RHalf<R, W>, len: usize) -> Self {
79 Self {
80 recv: Some(recv),
81 len,
82 offset: 0,
83 waiter: Arc::new(SingleCallbackWaiter::new()),
84 }
85 }
86
87
88 pub(crate) fn get_waiter(&self) -> Arc<SingleCallbackWaiter<CmdResult<RHalf<R, W>>>> {
89 self.waiter.clone()
90 }
91}
92
93#[async_trait::async_trait]
94impl<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin> CmdBodyReadAll for CmdBodyRead<R, W> {
95 async fn read_all(&mut self) -> CmdResult<Vec<u8>> {
96 if self.offset == self.len {
97 return Ok(Vec::new());
98 }
99 let mut buf = vec![0u8; self.len - self.offset];
100 let ret = self.recv.as_mut().unwrap().read_exact(&mut buf).await.map_err(into_cmd_err!(CmdErrorCode::IoError));
101 if ret.is_ok() {
102 self.offset = self.len;
103 self.waiter.set_result_with_cache(Ok(self.recv.take().unwrap()));
104 Ok(buf)
105 } else {
106 self.recv.take();
107 self.waiter.set_result_with_cache(Err(cmd_err!(CmdErrorCode::IoError, "read body error")));
108 Err(ret.err().unwrap())
109 }
110 }
111}
112
113impl<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin> Drop for CmdBodyRead<R, W> {
114 fn drop(&mut self) {
115 if self.recv.is_none() || (self.len == self.offset && self.len != 0) {
116 return;
117 }
118 let mut recv = self.recv.take().unwrap();
119 let len = self.len - self.offset;
120 let waiter = self.waiter.clone();
121 if len == 0 {
122 waiter.set_result_with_cache(Ok(recv));
123 return;
124 }
125
126 tokio::spawn(async move {
127 let mut buf = vec![0u8; len];
128 if let Err(e) = recv.read_exact(&mut buf).await {
129 waiter.set_result_with_cache(Err(cmd_err!(CmdErrorCode::IoError, "read body error {}", e)));
130 } else {
131 waiter.set_result_with_cache(Ok(recv));
132 }
133 });
134 }
135}
136
137impl<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin> tokio::io::AsyncRead for CmdBodyRead<R, W> {
138 fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
139 let this = Pin::into_inner(self);
140 let len = this.len - this.offset;
141 if len == 0 {
142 return Poll::Ready(Ok(()));
143 }
144 let recv = Pin::new(this.recv.as_mut().unwrap().deref_mut());
145 let fut = recv.poll_read(cx, buf);
146 match fut {
147 Poll::Ready(Ok(())) => {
148 this.offset += buf.filled().len();
149 if this.offset == this.len {
150 this.waiter.set_result_with_cache(Ok(this.recv.take().unwrap()));
151 }
152 Poll::Ready(Ok(()))
153 }
154 Poll::Ready(Err(e)) => {
155 this.recv.take();
156 this.waiter.set_result_with_cache(Err(cmd_err!(CmdErrorCode::IoError, "read body error")));
157 Poll::Ready(Err(e))
158 },
159 Poll::Pending => Poll::Pending,
160 }
161 }
162}
163
164
165#[callback_trait::callback_trait]
166pub trait CmdHandler<LEN, CMD>: Send + Sync + 'static
167where LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
168 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static {
169 async fn handle(&self, peer_id: PeerId, tunnel_id: TunnelId, header: CmdHeader<LEN, CMD>, body: CmdBody) -> CmdResult<Option<CmdBody>>;
170}
171
172pub(crate) struct CmdHandlerMap<LEN, CMD> {
173 map: Mutex<HashMap<CMD, Arc<dyn CmdHandler<LEN, CMD>>>>,
174}
175
176impl <LEN, CMD> CmdHandlerMap<LEN, CMD>
177where LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
178 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Eq + Hash {
179 pub fn new() -> Self {
180 Self {
181 map: Mutex::new(HashMap::new()),
182 }
183 }
184
185 pub fn insert(&self, cmd: CMD, handler: impl CmdHandler<LEN, CMD>) {
186 self.map.lock().unwrap().insert(cmd, Arc::new(handler));
187 }
188
189 pub fn get(&self, cmd: CMD) -> Option<Arc<dyn CmdHandler<LEN, CMD>>> {
190 self.map.lock().unwrap().get(&cmd).map(|v| v.clone())
191 }
192}
193pin_project_lite::pin_project! {
194pub struct CmdBody {
195 #[pin]
196 reader: Box<dyn AsyncBufRead + Unpin + Send + 'static>,
197 length: u64,
198 bytes_read: u64,
199 }
200}
201
202impl CmdBody {
203 pub fn empty() -> Self {
204 Self {
205 reader: Box::new(tokio::io::empty()),
206 length: 0,
207 bytes_read: 0,
208 }
209 }
210
211 pub fn from_reader(
212 reader: impl AsyncBufRead + Unpin + Send + 'static,
213 length: u64,
214 ) -> Self {
215 Self {
216 reader: Box::new(reader),
217 length,
218 bytes_read: 0,
219 }
220 }
221
222 pub fn into_reader(self) -> Box<dyn AsyncBufRead + Unpin + Send + 'static> {
223 self.reader
224 }
225
226 pub async fn read_all(&mut self) -> CmdResult<Vec<u8>> {
227 let mut buf = Vec::with_capacity(1024);
228 self.read_to_end(&mut buf).await.map_err(into_cmd_err!(CmdErrorCode::Failed, "read to end failed"))?;
229 Ok(buf)
230 }
231
232 pub fn from_bytes(bytes: Vec<u8>) -> Self {
233 Self {
234 length: bytes.len() as u64,
235 reader: Box::new(io::Cursor::new(bytes)),
236 bytes_read: 0,
237 }
238 }
239
240 pub async fn into_bytes(mut self) -> CmdResult<Vec<u8>> {
241 let mut buf = Vec::with_capacity(1024);
242 self.read_to_end(&mut buf)
243 .await.map_err(into_cmd_err!(CmdErrorCode::Failed, "read to end failed"))?;
244 Ok(buf)
245 }
246
247 pub fn from_string(s: String) -> Self {
248 Self {
249 length: s.len() as u64,
250 reader: Box::new(io::Cursor::new(s.into_bytes())),
251 bytes_read: 0,
252 }
253 }
254
255 pub async fn into_string(mut self) -> CmdResult<String> {
256 let mut result = String::with_capacity(self.len() as usize);
257 self.read_to_string(&mut result)
258 .await.map_err(into_cmd_err!(CmdErrorCode::Failed, "read to string failed"))?;
259 Ok(result)
260 }
261
262 pub async fn from_path<P>(path: P) -> io::Result<Self>
263 where
264 P: AsRef<std::path::Path>,
265 {
266 let path = path.as_ref();
267 let file = tokio::fs::File::open(path).await?;
268 Self::from_file(file).await
269 }
270
271 pub async fn from_file(
272 file: tokio::fs::File,
273 ) -> io::Result<Self> {
274 let len = file.metadata().await?.len();
275
276 Ok(Self {
277 length: len,
278 reader: Box::new(tokio::io::BufReader::new(file)),
279 bytes_read: 0,
280 })
281 }
282
283 pub fn len(&self) -> u64 {
284 self.length
285 }
286
287 pub fn is_empty(&self) -> bool {
289 self.length == 0
290 }
291
292 pub fn chain(self, other: CmdBody) -> Self {
293 let length = (self.length - self.bytes_read).checked_add(other.length - other.bytes_read).unwrap_or(0);
294 Self {
295 length,
296 reader: Box::new(tokio::io::AsyncReadExt::chain(self, other)),
297 bytes_read: 0,
298 }
299 }
300}
301
302impl Debug for CmdBody {
303 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
304 f.debug_struct("CmdResponse")
305 .field("reader", &"<hidden>")
306 .field("length", &self.length)
307 .field("bytes_read", &self.bytes_read)
308 .finish()
309 }
310}
311
312
313impl From<String> for CmdBody {
314 fn from(s: String) -> Self {
315 Self::from_string(s)
316 }
317}
318
319impl<'a> From<&'a str> for CmdBody {
320 fn from(s: &'a str) -> Self {
321 Self::from_string(s.to_owned())
322 }
323}
324
325impl From<Vec<u8>> for CmdBody {
326 fn from(b: Vec<u8>) -> Self {
327 Self::from_bytes(b)
328 }
329}
330
331impl<'a> From<&'a [u8]> for CmdBody {
332 fn from(b: &'a [u8]) -> Self {
333 Self::from_bytes(b.to_owned())
334 }
335}
336
337impl AsyncRead for CmdBody {
338 #[allow(rustdoc::missing_doc_code_examples)]
339 fn poll_read(
340 mut self: Pin<&mut Self>,
341 cx: &mut Context<'_>,
342 buf: &mut ReadBuf<'_>,
343 ) -> Poll<io::Result<()>> {
344 let buf = if self.length == self.bytes_read {
345 return Poll::Ready(Ok(()));
346 } else {
347 buf
348 };
349
350 ready!(Pin::new(&mut self.reader).poll_read(cx, buf))?;
351 self.bytes_read += buf.filled().len() as u64;
352 Poll::Ready(Ok(()))
353 }
354}
355
356impl AsyncBufRead for CmdBody {
357 #[allow(rustdoc::missing_doc_code_examples)]
358 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&'_ [u8]>> {
359 self.project().reader.poll_fill_buf(cx)
360 }
361
362 fn consume(mut self: Pin<&mut Self>, amt: usize) {
363 Pin::new(&mut self.reader).consume(amt)
364 }
365}