Skip to main content

sfo_cmd_server/
cmd.rs

1use crate::TunnelId;
2use crate::errors::{CmdErrorCode, CmdResult, cmd_err, into_cmd_err};
3use crate::peer_id::PeerId;
4use bucky_raw_codec::{RawDecode, RawEncode};
5use callback_result::SingleCallbackWaiter;
6use futures_lite::ready;
7use num::{FromPrimitive, ToPrimitive};
8use sfo_split::RHalf;
9use std::collections::HashMap;
10use std::fmt::Debug;
11use std::hash::Hash;
12use std::ops::DerefMut;
13use std::pin::Pin;
14use std::sync::{Arc, Mutex};
15use std::task::{Context, Poll};
16use std::{fmt, io};
17use tokio::io::{AsyncBufRead, AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
18
19#[derive(RawEncode, RawDecode)]
20pub struct CmdHeader<LEN, CMD> {
21    pkg_len: LEN,
22    version: u8,
23    cmd_code: CMD,
24    is_resp: bool,
25    seq: Option<u32>,
26}
27
28impl<
29    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
30    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static,
31> CmdHeader<LEN, CMD>
32{
33    pub fn new(version: u8, is_resp: bool, seq: Option<u32>, cmd_code: CMD, pkg_len: LEN) -> Self {
34        Self {
35            pkg_len,
36            version,
37            seq,
38            cmd_code,
39            is_resp,
40        }
41    }
42
43    pub fn pkg_len(&self) -> LEN {
44        self.pkg_len
45    }
46
47    pub fn version(&self) -> u8 {
48        self.version
49    }
50
51    pub fn seq(&self) -> Option<u32> {
52        self.seq
53    }
54
55    pub fn is_resp(&self) -> bool {
56        self.is_resp
57    }
58
59    pub fn cmd_code(&self) -> CMD {
60        self.cmd_code
61    }
62
63    pub fn set_pkg_len(&mut self, pkg_len: LEN) {
64        self.pkg_len = pkg_len;
65    }
66}
67
68#[async_trait::async_trait]
69pub trait CmdBodyReadAll: tokio::io::AsyncRead + Send + 'static {
70    async fn read_all(&mut self) -> CmdResult<Vec<u8>>;
71}
72
73pub(crate) struct CmdBodyRead<
74    R: AsyncRead + Send + 'static + Unpin,
75    W: AsyncWrite + Send + 'static + Unpin,
76> {
77    recv: Option<RHalf<R, W>>,
78    len: usize,
79    offset: usize,
80    waiter: Arc<SingleCallbackWaiter<CmdResult<RHalf<R, W>>>>,
81}
82
83impl<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin>
84    CmdBodyRead<R, W>
85{
86    pub fn new(recv: RHalf<R, W>, len: usize) -> Self {
87        Self {
88            recv: Some(recv),
89            len,
90            offset: 0,
91            waiter: Arc::new(SingleCallbackWaiter::new()),
92        }
93    }
94
95    pub(crate) fn get_waiter(&self) -> Arc<SingleCallbackWaiter<CmdResult<RHalf<R, W>>>> {
96        self.waiter.clone()
97    }
98}
99
100#[async_trait::async_trait]
101impl<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin> CmdBodyReadAll
102    for CmdBodyRead<R, W>
103{
104    async fn read_all(&mut self) -> CmdResult<Vec<u8>> {
105        if self.offset == self.len {
106            return Ok(Vec::new());
107        }
108        let mut buf = vec![0u8; self.len - self.offset];
109        let ret = self
110            .recv
111            .as_mut()
112            .unwrap()
113            .read_exact(&mut buf)
114            .await
115            .map_err(into_cmd_err!(CmdErrorCode::IoError));
116        if ret.is_ok() {
117            self.offset = self.len;
118            self.waiter
119                .set_result_with_cache(Ok(self.recv.take().unwrap()));
120            Ok(buf)
121        } else {
122            self.recv.take();
123            self.waiter
124                .set_result_with_cache(Err(cmd_err!(CmdErrorCode::IoError, "read body error")));
125            Err(ret.err().unwrap())
126        }
127    }
128}
129
130impl<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin> Drop
131    for CmdBodyRead<R, W>
132{
133    fn drop(&mut self) {
134        if self.recv.is_none() || (self.len == self.offset && self.len != 0) {
135            return;
136        }
137        let mut recv = self.recv.take().unwrap();
138        let len = self.len - self.offset;
139        let waiter = self.waiter.clone();
140        if len == 0 {
141            waiter.set_result_with_cache(Ok(recv));
142            return;
143        }
144
145        tokio::spawn(async move {
146            let mut buf = vec![0u8; len];
147            if let Err(e) = recv.read_exact(&mut buf).await {
148                waiter.set_result_with_cache(Err(cmd_err!(
149                    CmdErrorCode::IoError,
150                    "read body error {}",
151                    e
152                )));
153            } else {
154                waiter.set_result_with_cache(Ok(recv));
155            }
156        });
157    }
158}
159
160impl<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin>
161    tokio::io::AsyncRead for CmdBodyRead<R, W>
162{
163    fn poll_read(
164        self: Pin<&mut Self>,
165        cx: &mut Context<'_>,
166        buf: &mut ReadBuf<'_>,
167    ) -> Poll<std::io::Result<()>> {
168        let this = Pin::into_inner(self);
169        let len = this.len - this.offset;
170        if len == 0 {
171            return Poll::Ready(Ok(()));
172        }
173        let recv = Pin::new(this.recv.as_mut().unwrap().deref_mut());
174        let read_len = std::cmp::min(len, buf.remaining());
175        let mut read_buf = ReadBuf::new(buf.initialize_unfilled_to(read_len));
176        let fut = recv.poll_read(cx, &mut read_buf);
177        match fut {
178            Poll::Ready(Ok(())) => {
179                let len = read_buf.filled().len();
180                drop(read_buf);
181                this.offset += len;
182                buf.advance(len);
183                if this.offset == this.len {
184                    this.waiter
185                        .set_result_with_cache(Ok(this.recv.take().unwrap()));
186                }
187                Poll::Ready(Ok(()))
188            }
189            Poll::Ready(Err(e)) => {
190                this.recv.take();
191                this.waiter
192                    .set_result_with_cache(Err(cmd_err!(CmdErrorCode::IoError, "read body error")));
193                Poll::Ready(Err(e))
194            }
195            Poll::Pending => Poll::Pending,
196        }
197    }
198}
199
200#[callback_trait::callback_trait]
201pub trait CmdHandler<LEN, CMD>: Send + Sync + 'static
202where
203    LEN: RawEncode
204        + for<'a> RawDecode<'a>
205        + Copy
206        + Send
207        + Sync
208        + 'static
209        + FromPrimitive
210        + ToPrimitive,
211    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static,
212{
213    async fn handle(
214        &self,
215        peer_id: PeerId,
216        tunnel_id: TunnelId,
217        header: CmdHeader<LEN, CMD>,
218        body: CmdBody,
219    ) -> CmdResult<Option<CmdBody>>;
220}
221
222pub(crate) struct CmdHandlerMap<LEN, CMD> {
223    map: Mutex<HashMap<CMD, Arc<dyn CmdHandler<LEN, CMD>>>>,
224}
225
226impl<LEN, CMD> CmdHandlerMap<LEN, CMD>
227where
228    LEN: RawEncode
229        + for<'a> RawDecode<'a>
230        + Copy
231        + Send
232        + Sync
233        + 'static
234        + FromPrimitive
235        + ToPrimitive,
236    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Eq + Hash,
237{
238    pub fn new() -> Self {
239        Self {
240            map: Mutex::new(HashMap::new()),
241        }
242    }
243
244    pub fn insert(&self, cmd: CMD, handler: impl CmdHandler<LEN, CMD>) {
245        self.map.lock().unwrap().insert(cmd, Arc::new(handler));
246    }
247
248    pub fn get(&self, cmd: CMD) -> Option<Arc<dyn CmdHandler<LEN, CMD>>> {
249        self.map.lock().unwrap().get(&cmd).map(|v| v.clone())
250    }
251}
252pin_project_lite::pin_project! {
253pub struct CmdBody {
254        #[pin]
255        reader: Box<dyn AsyncBufRead + Unpin + Send  + 'static>,
256        length: u64,
257        bytes_read: u64,
258    }
259}
260
261impl CmdBody {
262    pub fn empty() -> Self {
263        Self {
264            reader: Box::new(tokio::io::empty()),
265            length: 0,
266            bytes_read: 0,
267        }
268    }
269
270    pub fn from_reader(reader: impl AsyncBufRead + Unpin + Send + 'static, length: u64) -> Self {
271        Self {
272            reader: Box::new(reader),
273            length,
274            bytes_read: 0,
275        }
276    }
277
278    pub fn into_reader(self) -> Box<dyn AsyncBufRead + Unpin + Send + 'static> {
279        self.reader
280    }
281
282    pub async fn read_all(&mut self) -> CmdResult<Vec<u8>> {
283        let mut buf = Vec::with_capacity(1024);
284        self.read_to_end(&mut buf)
285            .await
286            .map_err(into_cmd_err!(CmdErrorCode::Failed, "read to end failed"))?;
287        Ok(buf)
288    }
289
290    pub fn from_bytes(bytes: Vec<u8>) -> Self {
291        Self {
292            length: bytes.len() as u64,
293            reader: Box::new(io::Cursor::new(bytes)),
294            bytes_read: 0,
295        }
296    }
297
298    pub async fn into_bytes(mut self) -> CmdResult<Vec<u8>> {
299        let mut buf = Vec::with_capacity(1024);
300        self.read_to_end(&mut buf)
301            .await
302            .map_err(into_cmd_err!(CmdErrorCode::Failed, "read to end failed"))?;
303        Ok(buf)
304    }
305
306    pub fn from_string(s: String) -> Self {
307        Self {
308            length: s.len() as u64,
309            reader: Box::new(io::Cursor::new(s.into_bytes())),
310            bytes_read: 0,
311        }
312    }
313
314    pub async fn into_string(mut self) -> CmdResult<String> {
315        let mut result = String::with_capacity(self.len() as usize);
316        self.read_to_string(&mut result)
317            .await
318            .map_err(into_cmd_err!(CmdErrorCode::Failed, "read to string failed"))?;
319        Ok(result)
320    }
321
322    pub async fn from_path<P>(path: P) -> io::Result<Self>
323    where
324        P: AsRef<std::path::Path>,
325    {
326        let path = path.as_ref();
327        let file = tokio::fs::File::open(path).await?;
328        Self::from_file(file).await
329    }
330
331    pub async fn from_file(file: tokio::fs::File) -> io::Result<Self> {
332        let len = file.metadata().await?.len();
333
334        Ok(Self {
335            length: len,
336            reader: Box::new(tokio::io::BufReader::new(file)),
337            bytes_read: 0,
338        })
339    }
340
341    pub fn len(&self) -> u64 {
342        self.length
343    }
344
345    /// Returns `true` if the body has a length of zero, and `false` otherwise.
346    pub fn is_empty(&self) -> bool {
347        self.length == 0
348    }
349
350    pub fn chain(self, other: CmdBody) -> Self {
351        let length = (self.length - self.bytes_read)
352            .checked_add(other.length - other.bytes_read)
353            .unwrap_or(0);
354        Self {
355            length,
356            reader: Box::new(tokio::io::AsyncReadExt::chain(self, other)),
357            bytes_read: 0,
358        }
359    }
360}
361
362impl Debug for CmdBody {
363    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
364        f.debug_struct("CmdResponse")
365            .field("reader", &"<hidden>")
366            .field("length", &self.length)
367            .field("bytes_read", &self.bytes_read)
368            .finish()
369    }
370}
371
372impl From<String> for CmdBody {
373    fn from(s: String) -> Self {
374        Self::from_string(s)
375    }
376}
377
378impl<'a> From<&'a str> for CmdBody {
379    fn from(s: &'a str) -> Self {
380        Self::from_string(s.to_owned())
381    }
382}
383
384impl From<Vec<u8>> for CmdBody {
385    fn from(b: Vec<u8>) -> Self {
386        Self::from_bytes(b)
387    }
388}
389
390impl<'a> From<&'a [u8]> for CmdBody {
391    fn from(b: &'a [u8]) -> Self {
392        Self::from_bytes(b.to_owned())
393    }
394}
395
396impl AsyncRead for CmdBody {
397    #[allow(rustdoc::missing_doc_code_examples)]
398    fn poll_read(
399        mut self: Pin<&mut Self>,
400        cx: &mut Context<'_>,
401        buf: &mut ReadBuf<'_>,
402    ) -> Poll<io::Result<()>> {
403        let buf = if self.length == self.bytes_read {
404            return Poll::Ready(Ok(()));
405        } else {
406            buf
407        };
408
409        ready!(Pin::new(&mut self.reader).poll_read(cx, buf))?;
410        self.bytes_read += buf.filled().len() as u64;
411        Poll::Ready(Ok(()))
412    }
413}
414
415impl AsyncBufRead for CmdBody {
416    #[allow(rustdoc::missing_doc_code_examples)]
417    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&'_ [u8]>> {
418        self.project().reader.poll_fill_buf(cx)
419    }
420
421    fn consume(mut self: Pin<&mut Self>, amt: usize) {
422        Pin::new(&mut self.reader).consume(amt)
423    }
424}
425
426#[cfg(test)]
427mod tests {
428    use super::{CmdBody, CmdBodyRead, CmdBodyReadAll, CmdHeader};
429    use crate::{CmdTunnel, CmdTunnelRead, CmdTunnelWrite, PeerId};
430    use std::pin::Pin;
431    use std::task::{Context, Poll};
432    use tokio::io::{
433        AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, DuplexStream, ReadBuf, split,
434    };
435
436    struct TestRead {
437        read: tokio::io::ReadHalf<DuplexStream>,
438    }
439
440    impl AsyncRead for TestRead {
441        fn poll_read(
442            mut self: Pin<&mut Self>,
443            cx: &mut Context<'_>,
444            buf: &mut ReadBuf<'_>,
445        ) -> Poll<std::io::Result<()>> {
446            Pin::new(&mut self.read).poll_read(cx, buf)
447        }
448    }
449
450    impl CmdTunnelRead<()> for TestRead {
451        fn get_remote_peer_id(&self) -> PeerId {
452            PeerId::from(vec![1; 32])
453        }
454    }
455
456    struct TestWrite {
457        write: tokio::io::WriteHalf<DuplexStream>,
458    }
459
460    impl AsyncWrite for TestWrite {
461        fn poll_write(
462            mut self: Pin<&mut Self>,
463            cx: &mut Context<'_>,
464            buf: &[u8],
465        ) -> Poll<std::io::Result<usize>> {
466            Pin::new(&mut self.write).poll_write(cx, buf)
467        }
468
469        fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
470            Pin::new(&mut self.write).poll_flush(cx)
471        }
472
473        fn poll_shutdown(
474            mut self: Pin<&mut Self>,
475            cx: &mut Context<'_>,
476        ) -> Poll<std::io::Result<()>> {
477            Pin::new(&mut self.write).poll_shutdown(cx)
478        }
479    }
480
481    impl CmdTunnelWrite<()> for TestWrite {
482        fn get_remote_peer_id(&self) -> PeerId {
483            PeerId::from(vec![2; 32])
484        }
485    }
486
487    #[tokio::test]
488    async fn cmd_body_bytes_round_trip() {
489        let body = CmdBody::from_bytes(b"hello-body".to_vec());
490        let data = body.into_bytes().await.unwrap();
491        assert_eq!(data, b"hello-body");
492    }
493
494    #[tokio::test]
495    async fn cmd_body_string_round_trip() {
496        let body = CmdBody::from_string("hello-string".to_owned());
497        let s = body.into_string().await.unwrap();
498        assert_eq!(s, "hello-string");
499    }
500
501    #[tokio::test]
502    async fn cmd_body_chain_respects_consumed_prefix() {
503        let mut first = CmdBody::from_bytes(b"abc".to_vec());
504        let mut buf = [0u8; 1];
505        first.read_exact(&mut buf).await.unwrap();
506        assert_eq!(&buf, b"a");
507
508        let chained = first.chain(CmdBody::from_bytes(b"XYZ".to_vec()));
509        let s = chained.into_string().await.unwrap();
510        assert_eq!(s, "bcXYZ");
511    }
512
513    #[test]
514    fn cmd_body_empty_and_len() {
515        let empty = CmdBody::empty();
516        assert!(empty.is_empty());
517        assert_eq!(empty.len(), 0);
518
519        let body = CmdBody::from_bytes(vec![1, 2, 3, 4]);
520        assert!(!body.is_empty());
521        assert_eq!(body.len(), 4);
522    }
523
524    #[tokio::test]
525    async fn cmd_body_into_reader_and_read_all() {
526        let mut body = CmdBody::from_string("reader-body".to_owned());
527        let all = body.read_all().await.unwrap();
528        assert_eq!(all, b"reader-body");
529
530        let body = CmdBody::from_string("reader-body2".to_owned());
531        let mut reader = body.into_reader();
532        let mut out = Vec::new();
533        reader.read_to_end(&mut out).await.unwrap();
534        assert_eq!(out, b"reader-body2");
535    }
536
537    #[test]
538    fn cmd_header_set_pkg_len() {
539        let mut header = CmdHeader::<u16, u8>::new(1, false, Some(7), 0x11, 3);
540        assert_eq!(header.pkg_len(), 3);
541        header.set_pkg_len(9);
542        assert_eq!(header.pkg_len(), 9);
543    }
544
545    #[tokio::test]
546    async fn cmd_body_read_all_success_and_empty_after_read() {
547        let (side_a, side_b) = tokio::io::duplex(128);
548        let (a_read, a_write) = split(side_a);
549        let (_b_read, mut b_write) = split(side_b);
550        b_write.write_all(b"abcdef").await.unwrap();
551        b_write.flush().await.unwrap();
552
553        let tunnel = CmdTunnel::new(TestRead { read: a_read }, TestWrite { write: a_write });
554        let (reader, _writer) = tunnel.split();
555        let mut body_read = CmdBodyRead::new(reader, 6);
556
557        let first = body_read.read_all().await.unwrap();
558        assert_eq!(first, b"abcdef");
559        let second = body_read.read_all().await.unwrap();
560        assert!(second.is_empty());
561    }
562
563    #[tokio::test]
564    async fn cmd_body_read_all_error_when_source_short() {
565        let (side_a, side_b) = tokio::io::duplex(128);
566        let (a_read, a_write) = split(side_a);
567        let (_b_read, mut b_write) = split(side_b);
568        b_write.write_all(b"ab").await.unwrap();
569        b_write.shutdown().await.unwrap();
570
571        let tunnel = CmdTunnel::new(TestRead { read: a_read }, TestWrite { write: a_write });
572        let (reader, _writer) = tunnel.split();
573        let mut body_read = CmdBodyRead::new(reader, 5);
574        assert!(body_read.read_all().await.is_err());
575    }
576}