sfo_cmd_server/client/
client.rs

1use std::fmt::Debug;
2use std::hash::Hash;
3use std::sync::{Arc, Mutex};
4use bucky_raw_codec::{RawConvertTo, RawDecode, RawEncode, RawFixedBytes, RawFrom};
5use num::{FromPrimitive, ToPrimitive};
6use sfo_pool::{into_pool_err, pool_err, ClassifiedWorker, ClassifiedWorkerFactory, ClassifiedWorkerGuard, ClassifiedWorkerPool, ClassifiedWorkerPoolRef, PoolErrorCode, PoolResult};
7use sfo_split::{Splittable, WHalf};
8use tokio::io::{AsyncReadExt, AsyncWriteExt};
9use tokio::spawn;
10use tokio::task::JoinHandle;
11use crate::{CmdTunnelRead, CmdTunnelWrite, TunnelId, TunnelIdGenerator};
12use crate::client::CmdClient;
13use crate::cmd::{CmdBodyReadImpl, CmdHandler, CmdHandlerMap, CmdHeader};
14use crate::errors::{into_cmd_err, CmdErrorCode, CmdResult};
15use crate::peer_id::PeerId;
16
17#[async_trait::async_trait]
18pub trait CmdTunnelFactory<R: CmdTunnelRead, W: CmdTunnelWrite>: Send + Sync + 'static {
19    async fn create_tunnel(&self) -> CmdResult<Splittable<R, W>>;
20}
21
22pub struct CmdSend<R: CmdTunnelRead, W: CmdTunnelWrite, LEN, CMD>
23where LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
24      CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug {
25    pub(crate) recv_handle: JoinHandle<CmdResult<()>>,
26    pub(crate) write: WHalf<R, W>,
27    pub(crate) is_work: bool,
28    pub(crate) tunnel_id: TunnelId,
29    _p: std::marker::PhantomData<(LEN, CMD)>,
30
31}
32
33impl<R, W, LEN, CMD> CmdSend<R, W, LEN, CMD>
34where R: CmdTunnelRead,
35      W: CmdTunnelWrite,
36      LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
37      CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug {
38    pub fn new(tunnel_id: TunnelId, recv_handle: JoinHandle<CmdResult<()>>, write: WHalf<R, W>) -> Self {
39        Self {
40            recv_handle,
41            write,
42            is_work: true,
43            tunnel_id,
44            _p: Default::default(),
45        }
46    }
47
48    pub fn get_tunnel_id(&self) -> TunnelId {
49        self.tunnel_id
50    }
51
52    pub fn set_disable(&mut self) {
53        self.is_work = false;
54        self.recv_handle.abort();
55    }
56
57    pub async fn send(&mut self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
58        log::trace!("client {:?} send cmd: {:?}, len: {} data:{}", self.tunnel_id, cmd, body.len(), hex::encode(body));
59        let header = CmdHeader::<LEN, CMD>::new(version, cmd, LEN::from_u64(body.len() as u64).unwrap());
60        let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
61        let ret = self.send_inner(buf.as_slice(), body).await;
62        if let Err(e) = ret {
63            self.set_disable();
64            return Err(e);
65        }
66        Ok(())
67    }
68
69    pub async fn send2(&mut self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
70        let mut len = 0;
71        for b in body.iter() {
72            len += b.len();
73            log::trace!("client {:?} send2 cmd: {:?}, data {}", self.tunnel_id, cmd, hex::encode(b));
74        }
75        log::trace!("client {:?} send2 cmd: {:?}, len {}", self.tunnel_id, cmd, len);
76        let header = CmdHeader::<LEN, CMD>::new(version, cmd, LEN::from_u64(len as u64).unwrap());
77        let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
78        let ret = self.send_inner2(buf.as_slice(), body).await;
79        if let Err(e) = ret {
80            self.set_disable();
81            return Err(e);
82        }
83        Ok(())
84    }
85
86    async fn send_inner(&mut self, header: &[u8], body: &[u8]) -> CmdResult<()> {
87        self.write.write_all(header).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
88        self.write.write_all(body).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
89        self.write.flush().await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
90        Ok(())
91    }
92
93    async fn send_inner2(&mut self, header: &[u8], body: &[&[u8]]) -> CmdResult<()> {
94        self.write.write_all(header).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
95        for b in body.iter() {
96            self.write.write_all(b).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
97        }
98        self.write.flush().await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
99        Ok(())
100    }
101}
102
103impl<R, W, LEN, CMD> Drop for CmdSend<R, W, LEN, CMD>
104where R: CmdTunnelRead,
105      W: CmdTunnelWrite,
106      LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
107      CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug {
108    fn drop(&mut self) {
109        self.set_disable();
110    }
111}
112
113impl<R, W, LEN, CMD> ClassifiedWorker<TunnelId> for CmdSend<R, W, LEN, CMD>
114where R: CmdTunnelRead,
115      W: CmdTunnelWrite,
116      LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
117      CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug{
118    fn is_work(&self) -> bool {
119        self.is_work && !self.recv_handle.is_finished()
120    }
121
122    fn is_valid(&self, c: TunnelId) -> bool {
123        self.tunnel_id == c
124    }
125
126    fn classification(&self) -> TunnelId {
127        self.tunnel_id
128    }
129}
130
131
132struct CmdWriteFactory<R: CmdTunnelRead,
133    W: CmdTunnelWrite,
134    F: CmdTunnelFactory<R, W>,
135    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
136    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug> {
137    tunnel_factory: F,
138    cmd_handler: Arc<dyn CmdHandler<LEN, CMD>>,
139    tunnel_id_generator: TunnelIdGenerator,
140    p: std::marker::PhantomData<Mutex<(R, W)>>,
141}
142
143impl<R: CmdTunnelRead,
144    W: CmdTunnelWrite,
145    F: CmdTunnelFactory<R, W>,
146    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
147    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug> CmdWriteFactory<R, W, F, LEN, CMD> {
148    pub fn new(tunnel_factory: F, cmd_handler: impl CmdHandler<LEN, CMD>) -> Self {
149        Self {
150            tunnel_factory,
151            cmd_handler: Arc::new(cmd_handler),
152            tunnel_id_generator: TunnelIdGenerator::new(),
153            p: Default::default(),
154        }
155    }
156}
157
158#[async_trait::async_trait]
159impl<R: CmdTunnelRead,
160    W: CmdTunnelWrite,
161    F: CmdTunnelFactory<R, W>,
162    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
163    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Debug> ClassifiedWorkerFactory<TunnelId, CmdSend<R, W, LEN, CMD>> for CmdWriteFactory<R, W, F, LEN, CMD> {
164    async fn create(&self, c: Option<TunnelId>) -> PoolResult<CmdSend<R, W, LEN, CMD>> {
165        if c.is_some() {
166            return Err(pool_err!(PoolErrorCode::Failed, "tunnel {:?} not found", c.unwrap()));
167        }
168        let tunnel = self.tunnel_factory.create_tunnel().await.map_err(into_pool_err!(PoolErrorCode::Failed))?;
169        let peer_id = tunnel.get_remote_peer_id();
170        let tunnel_id = self.tunnel_id_generator.generate();
171        let (mut recv, write) = tunnel.split();
172        let cmd_handler = self.cmd_handler.clone();
173        let handle = spawn(async move {
174            let ret: CmdResult<()> = async move {
175                loop {
176                    let mut header = vec![0u8; CmdHeader::<LEN, CMD>::raw_bytes().unwrap()];
177                    let n = recv.read_exact(header.as_mut()).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
178                    if n == 0 {
179                        break;
180                    }
181                    let header = CmdHeader::<LEN, CMD>::clone_from_slice(header.as_slice()).map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
182                    log::trace!("recv cmd {:?} from {} len {}", header.cmd_code(), peer_id.to_base58(), header.pkg_len().to_u64().unwrap());
183                    let cmd_read = Box::new(CmdBodyReadImpl::new(recv, header.pkg_len().to_u64().unwrap() as usize));
184                    let waiter = cmd_read.get_waiter();
185                    let future = waiter.create_result_future();
186                    if let Err(e) = cmd_handler.handle(peer_id.clone(), tunnel_id, header, cmd_read).await {
187                        log::error!("handle cmd error: {:?}", e);
188                    }
189                    recv = future.await.map_err(into_cmd_err!(CmdErrorCode::Failed))??;
190                }
191                Ok(())
192            }.await;
193            ret
194        });
195        Ok(CmdSend::new(tunnel_id, handle, write))
196    }
197}
198
199pub struct DefaultCmdClient<R: CmdTunnelRead,
200    W: CmdTunnelWrite,
201    F: CmdTunnelFactory<R, W>,
202    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
203    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Eq + Hash + Debug> {
204    tunnel_pool: ClassifiedWorkerPoolRef<TunnelId, CmdSend<R, W, LEN, CMD>, CmdWriteFactory<R, W, F, LEN, CMD>>,
205    cmd_handler_map: Arc<CmdHandlerMap<LEN, CMD>>,
206}
207
208impl<R: CmdTunnelRead,
209    W: CmdTunnelWrite,
210    F: CmdTunnelFactory<R, W>,
211    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
212    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Eq + Hash + Debug> DefaultCmdClient<R, W, F, LEN, CMD> {
213    pub fn new(factory: F, tunnel_count: u16) -> Arc<Self> {
214        let cmd_handler_map = Arc::new(CmdHandlerMap::new());
215        let handler_map = cmd_handler_map.clone();
216        Arc::new(Self {
217            tunnel_pool: ClassifiedWorkerPool::new(tunnel_count, CmdWriteFactory::<R, W, _, LEN, CMD>::new(factory, move |peer_id: PeerId, tunnel_id: TunnelId, header: CmdHeader<LEN, CMD>, body_read| {
218                let handler_map = handler_map.clone();
219                async move {
220                    if let Some(handler) = handler_map.get(header.cmd_code()) {
221                        handler.handle(peer_id, tunnel_id, header, body_read).await?;
222                    }
223                    Ok(())
224                }
225            })),
226            cmd_handler_map,
227        })
228    }
229
230    async fn get_send(&self) -> CmdResult<ClassifiedWorkerGuard<TunnelId, CmdSend<R, W, LEN, CMD>, CmdWriteFactory<R, W, F, LEN, CMD>>> {
231        self.tunnel_pool.get_worker().await.map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
232    }
233
234    async fn get_send_of_tunnel_id(&self, tunnel_id: TunnelId) -> CmdResult<ClassifiedWorkerGuard<TunnelId, CmdSend<R, W, LEN, CMD>, CmdWriteFactory<R, W, F, LEN, CMD>>> {
235        self.tunnel_pool.get_classified_worker(tunnel_id).await.map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
236    }
237}
238
239#[async_trait::async_trait]
240impl<R: CmdTunnelRead,
241    W: CmdTunnelWrite,
242    F: CmdTunnelFactory<R, W>,
243    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
244    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Eq + Hash + Debug> CmdClient<LEN, CMD> for DefaultCmdClient<R, W, F, LEN, CMD> {
245    fn register_cmd_handler(&self, cmd: CMD, handler: impl CmdHandler<LEN, CMD>) {
246        self.cmd_handler_map.insert(cmd, handler);
247    }
248
249    async fn send(&self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
250        let mut send = self.get_send().await?;
251        send.send(cmd, version, body).await
252    }
253
254    async fn send2(&self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
255        let mut send = self.get_send().await?;
256        send.send2(cmd, version, body).await
257    }
258
259    async fn send_by_specify_tunnel(&self, tunnel_id: TunnelId, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
260        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
261        send.send(cmd, version, body).await
262    }
263
264    async fn send2_by_specify_tunnel(&self, tunnel_id: TunnelId, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
265        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
266        send.send2(cmd, version, body).await
267    }
268
269    async fn clear_all_tunnel(&self) {
270        self.tunnel_pool.clear_all_worker().await;
271    }
272}