sfo_cmd_server/
cmd.rs

1use std::collections::HashMap;
2use std::hash::Hash;
3use std::ops::DerefMut;
4use std::pin::Pin;
5use std::sync::{Arc, Mutex};
6use std::task::{Context, Poll};
7use bucky_raw_codec::{RawDecode, RawEncode, RawFixedBytes};
8use callback_result::{SingleCallbackWaiter};
9use num::{FromPrimitive, ToPrimitive};
10use sfo_split::RHalf;
11use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
12use crate::errors::{cmd_err, into_cmd_err, CmdErrorCode, CmdResult};
13use crate::peer_id::PeerId;
14use crate::{TunnelId};
15
16#[derive(RawEncode, RawDecode)]
17pub struct CmdHeader<LEN, CMD> {
18    pkg_len: LEN,
19    version: u8,
20    cmd_code: CMD,
21}
22
23impl<LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
24    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static> CmdHeader<LEN, CMD> {
25    pub fn new(version: u8, cmd_code: CMD, pkg_len: LEN) -> Self {
26        Self {
27            pkg_len,
28            version,
29            cmd_code,
30        }
31    }
32
33    pub fn pkg_len(&self) -> LEN {
34        self.pkg_len
35    }
36
37    pub fn version(&self) -> u8 {
38        self.version
39    }
40
41    pub fn cmd_code(&self) -> CMD {
42        self.cmd_code
43    }
44
45    pub fn set_pkg_len(&mut self, pkg_len: LEN) {
46        self.pkg_len = pkg_len;
47    }
48}
49
50impl<LEN: RawEncode + for<'a> RawDecode<'a> + Copy + RawFixedBytes,
51    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + RawFixedBytes> RawFixedBytes for CmdHeader<LEN, CMD> {
52    fn raw_bytes() -> Option<usize> {
53        Some(LEN::raw_bytes().unwrap() + u8::raw_bytes().unwrap() + CMD::raw_bytes().unwrap())
54    }
55}
56
57#[async_trait::async_trait]
58pub trait CmdBodyReadAll: tokio::io::AsyncRead + Send + 'static {
59    async fn read_all(&mut self) -> CmdResult<Vec<u8>>;
60}
61pub type CmdBodyRead = Box<dyn CmdBodyReadAll>;
62
63pub(crate) struct CmdBodyReadImpl<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin> {
64    recv: Option<RHalf<R, W>>,
65    len: usize,
66    offset: usize,
67    waiter: Arc<SingleCallbackWaiter<CmdResult<RHalf<R, W>>>>,
68}
69
70impl<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin> CmdBodyReadImpl<R, W> {
71    pub fn new(recv: RHalf<R, W>, len: usize) -> Self {
72        Self {
73            recv: Some(recv),
74            len,
75            offset: 0,
76            waiter: Arc::new(SingleCallbackWaiter::new()),
77        }
78    }
79
80
81    pub(crate) fn get_waiter(&self) -> Arc<SingleCallbackWaiter<CmdResult<RHalf<R, W>>>> {
82        self.waiter.clone()
83    }
84}
85
86#[async_trait::async_trait]
87impl<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin> CmdBodyReadAll for CmdBodyReadImpl<R, W> {
88    async fn read_all(&mut self) -> CmdResult<Vec<u8>> {
89        if self.offset == self.len {
90            return Ok(Vec::new());
91        }
92        let mut buf = vec![0u8; self.len - self.offset];
93        let ret = self.recv.as_mut().unwrap().read_exact(&mut buf).await.map_err(into_cmd_err!(CmdErrorCode::IoError));
94        if ret.is_ok() {
95            self.offset = self.len;
96            self.waiter.set_result_with_cache(Ok(self.recv.take().unwrap()));
97            Ok(buf)
98        } else {
99            self.recv.take();
100            self.waiter.set_result_with_cache(Err(cmd_err!(CmdErrorCode::IoError, "read body error")));
101            Err(ret.err().unwrap())
102        }
103    }
104}
105
106impl<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin> Drop for CmdBodyReadImpl<R, W> {
107    fn drop(&mut self) {
108        if self.recv.is_none() || self.len == self.offset {
109            return;
110        }
111        let mut recv = self.recv.take().unwrap();
112        let len = self.len - self.offset;
113        let waiter = self.waiter.clone();
114        if len == 0 {
115            waiter.set_result_with_cache(Ok(recv));
116            return;
117        }
118
119        tokio::spawn(async move {
120            let mut buf = vec![0u8; len];
121            if let Err(e) = recv.read_exact(&mut buf).await {
122                waiter.set_result_with_cache(Err(cmd_err!(CmdErrorCode::IoError, "read body error {}", e)));
123            } else {
124                waiter.set_result_with_cache(Ok(recv));
125            }
126        });
127    }
128}
129
130impl<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin> tokio::io::AsyncRead for CmdBodyReadImpl<R, W> {
131    fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
132        let this = Pin::into_inner(self);
133        let len = this.len - this.offset;
134        if len == 0 {
135            return Poll::Ready(Ok(()));
136        }
137        let buf = buf.initialize_unfilled();
138        let mut buf = ReadBuf::new(&mut buf[..len]);
139        let recv = Pin::new(this.recv.as_mut().unwrap().deref_mut());
140        let fut = recv.poll_read(cx, &mut buf);
141        match fut {
142            Poll::Ready(Ok(())) => {
143                this.offset += buf.filled().len();
144                if this.offset == this.len {
145                    this.waiter.set_result_with_cache(Ok(this.recv.take().unwrap()));
146                }
147                Poll::Ready(Ok(()))
148            }
149            Poll::Ready(Err(e)) => {
150                this.recv.take();
151                this.waiter.set_result_with_cache(Err(cmd_err!(CmdErrorCode::IoError, "read body error")));
152                Poll::Ready(Err(e))
153            },
154            Poll::Pending => Poll::Pending,
155        }
156    }
157}
158
159
160#[callback_trait::callback_trait]
161pub trait CmdHandler<LEN, CMD>: Send + Sync + 'static
162where LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
163      CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static {
164    async fn handle(&self, peer_id: PeerId, tunnel_id: TunnelId, header: CmdHeader<LEN, CMD>, body: CmdBodyRead) -> CmdResult<()>;
165}
166
167pub(crate) struct CmdHandlerMap<LEN, CMD> {
168    map: Mutex<HashMap<CMD, Arc<dyn CmdHandler<LEN, CMD>>>>,
169}
170
171impl <LEN, CMD> CmdHandlerMap<LEN, CMD>
172where LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
173      CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Eq + Hash {
174    pub fn new() -> Self {
175        Self {
176            map: Mutex::new(HashMap::new()),
177        }
178    }
179
180    pub fn insert(&self, cmd: CMD, handler: impl CmdHandler<LEN, CMD>) {
181        self.map.lock().unwrap().insert(cmd, Arc::new(handler));
182    }
183
184    pub fn get(&self, cmd: CMD) -> Option<Arc<dyn CmdHandler<LEN, CMD>>> {
185        self.map.lock().unwrap().get(&cmd).map(|v| v.clone())
186    }
187}