sfo_cmd_server/
cmd.rs

1use std::collections::HashMap;
2use std::fmt::Debug;
3use std::hash::Hash;
4use std::{fmt, io};
5use std::ops::DerefMut;
6use std::pin::Pin;
7use std::sync::{Arc, Mutex};
8use std::task::{Context, Poll};
9use bucky_raw_codec::{RawDecode, RawEncode};
10use callback_result::{SingleCallbackWaiter};
11use futures_lite::ready;
12use num::{FromPrimitive, ToPrimitive};
13use sfo_split::RHalf;
14use tokio::io::{AsyncBufRead, AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
15use crate::errors::{cmd_err, into_cmd_err, CmdErrorCode, CmdResult};
16use crate::peer_id::PeerId;
17use crate::{TunnelId};
18
19#[derive(RawEncode, RawDecode)]
20pub struct CmdHeader<LEN, CMD> {
21    pkg_len: LEN,
22    version: u8,
23    cmd_code: CMD,
24    is_resp: bool,
25    seq: Option<u32>,
26}
27
28impl<LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
29    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static> CmdHeader<LEN, CMD> {
30    pub fn new(version: u8, is_resp: bool, seq: Option<u32>, cmd_code: CMD, pkg_len: LEN) -> Self {
31        Self {
32            pkg_len,
33            version,
34            seq,
35            cmd_code,
36            is_resp,
37        }
38    }
39
40    pub fn pkg_len(&self) -> LEN {
41        self.pkg_len
42    }
43
44    pub fn version(&self) -> u8 {
45        self.version
46    }
47
48    pub fn seq(&self) -> Option<u32> {
49        self.seq
50    }
51
52    pub fn is_resp(&self) -> bool {
53        self.is_resp
54    }
55
56    pub fn cmd_code(&self) -> CMD {
57        self.cmd_code
58    }
59
60    pub fn set_pkg_len(&mut self, pkg_len: LEN) {
61        self.pkg_len = pkg_len;
62    }
63}
64
65#[async_trait::async_trait]
66pub trait CmdBodyReadAll: tokio::io::AsyncRead + Send + 'static {
67    async fn read_all(&mut self) -> CmdResult<Vec<u8>>;
68}
69
70pub(crate) struct CmdBodyRead<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin> {
71    recv: Option<RHalf<R, W>>,
72    len: usize,
73    offset: usize,
74    waiter: Arc<SingleCallbackWaiter<CmdResult<RHalf<R, W>>>>,
75}
76
77impl<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin> CmdBodyRead<R, W> {
78    pub fn new(recv: RHalf<R, W>, len: usize) -> Self {
79        Self {
80            recv: Some(recv),
81            len,
82            offset: 0,
83            waiter: Arc::new(SingleCallbackWaiter::new()),
84        }
85    }
86
87
88    pub(crate) fn get_waiter(&self) -> Arc<SingleCallbackWaiter<CmdResult<RHalf<R, W>>>> {
89        self.waiter.clone()
90    }
91}
92
93#[async_trait::async_trait]
94impl<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin> CmdBodyReadAll for CmdBodyRead<R, W> {
95    async fn read_all(&mut self) -> CmdResult<Vec<u8>> {
96        if self.offset == self.len {
97            return Ok(Vec::new());
98        }
99        let mut buf = vec![0u8; self.len - self.offset];
100        let ret = self.recv.as_mut().unwrap().read_exact(&mut buf).await.map_err(into_cmd_err!(CmdErrorCode::IoError));
101        if ret.is_ok() {
102            self.offset = self.len;
103            self.waiter.set_result_with_cache(Ok(self.recv.take().unwrap()));
104            Ok(buf)
105        } else {
106            self.recv.take();
107            self.waiter.set_result_with_cache(Err(cmd_err!(CmdErrorCode::IoError, "read body error")));
108            Err(ret.err().unwrap())
109        }
110    }
111}
112
113impl<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin> Drop for CmdBodyRead<R, W> {
114    fn drop(&mut self) {
115        if self.recv.is_none() || (self.len == self.offset && self.len != 0) {
116            return;
117        }
118        let mut recv = self.recv.take().unwrap();
119        let len = self.len - self.offset;
120        let waiter = self.waiter.clone();
121        if len == 0 {
122            waiter.set_result_with_cache(Ok(recv));
123            return;
124        }
125
126        tokio::spawn(async move {
127            let mut buf = vec![0u8; len];
128            if let Err(e) = recv.read_exact(&mut buf).await {
129                waiter.set_result_with_cache(Err(cmd_err!(CmdErrorCode::IoError, "read body error {}", e)));
130            } else {
131                waiter.set_result_with_cache(Ok(recv));
132            }
133        });
134    }
135}
136
137impl<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin> tokio::io::AsyncRead for CmdBodyRead<R, W> {
138    fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
139        let this = Pin::into_inner(self);
140        let len = this.len - this.offset;
141        if len == 0 {
142            return Poll::Ready(Ok(()));
143        }
144        let recv = Pin::new(this.recv.as_mut().unwrap().deref_mut());
145        let read_len = std::cmp::min(len, buf.remaining());
146        let mut read_buf = ReadBuf::new(buf.initialize_unfilled_to(read_len));
147        let fut = recv.poll_read(cx, &mut read_buf);
148        match fut {
149            Poll::Ready(Ok(())) => {
150                let len = read_buf.filled().len();
151                drop(read_buf);
152                this.offset += len;
153                buf.advance(len);
154                if this.offset == this.len {
155                    this.waiter.set_result_with_cache(Ok(this.recv.take().unwrap()));
156                }
157                Poll::Ready(Ok(()))
158            }
159            Poll::Ready(Err(e)) => {
160                this.recv.take();
161                this.waiter.set_result_with_cache(Err(cmd_err!(CmdErrorCode::IoError, "read body error")));
162                Poll::Ready(Err(e))
163            },
164            Poll::Pending => Poll::Pending,
165        }
166    }
167}
168
169
170#[callback_trait::callback_trait]
171pub trait CmdHandler<LEN, CMD>: Send + Sync + 'static
172where LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
173      CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static {
174    async fn handle(&self, peer_id: PeerId, tunnel_id: TunnelId, header: CmdHeader<LEN, CMD>, body: CmdBody) -> CmdResult<Option<CmdBody>>;
175}
176
177pub(crate) struct CmdHandlerMap<LEN, CMD> {
178    map: Mutex<HashMap<CMD, Arc<dyn CmdHandler<LEN, CMD>>>>,
179}
180
181impl <LEN, CMD> CmdHandlerMap<LEN, CMD>
182where LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
183      CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Eq + Hash {
184    pub fn new() -> Self {
185        Self {
186            map: Mutex::new(HashMap::new()),
187        }
188    }
189
190    pub fn insert(&self, cmd: CMD, handler: impl CmdHandler<LEN, CMD>) {
191        self.map.lock().unwrap().insert(cmd, Arc::new(handler));
192    }
193
194    pub fn get(&self, cmd: CMD) -> Option<Arc<dyn CmdHandler<LEN, CMD>>> {
195        self.map.lock().unwrap().get(&cmd).map(|v| v.clone())
196    }
197}
198pin_project_lite::pin_project! {
199pub struct CmdBody {
200        #[pin]
201        reader: Box<dyn AsyncBufRead + Unpin + Send  + 'static>,
202        length: u64,
203        bytes_read: u64,
204    }
205}
206
207impl CmdBody {
208    pub fn empty() -> Self {
209        Self {
210            reader: Box::new(tokio::io::empty()),
211            length: 0,
212            bytes_read: 0,
213        }
214    }
215
216    pub fn from_reader(
217        reader: impl AsyncBufRead + Unpin + Send + 'static,
218        length: u64,
219    ) -> Self {
220        Self {
221            reader: Box::new(reader),
222            length,
223            bytes_read: 0,
224        }
225    }
226
227    pub fn into_reader(self) -> Box<dyn AsyncBufRead + Unpin + Send + 'static> {
228        self.reader
229    }
230
231    pub async fn read_all(&mut self) -> CmdResult<Vec<u8>> {
232        let mut buf = Vec::with_capacity(1024);
233        self.read_to_end(&mut buf).await.map_err(into_cmd_err!(CmdErrorCode::Failed, "read to end failed"))?;
234        Ok(buf)
235    }
236
237    pub fn from_bytes(bytes: Vec<u8>) -> Self {
238        Self {
239            length: bytes.len() as u64,
240            reader: Box::new(io::Cursor::new(bytes)),
241            bytes_read: 0,
242        }
243    }
244
245    pub async fn into_bytes(mut self) -> CmdResult<Vec<u8>> {
246        let mut buf = Vec::with_capacity(1024);
247        self.read_to_end(&mut buf)
248            .await.map_err(into_cmd_err!(CmdErrorCode::Failed, "read to end failed"))?;
249        Ok(buf)
250    }
251
252    pub fn from_string(s: String) -> Self {
253        Self {
254            length: s.len() as u64,
255            reader: Box::new(io::Cursor::new(s.into_bytes())),
256            bytes_read: 0,
257        }
258    }
259
260    pub async fn into_string(mut self) -> CmdResult<String> {
261        let mut result = String::with_capacity(self.len() as usize);
262        self.read_to_string(&mut result)
263            .await.map_err(into_cmd_err!(CmdErrorCode::Failed, "read to string failed"))?;
264        Ok(result)
265    }
266
267    pub async fn from_path<P>(path: P) -> io::Result<Self>
268    where
269        P: AsRef<std::path::Path>,
270    {
271        let path = path.as_ref();
272        let file = tokio::fs::File::open(path).await?;
273        Self::from_file(file).await
274    }
275
276    pub async fn from_file(
277        file: tokio::fs::File,
278    ) -> io::Result<Self> {
279        let len = file.metadata().await?.len();
280
281        Ok(Self {
282            length: len,
283            reader: Box::new(tokio::io::BufReader::new(file)),
284            bytes_read: 0,
285        })
286    }
287
288    pub fn len(&self) -> u64 {
289        self.length
290    }
291
292    /// Returns `true` if the body has a length of zero, and `false` otherwise.
293    pub fn is_empty(&self) -> bool {
294        self.length == 0
295    }
296
297    pub fn chain(self, other: CmdBody) -> Self {
298        let length = (self.length - self.bytes_read).checked_add(other.length - other.bytes_read).unwrap_or(0);
299        Self {
300            length,
301            reader: Box::new(tokio::io::AsyncReadExt::chain(self, other)),
302            bytes_read: 0,
303        }
304    }
305}
306
307impl Debug for CmdBody {
308    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
309        f.debug_struct("CmdResponse")
310            .field("reader", &"<hidden>")
311            .field("length", &self.length)
312            .field("bytes_read", &self.bytes_read)
313            .finish()
314    }
315}
316
317
318impl From<String> for CmdBody {
319    fn from(s: String) -> Self {
320        Self::from_string(s)
321    }
322}
323
324impl<'a> From<&'a str> for CmdBody {
325    fn from(s: &'a str) -> Self {
326        Self::from_string(s.to_owned())
327    }
328}
329
330impl From<Vec<u8>> for CmdBody {
331    fn from(b: Vec<u8>) -> Self {
332        Self::from_bytes(b)
333    }
334}
335
336impl<'a> From<&'a [u8]> for CmdBody {
337    fn from(b: &'a [u8]) -> Self {
338        Self::from_bytes(b.to_owned())
339    }
340}
341
342impl AsyncRead for CmdBody {
343    #[allow(rustdoc::missing_doc_code_examples)]
344    fn poll_read(
345        mut self: Pin<&mut Self>,
346        cx: &mut Context<'_>,
347        buf: &mut ReadBuf<'_>,
348    ) -> Poll<io::Result<()>> {
349        let buf = if self.length == self.bytes_read {
350            return Poll::Ready(Ok(()));
351        } else {
352            buf
353        };
354
355        ready!(Pin::new(&mut self.reader).poll_read(cx, buf))?;
356        self.bytes_read += buf.filled().len() as u64;
357        Poll::Ready(Ok(()))
358    }
359}
360
361impl AsyncBufRead for CmdBody {
362    #[allow(rustdoc::missing_doc_code_examples)]
363    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&'_ [u8]>> {
364        self.project().reader.poll_fill_buf(cx)
365    }
366
367    fn consume(mut self: Pin<&mut Self>, amt: usize) {
368        Pin::new(&mut self.reader).consume(amt)
369    }
370}