Skip to main content

sfo_cmd_server/
cmd.rs

1use crate::TunnelId;
2use crate::errors::{CmdErrorCode, CmdResult, cmd_err, into_cmd_err};
3use crate::peer_id::PeerId;
4use bucky_raw_codec::{RawDecode, RawEncode};
5use callback_result::SingleCallbackWaiter;
6use futures_lite::ready;
7use num::{FromPrimitive, ToPrimitive};
8use sfo_split::RHalf;
9use std::collections::HashMap;
10use std::fmt::Debug;
11use std::hash::Hash;
12use std::ops::DerefMut;
13use std::pin::Pin;
14use std::sync::{Arc, Mutex};
15use std::task::{Context, Poll};
16use std::{fmt, io};
17use tokio::io::{AsyncBufRead, AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
18
19#[derive(RawEncode, RawDecode)]
20pub struct CmdHeader<LEN, CMD> {
21    pkg_len: LEN,
22    version: u8,
23    cmd_code: CMD,
24    is_resp: bool,
25    seq: Option<u32>,
26}
27
28impl<
29    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
30    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static,
31> CmdHeader<LEN, CMD>
32{
33    pub fn new(version: u8, is_resp: bool, seq: Option<u32>, cmd_code: CMD, pkg_len: LEN) -> Self {
34        Self {
35            pkg_len,
36            version,
37            seq,
38            cmd_code,
39            is_resp,
40        }
41    }
42
43    pub fn pkg_len(&self) -> LEN {
44        self.pkg_len
45    }
46
47    pub fn version(&self) -> u8 {
48        self.version
49    }
50
51    pub fn seq(&self) -> Option<u32> {
52        self.seq
53    }
54
55    pub fn is_resp(&self) -> bool {
56        self.is_resp
57    }
58
59    pub fn cmd_code(&self) -> CMD {
60        self.cmd_code
61    }
62
63    pub fn set_pkg_len(&mut self, pkg_len: LEN) {
64        self.pkg_len = pkg_len;
65    }
66}
67
68#[async_trait::async_trait]
69pub trait CmdBodyReadAll: tokio::io::AsyncRead + Send + 'static {
70    async fn read_all(&mut self) -> CmdResult<Vec<u8>>;
71}
72
73pub(crate) struct CmdBodyRead<
74    R: AsyncRead + Send + 'static + Unpin,
75    W: AsyncWrite + Send + 'static + Unpin,
76> {
77    recv: Option<RHalf<R, W>>,
78    len: usize,
79    offset: usize,
80    waiter: Arc<SingleCallbackWaiter<CmdResult<RHalf<R, W>>>>,
81}
82
83impl<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin>
84    CmdBodyRead<R, W>
85{
86    pub fn new(recv: RHalf<R, W>, len: usize) -> Self {
87        Self {
88            recv: Some(recv),
89            len,
90            offset: 0,
91            waiter: Arc::new(SingleCallbackWaiter::new()),
92        }
93    }
94
95    pub(crate) fn get_waiter(&self) -> Arc<SingleCallbackWaiter<CmdResult<RHalf<R, W>>>> {
96        self.waiter.clone()
97    }
98}
99
100#[async_trait::async_trait]
101impl<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin> CmdBodyReadAll
102    for CmdBodyRead<R, W>
103{
104    async fn read_all(&mut self) -> CmdResult<Vec<u8>> {
105        if self.offset == self.len {
106            return Ok(Vec::new());
107        }
108        let mut buf = vec![0u8; self.len - self.offset];
109        let ret = self
110            .recv
111            .as_mut()
112            .unwrap()
113            .read_exact(&mut buf)
114            .await
115            .map_err(into_cmd_err!(CmdErrorCode::IoError));
116        if ret.is_ok() {
117            self.offset = self.len;
118            self.waiter
119                .set_result_with_cache(Ok(self.recv.take().unwrap()));
120            Ok(buf)
121        } else {
122            self.recv.take();
123            self.waiter
124                .set_result_with_cache(Err(cmd_err!(CmdErrorCode::IoError, "read body error")));
125            Err(ret.err().unwrap())
126        }
127    }
128}
129
130impl<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin> Drop
131    for CmdBodyRead<R, W>
132{
133    fn drop(&mut self) {
134        if self.recv.is_none() || (self.len == self.offset && self.len != 0) {
135            return;
136        }
137        let mut recv = self.recv.take().unwrap();
138        let len = self.len - self.offset;
139        let waiter = self.waiter.clone();
140        if len == 0 {
141            waiter.set_result_with_cache(Ok(recv));
142            return;
143        }
144
145        tokio::spawn(async move {
146            let mut buf = vec![0u8; len];
147            if let Err(e) = recv.read_exact(&mut buf).await {
148                waiter.set_result_with_cache(Err(cmd_err!(
149                    CmdErrorCode::IoError,
150                    "read body error {}",
151                    e
152                )));
153            } else {
154                waiter.set_result_with_cache(Ok(recv));
155            }
156        });
157    }
158}
159
160impl<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin>
161    tokio::io::AsyncRead for CmdBodyRead<R, W>
162{
163    fn poll_read(
164        self: Pin<&mut Self>,
165        cx: &mut Context<'_>,
166        buf: &mut ReadBuf<'_>,
167    ) -> Poll<std::io::Result<()>> {
168        let this = Pin::into_inner(self);
169        let len = this.len - this.offset;
170        if len == 0 {
171            return Poll::Ready(Ok(()));
172        }
173        let recv = Pin::new(this.recv.as_mut().unwrap().deref_mut());
174        let read_len = std::cmp::min(len, buf.remaining());
175        let mut read_buf = ReadBuf::new(buf.initialize_unfilled_to(read_len));
176        let fut = recv.poll_read(cx, &mut read_buf);
177        match fut {
178            Poll::Ready(Ok(())) => {
179                let len = read_buf.filled().len();
180                drop(read_buf);
181                this.offset += len;
182                buf.advance(len);
183                if this.offset == this.len {
184                    this.waiter
185                        .set_result_with_cache(Ok(this.recv.take().unwrap()));
186                }
187                Poll::Ready(Ok(()))
188            }
189            Poll::Ready(Err(e)) => {
190                this.recv.take();
191                this.waiter
192                    .set_result_with_cache(Err(cmd_err!(CmdErrorCode::IoError, "read body error")));
193                Poll::Ready(Err(e))
194            }
195            Poll::Pending => Poll::Pending,
196        }
197    }
198}
199
200#[callback_trait::callback_trait]
201pub trait CmdHandler<LEN, CMD>: Send + Sync + 'static
202where
203    LEN: RawEncode
204        + for<'a> RawDecode<'a>
205        + Copy
206        + Send
207        + Sync
208        + 'static
209        + FromPrimitive
210        + ToPrimitive,
211    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static,
212{
213    async fn handle(
214        &self,
215        local_id: PeerId,
216        peer_id: PeerId,
217        tunnel_id: TunnelId,
218        header: CmdHeader<LEN, CMD>,
219        body: CmdBody,
220    ) -> CmdResult<Option<CmdBody>>;
221}
222
223pub(crate) struct CmdHandlerMap<LEN, CMD> {
224    map: Mutex<HashMap<CMD, Arc<dyn CmdHandler<LEN, CMD>>>>,
225}
226
227impl<LEN, CMD> CmdHandlerMap<LEN, CMD>
228where
229    LEN: RawEncode
230        + for<'a> RawDecode<'a>
231        + Copy
232        + Send
233        + Sync
234        + 'static
235        + FromPrimitive
236        + ToPrimitive,
237    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Eq + Hash,
238{
239    pub fn new() -> Self {
240        Self {
241            map: Mutex::new(HashMap::new()),
242        }
243    }
244
245    pub fn insert(&self, cmd: CMD, handler: impl CmdHandler<LEN, CMD>) {
246        self.map.lock().unwrap().insert(cmd, Arc::new(handler));
247    }
248
249    pub fn get(&self, cmd: CMD) -> Option<Arc<dyn CmdHandler<LEN, CMD>>> {
250        self.map.lock().unwrap().get(&cmd).map(|v| v.clone())
251    }
252}
253pin_project_lite::pin_project! {
254pub struct CmdBody {
255        #[pin]
256        reader: Box<dyn AsyncBufRead + Unpin + Send  + 'static>,
257        length: u64,
258        bytes_read: u64,
259    }
260}
261
262impl CmdBody {
263    pub fn empty() -> Self {
264        Self {
265            reader: Box::new(tokio::io::empty()),
266            length: 0,
267            bytes_read: 0,
268        }
269    }
270
271    pub fn from_reader(reader: impl AsyncBufRead + Unpin + Send + 'static, length: u64) -> Self {
272        Self {
273            reader: Box::new(reader),
274            length,
275            bytes_read: 0,
276        }
277    }
278
279    pub fn into_reader(self) -> Box<dyn AsyncBufRead + Unpin + Send + 'static> {
280        self.reader
281    }
282
283    pub async fn read_all(&mut self) -> CmdResult<Vec<u8>> {
284        let mut buf = Vec::with_capacity(1024);
285        self.read_to_end(&mut buf)
286            .await
287            .map_err(into_cmd_err!(CmdErrorCode::Failed, "read to end failed"))?;
288        Ok(buf)
289    }
290
291    pub fn from_bytes(bytes: Vec<u8>) -> Self {
292        Self {
293            length: bytes.len() as u64,
294            reader: Box::new(io::Cursor::new(bytes)),
295            bytes_read: 0,
296        }
297    }
298
299    pub async fn into_bytes(mut self) -> CmdResult<Vec<u8>> {
300        let mut buf = Vec::with_capacity(1024);
301        self.read_to_end(&mut buf)
302            .await
303            .map_err(into_cmd_err!(CmdErrorCode::Failed, "read to end failed"))?;
304        Ok(buf)
305    }
306
307    pub fn from_string(s: String) -> Self {
308        Self {
309            length: s.len() as u64,
310            reader: Box::new(io::Cursor::new(s.into_bytes())),
311            bytes_read: 0,
312        }
313    }
314
315    pub async fn into_string(mut self) -> CmdResult<String> {
316        let mut result = String::with_capacity(self.len() as usize);
317        self.read_to_string(&mut result)
318            .await
319            .map_err(into_cmd_err!(CmdErrorCode::Failed, "read to string failed"))?;
320        Ok(result)
321    }
322
323    pub async fn from_path<P>(path: P) -> io::Result<Self>
324    where
325        P: AsRef<std::path::Path>,
326    {
327        let path = path.as_ref();
328        let file = tokio::fs::File::open(path).await?;
329        Self::from_file(file).await
330    }
331
332    pub async fn from_file(file: tokio::fs::File) -> io::Result<Self> {
333        let len = file.metadata().await?.len();
334
335        Ok(Self {
336            length: len,
337            reader: Box::new(tokio::io::BufReader::new(file)),
338            bytes_read: 0,
339        })
340    }
341
342    pub fn len(&self) -> u64 {
343        self.length
344    }
345
346    /// Returns `true` if the body has a length of zero, and `false` otherwise.
347    pub fn is_empty(&self) -> bool {
348        self.length == 0
349    }
350
351    pub fn chain(self, other: CmdBody) -> Self {
352        let length = (self.length - self.bytes_read)
353            .checked_add(other.length - other.bytes_read)
354            .unwrap_or(0);
355        Self {
356            length,
357            reader: Box::new(tokio::io::AsyncReadExt::chain(self, other)),
358            bytes_read: 0,
359        }
360    }
361}
362
363impl Debug for CmdBody {
364    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
365        f.debug_struct("CmdResponse")
366            .field("reader", &"<hidden>")
367            .field("length", &self.length)
368            .field("bytes_read", &self.bytes_read)
369            .finish()
370    }
371}
372
373impl From<String> for CmdBody {
374    fn from(s: String) -> Self {
375        Self::from_string(s)
376    }
377}
378
379impl<'a> From<&'a str> for CmdBody {
380    fn from(s: &'a str) -> Self {
381        Self::from_string(s.to_owned())
382    }
383}
384
385impl From<Vec<u8>> for CmdBody {
386    fn from(b: Vec<u8>) -> Self {
387        Self::from_bytes(b)
388    }
389}
390
391impl<'a> From<&'a [u8]> for CmdBody {
392    fn from(b: &'a [u8]) -> Self {
393        Self::from_bytes(b.to_owned())
394    }
395}
396
397impl AsyncRead for CmdBody {
398    #[allow(rustdoc::missing_doc_code_examples)]
399    fn poll_read(
400        mut self: Pin<&mut Self>,
401        cx: &mut Context<'_>,
402        buf: &mut ReadBuf<'_>,
403    ) -> Poll<io::Result<()>> {
404        let buf = if self.length == self.bytes_read {
405            return Poll::Ready(Ok(()));
406        } else {
407            buf
408        };
409
410        ready!(Pin::new(&mut self.reader).poll_read(cx, buf))?;
411        self.bytes_read += buf.filled().len() as u64;
412        Poll::Ready(Ok(()))
413    }
414}
415
416impl AsyncBufRead for CmdBody {
417    #[allow(rustdoc::missing_doc_code_examples)]
418    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&'_ [u8]>> {
419        self.project().reader.poll_fill_buf(cx)
420    }
421
422    fn consume(mut self: Pin<&mut Self>, amt: usize) {
423        Pin::new(&mut self.reader).consume(amt)
424    }
425}
426
427#[cfg(test)]
428mod tests {
429    use super::{CmdBody, CmdBodyRead, CmdBodyReadAll, CmdHeader};
430    use crate::{CmdTunnel, CmdTunnelRead, CmdTunnelWrite, PeerId};
431    use std::pin::Pin;
432    use std::task::{Context, Poll};
433    use tokio::io::{
434        AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, DuplexStream, ReadBuf, split,
435    };
436
437    struct TestRead {
438        read: tokio::io::ReadHalf<DuplexStream>,
439    }
440
441    impl AsyncRead for TestRead {
442        fn poll_read(
443            mut self: Pin<&mut Self>,
444            cx: &mut Context<'_>,
445            buf: &mut ReadBuf<'_>,
446        ) -> Poll<std::io::Result<()>> {
447            Pin::new(&mut self.read).poll_read(cx, buf)
448        }
449    }
450
451    impl CmdTunnelRead<()> for TestRead {
452        fn get_local_peer_id(&self) -> PeerId {
453            PeerId::from(vec![9; 32])
454        }
455
456        fn get_remote_peer_id(&self) -> PeerId {
457            PeerId::from(vec![1; 32])
458        }
459    }
460
461    struct TestWrite {
462        write: tokio::io::WriteHalf<DuplexStream>,
463    }
464
465    impl AsyncWrite for TestWrite {
466        fn poll_write(
467            mut self: Pin<&mut Self>,
468            cx: &mut Context<'_>,
469            buf: &[u8],
470        ) -> Poll<std::io::Result<usize>> {
471            Pin::new(&mut self.write).poll_write(cx, buf)
472        }
473
474        fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
475            Pin::new(&mut self.write).poll_flush(cx)
476        }
477
478        fn poll_shutdown(
479            mut self: Pin<&mut Self>,
480            cx: &mut Context<'_>,
481        ) -> Poll<std::io::Result<()>> {
482            Pin::new(&mut self.write).poll_shutdown(cx)
483        }
484    }
485
486    impl CmdTunnelWrite<()> for TestWrite {
487        fn get_local_peer_id(&self) -> PeerId {
488            PeerId::from(vec![9; 32])
489        }
490
491        fn get_remote_peer_id(&self) -> PeerId {
492            PeerId::from(vec![2; 32])
493        }
494    }
495
496    #[tokio::test]
497    async fn cmd_body_bytes_round_trip() {
498        let body = CmdBody::from_bytes(b"hello-body".to_vec());
499        let data = body.into_bytes().await.unwrap();
500        assert_eq!(data, b"hello-body");
501    }
502
503    #[tokio::test]
504    async fn cmd_body_string_round_trip() {
505        let body = CmdBody::from_string("hello-string".to_owned());
506        let s = body.into_string().await.unwrap();
507        assert_eq!(s, "hello-string");
508    }
509
510    #[tokio::test]
511    async fn cmd_body_chain_respects_consumed_prefix() {
512        let mut first = CmdBody::from_bytes(b"abc".to_vec());
513        let mut buf = [0u8; 1];
514        first.read_exact(&mut buf).await.unwrap();
515        assert_eq!(&buf, b"a");
516
517        let chained = first.chain(CmdBody::from_bytes(b"XYZ".to_vec()));
518        let s = chained.into_string().await.unwrap();
519        assert_eq!(s, "bcXYZ");
520    }
521
522    #[test]
523    fn cmd_body_empty_and_len() {
524        let empty = CmdBody::empty();
525        assert!(empty.is_empty());
526        assert_eq!(empty.len(), 0);
527
528        let body = CmdBody::from_bytes(vec![1, 2, 3, 4]);
529        assert!(!body.is_empty());
530        assert_eq!(body.len(), 4);
531    }
532
533    #[tokio::test]
534    async fn cmd_body_into_reader_and_read_all() {
535        let mut body = CmdBody::from_string("reader-body".to_owned());
536        let all = body.read_all().await.unwrap();
537        assert_eq!(all, b"reader-body");
538
539        let body = CmdBody::from_string("reader-body2".to_owned());
540        let mut reader = body.into_reader();
541        let mut out = Vec::new();
542        reader.read_to_end(&mut out).await.unwrap();
543        assert_eq!(out, b"reader-body2");
544    }
545
546    #[test]
547    fn cmd_header_set_pkg_len() {
548        let mut header = CmdHeader::<u16, u8>::new(1, false, Some(7), 0x11, 3);
549        assert_eq!(header.pkg_len(), 3);
550        header.set_pkg_len(9);
551        assert_eq!(header.pkg_len(), 9);
552    }
553
554    #[tokio::test]
555    async fn cmd_body_read_all_success_and_empty_after_read() {
556        let (side_a, side_b) = tokio::io::duplex(128);
557        let (a_read, a_write) = split(side_a);
558        let (_b_read, mut b_write) = split(side_b);
559        b_write.write_all(b"abcdef").await.unwrap();
560        b_write.flush().await.unwrap();
561
562        let tunnel = CmdTunnel::new(TestRead { read: a_read }, TestWrite { write: a_write });
563        let (reader, _writer) = tunnel.split();
564        let mut body_read = CmdBodyRead::new(reader, 6);
565
566        let first = body_read.read_all().await.unwrap();
567        assert_eq!(first, b"abcdef");
568        let second = body_read.read_all().await.unwrap();
569        assert!(second.is_empty());
570    }
571
572    #[tokio::test]
573    async fn cmd_body_read_all_error_when_source_short() {
574        let (side_a, side_b) = tokio::io::duplex(128);
575        let (a_read, a_write) = split(side_a);
576        let (_b_read, mut b_write) = split(side_b);
577        b_write.write_all(b"ab").await.unwrap();
578        b_write.shutdown().await.unwrap();
579
580        let tunnel = CmdTunnel::new(TestRead { read: a_read }, TestWrite { write: a_write });
581        let (reader, _writer) = tunnel.split();
582        let mut body_read = CmdBodyRead::new(reader, 5);
583        assert!(body_read.read_all().await.is_err());
584    }
585}