sfo_cmd_server/
cmd.rs

1use std::collections::HashMap;
2use std::fmt::Debug;
3use std::hash::Hash;
4use std::{fmt, io};
5use std::ops::DerefMut;
6use std::pin::Pin;
7use std::sync::{Arc, Mutex};
8use std::task::{Context, Poll};
9use bucky_raw_codec::{RawDecode, RawEncode};
10use callback_result::{SingleCallbackWaiter};
11use futures_lite::ready;
12use num::{FromPrimitive, ToPrimitive};
13use sfo_split::RHalf;
14use tokio::io::{AsyncBufRead, AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
15use crate::errors::{cmd_err, into_cmd_err, CmdErrorCode, CmdResult};
16use crate::peer_id::PeerId;
17use crate::{TunnelId};
18
19#[derive(RawEncode, RawDecode)]
20pub struct CmdHeader<LEN, CMD> {
21    pkg_len: LEN,
22    version: u8,
23    cmd_code: CMD,
24    is_resp: bool,
25    seq: Option<u32>,
26}
27
28impl<LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
29    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static> CmdHeader<LEN, CMD> {
30    pub fn new(version: u8, is_resp: bool, seq: Option<u32>, cmd_code: CMD, pkg_len: LEN) -> Self {
31        Self {
32            pkg_len,
33            version,
34            seq,
35            cmd_code,
36            is_resp,
37        }
38    }
39
40    pub fn pkg_len(&self) -> LEN {
41        self.pkg_len
42    }
43
44    pub fn version(&self) -> u8 {
45        self.version
46    }
47
48    pub fn seq(&self) -> Option<u32> {
49        self.seq
50    }
51
52    pub fn is_resp(&self) -> bool {
53        self.is_resp
54    }
55
56    pub fn cmd_code(&self) -> CMD {
57        self.cmd_code
58    }
59
60    pub fn set_pkg_len(&mut self, pkg_len: LEN) {
61        self.pkg_len = pkg_len;
62    }
63}
64
65#[async_trait::async_trait]
66pub trait CmdBodyReadAll: tokio::io::AsyncRead + Send + 'static {
67    async fn read_all(&mut self) -> CmdResult<Vec<u8>>;
68}
69
70pub(crate) struct CmdBodyRead<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin> {
71    recv: Option<RHalf<R, W>>,
72    len: usize,
73    offset: usize,
74    waiter: Arc<SingleCallbackWaiter<CmdResult<RHalf<R, W>>>>,
75}
76
77impl<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin> CmdBodyRead<R, W> {
78    pub fn new(recv: RHalf<R, W>, len: usize) -> Self {
79        Self {
80            recv: Some(recv),
81            len,
82            offset: 0,
83            waiter: Arc::new(SingleCallbackWaiter::new()),
84        }
85    }
86
87
88    pub(crate) fn get_waiter(&self) -> Arc<SingleCallbackWaiter<CmdResult<RHalf<R, W>>>> {
89        self.waiter.clone()
90    }
91}
92
93#[async_trait::async_trait]
94impl<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin> CmdBodyReadAll for CmdBodyRead<R, W> {
95    async fn read_all(&mut self) -> CmdResult<Vec<u8>> {
96        if self.offset == self.len {
97            return Ok(Vec::new());
98        }
99        let mut buf = vec![0u8; self.len - self.offset];
100        let ret = self.recv.as_mut().unwrap().read_exact(&mut buf).await.map_err(into_cmd_err!(CmdErrorCode::IoError));
101        if ret.is_ok() {
102            self.offset = self.len;
103            self.waiter.set_result_with_cache(Ok(self.recv.take().unwrap()));
104            Ok(buf)
105        } else {
106            self.recv.take();
107            self.waiter.set_result_with_cache(Err(cmd_err!(CmdErrorCode::IoError, "read body error")));
108            Err(ret.err().unwrap())
109        }
110    }
111}
112
113impl<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin> Drop for CmdBodyRead<R, W> {
114    fn drop(&mut self) {
115        if self.recv.is_none() || (self.len == self.offset && self.len != 0) {
116            return;
117        }
118        let mut recv = self.recv.take().unwrap();
119        let len = self.len - self.offset;
120        let waiter = self.waiter.clone();
121        if len == 0 {
122            waiter.set_result_with_cache(Ok(recv));
123            return;
124        }
125
126        tokio::spawn(async move {
127            let mut buf = vec![0u8; len];
128            if let Err(e) = recv.read_exact(&mut buf).await {
129                waiter.set_result_with_cache(Err(cmd_err!(CmdErrorCode::IoError, "read body error {}", e)));
130            } else {
131                waiter.set_result_with_cache(Ok(recv));
132            }
133        });
134    }
135}
136
137impl<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin> tokio::io::AsyncRead for CmdBodyRead<R, W> {
138    fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
139        let this = Pin::into_inner(self);
140        let len = this.len - this.offset;
141        if len == 0 {
142            return Poll::Ready(Ok(()));
143        }
144        let recv = Pin::new(this.recv.as_mut().unwrap().deref_mut());
145        let fut = recv.poll_read(cx, buf);
146        match fut {
147            Poll::Ready(Ok(())) => {
148                this.offset += buf.filled().len();
149                if this.offset == this.len {
150                    this.waiter.set_result_with_cache(Ok(this.recv.take().unwrap()));
151                }
152                Poll::Ready(Ok(()))
153            }
154            Poll::Ready(Err(e)) => {
155                this.recv.take();
156                this.waiter.set_result_with_cache(Err(cmd_err!(CmdErrorCode::IoError, "read body error")));
157                Poll::Ready(Err(e))
158            },
159            Poll::Pending => Poll::Pending,
160        }
161    }
162}
163
164
165#[callback_trait::callback_trait]
166pub trait CmdHandler<LEN, CMD>: Send + Sync + 'static
167where LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
168      CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static {
169    async fn handle(&self, peer_id: PeerId, tunnel_id: TunnelId, header: CmdHeader<LEN, CMD>, body: CmdBody) -> CmdResult<Option<CmdBody>>;
170}
171
172pub(crate) struct CmdHandlerMap<LEN, CMD> {
173    map: Mutex<HashMap<CMD, Arc<dyn CmdHandler<LEN, CMD>>>>,
174}
175
176impl <LEN, CMD> CmdHandlerMap<LEN, CMD>
177where LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
178      CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Eq + Hash {
179    pub fn new() -> Self {
180        Self {
181            map: Mutex::new(HashMap::new()),
182        }
183    }
184
185    pub fn insert(&self, cmd: CMD, handler: impl CmdHandler<LEN, CMD>) {
186        self.map.lock().unwrap().insert(cmd, Arc::new(handler));
187    }
188
189    pub fn get(&self, cmd: CMD) -> Option<Arc<dyn CmdHandler<LEN, CMD>>> {
190        self.map.lock().unwrap().get(&cmd).map(|v| v.clone())
191    }
192}
193pin_project_lite::pin_project! {
194pub struct CmdBody {
195        #[pin]
196        reader: Box<dyn AsyncBufRead + Unpin + Send  + 'static>,
197        length: u64,
198        bytes_read: u64,
199    }
200}
201
202impl CmdBody {
203    pub fn empty() -> Self {
204        Self {
205            reader: Box::new(tokio::io::empty()),
206            length: 0,
207            bytes_read: 0,
208        }
209    }
210
211    pub fn from_reader(
212        reader: impl AsyncBufRead + Unpin + Send + 'static,
213        length: u64,
214    ) -> Self {
215        Self {
216            reader: Box::new(reader),
217            length,
218            bytes_read: 0,
219        }
220    }
221
222    pub fn into_reader(self) -> Box<dyn AsyncBufRead + Unpin + Send + 'static> {
223        self.reader
224    }
225
226    pub async fn read_all(&mut self) -> CmdResult<Vec<u8>> {
227        let mut buf = Vec::with_capacity(1024);
228        self.read_to_end(&mut buf).await.map_err(into_cmd_err!(CmdErrorCode::Failed, "read to end failed"))?;
229        Ok(buf)
230    }
231
232    pub fn from_bytes(bytes: Vec<u8>) -> Self {
233        Self {
234            length: bytes.len() as u64,
235            reader: Box::new(io::Cursor::new(bytes)),
236            bytes_read: 0,
237        }
238    }
239
240    pub async fn into_bytes(mut self) -> CmdResult<Vec<u8>> {
241        let mut buf = Vec::with_capacity(1024);
242        self.read_to_end(&mut buf)
243            .await.map_err(into_cmd_err!(CmdErrorCode::Failed, "read to end failed"))?;
244        Ok(buf)
245    }
246
247    pub fn from_string(s: String) -> Self {
248        Self {
249            length: s.len() as u64,
250            reader: Box::new(io::Cursor::new(s.into_bytes())),
251            bytes_read: 0,
252        }
253    }
254
255    pub async fn into_string(mut self) -> CmdResult<String> {
256        let mut result = String::with_capacity(self.len() as usize);
257        self.read_to_string(&mut result)
258            .await.map_err(into_cmd_err!(CmdErrorCode::Failed, "read to string failed"))?;
259        Ok(result)
260    }
261
262    pub async fn from_path<P>(path: P) -> io::Result<Self>
263    where
264        P: AsRef<std::path::Path>,
265    {
266        let path = path.as_ref();
267        let file = tokio::fs::File::open(path).await?;
268        Self::from_file(file).await
269    }
270
271    pub async fn from_file(
272        file: tokio::fs::File,
273    ) -> io::Result<Self> {
274        let len = file.metadata().await?.len();
275
276        Ok(Self {
277            length: len,
278            reader: Box::new(tokio::io::BufReader::new(file)),
279            bytes_read: 0,
280        })
281    }
282
283    pub fn len(&self) -> u64 {
284        self.length
285    }
286
287    /// Returns `true` if the body has a length of zero, and `false` otherwise.
288    pub fn is_empty(&self) -> bool {
289        self.length == 0
290    }
291
292    pub fn chain(self, other: CmdBody) -> Self {
293        let length = (self.length - self.bytes_read).checked_add(other.length - other.bytes_read).unwrap_or(0);
294        Self {
295            length,
296            reader: Box::new(tokio::io::AsyncReadExt::chain(self, other)),
297            bytes_read: 0,
298        }
299    }
300}
301
302impl Debug for CmdBody {
303    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
304        f.debug_struct("CmdResponse")
305            .field("reader", &"<hidden>")
306            .field("length", &self.length)
307            .field("bytes_read", &self.bytes_read)
308            .finish()
309    }
310}
311
312
313impl From<String> for CmdBody {
314    fn from(s: String) -> Self {
315        Self::from_string(s)
316    }
317}
318
319impl<'a> From<&'a str> for CmdBody {
320    fn from(s: &'a str) -> Self {
321        Self::from_string(s.to_owned())
322    }
323}
324
325impl From<Vec<u8>> for CmdBody {
326    fn from(b: Vec<u8>) -> Self {
327        Self::from_bytes(b)
328    }
329}
330
331impl<'a> From<&'a [u8]> for CmdBody {
332    fn from(b: &'a [u8]) -> Self {
333        Self::from_bytes(b.to_owned())
334    }
335}
336
337impl AsyncRead for CmdBody {
338    #[allow(rustdoc::missing_doc_code_examples)]
339    fn poll_read(
340        mut self: Pin<&mut Self>,
341        cx: &mut Context<'_>,
342        buf: &mut ReadBuf<'_>,
343    ) -> Poll<io::Result<()>> {
344        let buf = if self.length == self.bytes_read {
345            return Poll::Ready(Ok(()));
346        } else {
347            buf
348        };
349
350        ready!(Pin::new(&mut self.reader).poll_read(cx, buf))?;
351        self.bytes_read += buf.filled().len() as u64;
352        Poll::Ready(Ok(()))
353    }
354}
355
356impl AsyncBufRead for CmdBody {
357    #[allow(rustdoc::missing_doc_code_examples)]
358    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&'_ [u8]>> {
359        self.project().reader.poll_fill_buf(cx)
360    }
361
362    fn consume(mut self: Pin<&mut Self>, amt: usize) {
363        Pin::new(&mut self.reader).consume(amt)
364    }
365}