1use std::collections::HashMap;
2use std::fmt::Debug;
3use std::hash::Hash;
4use std::{fmt, io};
5use std::ops::DerefMut;
6use std::pin::Pin;
7use std::sync::{Arc, Mutex};
8use std::task::{Context, Poll};
9use bucky_raw_codec::{RawDecode, RawEncode};
10use callback_result::{SingleCallbackWaiter};
11use futures_lite::ready;
12use num::{FromPrimitive, ToPrimitive};
13use sfo_split::RHalf;
14use tokio::io::{AsyncBufRead, AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
15use crate::errors::{cmd_err, into_cmd_err, CmdErrorCode, CmdResult};
16use crate::peer_id::PeerId;
17use crate::{TunnelId};
18
19#[derive(RawEncode, RawDecode)]
20pub struct CmdHeader<LEN, CMD> {
21 pkg_len: LEN,
22 version: u8,
23 cmd_code: CMD,
24 is_resp: bool,
25 seq: Option<u32>,
26}
27
28impl<LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
29 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static> CmdHeader<LEN, CMD> {
30 pub fn new(version: u8, is_resp: bool, seq: Option<u32>, cmd_code: CMD, pkg_len: LEN) -> Self {
31 Self {
32 pkg_len,
33 version,
34 seq,
35 cmd_code,
36 is_resp,
37 }
38 }
39
40 pub fn pkg_len(&self) -> LEN {
41 self.pkg_len
42 }
43
44 pub fn version(&self) -> u8 {
45 self.version
46 }
47
48 pub fn seq(&self) -> Option<u32> {
49 self.seq
50 }
51
52 pub fn is_resp(&self) -> bool {
53 self.is_resp
54 }
55
56 pub fn cmd_code(&self) -> CMD {
57 self.cmd_code
58 }
59
60 pub fn set_pkg_len(&mut self, pkg_len: LEN) {
61 self.pkg_len = pkg_len;
62 }
63}
64
65#[async_trait::async_trait]
66pub trait CmdBodyReadAll: tokio::io::AsyncRead + Send + 'static {
67 async fn read_all(&mut self) -> CmdResult<Vec<u8>>;
68}
69
70pub(crate) struct CmdBodyRead<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin> {
71 recv: Option<RHalf<R, W>>,
72 len: usize,
73 offset: usize,
74 waiter: Arc<SingleCallbackWaiter<CmdResult<RHalf<R, W>>>>,
75}
76
77impl<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin> CmdBodyRead<R, W> {
78 pub fn new(recv: RHalf<R, W>, len: usize) -> Self {
79 Self {
80 recv: Some(recv),
81 len,
82 offset: 0,
83 waiter: Arc::new(SingleCallbackWaiter::new()),
84 }
85 }
86
87
88 pub(crate) fn get_waiter(&self) -> Arc<SingleCallbackWaiter<CmdResult<RHalf<R, W>>>> {
89 self.waiter.clone()
90 }
91}
92
93#[async_trait::async_trait]
94impl<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin> CmdBodyReadAll for CmdBodyRead<R, W> {
95 async fn read_all(&mut self) -> CmdResult<Vec<u8>> {
96 if self.offset == self.len {
97 return Ok(Vec::new());
98 }
99 let mut buf = vec![0u8; self.len - self.offset];
100 let ret = self.recv.as_mut().unwrap().read_exact(&mut buf).await.map_err(into_cmd_err!(CmdErrorCode::IoError));
101 if ret.is_ok() {
102 self.offset = self.len;
103 self.waiter.set_result_with_cache(Ok(self.recv.take().unwrap()));
104 Ok(buf)
105 } else {
106 self.recv.take();
107 self.waiter.set_result_with_cache(Err(cmd_err!(CmdErrorCode::IoError, "read body error")));
108 Err(ret.err().unwrap())
109 }
110 }
111}
112
113impl<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin> Drop for CmdBodyRead<R, W> {
114 fn drop(&mut self) {
115 if self.recv.is_none() || (self.len == self.offset && self.len != 0) {
116 return;
117 }
118 let mut recv = self.recv.take().unwrap();
119 let len = self.len - self.offset;
120 let waiter = self.waiter.clone();
121 if len == 0 {
122 waiter.set_result_with_cache(Ok(recv));
123 return;
124 }
125
126 tokio::spawn(async move {
127 let mut buf = vec![0u8; len];
128 if let Err(e) = recv.read_exact(&mut buf).await {
129 waiter.set_result_with_cache(Err(cmd_err!(CmdErrorCode::IoError, "read body error {}", e)));
130 } else {
131 waiter.set_result_with_cache(Ok(recv));
132 }
133 });
134 }
135}
136
137impl<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin> tokio::io::AsyncRead for CmdBodyRead<R, W> {
138 fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
139 let this = Pin::into_inner(self);
140 let len = this.len - this.offset;
141 if len == 0 {
142 return Poll::Ready(Ok(()));
143 }
144 let recv = Pin::new(this.recv.as_mut().unwrap().deref_mut());
145 let read_len = std::cmp::min(len, buf.remaining());
146 let mut read_buf = ReadBuf::new(buf.initialize_unfilled_to(read_len));
147 let fut = recv.poll_read(cx, &mut read_buf);
148 match fut {
149 Poll::Ready(Ok(())) => {
150 let len = read_buf.filled().len();
151 drop(read_buf);
152 this.offset += len;
153 buf.advance(len);
154 if this.offset == this.len {
155 this.waiter.set_result_with_cache(Ok(this.recv.take().unwrap()));
156 }
157 Poll::Ready(Ok(()))
158 }
159 Poll::Ready(Err(e)) => {
160 this.recv.take();
161 this.waiter.set_result_with_cache(Err(cmd_err!(CmdErrorCode::IoError, "read body error")));
162 Poll::Ready(Err(e))
163 },
164 Poll::Pending => Poll::Pending,
165 }
166 }
167}
168
169
170#[callback_trait::callback_trait]
171pub trait CmdHandler<LEN, CMD>: Send + Sync + 'static
172where LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
173 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static {
174 async fn handle(&self, peer_id: PeerId, tunnel_id: TunnelId, header: CmdHeader<LEN, CMD>, body: CmdBody) -> CmdResult<Option<CmdBody>>;
175}
176
177pub(crate) struct CmdHandlerMap<LEN, CMD> {
178 map: Mutex<HashMap<CMD, Arc<dyn CmdHandler<LEN, CMD>>>>,
179}
180
181impl <LEN, CMD> CmdHandlerMap<LEN, CMD>
182where LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
183 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Eq + Hash {
184 pub fn new() -> Self {
185 Self {
186 map: Mutex::new(HashMap::new()),
187 }
188 }
189
190 pub fn insert(&self, cmd: CMD, handler: impl CmdHandler<LEN, CMD>) {
191 self.map.lock().unwrap().insert(cmd, Arc::new(handler));
192 }
193
194 pub fn get(&self, cmd: CMD) -> Option<Arc<dyn CmdHandler<LEN, CMD>>> {
195 self.map.lock().unwrap().get(&cmd).map(|v| v.clone())
196 }
197}
198pin_project_lite::pin_project! {
199pub struct CmdBody {
200 #[pin]
201 reader: Box<dyn AsyncBufRead + Unpin + Send + 'static>,
202 length: u64,
203 bytes_read: u64,
204 }
205}
206
207impl CmdBody {
208 pub fn empty() -> Self {
209 Self {
210 reader: Box::new(tokio::io::empty()),
211 length: 0,
212 bytes_read: 0,
213 }
214 }
215
216 pub fn from_reader(
217 reader: impl AsyncBufRead + Unpin + Send + 'static,
218 length: u64,
219 ) -> Self {
220 Self {
221 reader: Box::new(reader),
222 length,
223 bytes_read: 0,
224 }
225 }
226
227 pub fn into_reader(self) -> Box<dyn AsyncBufRead + Unpin + Send + 'static> {
228 self.reader
229 }
230
231 pub async fn read_all(&mut self) -> CmdResult<Vec<u8>> {
232 let mut buf = Vec::with_capacity(1024);
233 self.read_to_end(&mut buf).await.map_err(into_cmd_err!(CmdErrorCode::Failed, "read to end failed"))?;
234 Ok(buf)
235 }
236
237 pub fn from_bytes(bytes: Vec<u8>) -> Self {
238 Self {
239 length: bytes.len() as u64,
240 reader: Box::new(io::Cursor::new(bytes)),
241 bytes_read: 0,
242 }
243 }
244
245 pub async fn into_bytes(mut self) -> CmdResult<Vec<u8>> {
246 let mut buf = Vec::with_capacity(1024);
247 self.read_to_end(&mut buf)
248 .await.map_err(into_cmd_err!(CmdErrorCode::Failed, "read to end failed"))?;
249 Ok(buf)
250 }
251
252 pub fn from_string(s: String) -> Self {
253 Self {
254 length: s.len() as u64,
255 reader: Box::new(io::Cursor::new(s.into_bytes())),
256 bytes_read: 0,
257 }
258 }
259
260 pub async fn into_string(mut self) -> CmdResult<String> {
261 let mut result = String::with_capacity(self.len() as usize);
262 self.read_to_string(&mut result)
263 .await.map_err(into_cmd_err!(CmdErrorCode::Failed, "read to string failed"))?;
264 Ok(result)
265 }
266
267 pub async fn from_path<P>(path: P) -> io::Result<Self>
268 where
269 P: AsRef<std::path::Path>,
270 {
271 let path = path.as_ref();
272 let file = tokio::fs::File::open(path).await?;
273 Self::from_file(file).await
274 }
275
276 pub async fn from_file(
277 file: tokio::fs::File,
278 ) -> io::Result<Self> {
279 let len = file.metadata().await?.len();
280
281 Ok(Self {
282 length: len,
283 reader: Box::new(tokio::io::BufReader::new(file)),
284 bytes_read: 0,
285 })
286 }
287
288 pub fn len(&self) -> u64 {
289 self.length
290 }
291
292 pub fn is_empty(&self) -> bool {
294 self.length == 0
295 }
296
297 pub fn chain(self, other: CmdBody) -> Self {
298 let length = (self.length - self.bytes_read).checked_add(other.length - other.bytes_read).unwrap_or(0);
299 Self {
300 length,
301 reader: Box::new(tokio::io::AsyncReadExt::chain(self, other)),
302 bytes_read: 0,
303 }
304 }
305}
306
307impl Debug for CmdBody {
308 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
309 f.debug_struct("CmdResponse")
310 .field("reader", &"<hidden>")
311 .field("length", &self.length)
312 .field("bytes_read", &self.bytes_read)
313 .finish()
314 }
315}
316
317
318impl From<String> for CmdBody {
319 fn from(s: String) -> Self {
320 Self::from_string(s)
321 }
322}
323
324impl<'a> From<&'a str> for CmdBody {
325 fn from(s: &'a str) -> Self {
326 Self::from_string(s.to_owned())
327 }
328}
329
330impl From<Vec<u8>> for CmdBody {
331 fn from(b: Vec<u8>) -> Self {
332 Self::from_bytes(b)
333 }
334}
335
336impl<'a> From<&'a [u8]> for CmdBody {
337 fn from(b: &'a [u8]) -> Self {
338 Self::from_bytes(b.to_owned())
339 }
340}
341
342impl AsyncRead for CmdBody {
343 #[allow(rustdoc::missing_doc_code_examples)]
344 fn poll_read(
345 mut self: Pin<&mut Self>,
346 cx: &mut Context<'_>,
347 buf: &mut ReadBuf<'_>,
348 ) -> Poll<io::Result<()>> {
349 let buf = if self.length == self.bytes_read {
350 return Poll::Ready(Ok(()));
351 } else {
352 buf
353 };
354
355 ready!(Pin::new(&mut self.reader).poll_read(cx, buf))?;
356 self.bytes_read += buf.filled().len() as u64;
357 Poll::Ready(Ok(()))
358 }
359}
360
361impl AsyncBufRead for CmdBody {
362 #[allow(rustdoc::missing_doc_code_examples)]
363 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&'_ [u8]>> {
364 self.project().reader.poll_fill_buf(cx)
365 }
366
367 fn consume(mut self: Pin<&mut Self>, amt: usize) {
368 Pin::new(&mut self.reader).consume(amt)
369 }
370}