1mod crypto;
2mod datagram;
3
4use std::{sync::Arc, time::Duration};
5
6use anyhow::Context;
7use async_io::Timer;
8use async_task::Task;
9use async_trait::async_trait;
10use event_listener::Event;
11use parking_lot::Mutex;
12use serde::{Deserialize, Serialize};
13use stdcode::StdcodeSerializeExt;
14use virta::{stream_state::StreamState, StreamMessage};
15
16use crate::{
17 crypto::PresharedSecret,
18 datagram::{DgConnection, DgListener},
19};
20use futures_lite::FutureExt;
21use futures_util::io::{AsyncRead, AsyncWrite};
22
23use pin_project::pin_project;
24use sillad::{dialer::Dialer, listener::Listener, Pipe};
25
26#[derive(Clone, Copy, Debug, Serialize, Deserialize, Eq, Hash, PartialEq)]
28pub struct MeeklikeConfig {
29 pub max_inflight: usize,
30 pub mss: usize,
31 pub base64: bool,
32}
33
34impl Default for MeeklikeConfig {
35 fn default() -> Self {
36 Self {
37 max_inflight: 50,
38 mss: 3000,
39 base64: false,
40 }
41 }
42}
43
44#[pin_project]
45pub struct MeeklikePipe {
47 #[pin]
48 inner: virta::Stream,
49
50 _task: Task<()>,
51}
52
53impl AsyncRead for MeeklikePipe {
54 fn poll_read(
55 self: std::pin::Pin<&mut Self>,
56 cx: &mut std::task::Context<'_>,
57 buf: &mut [u8],
58 ) -> std::task::Poll<std::io::Result<usize>> {
59 self.project().inner.poll_read(cx, buf)
60 }
61}
62
63impl AsyncWrite for MeeklikePipe {
64 fn poll_write(
65 self: std::pin::Pin<&mut Self>,
66 cx: &mut std::task::Context<'_>,
67 buf: &[u8],
68 ) -> std::task::Poll<std::io::Result<usize>> {
69 self.project().inner.poll_write(cx, buf)
70 }
71
72 fn poll_flush(
73 self: std::pin::Pin<&mut Self>,
74 cx: &mut std::task::Context<'_>,
75 ) -> std::task::Poll<std::io::Result<()>> {
76 self.project().inner.poll_flush(cx)
77 }
78
79 fn poll_close(
80 self: std::pin::Pin<&mut Self>,
81 cx: &mut std::task::Context<'_>,
82 ) -> std::task::Poll<std::io::Result<()>> {
83 self.project().inner.poll_close(cx)
84 }
85}
86
87impl Pipe for MeeklikePipe {
88 fn protocol(&self) -> &str {
89 "meeklike"
90 }
91 fn remote_addr(&self) -> Option<&str> {
92 Some("0.0.0.0:0")
93 }
94}
95
96pub struct MeeklikeDialer<D: Dialer> {
97 pub inner: Arc<D>,
98 pub key: [u8; 32],
99 pub cfg: MeeklikeConfig,
100}
101
102#[async_trait]
103impl<D: Dialer> Dialer for MeeklikeDialer<D> {
104 type P = MeeklikePipe;
105 async fn dial(&self) -> std::io::Result<Self::P> {
106 let stream_id: u128 = rand::random();
107 let dg_conn = DgConnection::new(
108 self.cfg,
109 PresharedSecret::new(&self.key).into(),
110 stream_id,
111 self.inner.clone(),
112 );
113 let notify = Arc::new(Event::new());
114 let (mut state, inner) = virta::stream_state::StreamState::new_pending({
115 let notify = notify.clone();
116 move || {
117 notify.notify(1);
118 }
119 });
120 state.set_mss(self.cfg.mss);
121 let _task = smolscale::spawn(ticker(notify, state, dg_conn));
122 inner.wait_connected().await?;
123 Ok(MeeklikePipe { inner, _task })
124 }
125}
126
127async fn ticker(notify: Arc<Event>, state: StreamState, dg_conn: DgConnection) {
128 let state = Mutex::new(state);
129 let up = async {
130 let mut timer = Timer::after(Duration::from_secs(10));
131 loop {
132 let evt = notify.listen();
133 let next = state.lock().tick(|b| dg_conn.send(b.stdcode().into()));
134 if let Some(next) = next {
135 timer.set_at(next);
136 async {
137 (&mut timer).await;
138 }
139 .race(async {
140 evt.await;
141 })
142 .await
143 } else {
144 break;
145 }
146 }
147 anyhow::Ok(())
148 };
149
150 let dn = async {
151 loop {
152 let bts = dg_conn
153 .recv()
154 .await
155 .context("could not received from underlying")?;
156 let msg: Result<StreamMessage, _> = stdcode::deserialize(&bts);
157 match msg {
158 Ok(msg) => {
159 state.lock().inject_incoming(msg);
160 }
161 Err(err) => {
162 tracing::warn!(err = debug(err), "error getting message")
163 }
164 }
165 }
166 };
167
168 if let Err(err) = up.race(dn).await {
169 tracing::warn!(err = debug(err), "ticker died abnormally")
170 }
171}
172
173pub struct MeeklikeListener<L: Listener> {
174 listener: DgListener,
175 cfg: MeeklikeConfig,
176 _phantom: std::marker::PhantomData<L>,
177}
178
179impl<L: Listener> MeeklikeListener<L> {
180 pub fn new(inner: L, key: [u8; 32], cfg: MeeklikeConfig) -> Self {
181 let listener = DgListener::new(cfg, PresharedSecret::new(&key).into(), inner);
182 Self {
183 listener,
184 cfg,
185 _phantom: std::marker::PhantomData,
186 }
187 }
188}
189
190#[async_trait]
191impl<L: Listener> Listener for MeeklikeListener<L> {
192 type P = MeeklikePipe;
193 async fn accept(&mut self) -> std::io::Result<Self::P> {
194 let dg_conn = self
195 .listener
196 .accept()
197 .await
198 .map_err(std::io::Error::other)?;
199 let notify = Arc::new(Event::new());
200 let (mut state, inner) = virta::stream_state::StreamState::new_established({
201 let notify = notify.clone();
202 move || {
203 notify.notify(1);
204 }
205 });
206 state.set_mss(self.cfg.mss);
207 let _task = smolscale::spawn(ticker(notify, state, dg_conn));
208 Ok(MeeklikePipe { inner, _task })
209 }
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215 use futures_util::{AsyncReadExt, AsyncWriteExt};
216 use sillad::tcp::{TcpDialer, TcpListener};
217 use std::net::{IpAddr, Ipv4Addr, SocketAddr};
218
219 #[test]
220 fn ping_pong() {
221 smolscale::block_on(async {
222 let listener = TcpListener::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
223 .await
224 .unwrap();
225 let addr = listener.local_addr().await;
226 let key = [7u8; 32];
227 let mut meek_listener = MeeklikeListener::new(listener, key, Default::default());
228
229 let dialer = MeeklikeDialer {
230 inner: TcpDialer { dest_addr: addr }.into(),
231 key,
232 cfg: Default::default(),
233 };
234
235 let server = smolscale::spawn(async move {
236 let mut pipe = meek_listener.accept().await.unwrap();
237 let mut buf = [0u8; 4];
238 pipe.read_exact(&mut buf).await.unwrap();
239 assert_eq!(&buf, b"ping");
240 pipe.write_all(b"pong").await.unwrap();
241 pipe.flush().await.unwrap();
242 });
243
244 let client = smolscale::spawn(async move {
245 let mut pipe = dialer.dial().await.unwrap();
246 pipe.write_all(b"ping").await.unwrap();
247 pipe.flush().await.unwrap();
248 let mut buf = [0u8; 4];
249 pipe.read_exact(&mut buf).await.unwrap();
250 assert_eq!(&buf, b"pong");
251 });
252
253 server.await;
254 client.await;
255 });
256 }
257}