1use super::*;
9use std::collections::HashMap;
10use std::sync::{Arc, Mutex, Weak};
11
12static U: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(1);
13
14enum TaskMsg {
15 NewWs {
16 uniq: u64,
17 index: usize,
18 ws: Arc<dyn SbdWebsocket>,
19 ip: Arc<Ipv6Addr>,
20 pk: PubKey,
21 maybe_auth: Option<(Option<Arc<str>>, AuthTokenTracker)>,
22 },
23 Close,
24}
25
26struct SlotEntry {
27 send: tokio::sync::mpsc::UnboundedSender<TaskMsg>,
28}
29
30struct SlabEntry {
31 uniq: u64,
32 handshake_complete: bool,
33 weak_ws: Weak<dyn SbdWebsocket>,
34}
35
36struct CSlotInner {
37 max_count: usize,
38 slots: Vec<SlotEntry>,
39 slab: slab::Slab<SlabEntry>,
40 pk_to_index: HashMap<PubKey, usize>,
41 ip_to_index: HashMap<Arc<Ipv6Addr>, Vec<usize>>,
42 task_list: Vec<tokio::task::JoinHandle<()>>,
43}
44
45impl Drop for CSlotInner {
46 fn drop(&mut self) {
47 for task in self.task_list.iter() {
48 task.abort();
49 }
50 }
51}
52
53#[derive(Clone)]
55pub struct WeakCSlot(Weak<Mutex<CSlotInner>>);
56
57impl WeakCSlot {
58 pub fn upgrade(&self) -> Option<CSlot> {
60 self.0.upgrade().map(CSlot)
61 }
62}
63
64pub struct CSlot(Arc<Mutex<CSlotInner>>);
69
70impl CSlot {
71 pub fn new(config: Arc<Config>, ip_rate: Arc<IpRate>) -> Self {
73 let count = config.limit_clients as usize;
74 Self(Arc::new_cyclic(|this| {
75 let mut slots = Vec::with_capacity(count);
76 let mut task_list = Vec::with_capacity(count);
77 for _ in 0..count {
78 let (send, recv) = tokio::sync::mpsc::unbounded_channel();
79 slots.push(SlotEntry { send });
80 task_list.push(tokio::task::spawn(top_task(
81 config.clone(),
82 ip_rate.clone(),
83 WeakCSlot(this.clone()),
84 recv,
85 )));
86 }
87 Mutex::new(CSlotInner {
88 max_count: count,
89 slots,
90 slab: slab::Slab::with_capacity(count),
91 pk_to_index: HashMap::with_capacity(count),
92 ip_to_index: HashMap::with_capacity(count),
93 task_list,
94 })
95 }))
96 }
97
98 pub fn weak(&self) -> WeakCSlot {
100 WeakCSlot(Arc::downgrade(&self.0))
101 }
102
103 fn remove(&self, uniq: u64, index: usize) {
105 let mut lock = self.0.lock().unwrap();
106
107 match lock.slab.get(index) {
108 None => return,
109 Some(s) => {
110 if s.uniq != uniq {
111 return;
112 }
113 }
114 }
115
116 let _ = lock.slots.get(index).unwrap().send.send(TaskMsg::Close);
117 lock.slab.remove(index);
118 lock.pk_to_index.retain(|_, i| *i != index);
119 lock.ip_to_index.retain(|_, v| {
120 v.retain(|i| *i != index);
121 !v.is_empty()
122 });
123 }
124
125 #[allow(clippy::type_complexity)]
128 fn insert_and_get_rate_send_list(
129 &self,
130 ip: Arc<Ipv6Addr>,
131 pk: PubKey,
132 ws: Arc<dyn SbdWebsocket>,
133 maybe_auth: Option<(Option<Arc<str>>, AuthTokenTracker)>,
134 ) -> std::result::Result<
135 Vec<(u64, usize, Weak<dyn SbdWebsocket>)>,
136 Arc<dyn SbdWebsocket>,
137 > {
138 let mut lock = self.0.lock().unwrap();
139
140 if lock.slab.len() >= lock.max_count {
141 return Err(ws);
142 }
143
144 let weak_ws = Arc::downgrade(&ws);
145
146 let uniq = U.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
147
148 let index = lock.slab.insert(SlabEntry {
149 uniq,
150 weak_ws,
151 handshake_complete: false,
152 });
153
154 lock.pk_to_index.insert(pk.clone(), index);
155
156 let rate_send_list = {
157 let list = {
158 let e = lock
162 .ip_to_index
163 .entry(ip.clone())
164 .or_insert_with(|| Vec::with_capacity(1024));
165
166 e.push(index);
167
168 e.clone()
169 };
170
171 let mut rate_send_list = Vec::with_capacity(list.len());
172
173 for index in list.iter() {
174 if let Some(slab) = lock.slab.get(*index) {
175 rate_send_list.push((
176 slab.uniq,
177 *index,
178 slab.weak_ws.clone(),
179 ));
180 }
181 }
182
183 rate_send_list
184 };
185
186 let send = lock.slots.get(index).unwrap().send.clone();
187 let _ = send.send(TaskMsg::NewWs {
188 uniq,
189 index,
190 ws,
191 ip,
192 pk,
193 maybe_auth,
194 });
195
196 Ok(rate_send_list)
197 }
198
199 pub async fn insert(
201 &self,
202 config: &Config,
203 ip: Arc<Ipv6Addr>,
204 pk: PubKey,
205 ws: Arc<impl SbdWebsocket>,
206 maybe_auth: Option<(Option<Arc<str>>, AuthTokenTracker)>,
207 ) {
208 let rate_send_list =
209 self.insert_and_get_rate_send_list(ip, pk, ws, maybe_auth);
210
211 match rate_send_list {
212 Ok(rate_send_list) => {
213 let rate = if config.disable_rate_limiting {
214 1
215 } else {
216 let mut rate = config.limit_ip_byte_nanos() as u64
217 * rate_send_list.len() as u64;
218 if rate > i32::MAX as u64 {
219 rate = i32::MAX as u64;
220 }
221 rate as i32
222 };
223
224 for (uniq, index, weak_ws) in rate_send_list {
225 if let Some(ws) = weak_ws.upgrade() {
226 if ws
227 .send(cmd::SbdCmd::limit_byte_nanos(rate))
228 .await
229 .is_err()
230 {
231 self.remove(uniq, index);
232 }
233 }
234 }
235 }
236 Err(ws) => {
237 ws.close().await;
238 drop(ws);
239 }
240 }
241 }
242
243 fn mark_ready(&self, uniq: u64, index: usize) {
245 let mut lock = self.0.lock().unwrap();
246 if let Some(slab) = lock.slab.get_mut(index) {
247 if slab.uniq == uniq {
248 slab.handshake_complete = true;
249 }
250 }
251 }
252
253 fn get_sender(
255 &self,
256 pk: &PubKey,
257 ) -> Result<(u64, usize, Arc<dyn SbdWebsocket>)> {
258 let lock = self.0.lock().unwrap();
259
260 let index = match lock.pk_to_index.get(pk) {
261 None => return Err(Error::other("no such peer")),
262 Some(index) => *index,
263 };
264
265 let slab = lock.slab.get(index).unwrap();
266
267 if !slab.handshake_complete {
268 return Err(Error::other("no such peer"));
269 }
270
271 let uniq = slab.uniq;
272 let ws = match slab.weak_ws.upgrade() {
273 None => return Err(Error::other("no such peer")),
274 Some(ws) => ws,
275 };
276
277 Ok((uniq, index, ws))
278 }
279
280 async fn send(&self, pk: &PubKey, payload: Payload) -> Result<()> {
282 let (uniq, index, ws) = self.get_sender(pk)?;
283
284 match ws.send(payload).await {
285 Err(err) => {
286 self.remove(uniq, index);
287 Err(err)
288 }
289 Ok(_) => Ok(()),
290 }
291 }
292}
293
294async fn top_task(
297 config: Arc<Config>,
298 ip_rate: Arc<IpRate>,
299 weak: WeakCSlot,
300 mut recv: tokio::sync::mpsc::UnboundedReceiver<TaskMsg>,
301) {
302 let mut item = recv.recv().await;
303 loop {
304 let uitem = match item {
305 None => break,
306 Some(uitem) => uitem,
307 };
308
309 item = if let TaskMsg::NewWs {
310 uniq,
311 index,
312 ws,
313 ip,
314 pk,
315 maybe_auth,
316 } = uitem
317 {
318 let next_i = tokio::select! {
320 i = recv.recv() => Some(i),
321 _ = ws_task(
322 &config,
323 &ip_rate,
324 &weak,
325 &ws,
326 ip,
327 pk,
328 uniq,
329 index,
330 maybe_auth,
331 ) => None,
332 };
333
334 ws.close().await;
336 drop(ws);
337 if let Some(cslot) = weak.upgrade() {
338 cslot.remove(uniq, index);
339 }
340
341 match next_i {
342 Some(i) => i,
343 None => recv.recv().await,
344 }
345 } else {
346 recv.recv().await
347 };
348 }
349}
350
351#[allow(clippy::too_many_arguments)]
353async fn ws_task(
354 config: &Arc<Config>,
355 ip_rate: &IpRate,
356 weak_cslot: &WeakCSlot,
357 ws: &Arc<dyn SbdWebsocket>,
358 ip: Arc<Ipv6Addr>,
359 pk: PubKey,
360 uniq: u64,
361 index: usize,
362 maybe_auth: Option<(Option<Arc<str>>, AuthTokenTracker)>,
363) {
364 let auth_res = tokio::time::timeout(config.idle_dur(), async {
365 use rand::Rng;
366 let mut nonce = [0xdb; 32];
367 rand::thread_rng().fill(&mut nonce[..]);
368
369 ws.send(cmd::SbdCmd::auth_req(&nonce)).await?;
371
372 loop {
373 let auth_res = ws.recv().await?;
374
375 if !ip_rate.is_ok(&ip, auth_res.as_ref().len()).await {
376 return Err(Error::other("ip rate limited"));
377 }
378
379 if let Some((token, token_tracker)) = &maybe_auth {
380 let _ =
384 token_tracker.check_is_token_valid(config, token.clone());
385 }
386
387 match cmd::SbdCmd::parse(auth_res)? {
388 cmd::SbdCmd::AuthRes(sig) => {
389 if !pk.verify(&sig, &nonce) {
390 return Err(Error::other("invalid sig"));
391 }
392 break;
393 }
394 cmd::SbdCmd::Message(_) => {
395 return Err(Error::other(
396 "invalid forward before handshake",
397 ));
398 }
399 _ => continue,
400 }
401 }
402
403 ws.send(cmd::SbdCmd::limit_idle_millis(config.limit_idle_millis))
406 .await?;
407
408 if let Some(cslot) = weak_cslot.upgrade() {
409 cslot.mark_ready(uniq, index);
410 } else {
411 return Err(Error::other("closed"));
412 }
413
414 ws.send(cmd::SbdCmd::ready()).await?;
415
416 Ok(())
417 })
418 .await;
419
420 if auth_res.is_err() {
421 return;
422 }
423
424 while let Ok(Ok(payload)) =
427 tokio::time::timeout(config.idle_dur(), ws.recv()).await
428 {
429 if !ip_rate.is_ok(&ip, payload.len()).await {
430 break;
431 }
432
433 if let Some((token, token_tracker)) = &maybe_auth {
434 let _ = token_tracker.check_is_token_valid(config, token.clone());
438 }
439
440 let cmd = match cmd::SbdCmd::parse(payload) {
441 Err(_) => break,
442 Ok(cmd) => cmd,
443 };
444
445 match cmd {
446 cmd::SbdCmd::Keepalive => (),
448 cmd::SbdCmd::AuthRes(_) => break,
450 cmd::SbdCmd::Unknown => (),
452 cmd::SbdCmd::Message(mut payload) => {
454 let dest = {
455 let payload = payload.to_mut();
456
457 let mut dest = [0; 32];
458 dest.copy_from_slice(&payload[..32]);
459 let dest = PubKey(Arc::new(dest));
460
461 payload[..32].copy_from_slice(&pk.0[..]);
462
463 dest
464 };
465
466 if let Some(cslot) = weak_cslot.upgrade() {
467 let _ = cslot.send(&dest, payload).await;
468 } else {
469 break;
470 }
471 }
472 }
473 }
474
475 tracing::debug!("Closed connection for {ip}");
476}