1use super::*;
9use std::collections::HashMap;
10use std::sync::{Arc, Mutex, Weak};
11
12static U: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(1);
13
14enum TaskMsg {
15 NewWs {
16 uniq: u64,
17 index: usize,
18 ws: Arc<dyn SbdWebsocket>,
19 ip: Arc<Ipv6Addr>,
20 pk: PubKey,
21 maybe_auth: Option<(Option<Arc<str>>, AuthTokenTracker)>,
22 },
23 Close,
24}
25
26struct SlotEntry {
27 send: tokio::sync::mpsc::UnboundedSender<TaskMsg>,
28}
29
30struct SlabEntry {
31 uniq: u64,
32 handshake_complete: bool,
33 weak_ws: Weak<dyn SbdWebsocket>,
34}
35
36struct CSlotInner {
37 max_count: usize,
38 slots: Vec<SlotEntry>,
39 slab: slab::Slab<SlabEntry>,
40 pk_to_index: HashMap<PubKey, usize>,
41 ip_to_index: HashMap<Arc<Ipv6Addr>, Vec<usize>>,
42 task_list: Vec<tokio::task::JoinHandle<()>>,
43 open_connections: opentelemetry::metrics::UpDownCounter<i64>,
44}
45
46impl Drop for CSlotInner {
47 fn drop(&mut self) {
48 for task in self.task_list.iter() {
49 task.abort();
50 }
51 }
52}
53
54#[derive(Clone)]
56pub struct WeakCSlot(Weak<Mutex<CSlotInner>>);
57
58impl WeakCSlot {
59 pub fn upgrade(&self) -> Option<CSlot> {
61 self.0.upgrade().map(CSlot)
62 }
63}
64
65pub struct CSlot(Arc<Mutex<CSlotInner>>);
70
71impl CSlot {
72 pub fn new(
74 config: Arc<Config>,
75 ip_rate: Arc<IpRate>,
76 meter: opentelemetry::metrics::Meter,
77 ) -> Self {
78 let count = config.limit_clients as usize;
79
80 let ip_rate_counter = meter
81 .u64_counter("sbd.server.ip_rate_limited")
82 .with_description("Total number of IP rate limited events")
83 .with_unit("count")
84 .build();
85
86 Self(Arc::new_cyclic(|this| {
87 let mut slots = Vec::with_capacity(count);
88 let mut task_list = Vec::with_capacity(count);
89 for _ in 0..count {
90 let (send, recv) = tokio::sync::mpsc::unbounded_channel();
91 slots.push(SlotEntry { send });
92 task_list.push(tokio::task::spawn(top_task(
93 config.clone(),
94 ip_rate.clone(),
95 WeakCSlot(this.clone()),
96 recv,
97 ip_rate_counter.clone(),
98 )));
99 }
100
101 let open_connections = meter
102 .i64_up_down_counter("sbd.server.open_connections")
103 .with_description("Number of open client connections")
104 .build();
105
106 Mutex::new(CSlotInner {
107 max_count: count,
108 slots,
109 slab: slab::Slab::with_capacity(count),
110 pk_to_index: HashMap::with_capacity(count),
111 ip_to_index: HashMap::with_capacity(count),
112 task_list,
113 open_connections,
114 })
115 }))
116 }
117
118 pub fn weak(&self) -> WeakCSlot {
120 WeakCSlot(Arc::downgrade(&self.0))
121 }
122
123 fn remove(&self, uniq: u64, index: usize) {
125 let mut lock = self.0.lock().unwrap();
126
127 match lock.slab.get(index) {
128 None => return,
129 Some(s) => {
130 if s.uniq != uniq {
131 return;
132 }
133 }
134 }
135
136 let _ = lock.slots.get(index).unwrap().send.send(TaskMsg::Close);
137 lock.slab.remove(index);
138 lock.pk_to_index.retain(|_, i| *i != index);
139 lock.ip_to_index.retain(|_, v| {
140 v.retain(|i| *i != index);
141 !v.is_empty()
142 });
143
144 lock.open_connections.add(-1, &[])
146 }
147
148 #[allow(clippy::type_complexity)]
151 fn insert_and_get_rate_send_list(
152 &self,
153 ip: Arc<Ipv6Addr>,
154 pk: PubKey,
155 ws: Arc<dyn SbdWebsocket>,
156 maybe_auth: Option<(Option<Arc<str>>, AuthTokenTracker)>,
157 ) -> std::result::Result<
158 Vec<(u64, usize, Weak<dyn SbdWebsocket>)>,
159 Arc<dyn SbdWebsocket>,
160 > {
161 let mut lock = self.0.lock().unwrap();
162
163 if lock.slab.len() >= lock.max_count {
164 return Err(ws);
165 }
166
167 let weak_ws = Arc::downgrade(&ws);
168
169 let uniq = U.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
170
171 let index = lock.slab.insert(SlabEntry {
172 uniq,
173 weak_ws,
174 handshake_complete: false,
175 });
176
177 lock.pk_to_index.insert(pk.clone(), index);
178
179 let rate_send_list = {
180 let list = {
181 let e = lock
185 .ip_to_index
186 .entry(ip.clone())
187 .or_insert_with(|| Vec::with_capacity(1024));
188
189 e.push(index);
190
191 e.clone()
192 };
193
194 let mut rate_send_list = Vec::with_capacity(list.len());
195
196 for index in list.iter() {
197 if let Some(slab) = lock.slab.get(*index) {
198 rate_send_list.push((
199 slab.uniq,
200 *index,
201 slab.weak_ws.clone(),
202 ));
203 }
204 }
205
206 rate_send_list
207 };
208
209 let send = lock.slots.get(index).unwrap().send.clone();
210 let _ = send.send(TaskMsg::NewWs {
211 uniq,
212 index,
213 ws,
214 ip,
215 pk,
216 maybe_auth,
217 });
218
219 lock.open_connections.add(1, &[]);
221
222 Ok(rate_send_list)
223 }
224
225 pub async fn insert(
227 &self,
228 config: &Config,
229 ip: Arc<Ipv6Addr>,
230 pk: PubKey,
231 ws: Arc<impl SbdWebsocket>,
232 maybe_auth: Option<(Option<Arc<str>>, AuthTokenTracker)>,
233 ) {
234 let rate_send_list =
235 self.insert_and_get_rate_send_list(ip, pk, ws, maybe_auth);
236
237 match rate_send_list {
238 Ok(rate_send_list) => {
239 let rate = if config.disable_rate_limiting {
240 1
241 } else {
242 let mut rate = config.limit_ip_byte_nanos() as u64
243 * rate_send_list.len() as u64;
244 if rate > i32::MAX as u64 {
245 rate = i32::MAX as u64;
246 }
247 rate as i32
248 };
249
250 for (uniq, index, weak_ws) in rate_send_list {
251 if let Some(ws) = weak_ws.upgrade() {
252 if ws
253 .send(cmd::SbdCmd::limit_byte_nanos(rate))
254 .await
255 .is_err()
256 {
257 self.remove(uniq, index);
258 }
259 }
260 }
261 }
262 Err(ws) => {
263 ws.close().await;
264 drop(ws);
265 }
266 }
267 }
268
269 fn mark_ready(&self, uniq: u64, index: usize) {
271 let mut lock = self.0.lock().unwrap();
272 if let Some(slab) = lock.slab.get_mut(index) {
273 if slab.uniq == uniq {
274 slab.handshake_complete = true;
275 }
276 }
277 }
278
279 fn get_sender(
281 &self,
282 pk: &PubKey,
283 ) -> Result<(u64, usize, Arc<dyn SbdWebsocket>)> {
284 let lock = self.0.lock().unwrap();
285
286 let index = match lock.pk_to_index.get(pk) {
287 None => return Err(Error::other("no such peer")),
288 Some(index) => *index,
289 };
290
291 let slab = lock.slab.get(index).unwrap();
292
293 if !slab.handshake_complete {
294 return Err(Error::other("no such peer"));
295 }
296
297 let uniq = slab.uniq;
298 let ws = match slab.weak_ws.upgrade() {
299 None => return Err(Error::other("no such peer")),
300 Some(ws) => ws,
301 };
302
303 Ok((uniq, index, ws))
304 }
305
306 async fn send(&self, pk: &PubKey, payload: Payload) -> Result<()> {
308 let (uniq, index, ws) = self.get_sender(pk)?;
309
310 match ws.send(payload).await {
311 Err(err) => {
312 self.remove(uniq, index);
313 Err(err)
314 }
315 Ok(_) => Ok(()),
316 }
317 }
318}
319
320async fn top_task(
323 config: Arc<Config>,
324 ip_rate: Arc<IpRate>,
325 weak: WeakCSlot,
326 mut recv: tokio::sync::mpsc::UnboundedReceiver<TaskMsg>,
327 ip_rate_counter: opentelemetry::metrics::Counter<u64>,
328) {
329 let mut item = recv.recv().await;
330 loop {
331 let uitem = match item {
332 None => break,
333 Some(uitem) => uitem,
334 };
335
336 item = if let TaskMsg::NewWs {
337 uniq,
338 index,
339 ws,
340 ip,
341 pk,
342 maybe_auth,
343 } = uitem
344 {
345 let next_i = tokio::select! {
347 i = recv.recv() => Some(i),
348 _ = ws_task(
349 &config,
350 &ip_rate,
351 &weak,
352 &ws,
353 ip,
354 pk,
355 uniq,
356 index,
357 maybe_auth,
358 &ip_rate_counter,
359 ) => None,
360 };
361
362 ws.close().await;
364 drop(ws);
365 if let Some(cslot) = weak.upgrade() {
366 cslot.remove(uniq, index);
367 }
368
369 match next_i {
370 Some(i) => i,
371 None => recv.recv().await,
372 }
373 } else {
374 recv.recv().await
375 };
376 }
377}
378
379#[allow(clippy::too_many_arguments)]
381async fn ws_task(
382 config: &Arc<Config>,
383 ip_rate: &IpRate,
384 weak_cslot: &WeakCSlot,
385 ws: &Arc<dyn SbdWebsocket>,
386 ip: Arc<Ipv6Addr>,
387 pk: PubKey,
388 uniq: u64,
389 index: usize,
390 maybe_auth: Option<(Option<Arc<str>>, AuthTokenTracker)>,
391 ip_rate_counter: &opentelemetry::metrics::Counter<u64>,
392) {
393 let pub_key =
394 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(*pk.0);
395 let auth_res = tokio::time::timeout(config.idle_dur(), async {
396 use rand::Rng;
397 let mut nonce = [0xdb; 32];
398 rand::thread_rng().fill(&mut nonce[..]);
399
400 ws.send(cmd::SbdCmd::auth_req(&nonce)).await?;
402
403 loop {
404 let auth_res = ws.recv().await?;
405
406 if !ip_rate.is_ok(&ip, auth_res.as_ref().len()).await {
407 ip_rate_counter.add(
408 1,
409 &[
410 opentelemetry::KeyValue::new(
411 "pub_key",
412 pub_key.clone(),
413 ),
414 opentelemetry::KeyValue::new("kind", "auth"),
415 ],
416 );
417
418 return Err(Error::other("ip rate limited"));
419 }
420
421 if let Some((token, token_tracker)) = &maybe_auth {
422 let _ =
426 token_tracker.check_is_token_valid(config, token.clone());
427 }
428
429 match cmd::SbdCmd::parse(auth_res)? {
430 cmd::SbdCmd::AuthRes(sig) => {
431 if !pk.verify(&sig, &nonce) {
432 return Err(Error::other("invalid sig"));
433 }
434 break;
435 }
436 cmd::SbdCmd::Message(_) => {
437 return Err(Error::other(
438 "invalid forward before handshake",
439 ));
440 }
441 _ => continue,
442 }
443 }
444
445 ws.send(cmd::SbdCmd::limit_idle_millis(config.limit_idle_millis))
448 .await?;
449
450 if let Some(cslot) = weak_cslot.upgrade() {
451 cslot.mark_ready(uniq, index);
452 } else {
453 return Err(Error::other("closed"));
454 }
455
456 ws.send(cmd::SbdCmd::ready()).await?;
457
458 Ok(())
459 })
460 .await;
461
462 if auth_res.is_err() {
463 return;
464 }
465
466 while let Ok(Ok(payload)) =
469 tokio::time::timeout(config.idle_dur(), ws.recv()).await
470 {
471 if !ip_rate.is_ok(&ip, payload.len()).await {
472 ip_rate_counter.add(
473 1,
474 &[
475 opentelemetry::KeyValue::new("pub_key", pub_key),
476 opentelemetry::KeyValue::new("kind", "msg"),
477 ],
478 );
479
480 break;
481 }
482
483 if let Some((token, token_tracker)) = &maybe_auth {
484 let _ = token_tracker.check_is_token_valid(config, token.clone());
488 }
489
490 let cmd = match cmd::SbdCmd::parse(payload) {
491 Err(_) => break,
492 Ok(cmd) => cmd,
493 };
494
495 match cmd {
496 cmd::SbdCmd::Keepalive => (),
498 cmd::SbdCmd::AuthRes(_) => break,
500 cmd::SbdCmd::Unknown => (),
502 cmd::SbdCmd::Message(mut payload) => {
504 let dest = {
505 let payload = payload.to_mut();
506
507 let mut dest = [0; 32];
508 dest.copy_from_slice(&payload[..32]);
509 let dest = PubKey(Arc::new(dest));
510
511 payload[..32].copy_from_slice(&pk.0[..]);
512
513 dest
514 };
515
516 if let Some(cslot) = weak_cslot.upgrade() {
517 let _ = cslot.send(&dest, payload).await;
518 } else {
519 break;
520 }
521 }
522 }
523 }
524
525 tracing::debug!("Closed connection for {ip}");
526}