1use super::*;
9use std::collections::HashMap;
10use std::sync::{Arc, Mutex, Weak};
11
12static U: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(1);
13
14enum TaskMsg {
15 NewWs {
16 uniq: u64,
17 index: usize,
18 ws: Arc<dyn SbdWebsocket>,
19 ip: Arc<std::net::Ipv6Addr>,
20 pk: PubKey,
21 maybe_auth: Option<(Option<Arc<str>>, AuthTokenTracker)>,
22 },
23 Close,
24}
25
26struct SlotEntry {
27 send: tokio::sync::mpsc::UnboundedSender<TaskMsg>,
28}
29
30struct SlabEntry {
31 uniq: u64,
32 handshake_complete: bool,
33 weak_ws: Weak<dyn SbdWebsocket>,
34}
35
36struct CSlotInner {
37 max_count: usize,
38 slots: Vec<SlotEntry>,
39 slab: slab::Slab<SlabEntry>,
40 pk_to_index: HashMap<PubKey, usize>,
41 ip_to_index: HashMap<Arc<std::net::Ipv6Addr>, Vec<usize>>,
42 task_list: Vec<tokio::task::JoinHandle<()>>,
43}
44
45impl Drop for CSlotInner {
46 fn drop(&mut self) {
47 for task in self.task_list.iter() {
48 task.abort();
49 }
50 }
51}
52
53#[derive(Clone)]
55pub struct WeakCSlot(Weak<Mutex<CSlotInner>>);
56
57impl WeakCSlot {
58 pub fn upgrade(&self) -> Option<CSlot> {
60 self.0.upgrade().map(CSlot)
61 }
62}
63
64pub struct CSlot(Arc<Mutex<CSlotInner>>);
66
67impl CSlot {
68 pub fn new(config: Arc<Config>, ip_rate: Arc<IpRate>) -> Self {
70 let count = config.limit_clients as usize;
71 Self(Arc::new_cyclic(|this| {
72 let mut slots = Vec::with_capacity(count);
73 let mut task_list = Vec::with_capacity(count);
74 for _ in 0..count {
75 let (send, recv) = tokio::sync::mpsc::unbounded_channel();
76 slots.push(SlotEntry { send });
77 task_list.push(tokio::task::spawn(top_task(
78 config.clone(),
79 ip_rate.clone(),
80 WeakCSlot(this.clone()),
81 recv,
82 )));
83 }
84 Mutex::new(CSlotInner {
85 max_count: count,
86 slots,
87 slab: slab::Slab::with_capacity(count),
88 pk_to_index: HashMap::with_capacity(count),
89 ip_to_index: HashMap::with_capacity(count),
90 task_list,
91 })
92 }))
93 }
94
95 pub fn weak(&self) -> WeakCSlot {
97 WeakCSlot(Arc::downgrade(&self.0))
98 }
99
100 fn remove(&self, uniq: u64, index: usize) {
101 let mut lock = self.0.lock().unwrap();
102
103 match lock.slab.get(index) {
104 None => return,
105 Some(s) => {
106 if s.uniq != uniq {
107 return;
108 }
109 }
110 }
111
112 let _ = lock.slots.get(index).unwrap().send.send(TaskMsg::Close);
113 lock.slab.remove(index);
114 lock.pk_to_index.retain(|_, i| *i != index);
115 lock.ip_to_index.retain(|_, v| {
116 v.retain(|i| *i != index);
117 !v.is_empty()
118 });
119 }
120
121 #[allow(clippy::type_complexity)]
123 fn insert_and_get_rate_send_list(
124 &self,
125 ip: Arc<std::net::Ipv6Addr>,
126 pk: PubKey,
127 ws: Arc<dyn SbdWebsocket>,
128 maybe_auth: Option<(Option<Arc<str>>, AuthTokenTracker)>,
129 ) -> std::result::Result<
130 Vec<(u64, usize, Weak<dyn SbdWebsocket>)>,
131 Arc<dyn SbdWebsocket>,
132 > {
133 let mut lock = self.0.lock().unwrap();
134
135 if lock.slab.len() >= lock.max_count {
136 return Err(ws);
137 }
138
139 let weak_ws = Arc::downgrade(&ws);
140
141 let uniq = U.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
142
143 let index = lock.slab.insert(SlabEntry {
144 uniq,
145 weak_ws,
146 handshake_complete: false,
147 });
148
149 lock.pk_to_index.insert(pk.clone(), index);
150
151 let rate_send_list = {
152 let list = {
153 let e = lock
157 .ip_to_index
158 .entry(ip.clone())
159 .or_insert_with(|| Vec::with_capacity(1024));
160
161 e.push(index);
162
163 e.clone()
164 };
165
166 let mut rate_send_list = Vec::with_capacity(list.len());
167
168 for index in list.iter() {
169 if let Some(slab) = lock.slab.get(*index) {
170 rate_send_list.push((
171 slab.uniq,
172 *index,
173 slab.weak_ws.clone(),
174 ));
175 }
176 }
177
178 rate_send_list
179 };
180
181 let send = lock.slots.get(index).unwrap().send.clone();
182 let _ = send.send(TaskMsg::NewWs {
183 uniq,
184 index,
185 ws,
186 ip,
187 pk,
188 maybe_auth,
189 });
190
191 Ok(rate_send_list)
192 }
193
194 pub async fn insert(
196 &self,
197 config: &Config,
198 ip: Arc<std::net::Ipv6Addr>,
199 pk: PubKey,
200 ws: Arc<impl SbdWebsocket>,
201 maybe_auth: Option<(Option<Arc<str>>, AuthTokenTracker)>,
202 ) {
203 let rate_send_list =
204 self.insert_and_get_rate_send_list(ip, pk, ws, maybe_auth);
205
206 match rate_send_list {
207 Ok(rate_send_list) => {
208 let rate = if config.disable_rate_limiting {
209 1
210 } else {
211 let mut rate = config.limit_ip_byte_nanos() as u64
212 * rate_send_list.len() as u64;
213 if rate > i32::MAX as u64 {
214 rate = i32::MAX as u64;
215 }
216 rate as i32
217 };
218
219 for (uniq, index, weak_ws) in rate_send_list {
220 if let Some(ws) = weak_ws.upgrade() {
221 if ws
222 .send(cmd::SbdCmd::limit_byte_nanos(rate))
223 .await
224 .is_err()
225 {
226 self.remove(uniq, index);
227 }
228 }
229 }
230 }
231 Err(ws) => {
232 ws.close().await;
233 drop(ws);
234 }
235 }
236 }
237
238 fn mark_ready(&self, uniq: u64, index: usize) {
239 let mut lock = self.0.lock().unwrap();
240 if let Some(slab) = lock.slab.get_mut(index) {
241 if slab.uniq == uniq {
242 slab.handshake_complete = true;
243 }
244 }
245 }
246
247 fn get_sender(
248 &self,
249 pk: &PubKey,
250 ) -> Result<(u64, usize, Arc<dyn SbdWebsocket>)> {
251 let lock = self.0.lock().unwrap();
252
253 let index = match lock.pk_to_index.get(pk) {
254 None => return Err(Error::other("no such peer")),
255 Some(index) => *index,
256 };
257
258 let slab = lock.slab.get(index).unwrap();
259
260 if !slab.handshake_complete {
261 return Err(Error::other("no such peer"));
262 }
263
264 let uniq = slab.uniq;
265 let ws = match slab.weak_ws.upgrade() {
266 None => return Err(Error::other("no such peer")),
267 Some(ws) => ws,
268 };
269
270 Ok((uniq, index, ws))
271 }
272
273 async fn send(&self, pk: &PubKey, payload: Payload) -> Result<()> {
274 let (uniq, index, ws) = self.get_sender(pk)?;
275
276 match ws.send(payload).await {
277 Err(err) => {
278 self.remove(uniq, index);
279 Err(err)
280 }
281 Ok(_) => Ok(()),
282 }
283 }
284}
285
286async fn top_task(
287 config: Arc<Config>,
288 ip_rate: Arc<ip_rate::IpRate>,
289 weak: WeakCSlot,
290 mut recv: tokio::sync::mpsc::UnboundedReceiver<TaskMsg>,
291) {
292 let mut item = recv.recv().await;
293 loop {
294 let uitem = match item {
295 None => break,
296 Some(uitem) => uitem,
297 };
298
299 item = if let TaskMsg::NewWs {
300 uniq,
301 index,
302 ws,
303 ip,
304 pk,
305 maybe_auth,
306 } = uitem
307 {
308 let next_i = tokio::select! {
309 i = recv.recv() => Some(i),
310 _ = ws_task(
311 &config,
312 &ip_rate,
313 &weak,
314 &ws,
315 ip,
316 pk,
317 uniq,
318 index,
319 maybe_auth,
320 ) => None,
321 };
322
323 ws.close().await;
324 drop(ws);
325 if let Some(cslot) = weak.upgrade() {
326 cslot.remove(uniq, index);
327 }
328
329 match next_i {
330 Some(i) => i,
331 None => recv.recv().await,
332 }
333 } else {
334 recv.recv().await
335 };
336 }
337}
338
339#[allow(clippy::too_many_arguments)]
340async fn ws_task(
341 config: &Arc<Config>,
342 ip_rate: &ip_rate::IpRate,
343 weak_cslot: &WeakCSlot,
344 ws: &Arc<dyn SbdWebsocket>,
345 ip: Arc<std::net::Ipv6Addr>,
346 pk: PubKey,
347 uniq: u64,
348 index: usize,
349 maybe_auth: Option<(Option<Arc<str>>, AuthTokenTracker)>,
350) {
351 let auth_res = tokio::time::timeout(config.idle_dur(), async {
352 use rand::Rng;
353 let mut nonce = [0xdb; 32];
354 rand::thread_rng().fill(&mut nonce[..]);
355
356 ws.send(cmd::SbdCmd::auth_req(&nonce)).await?;
357
358 loop {
359 let auth_res = ws.recv().await?;
360
361 if !ip_rate.is_ok(&ip, auth_res.as_ref().len()).await {
362 return Err(Error::other("ip rate limited"));
363 }
364
365 if let Some((token, token_tracker)) = &maybe_auth {
366 let _ =
370 token_tracker.check_is_token_valid(config, token.clone());
371 }
372
373 match cmd::SbdCmd::parse(auth_res)? {
374 cmd::SbdCmd::AuthRes(sig) => {
375 if !pk.verify(&sig, &nonce) {
376 return Err(Error::other("invalid sig"));
377 }
378 break;
379 }
380 cmd::SbdCmd::Message(_) => {
381 return Err(Error::other(
382 "invalid forward before handshake",
383 ));
384 }
385 _ => continue,
386 }
387 }
388
389 ws.send(cmd::SbdCmd::limit_idle_millis(config.limit_idle_millis))
392 .await?;
393
394 if let Some(cslot) = weak_cslot.upgrade() {
395 cslot.mark_ready(uniq, index);
396 } else {
397 return Err(Error::other("closed"));
398 }
399
400 ws.send(cmd::SbdCmd::ready()).await?;
401
402 Ok(())
403 })
404 .await;
405
406 if auth_res.is_err() {
407 return;
408 }
409
410 while let Ok(Ok(payload)) =
411 tokio::time::timeout(config.idle_dur(), ws.recv()).await
412 {
413 if !ip_rate.is_ok(&ip, payload.len()).await {
414 break;
415 }
416
417 if let Some((token, token_tracker)) = &maybe_auth {
418 let _ = token_tracker.check_is_token_valid(config, token.clone());
422 }
423
424 let cmd = match cmd::SbdCmd::parse(payload) {
425 Err(_) => break,
426 Ok(cmd) => cmd,
427 };
428
429 match cmd {
430 cmd::SbdCmd::Keepalive => (),
431 cmd::SbdCmd::AuthRes(_) => break,
432 cmd::SbdCmd::Unknown => (),
433 cmd::SbdCmd::Message(mut payload) => {
434 let dest = {
435 let payload = payload.to_mut();
436
437 let mut dest = [0; 32];
438 dest.copy_from_slice(&payload[..32]);
439 let dest = PubKey(Arc::new(dest));
440
441 payload[..32].copy_from_slice(&pk.0[..]);
442
443 dest
444 };
445
446 if let Some(cslot) = weak_cslot.upgrade() {
447 let _ = cslot.send(&dest, payload).await;
448 } else {
449 break;
450 }
451 }
452 }
453 }
454}