sillad_meeklike/
lib.rs

1mod crypto;
2mod datagram;
3
4use std::{sync::Arc, time::Duration};
5
6use anyhow::Context;
7use async_io::Timer;
8use async_task::Task;
9use async_trait::async_trait;
10use event_listener::Event;
11use parking_lot::Mutex;
12use serde::{Deserialize, Serialize};
13use stdcode::StdcodeSerializeExt;
14use virta::{stream_state::StreamState, StreamMessage};
15
16use crate::{
17    crypto::PresharedSecret,
18    datagram::{DgConnection, DgListener},
19};
20use futures_lite::FutureExt;
21use futures_util::io::{AsyncRead, AsyncWrite};
22
23use pin_project::pin_project;
24use sillad::{dialer::Dialer, listener::Listener, Pipe};
25
26/// Configuration for Meeklike
27#[derive(Clone, Copy, Debug, Serialize, Deserialize, Eq, Hash, PartialEq)]
28pub struct MeeklikeConfig {
29    pub max_inflight: usize,
30    pub mss: usize,
31    pub base64: bool,
32}
33
34impl Default for MeeklikeConfig {
35    fn default() -> Self {
36        Self {
37            max_inflight: 50,
38            mss: 3000,
39            base64: false,
40        }
41    }
42}
43
44#[pin_project]
45/// A meeklike "pipe" that takes a meeklike stuff.
46pub struct MeeklikePipe {
47    #[pin]
48    inner: virta::Stream,
49
50    _task: Task<()>,
51}
52
53impl AsyncRead for MeeklikePipe {
54    fn poll_read(
55        self: std::pin::Pin<&mut Self>,
56        cx: &mut std::task::Context<'_>,
57        buf: &mut [u8],
58    ) -> std::task::Poll<std::io::Result<usize>> {
59        self.project().inner.poll_read(cx, buf)
60    }
61}
62
63impl AsyncWrite for MeeklikePipe {
64    fn poll_write(
65        self: std::pin::Pin<&mut Self>,
66        cx: &mut std::task::Context<'_>,
67        buf: &[u8],
68    ) -> std::task::Poll<std::io::Result<usize>> {
69        self.project().inner.poll_write(cx, buf)
70    }
71
72    fn poll_flush(
73        self: std::pin::Pin<&mut Self>,
74        cx: &mut std::task::Context<'_>,
75    ) -> std::task::Poll<std::io::Result<()>> {
76        self.project().inner.poll_flush(cx)
77    }
78
79    fn poll_close(
80        self: std::pin::Pin<&mut Self>,
81        cx: &mut std::task::Context<'_>,
82    ) -> std::task::Poll<std::io::Result<()>> {
83        self.project().inner.poll_close(cx)
84    }
85}
86
87impl Pipe for MeeklikePipe {
88    fn protocol(&self) -> &str {
89        "meeklike"
90    }
91    fn remote_addr(&self) -> Option<&str> {
92        Some("0.0.0.0:0")
93    }
94}
95
96pub struct MeeklikeDialer<D: Dialer> {
97    pub inner: Arc<D>,
98    pub key: [u8; 32],
99    pub cfg: MeeklikeConfig,
100}
101
102#[async_trait]
103impl<D: Dialer> Dialer for MeeklikeDialer<D> {
104    type P = MeeklikePipe;
105    async fn dial(&self) -> std::io::Result<Self::P> {
106        let stream_id: u128 = rand::random();
107        let dg_conn = DgConnection::new(
108            self.cfg,
109            PresharedSecret::new(&self.key).into(),
110            stream_id,
111            self.inner.clone(),
112        );
113        let notify = Arc::new(Event::new());
114        let (mut state, inner) = virta::stream_state::StreamState::new_pending({
115            let notify = notify.clone();
116            move || {
117                notify.notify(1);
118            }
119        });
120        state.set_mss(self.cfg.mss);
121        let _task = smolscale::spawn(ticker(notify, state, dg_conn));
122        inner.wait_connected().await?;
123        Ok(MeeklikePipe { inner, _task })
124    }
125}
126
127async fn ticker(notify: Arc<Event>, state: StreamState, dg_conn: DgConnection) {
128    let state = Mutex::new(state);
129    let up = async {
130        let mut timer = Timer::after(Duration::from_secs(10));
131        loop {
132            let evt = notify.listen();
133            let next = state.lock().tick(|b| dg_conn.send(b.stdcode().into()));
134            if let Some(next) = next {
135                timer.set_at(next);
136                async {
137                    (&mut timer).await;
138                }
139                .race(async {
140                    evt.await;
141                })
142                .await
143            } else {
144                break;
145            }
146        }
147        anyhow::Ok(())
148    };
149
150    let dn = async {
151        loop {
152            let bts = dg_conn
153                .recv()
154                .await
155                .context("could not received from underlying")?;
156            let msg: Result<StreamMessage, _> = stdcode::deserialize(&bts);
157            match msg {
158                Ok(msg) => {
159                    state.lock().inject_incoming(msg);
160                }
161                Err(err) => {
162                    tracing::warn!(err = debug(err), "error getting message")
163                }
164            }
165        }
166    };
167
168    if let Err(err) = up.race(dn).await {
169        tracing::warn!(err = debug(err), "ticker died abnormally")
170    }
171}
172
173pub struct MeeklikeListener<L: Listener> {
174    listener: DgListener,
175    cfg: MeeklikeConfig,
176    _phantom: std::marker::PhantomData<L>,
177}
178
179impl<L: Listener> MeeklikeListener<L> {
180    pub fn new(inner: L, key: [u8; 32], cfg: MeeklikeConfig) -> Self {
181        let listener = DgListener::new(cfg, PresharedSecret::new(&key).into(), inner);
182        Self {
183            listener,
184            cfg,
185            _phantom: std::marker::PhantomData,
186        }
187    }
188}
189
190#[async_trait]
191impl<L: Listener> Listener for MeeklikeListener<L> {
192    type P = MeeklikePipe;
193    async fn accept(&mut self) -> std::io::Result<Self::P> {
194        let dg_conn = self
195            .listener
196            .accept()
197            .await
198            .map_err(std::io::Error::other)?;
199        let notify = Arc::new(Event::new());
200        let (mut state, inner) = virta::stream_state::StreamState::new_established({
201            let notify = notify.clone();
202            move || {
203                notify.notify(1);
204            }
205        });
206        state.set_mss(self.cfg.mss);
207        let _task = smolscale::spawn(ticker(notify, state, dg_conn));
208        Ok(MeeklikePipe { inner, _task })
209    }
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215    use futures_util::{AsyncReadExt, AsyncWriteExt};
216    use sillad::tcp::{TcpDialer, TcpListener};
217    use std::net::{IpAddr, Ipv4Addr, SocketAddr};
218
219    #[test]
220    fn ping_pong() {
221        smolscale::block_on(async {
222            let listener = TcpListener::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
223                .await
224                .unwrap();
225            let addr = listener.local_addr().await;
226            let key = [7u8; 32];
227            let mut meek_listener = MeeklikeListener::new(listener, key, Default::default());
228
229            let dialer = MeeklikeDialer {
230                inner: TcpDialer { dest_addr: addr }.into(),
231                key,
232                cfg: Default::default(),
233            };
234
235            let server = smolscale::spawn(async move {
236                let mut pipe = meek_listener.accept().await.unwrap();
237                let mut buf = [0u8; 4];
238                pipe.read_exact(&mut buf).await.unwrap();
239                assert_eq!(&buf, b"ping");
240                pipe.write_all(b"pong").await.unwrap();
241                pipe.flush().await.unwrap();
242            });
243
244            let client = smolscale::spawn(async move {
245                let mut pipe = dialer.dial().await.unwrap();
246                pipe.write_all(b"ping").await.unwrap();
247                pipe.flush().await.unwrap();
248                let mut buf = [0u8; 4];
249                pipe.read_exact(&mut buf).await.unwrap();
250                assert_eq!(&buf, b"pong");
251            });
252
253            server.await;
254            client.await;
255        });
256    }
257}