1use crate::TunnelId;
2use crate::errors::{CmdErrorCode, CmdResult, cmd_err, into_cmd_err};
3use crate::peer_id::PeerId;
4use bucky_raw_codec::{RawDecode, RawEncode};
5use callback_result::SingleCallbackWaiter;
6use futures_lite::ready;
7use num::{FromPrimitive, ToPrimitive};
8use sfo_split::RHalf;
9use std::collections::HashMap;
10use std::fmt::Debug;
11use std::hash::Hash;
12use std::ops::DerefMut;
13use std::pin::Pin;
14use std::sync::{Arc, Mutex};
15use std::task::{Context, Poll};
16use std::{fmt, io};
17use tokio::io::{AsyncBufRead, AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
18
19#[derive(RawEncode, RawDecode)]
20pub struct CmdHeader<LEN, CMD> {
21 pkg_len: LEN,
22 version: u8,
23 cmd_code: CMD,
24 is_resp: bool,
25 seq: Option<u32>,
26}
27
28impl<
29 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
30 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static,
31> CmdHeader<LEN, CMD>
32{
33 pub fn new(version: u8, is_resp: bool, seq: Option<u32>, cmd_code: CMD, pkg_len: LEN) -> Self {
34 Self {
35 pkg_len,
36 version,
37 seq,
38 cmd_code,
39 is_resp,
40 }
41 }
42
43 pub fn pkg_len(&self) -> LEN {
44 self.pkg_len
45 }
46
47 pub fn version(&self) -> u8 {
48 self.version
49 }
50
51 pub fn seq(&self) -> Option<u32> {
52 self.seq
53 }
54
55 pub fn is_resp(&self) -> bool {
56 self.is_resp
57 }
58
59 pub fn cmd_code(&self) -> CMD {
60 self.cmd_code
61 }
62
63 pub fn set_pkg_len(&mut self, pkg_len: LEN) {
64 self.pkg_len = pkg_len;
65 }
66}
67
68#[async_trait::async_trait]
69pub trait CmdBodyReadAll: tokio::io::AsyncRead + Send + 'static {
70 async fn read_all(&mut self) -> CmdResult<Vec<u8>>;
71}
72
73pub(crate) struct CmdBodyRead<
74 R: AsyncRead + Send + 'static + Unpin,
75 W: AsyncWrite + Send + 'static + Unpin,
76> {
77 recv: Option<RHalf<R, W>>,
78 len: usize,
79 offset: usize,
80 waiter: Arc<SingleCallbackWaiter<CmdResult<RHalf<R, W>>>>,
81}
82
83impl<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin>
84 CmdBodyRead<R, W>
85{
86 pub fn new(recv: RHalf<R, W>, len: usize) -> Self {
87 Self {
88 recv: Some(recv),
89 len,
90 offset: 0,
91 waiter: Arc::new(SingleCallbackWaiter::new()),
92 }
93 }
94
95 pub(crate) fn get_waiter(&self) -> Arc<SingleCallbackWaiter<CmdResult<RHalf<R, W>>>> {
96 self.waiter.clone()
97 }
98}
99
100#[async_trait::async_trait]
101impl<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin> CmdBodyReadAll
102 for CmdBodyRead<R, W>
103{
104 async fn read_all(&mut self) -> CmdResult<Vec<u8>> {
105 if self.offset == self.len {
106 return Ok(Vec::new());
107 }
108 let mut buf = vec![0u8; self.len - self.offset];
109 let ret = self
110 .recv
111 .as_mut()
112 .unwrap()
113 .read_exact(&mut buf)
114 .await
115 .map_err(into_cmd_err!(CmdErrorCode::IoError));
116 if ret.is_ok() {
117 self.offset = self.len;
118 self.waiter
119 .set_result_with_cache(Ok(self.recv.take().unwrap()));
120 Ok(buf)
121 } else {
122 self.recv.take();
123 self.waiter
124 .set_result_with_cache(Err(cmd_err!(CmdErrorCode::IoError, "read body error")));
125 Err(ret.err().unwrap())
126 }
127 }
128}
129
130impl<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin> Drop
131 for CmdBodyRead<R, W>
132{
133 fn drop(&mut self) {
134 if self.recv.is_none() || (self.len == self.offset && self.len != 0) {
135 return;
136 }
137 let mut recv = self.recv.take().unwrap();
138 let len = self.len - self.offset;
139 let waiter = self.waiter.clone();
140 if len == 0 {
141 waiter.set_result_with_cache(Ok(recv));
142 return;
143 }
144
145 tokio::spawn(async move {
146 let mut buf = vec![0u8; len];
147 if let Err(e) = recv.read_exact(&mut buf).await {
148 waiter.set_result_with_cache(Err(cmd_err!(
149 CmdErrorCode::IoError,
150 "read body error {}",
151 e
152 )));
153 } else {
154 waiter.set_result_with_cache(Ok(recv));
155 }
156 });
157 }
158}
159
160impl<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin>
161 tokio::io::AsyncRead for CmdBodyRead<R, W>
162{
163 fn poll_read(
164 self: Pin<&mut Self>,
165 cx: &mut Context<'_>,
166 buf: &mut ReadBuf<'_>,
167 ) -> Poll<std::io::Result<()>> {
168 let this = Pin::into_inner(self);
169 let len = this.len - this.offset;
170 if len == 0 {
171 return Poll::Ready(Ok(()));
172 }
173 let recv = Pin::new(this.recv.as_mut().unwrap().deref_mut());
174 let read_len = std::cmp::min(len, buf.remaining());
175 let mut read_buf = ReadBuf::new(buf.initialize_unfilled_to(read_len));
176 let fut = recv.poll_read(cx, &mut read_buf);
177 match fut {
178 Poll::Ready(Ok(())) => {
179 let len = read_buf.filled().len();
180 drop(read_buf);
181 this.offset += len;
182 buf.advance(len);
183 if this.offset == this.len {
184 this.waiter
185 .set_result_with_cache(Ok(this.recv.take().unwrap()));
186 }
187 Poll::Ready(Ok(()))
188 }
189 Poll::Ready(Err(e)) => {
190 this.recv.take();
191 this.waiter
192 .set_result_with_cache(Err(cmd_err!(CmdErrorCode::IoError, "read body error")));
193 Poll::Ready(Err(e))
194 }
195 Poll::Pending => Poll::Pending,
196 }
197 }
198}
199
200#[callback_trait::callback_trait]
201pub trait CmdHandler<LEN, CMD>: Send + Sync + 'static
202where
203 LEN: RawEncode
204 + for<'a> RawDecode<'a>
205 + Copy
206 + Send
207 + Sync
208 + 'static
209 + FromPrimitive
210 + ToPrimitive,
211 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static,
212{
213 async fn handle(
214 &self,
215 local_id: PeerId,
216 peer_id: PeerId,
217 tunnel_id: TunnelId,
218 header: CmdHeader<LEN, CMD>,
219 body: CmdBody,
220 ) -> CmdResult<Option<CmdBody>>;
221}
222
223pub(crate) struct CmdHandlerMap<LEN, CMD> {
224 map: Mutex<HashMap<CMD, Arc<dyn CmdHandler<LEN, CMD>>>>,
225}
226
227impl<LEN, CMD> CmdHandlerMap<LEN, CMD>
228where
229 LEN: RawEncode
230 + for<'a> RawDecode<'a>
231 + Copy
232 + Send
233 + Sync
234 + 'static
235 + FromPrimitive
236 + ToPrimitive,
237 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Eq + Hash,
238{
239 pub fn new() -> Self {
240 Self {
241 map: Mutex::new(HashMap::new()),
242 }
243 }
244
245 pub fn insert(&self, cmd: CMD, handler: impl CmdHandler<LEN, CMD>) {
246 self.map.lock().unwrap().insert(cmd, Arc::new(handler));
247 }
248
249 pub fn get(&self, cmd: CMD) -> Option<Arc<dyn CmdHandler<LEN, CMD>>> {
250 self.map.lock().unwrap().get(&cmd).map(|v| v.clone())
251 }
252}
253pin_project_lite::pin_project! {
254pub struct CmdBody {
255 #[pin]
256 reader: Box<dyn AsyncBufRead + Unpin + Send + 'static>,
257 length: u64,
258 bytes_read: u64,
259 }
260}
261
262impl CmdBody {
263 pub fn empty() -> Self {
264 Self {
265 reader: Box::new(tokio::io::empty()),
266 length: 0,
267 bytes_read: 0,
268 }
269 }
270
271 pub fn from_reader(reader: impl AsyncBufRead + Unpin + Send + 'static, length: u64) -> Self {
272 Self {
273 reader: Box::new(reader),
274 length,
275 bytes_read: 0,
276 }
277 }
278
279 pub fn into_reader(self) -> Box<dyn AsyncBufRead + Unpin + Send + 'static> {
280 self.reader
281 }
282
283 pub async fn read_all(&mut self) -> CmdResult<Vec<u8>> {
284 let mut buf = Vec::with_capacity(1024);
285 self.read_to_end(&mut buf)
286 .await
287 .map_err(into_cmd_err!(CmdErrorCode::Failed, "read to end failed"))?;
288 Ok(buf)
289 }
290
291 pub fn from_bytes(bytes: Vec<u8>) -> Self {
292 Self {
293 length: bytes.len() as u64,
294 reader: Box::new(io::Cursor::new(bytes)),
295 bytes_read: 0,
296 }
297 }
298
299 pub async fn into_bytes(mut self) -> CmdResult<Vec<u8>> {
300 let mut buf = Vec::with_capacity(1024);
301 self.read_to_end(&mut buf)
302 .await
303 .map_err(into_cmd_err!(CmdErrorCode::Failed, "read to end failed"))?;
304 Ok(buf)
305 }
306
307 pub fn from_string(s: String) -> Self {
308 Self {
309 length: s.len() as u64,
310 reader: Box::new(io::Cursor::new(s.into_bytes())),
311 bytes_read: 0,
312 }
313 }
314
315 pub async fn into_string(mut self) -> CmdResult<String> {
316 let mut result = String::with_capacity(self.len() as usize);
317 self.read_to_string(&mut result)
318 .await
319 .map_err(into_cmd_err!(CmdErrorCode::Failed, "read to string failed"))?;
320 Ok(result)
321 }
322
323 pub async fn from_path<P>(path: P) -> io::Result<Self>
324 where
325 P: AsRef<std::path::Path>,
326 {
327 let path = path.as_ref();
328 let file = tokio::fs::File::open(path).await?;
329 Self::from_file(file).await
330 }
331
332 pub async fn from_file(file: tokio::fs::File) -> io::Result<Self> {
333 let len = file.metadata().await?.len();
334
335 Ok(Self {
336 length: len,
337 reader: Box::new(tokio::io::BufReader::new(file)),
338 bytes_read: 0,
339 })
340 }
341
342 pub fn len(&self) -> u64 {
343 self.length
344 }
345
346 pub fn is_empty(&self) -> bool {
348 self.length == 0
349 }
350
351 pub fn chain(self, other: CmdBody) -> Self {
352 let length = (self.length - self.bytes_read)
353 .checked_add(other.length - other.bytes_read)
354 .unwrap_or(0);
355 Self {
356 length,
357 reader: Box::new(tokio::io::AsyncReadExt::chain(self, other)),
358 bytes_read: 0,
359 }
360 }
361}
362
363impl Debug for CmdBody {
364 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
365 f.debug_struct("CmdResponse")
366 .field("reader", &"<hidden>")
367 .field("length", &self.length)
368 .field("bytes_read", &self.bytes_read)
369 .finish()
370 }
371}
372
373impl From<String> for CmdBody {
374 fn from(s: String) -> Self {
375 Self::from_string(s)
376 }
377}
378
379impl<'a> From<&'a str> for CmdBody {
380 fn from(s: &'a str) -> Self {
381 Self::from_string(s.to_owned())
382 }
383}
384
385impl From<Vec<u8>> for CmdBody {
386 fn from(b: Vec<u8>) -> Self {
387 Self::from_bytes(b)
388 }
389}
390
391impl<'a> From<&'a [u8]> for CmdBody {
392 fn from(b: &'a [u8]) -> Self {
393 Self::from_bytes(b.to_owned())
394 }
395}
396
397impl AsyncRead for CmdBody {
398 #[allow(rustdoc::missing_doc_code_examples)]
399 fn poll_read(
400 mut self: Pin<&mut Self>,
401 cx: &mut Context<'_>,
402 buf: &mut ReadBuf<'_>,
403 ) -> Poll<io::Result<()>> {
404 let buf = if self.length == self.bytes_read {
405 return Poll::Ready(Ok(()));
406 } else {
407 buf
408 };
409
410 ready!(Pin::new(&mut self.reader).poll_read(cx, buf))?;
411 self.bytes_read += buf.filled().len() as u64;
412 Poll::Ready(Ok(()))
413 }
414}
415
416impl AsyncBufRead for CmdBody {
417 #[allow(rustdoc::missing_doc_code_examples)]
418 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&'_ [u8]>> {
419 self.project().reader.poll_fill_buf(cx)
420 }
421
422 fn consume(mut self: Pin<&mut Self>, amt: usize) {
423 Pin::new(&mut self.reader).consume(amt)
424 }
425}
426
427#[cfg(test)]
428mod tests {
429 use super::{CmdBody, CmdBodyRead, CmdBodyReadAll, CmdHeader};
430 use crate::{CmdTunnel, CmdTunnelRead, CmdTunnelWrite, PeerId};
431 use std::pin::Pin;
432 use std::task::{Context, Poll};
433 use tokio::io::{
434 AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, DuplexStream, ReadBuf, split,
435 };
436
437 struct TestRead {
438 read: tokio::io::ReadHalf<DuplexStream>,
439 }
440
441 impl AsyncRead for TestRead {
442 fn poll_read(
443 mut self: Pin<&mut Self>,
444 cx: &mut Context<'_>,
445 buf: &mut ReadBuf<'_>,
446 ) -> Poll<std::io::Result<()>> {
447 Pin::new(&mut self.read).poll_read(cx, buf)
448 }
449 }
450
451 impl CmdTunnelRead<()> for TestRead {
452 fn get_local_peer_id(&self) -> PeerId {
453 PeerId::from(vec![9; 32])
454 }
455
456 fn get_remote_peer_id(&self) -> PeerId {
457 PeerId::from(vec![1; 32])
458 }
459 }
460
461 struct TestWrite {
462 write: tokio::io::WriteHalf<DuplexStream>,
463 }
464
465 impl AsyncWrite for TestWrite {
466 fn poll_write(
467 mut self: Pin<&mut Self>,
468 cx: &mut Context<'_>,
469 buf: &[u8],
470 ) -> Poll<std::io::Result<usize>> {
471 Pin::new(&mut self.write).poll_write(cx, buf)
472 }
473
474 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
475 Pin::new(&mut self.write).poll_flush(cx)
476 }
477
478 fn poll_shutdown(
479 mut self: Pin<&mut Self>,
480 cx: &mut Context<'_>,
481 ) -> Poll<std::io::Result<()>> {
482 Pin::new(&mut self.write).poll_shutdown(cx)
483 }
484 }
485
486 impl CmdTunnelWrite<()> for TestWrite {
487 fn get_local_peer_id(&self) -> PeerId {
488 PeerId::from(vec![9; 32])
489 }
490
491 fn get_remote_peer_id(&self) -> PeerId {
492 PeerId::from(vec![2; 32])
493 }
494 }
495
496 #[tokio::test]
497 async fn cmd_body_bytes_round_trip() {
498 let body = CmdBody::from_bytes(b"hello-body".to_vec());
499 let data = body.into_bytes().await.unwrap();
500 assert_eq!(data, b"hello-body");
501 }
502
503 #[tokio::test]
504 async fn cmd_body_string_round_trip() {
505 let body = CmdBody::from_string("hello-string".to_owned());
506 let s = body.into_string().await.unwrap();
507 assert_eq!(s, "hello-string");
508 }
509
510 #[tokio::test]
511 async fn cmd_body_chain_respects_consumed_prefix() {
512 let mut first = CmdBody::from_bytes(b"abc".to_vec());
513 let mut buf = [0u8; 1];
514 first.read_exact(&mut buf).await.unwrap();
515 assert_eq!(&buf, b"a");
516
517 let chained = first.chain(CmdBody::from_bytes(b"XYZ".to_vec()));
518 let s = chained.into_string().await.unwrap();
519 assert_eq!(s, "bcXYZ");
520 }
521
522 #[test]
523 fn cmd_body_empty_and_len() {
524 let empty = CmdBody::empty();
525 assert!(empty.is_empty());
526 assert_eq!(empty.len(), 0);
527
528 let body = CmdBody::from_bytes(vec![1, 2, 3, 4]);
529 assert!(!body.is_empty());
530 assert_eq!(body.len(), 4);
531 }
532
533 #[tokio::test]
534 async fn cmd_body_into_reader_and_read_all() {
535 let mut body = CmdBody::from_string("reader-body".to_owned());
536 let all = body.read_all().await.unwrap();
537 assert_eq!(all, b"reader-body");
538
539 let body = CmdBody::from_string("reader-body2".to_owned());
540 let mut reader = body.into_reader();
541 let mut out = Vec::new();
542 reader.read_to_end(&mut out).await.unwrap();
543 assert_eq!(out, b"reader-body2");
544 }
545
546 #[test]
547 fn cmd_header_set_pkg_len() {
548 let mut header = CmdHeader::<u16, u8>::new(1, false, Some(7), 0x11, 3);
549 assert_eq!(header.pkg_len(), 3);
550 header.set_pkg_len(9);
551 assert_eq!(header.pkg_len(), 9);
552 }
553
554 #[tokio::test]
555 async fn cmd_body_read_all_success_and_empty_after_read() {
556 let (side_a, side_b) = tokio::io::duplex(128);
557 let (a_read, a_write) = split(side_a);
558 let (_b_read, mut b_write) = split(side_b);
559 b_write.write_all(b"abcdef").await.unwrap();
560 b_write.flush().await.unwrap();
561
562 let tunnel = CmdTunnel::new(TestRead { read: a_read }, TestWrite { write: a_write });
563 let (reader, _writer) = tunnel.split();
564 let mut body_read = CmdBodyRead::new(reader, 6);
565
566 let first = body_read.read_all().await.unwrap();
567 assert_eq!(first, b"abcdef");
568 let second = body_read.read_all().await.unwrap();
569 assert!(second.is_empty());
570 }
571
572 #[tokio::test]
573 async fn cmd_body_read_all_error_when_source_short() {
574 let (side_a, side_b) = tokio::io::duplex(128);
575 let (a_read, a_write) = split(side_a);
576 let (_b_read, mut b_write) = split(side_b);
577 b_write.write_all(b"ab").await.unwrap();
578 b_write.shutdown().await.unwrap();
579
580 let tunnel = CmdTunnel::new(TestRead { read: a_read }, TestWrite { write: a_write });
581 let (reader, _writer) = tunnel.split();
582 let mut body_read = CmdBodyRead::new(reader, 5);
583 assert!(body_read.read_all().await.is_err());
584 }
585}