sfo_cmd_server/client/
classified_client.rs

1use std::hash::Hash;
2use std::sync::{Arc, Mutex};
3use bucky_raw_codec::{RawConvertTo, RawDecode, RawEncode, RawFixedBytes, RawFrom};
4use num::{FromPrimitive, ToPrimitive};
5use sfo_pool::{into_pool_err, pool_err, ClassifiedWorker, ClassifiedWorkerFactory, ClassifiedWorkerGuard, ClassifiedWorkerPool, ClassifiedWorkerPoolRef, PoolErrorCode, PoolResult, WorkerClassification};
6use tokio::io::{AsyncReadExt, AsyncWriteExt};
7use tokio::spawn;
8use tokio::task::JoinHandle;
9use crate::{CmdTunnelRead, CmdTunnelWrite, TunnelId, TunnelIdGenerator};
10use crate::client::{ClassifiedCmdClient, CmdClient};
11use crate::cmd::{CmdBodyReadImpl, CmdHandler, CmdHandlerMap, CmdHeader};
12use crate::errors::{into_cmd_err, CmdErrorCode, CmdResult};
13use crate::peer_id::PeerId;
14use std::fmt::Debug;
15use sfo_split::{RHalf, Splittable, WHalf};
16
17pub trait ClassifiedCmdTunnelRead<C: WorkerClassification>: CmdTunnelRead + 'static + Send {
18    fn get_classification(&self) -> C;
19}
20
21pub trait ClassifiedCmdTunnelWrite<C: WorkerClassification>: CmdTunnelWrite + 'static + Send {
22    fn get_classification(&self) -> C;
23}
24
25pub type ClassifiedCmdTunnel<R, W> = Splittable<R, W>;
26pub type ClassifiedCmdTunnelRHalf<R, W> = RHalf<R, W>;
27pub type ClassifiedCmdTunnelWHalf<R, W> = WHalf<R, W>;
28
29#[derive(Debug, Clone, Copy, Eq, Hash)]
30pub struct CmdClientTunnelClassification<C: WorkerClassification> {
31    tunnel_id: Option<TunnelId>,
32    classification: Option<C>,
33}
34
35impl<C: WorkerClassification> PartialEq for CmdClientTunnelClassification<C> {
36    fn eq(&self, other: &Self) -> bool {
37        self.tunnel_id == other.tunnel_id && self.classification == other.classification
38    }
39}
40
41
42#[async_trait::async_trait]
43pub trait ClassifiedCmdTunnelFactory<C: WorkerClassification, R: ClassifiedCmdTunnelRead<C>, W: ClassifiedCmdTunnelWrite<C>>: Send + Sync + 'static {
44    async fn create_tunnel(&self, classification: Option<C>) -> CmdResult<Splittable<R, W>>;
45}
46
47pub struct ClassifiedCmdSend<C, R, W, LEN, CMD>
48where
49    C: WorkerClassification,
50    R: ClassifiedCmdTunnelRead<C>,
51    W: ClassifiedCmdTunnelWrite<C>,
52    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
53    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug,
54{
55    pub(crate) recv_handle: JoinHandle<CmdResult<()>>,
56    pub(crate) write: ClassifiedCmdTunnelWHalf<R, W>,
57    pub(crate) is_work: bool,
58    pub(crate) classification: C,
59    pub(crate) tunnel_id: TunnelId,
60    _p: std::marker::PhantomData<(LEN, CMD)>,
61
62}
63
64impl<C, R, W, LEN, CMD> ClassifiedCmdSend<C, R, W, LEN, CMD>
65where C: WorkerClassification,
66      R: ClassifiedCmdTunnelRead<C>,
67      W: ClassifiedCmdTunnelWrite<C>,
68      LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
69      CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug {
70    pub(crate) fn new(tunnel_id: TunnelId, classification: C, recv_handle: JoinHandle<CmdResult<()>>, write: ClassifiedCmdTunnelWHalf<R, W>) -> Self {
71        Self {
72            recv_handle,
73            write,
74            is_work: true,
75            classification,
76            tunnel_id,
77            _p: Default::default(),
78        }
79    }
80
81    pub fn get_tunnel_id(&self) -> TunnelId {
82        self.tunnel_id
83    }
84
85    pub fn set_disable(&mut self) {
86        self.is_work = false;
87        self.recv_handle.abort();
88    }
89
90    pub async fn send(&mut self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
91        log::trace!("client {:?} send cmd: {:?}, len: {}, data: {}", self.tunnel_id, cmd, body.len(), hex::encode(body));
92        let header = CmdHeader::<LEN, CMD>::new(version, cmd, LEN::from_u64(body.len() as u64).unwrap());
93        let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
94        let ret = self.send_inner(buf.as_slice(), body).await;
95        if let Err(e) = ret {
96            self.set_disable();
97            return Err(e);
98        }
99        Ok(())
100    }
101
102    pub async fn send2(&mut self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
103        let mut len = 0;
104        for b in body.iter() {
105            len += b.len();
106            log::trace!("client {:?} send2 cmd {:?} body: {}", self.tunnel_id, cmd, hex::encode(b));
107        }
108        log::trace!("client {:?} send2 cmd: {:?}, len {}", self.tunnel_id, cmd, len);
109        let header = CmdHeader::<LEN, CMD>::new(version, cmd, LEN::from_u64(len as u64).unwrap());
110        let buf = header.to_vec().map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
111        let ret = self.send_inner2(buf.as_slice(), body).await;
112        if let Err(e) = ret {
113            self.set_disable();
114            return Err(e);
115        }
116        Ok(())
117    }
118
119    async fn send_inner(&mut self, header: &[u8], body: &[u8]) -> CmdResult<()> {
120        self.write.write_all(header).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
121        self.write.write_all(body).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
122        self.write.flush().await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
123        Ok(())
124    }
125
126    async fn send_inner2(&mut self, header: &[u8], body: &[&[u8]]) -> CmdResult<()> {
127        self.write.write_all(header).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
128        for b in body.iter() {
129            self.write.write_all(b).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
130        }
131        self.write.flush().await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
132        Ok(())
133    }
134}
135
136impl<C, R, W, LEN, CMD> Drop for ClassifiedCmdSend<C, R, W, LEN, CMD>
137where C: WorkerClassification,
138      R: ClassifiedCmdTunnelRead<C>,
139      W: ClassifiedCmdTunnelWrite<C>,
140      LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
141      CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug {
142    fn drop(&mut self) {
143        self.set_disable();
144    }
145}
146
147impl<C, R, W, LEN, CMD> ClassifiedWorker<CmdClientTunnelClassification<C>> for ClassifiedCmdSend<C, R, W, LEN, CMD>
148where C: WorkerClassification,
149      R: ClassifiedCmdTunnelRead<C>,
150      W: ClassifiedCmdTunnelWrite<C>,
151      LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
152      CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug {
153    fn is_work(&self) -> bool {
154        self.is_work && !self.recv_handle.is_finished()
155    }
156
157    fn is_valid(&self, c: CmdClientTunnelClassification<C>) -> bool {
158        if c.tunnel_id.is_some() {
159            self.tunnel_id == c.tunnel_id.unwrap()
160        } else {
161            if c.classification.is_some() {
162                self.classification == c.classification.unwrap()
163            } else {
164                true
165            }
166        }
167    }
168
169    fn classification(&self) -> CmdClientTunnelClassification<C> {
170        CmdClientTunnelClassification {
171            tunnel_id: Some(self.tunnel_id),
172            classification: Some(self.classification.clone()),
173        }
174    }
175}
176
177struct CmdWriteFactory<C: WorkerClassification,
178    R: ClassifiedCmdTunnelRead<C>,
179    W: ClassifiedCmdTunnelWrite<C>,
180    F: ClassifiedCmdTunnelFactory<C, R, W>,
181    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
182    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug> {
183    tunnel_factory: F,
184    cmd_handler: Arc<dyn CmdHandler<LEN, CMD>>,
185    tunnel_id_generator: TunnelIdGenerator,
186    _p: std::marker::PhantomData<Mutex<(C, R, W)>>,
187}
188
189impl<
190    C: WorkerClassification,
191    R: ClassifiedCmdTunnelRead<C>,
192    W: ClassifiedCmdTunnelWrite<C>,
193    F: ClassifiedCmdTunnelFactory<C, R, W>,
194    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
195    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Debug
196> CmdWriteFactory<C, R, W, F, LEN, CMD> {
197    pub fn new(tunnel_factory: F, cmd_handler: impl CmdHandler<LEN, CMD>) -> Self {
198        Self {
199            tunnel_factory,
200            cmd_handler: Arc::new(cmd_handler),
201            tunnel_id_generator: TunnelIdGenerator::new(),
202            _p: Default::default(),
203        }
204    }
205}
206
207#[async_trait::async_trait]
208impl<C: WorkerClassification,
209    R: ClassifiedCmdTunnelRead<C>,
210    W: ClassifiedCmdTunnelWrite<C>,
211    F: ClassifiedCmdTunnelFactory<C, R, W>,
212    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
213    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Debug> ClassifiedWorkerFactory<CmdClientTunnelClassification<C>, ClassifiedCmdSend<C, R, W, LEN, CMD>
214> for CmdWriteFactory<C, R, W, F, LEN, CMD> {
215    async fn create(&self, classification: Option<CmdClientTunnelClassification<C>>) -> PoolResult<ClassifiedCmdSend<C, R, W, LEN, CMD>> {
216        if classification.is_some() && classification.as_ref().unwrap().tunnel_id.is_some() {
217            return Err(pool_err!(PoolErrorCode::Failed, "tunnel {:?} not found", classification.as_ref().unwrap().tunnel_id.unwrap()));
218        }
219
220        let classification = if classification.is_some() && classification.as_ref().unwrap().classification.is_some() {
221            classification.unwrap().classification
222        } else {
223            None
224        };
225        let tunnel = self.tunnel_factory.create_tunnel(classification).await.map_err(into_pool_err!(PoolErrorCode::Failed))?;
226        let classification = tunnel.get_classification();
227        let peer_id = tunnel.get_remote_peer_id();
228        let tunnel_id = self.tunnel_id_generator.generate();
229        let (mut recv, write) = tunnel.split();
230        let cmd_handler = self.cmd_handler.clone();
231        let handle = spawn(async move {
232            let ret: CmdResult<()> = async move {
233                loop {
234                    let mut header = vec![0u8; CmdHeader::<LEN, CMD>::raw_bytes().unwrap()];
235                    let n = recv.read_exact(header.as_mut()).await.map_err(into_cmd_err!(CmdErrorCode::IoError))?;
236                    if n == 0 {
237                        break;
238                    }
239                    let header = CmdHeader::<LEN, CMD>::clone_from_slice(header.as_slice()).map_err(into_cmd_err!(CmdErrorCode::RawCodecError))?;
240                    log::trace!("recv cmd {:?} from {} len {} tunnel {:?}", header.cmd_code(), peer_id, header.pkg_len().to_u64().unwrap(), tunnel_id);
241                    let cmd_read = Box::new(CmdBodyReadImpl::new(recv, header.pkg_len().to_u64().unwrap() as usize));
242                    let waiter = cmd_read.get_waiter();
243                    let future = waiter.create_result_future();
244                    if let Err(e) = cmd_handler.handle(peer_id.clone(), tunnel_id, header, cmd_read).await {
245                        log::error!("handle cmd error: {:?}", e);
246                    }
247                    recv = future.await.map_err(into_cmd_err!(CmdErrorCode::Failed))??;
248                }
249                Ok(())
250            }.await;
251            ret
252        });
253        Ok(ClassifiedCmdSend::new(tunnel_id, classification, handle, write))
254    }
255}
256
257pub struct DefaultClassifiedCmdClient<C: WorkerClassification,
258    R: ClassifiedCmdTunnelRead<C>,
259    W: ClassifiedCmdTunnelWrite<C>,
260    F: ClassifiedCmdTunnelFactory<C, R, W>,
261    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
262    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Eq + Hash + Debug> {
263    tunnel_pool: ClassifiedWorkerPoolRef<CmdClientTunnelClassification<C>, ClassifiedCmdSend<C, R, W, LEN, CMD>, CmdWriteFactory<C, R, W, F, LEN, CMD>>,
264    cmd_handler_map: Arc<CmdHandlerMap<LEN, CMD>>,
265}
266
267impl<C: WorkerClassification,
268    R: ClassifiedCmdTunnelRead<C>,
269    W: ClassifiedCmdTunnelWrite<C>,
270    F: ClassifiedCmdTunnelFactory<C, R, W>,
271    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
272    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Eq + Hash + Debug> DefaultClassifiedCmdClient<C, R, W, F, LEN, CMD> {
273    pub fn new(factory: F, tunnel_count: u16) -> Arc<Self> {
274        let cmd_handler_map = Arc::new(CmdHandlerMap::new());
275        let handler_map = cmd_handler_map.clone();
276        Arc::new(Self {
277            tunnel_pool: ClassifiedWorkerPool::new(tunnel_count, CmdWriteFactory::<C, R, W, _, LEN, CMD>::new(factory, move |peer_id: PeerId, tunnel_id, header: CmdHeader<LEN, CMD>, body_read| {
278                let handler_map = handler_map.clone();
279                async move {
280                    if let Some(handler) = handler_map.get(header.cmd_code()) {
281                        handler.handle(peer_id, tunnel_id, header, body_read).await?;
282                    }
283                    Ok(())
284                }
285            })),
286            cmd_handler_map,
287        })
288    }
289
290    async fn get_send(&self) -> CmdResult<ClassifiedWorkerGuard<CmdClientTunnelClassification<C>, ClassifiedCmdSend<C, R, W, LEN, CMD>, CmdWriteFactory<C, R, W, F, LEN, CMD>>> {
291        self.tunnel_pool.get_worker().await.map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
292    }
293
294    async fn get_send_of_tunnel_id(&self, tunnel_id: TunnelId) -> CmdResult<ClassifiedWorkerGuard<CmdClientTunnelClassification<C>, ClassifiedCmdSend<C, R, W, LEN, CMD>, CmdWriteFactory<C, R, W, F, LEN, CMD>>> {
295        self.tunnel_pool.get_classified_worker(CmdClientTunnelClassification {
296            tunnel_id: Some(tunnel_id),
297            classification: None,
298        }).await.map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
299    }
300
301    async fn get_classified_send(&self, classification: C) -> CmdResult<ClassifiedWorkerGuard<CmdClientTunnelClassification<C>, ClassifiedCmdSend<C, R, W, LEN, CMD>, CmdWriteFactory<C, R, W, F, LEN, CMD>>> {
302        self.tunnel_pool.get_classified_worker(CmdClientTunnelClassification {
303            tunnel_id: None,
304            classification: Some(classification),
305        }).await.map_err(into_cmd_err!(CmdErrorCode::Failed, "get worker failed"))
306    }
307}
308
309#[async_trait::async_trait]
310impl<C: WorkerClassification,
311    R: ClassifiedCmdTunnelRead<C>,
312    W: ClassifiedCmdTunnelWrite<C>,
313    F: ClassifiedCmdTunnelFactory<C, R, W>,
314    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
315    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Eq + Hash + Debug> CmdClient<LEN, CMD> for DefaultClassifiedCmdClient<C, R, W, F, LEN, CMD> {
316    fn register_cmd_handler(&self, cmd: CMD, handler: impl CmdHandler<LEN, CMD>) {
317        self.cmd_handler_map.insert(cmd, handler);
318    }
319
320    async fn send(&self, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
321        let mut send = self.get_send().await?;
322        send.send(cmd, version, body).await
323    }
324
325    async fn send2(&self, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
326        let mut send = self.get_send().await?;
327        send.send2(cmd, version, body).await
328    }
329
330    async fn send_by_specify_tunnel(&self, tunnel_id: TunnelId, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
331        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
332        send.send(cmd, version, body).await
333    }
334
335    async fn send2_by_specify_tunnel(&self, tunnel_id: TunnelId, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
336        let mut send = self.get_send_of_tunnel_id(tunnel_id).await?;
337        send.send2(cmd, version, body).await
338    }
339
340    async fn clear_all_tunnel(&self) {
341        self.tunnel_pool.clear_all_worker().await;
342    }
343}
344
345#[async_trait::async_trait]
346impl<C: WorkerClassification,
347    R: ClassifiedCmdTunnelRead<C>,
348    W: ClassifiedCmdTunnelWrite<C>,
349    F: ClassifiedCmdTunnelFactory<C, R, W>,
350    LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive + RawFixedBytes,
351    CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + RawFixedBytes + Eq + Hash + Debug> ClassifiedCmdClient<LEN, CMD, C> for DefaultClassifiedCmdClient<C, R, W, F, LEN, CMD> {
352    async fn send_by_classified_tunnel(&self, classification: C, cmd: CMD, version: u8, body: &[u8]) -> CmdResult<()> {
353        let mut send = self.get_classified_send(classification).await?;
354        send.send(cmd, version, body).await
355    }
356
357    async fn send2_by_classified_tunnel(&self, classification: C, cmd: CMD, version: u8, body: &[&[u8]]) -> CmdResult<()> {
358        let mut send = self.get_classified_send(classification).await?;
359        send.send2(cmd, version, body).await
360    }
361
362    async fn find_tunnel_id_by_classified(&self, classification: C) -> CmdResult<TunnelId> {
363        let send = self.get_classified_send(classification).await?;
364        Ok(send.get_tunnel_id())
365    }
366}