sbd_server/
cslot.rs

1//! Attempt to pre-allocate as much as possible, including our tokio tasks.
2//! Ideally this would include a frame buffer that we could fill on ws
3//! recv and use ase a reference for ws send, but alas, fastwebsockets
4//! doesn't seem up to the task. tungstenite will willy-nilly allocate
5//! buffers for us, but at least we should only be dealing with one at a
6//! time per connection.
7
8use super::*;
9use std::collections::HashMap;
10use std::sync::{Arc, Mutex, Weak};
11
12static U: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(1);
13
14enum TaskMsg {
15    NewWs {
16        uniq: u64,
17        index: usize,
18        ws: Arc<dyn SbdWebsocket>,
19        ip: Arc<Ipv6Addr>,
20        pk: PubKey,
21        maybe_auth: Option<(Option<Arc<str>>, AuthTokenTracker)>,
22    },
23    Close,
24}
25
26struct SlotEntry {
27    send: tokio::sync::mpsc::UnboundedSender<TaskMsg>,
28}
29
30struct SlabEntry {
31    uniq: u64,
32    handshake_complete: bool,
33    weak_ws: Weak<dyn SbdWebsocket>,
34}
35
36struct CSlotInner {
37    max_count: usize,
38    slots: Vec<SlotEntry>,
39    slab: slab::Slab<SlabEntry>,
40    pk_to_index: HashMap<PubKey, usize>,
41    ip_to_index: HashMap<Arc<Ipv6Addr>, Vec<usize>>,
42    task_list: Vec<tokio::task::JoinHandle<()>>,
43}
44
45impl Drop for CSlotInner {
46    fn drop(&mut self) {
47        for task in self.task_list.iter() {
48            task.abort();
49        }
50    }
51}
52
53/// A weak reference to a connection slot container.
54#[derive(Clone)]
55pub struct WeakCSlot(Weak<Mutex<CSlotInner>>);
56
57impl WeakCSlot {
58    /// Upgrade this weak reference to a strong reference.
59    pub fn upgrade(&self) -> Option<CSlot> {
60        self.0.upgrade().map(CSlot)
61    }
62}
63
64/// A connection slot container.
65///
66/// Note this is not clone to ensure that when the single top-level handle
67/// is dropped, that everything is shutdown properly.
68pub struct CSlot(Arc<Mutex<CSlotInner>>);
69
70impl CSlot {
71    /// Create a new connection slot container.
72    pub fn new(config: Arc<Config>, ip_rate: Arc<IpRate>) -> Self {
73        let count = config.limit_clients as usize;
74        Self(Arc::new_cyclic(|this| {
75            let mut slots = Vec::with_capacity(count);
76            let mut task_list = Vec::with_capacity(count);
77            for _ in 0..count {
78                let (send, recv) = tokio::sync::mpsc::unbounded_channel();
79                slots.push(SlotEntry { send });
80                task_list.push(tokio::task::spawn(top_task(
81                    config.clone(),
82                    ip_rate.clone(),
83                    WeakCSlot(this.clone()),
84                    recv,
85                )));
86            }
87            Mutex::new(CSlotInner {
88                max_count: count,
89                slots,
90                slab: slab::Slab::with_capacity(count),
91                pk_to_index: HashMap::with_capacity(count),
92                ip_to_index: HashMap::with_capacity(count),
93                task_list,
94            })
95        }))
96    }
97
98    /// Get a weak reference to this connection slot container.
99    pub fn weak(&self) -> WeakCSlot {
100        WeakCSlot(Arc::downgrade(&self.0))
101    }
102
103    /// Remove a websocket from its slot.
104    fn remove(&self, uniq: u64, index: usize) {
105        let mut lock = self.0.lock().unwrap();
106
107        match lock.slab.get(index) {
108            None => return,
109            Some(s) => {
110                if s.uniq != uniq {
111                    return;
112                }
113            }
114        }
115
116        let _ = lock.slots.get(index).unwrap().send.send(TaskMsg::Close);
117        lock.slab.remove(index);
118        lock.pk_to_index.retain(|_, i| *i != index);
119        lock.ip_to_index.retain(|_, v| {
120            v.retain(|i| *i != index);
121            !v.is_empty()
122        });
123    }
124
125    /// Inner helper for inserting a websocket into an available slot.
126    // oi clippy, this is super straight forward...
127    #[allow(clippy::type_complexity)]
128    fn insert_and_get_rate_send_list(
129        &self,
130        ip: Arc<Ipv6Addr>,
131        pk: PubKey,
132        ws: Arc<dyn SbdWebsocket>,
133        maybe_auth: Option<(Option<Arc<str>>, AuthTokenTracker)>,
134    ) -> std::result::Result<
135        Vec<(u64, usize, Weak<dyn SbdWebsocket>)>,
136        Arc<dyn SbdWebsocket>,
137    > {
138        let mut lock = self.0.lock().unwrap();
139
140        if lock.slab.len() >= lock.max_count {
141            return Err(ws);
142        }
143
144        let weak_ws = Arc::downgrade(&ws);
145
146        let uniq = U.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
147
148        let index = lock.slab.insert(SlabEntry {
149            uniq,
150            weak_ws,
151            handshake_complete: false,
152        });
153
154        lock.pk_to_index.insert(pk.clone(), index);
155
156        let rate_send_list = {
157            let list = {
158                // WARN - allocation here!
159                // Also, do we want to limit the max connections from same ip?
160
161                let e = lock
162                    .ip_to_index
163                    .entry(ip.clone())
164                    .or_insert_with(|| Vec::with_capacity(1024));
165
166                e.push(index);
167
168                e.clone()
169            };
170
171            let mut rate_send_list = Vec::with_capacity(list.len());
172
173            for index in list.iter() {
174                if let Some(slab) = lock.slab.get(*index) {
175                    rate_send_list.push((
176                        slab.uniq,
177                        *index,
178                        slab.weak_ws.clone(),
179                    ));
180                }
181            }
182
183            rate_send_list
184        };
185
186        let send = lock.slots.get(index).unwrap().send.clone();
187        let _ = send.send(TaskMsg::NewWs {
188            uniq,
189            index,
190            ws,
191            ip,
192            pk,
193            maybe_auth,
194        });
195
196        Ok(rate_send_list)
197    }
198
199    /// Insert a connection to be managed by this container.
200    pub async fn insert(
201        &self,
202        config: &Config,
203        ip: Arc<Ipv6Addr>,
204        pk: PubKey,
205        ws: Arc<impl SbdWebsocket>,
206        maybe_auth: Option<(Option<Arc<str>>, AuthTokenTracker)>,
207    ) {
208        let rate_send_list =
209            self.insert_and_get_rate_send_list(ip, pk, ws, maybe_auth);
210
211        match rate_send_list {
212            Ok(rate_send_list) => {
213                let rate = if config.disable_rate_limiting {
214                    1
215                } else {
216                    let mut rate = config.limit_ip_byte_nanos() as u64
217                        * rate_send_list.len() as u64;
218                    if rate > i32::MAX as u64 {
219                        rate = i32::MAX as u64;
220                    }
221                    rate as i32
222                };
223
224                for (uniq, index, weak_ws) in rate_send_list {
225                    if let Some(ws) = weak_ws.upgrade() {
226                        if ws
227                            .send(cmd::SbdCmd::limit_byte_nanos(rate))
228                            .await
229                            .is_err()
230                        {
231                            self.remove(uniq, index);
232                        }
233                    }
234                }
235            }
236            Err(ws) => {
237                ws.close().await;
238                drop(ws);
239            }
240        }
241    }
242
243    /// Mark a slotted websocket as ready.
244    fn mark_ready(&self, uniq: u64, index: usize) {
245        let mut lock = self.0.lock().unwrap();
246        if let Some(slab) = lock.slab.get_mut(index) {
247            if slab.uniq == uniq {
248                slab.handshake_complete = true;
249            }
250        }
251    }
252
253    /// Get a websocket from its slot.
254    fn get_sender(
255        &self,
256        pk: &PubKey,
257    ) -> Result<(u64, usize, Arc<dyn SbdWebsocket>)> {
258        let lock = self.0.lock().unwrap();
259
260        let index = match lock.pk_to_index.get(pk) {
261            None => return Err(Error::other("no such peer")),
262            Some(index) => *index,
263        };
264
265        let slab = lock.slab.get(index).unwrap();
266
267        if !slab.handshake_complete {
268            return Err(Error::other("no such peer"));
269        }
270
271        let uniq = slab.uniq;
272        let ws = match slab.weak_ws.upgrade() {
273            None => return Err(Error::other("no such peer")),
274            Some(ws) => ws,
275        };
276
277        Ok((uniq, index, ws))
278    }
279
280    /// Send via a slotted websocket.
281    async fn send(&self, pk: &PubKey, payload: Payload) -> Result<()> {
282        let (uniq, index, ws) = self.get_sender(pk)?;
283
284        match ws.send(payload).await {
285            Err(err) => {
286                self.remove(uniq, index);
287                Err(err)
288            }
289            Ok(_) => Ok(()),
290        }
291    }
292}
293
294/// This top-task waits for incoming websockets, processes them until
295/// completion, and then waits for a new incoming websocket.
296async fn top_task(
297    config: Arc<Config>,
298    ip_rate: Arc<IpRate>,
299    weak: WeakCSlot,
300    mut recv: tokio::sync::mpsc::UnboundedReceiver<TaskMsg>,
301) {
302    let mut item = recv.recv().await;
303    loop {
304        let uitem = match item {
305            None => break,
306            Some(uitem) => uitem,
307        };
308
309        item = if let TaskMsg::NewWs {
310            uniq,
311            index,
312            ws,
313            ip,
314            pk,
315            maybe_auth,
316        } = uitem
317        {
318            // we have a websocket! process to completion
319            let next_i = tokio::select! {
320                i = recv.recv() => Some(i),
321                _ = ws_task(
322                    &config,
323                    &ip_rate,
324                    &weak,
325                    &ws,
326                    ip,
327                    pk,
328                    uniq,
329                    index,
330                    maybe_auth,
331                ) => None,
332            };
333
334            // our websocket task ended, clean up
335            ws.close().await;
336            drop(ws);
337            if let Some(cslot) = weak.upgrade() {
338                cslot.remove(uniq, index);
339            }
340
341            match next_i {
342                Some(i) => i,
343                None => recv.recv().await,
344            }
345        } else {
346            recv.recv().await
347        };
348    }
349}
350
351/// Process a single websocket until completion.
352#[allow(clippy::too_many_arguments)]
353async fn ws_task(
354    config: &Arc<Config>,
355    ip_rate: &IpRate,
356    weak_cslot: &WeakCSlot,
357    ws: &Arc<dyn SbdWebsocket>,
358    ip: Arc<Ipv6Addr>,
359    pk: PubKey,
360    uniq: u64,
361    index: usize,
362    maybe_auth: Option<(Option<Arc<str>>, AuthTokenTracker)>,
363) {
364    let auth_res = tokio::time::timeout(config.idle_dur(), async {
365        use rand::Rng;
366        let mut nonce = [0xdb; 32];
367        rand::thread_rng().fill(&mut nonce[..]);
368
369        // send them a nonce to prove they can sign with private key
370        ws.send(cmd::SbdCmd::auth_req(&nonce)).await?;
371
372        loop {
373            let auth_res = ws.recv().await?;
374
375            if !ip_rate.is_ok(&ip, auth_res.as_ref().len()).await {
376                return Err(Error::other("ip rate limited"));
377            }
378
379            if let Some((token, token_tracker)) = &maybe_auth {
380                // we already know they had a valid token
381                // when they opened this connection.
382                // just using this for side-effect marking token use time
383                let _ =
384                    token_tracker.check_is_token_valid(config, token.clone());
385            }
386
387            match cmd::SbdCmd::parse(auth_res)? {
388                cmd::SbdCmd::AuthRes(sig) => {
389                    if !pk.verify(&sig, &nonce) {
390                        return Err(Error::other("invalid sig"));
391                    }
392                    break;
393                }
394                cmd::SbdCmd::Message(_) => {
395                    return Err(Error::other(
396                        "invalid forward before handshake",
397                    ));
398                }
399                _ => continue,
400            }
401        }
402
403        // NOTE: the byte_nanos limit is sent during the cslot insert
404
405        ws.send(cmd::SbdCmd::limit_idle_millis(config.limit_idle_millis))
406            .await?;
407
408        if let Some(cslot) = weak_cslot.upgrade() {
409            cslot.mark_ready(uniq, index);
410        } else {
411            return Err(Error::other("closed"));
412        }
413
414        ws.send(cmd::SbdCmd::ready()).await?;
415
416        Ok(())
417    })
418    .await;
419
420    if auth_res.is_err() {
421        return;
422    }
423
424    // auth/init complete, now loop over incoming data
425
426    while let Ok(Ok(payload)) =
427        tokio::time::timeout(config.idle_dur(), ws.recv()).await
428    {
429        if !ip_rate.is_ok(&ip, payload.len()).await {
430            break;
431        }
432
433        if let Some((token, token_tracker)) = &maybe_auth {
434            // we already know they had a valid token
435            // when they opened this connection.
436            // just using this for side-effect marking token use time
437            let _ = token_tracker.check_is_token_valid(config, token.clone());
438        }
439
440        let cmd = match cmd::SbdCmd::parse(payload) {
441            Err(_) => break,
442            Ok(cmd) => cmd,
443        };
444
445        match cmd {
446            // don't need to do anything... we just get a new timeout above
447            cmd::SbdCmd::Keepalive => (),
448            // auth responses are invalid at this stage
449            cmd::SbdCmd::AuthRes(_) => break,
450            // ignore unknown messages
451            cmd::SbdCmd::Unknown => (),
452            // forward an actual message to a peer
453            cmd::SbdCmd::Message(mut payload) => {
454                let dest = {
455                    let payload = payload.to_mut();
456
457                    let mut dest = [0; 32];
458                    dest.copy_from_slice(&payload[..32]);
459                    let dest = PubKey(Arc::new(dest));
460
461                    payload[..32].copy_from_slice(&pk.0[..]);
462
463                    dest
464                };
465
466                if let Some(cslot) = weak_cslot.upgrade() {
467                    let _ = cslot.send(&dest, payload).await;
468                } else {
469                    break;
470                }
471            }
472        }
473    }
474
475    tracing::debug!("Closed connection for {ip}");
476}