sbd_server/
cslot.rs

1//! Attempt to pre-allocate as much as possible, including our tokio tasks.
2//! Ideally this would include a frame buffer that we could fill on ws
3//! recv and use ase a reference for ws send, but alas, fastwebsockets
4//! doesn't seem up to the task. tungstenite will willy-nilly allocate
5//! buffers for us, but at least we should only be dealing with one at a
6//! time per connection.
7
8use super::*;
9use std::collections::HashMap;
10use std::sync::{Arc, Mutex, Weak};
11
12static U: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(1);
13
14enum TaskMsg {
15    NewWs {
16        uniq: u64,
17        index: usize,
18        ws: Arc<dyn SbdWebsocket>,
19        ip: Arc<std::net::Ipv6Addr>,
20        pk: PubKey,
21        maybe_auth: Option<(Option<Arc<str>>, AuthTokenTracker)>,
22    },
23    Close,
24}
25
26struct SlotEntry {
27    send: tokio::sync::mpsc::UnboundedSender<TaskMsg>,
28}
29
30struct SlabEntry {
31    uniq: u64,
32    handshake_complete: bool,
33    weak_ws: Weak<dyn SbdWebsocket>,
34}
35
36struct CSlotInner {
37    max_count: usize,
38    slots: Vec<SlotEntry>,
39    slab: slab::Slab<SlabEntry>,
40    pk_to_index: HashMap<PubKey, usize>,
41    ip_to_index: HashMap<Arc<std::net::Ipv6Addr>, Vec<usize>>,
42    task_list: Vec<tokio::task::JoinHandle<()>>,
43}
44
45impl Drop for CSlotInner {
46    fn drop(&mut self) {
47        for task in self.task_list.iter() {
48            task.abort();
49        }
50    }
51}
52
53/// A weak reference to a connection slot container.
54#[derive(Clone)]
55pub struct WeakCSlot(Weak<Mutex<CSlotInner>>);
56
57impl WeakCSlot {
58    /// Upgrade this weak reference to a strong reference.
59    pub fn upgrade(&self) -> Option<CSlot> {
60        self.0.upgrade().map(CSlot)
61    }
62}
63
64/// A connection slot container.
65pub struct CSlot(Arc<Mutex<CSlotInner>>);
66
67impl CSlot {
68    /// Create a new connection slot container.
69    pub fn new(config: Arc<Config>, ip_rate: Arc<IpRate>) -> Self {
70        let count = config.limit_clients as usize;
71        Self(Arc::new_cyclic(|this| {
72            let mut slots = Vec::with_capacity(count);
73            let mut task_list = Vec::with_capacity(count);
74            for _ in 0..count {
75                let (send, recv) = tokio::sync::mpsc::unbounded_channel();
76                slots.push(SlotEntry { send });
77                task_list.push(tokio::task::spawn(top_task(
78                    config.clone(),
79                    ip_rate.clone(),
80                    WeakCSlot(this.clone()),
81                    recv,
82                )));
83            }
84            Mutex::new(CSlotInner {
85                max_count: count,
86                slots,
87                slab: slab::Slab::with_capacity(count),
88                pk_to_index: HashMap::with_capacity(count),
89                ip_to_index: HashMap::with_capacity(count),
90                task_list,
91            })
92        }))
93    }
94
95    /// Get a weak reference to this connection slot container.
96    pub fn weak(&self) -> WeakCSlot {
97        WeakCSlot(Arc::downgrade(&self.0))
98    }
99
100    fn remove(&self, uniq: u64, index: usize) {
101        let mut lock = self.0.lock().unwrap();
102
103        match lock.slab.get(index) {
104            None => return,
105            Some(s) => {
106                if s.uniq != uniq {
107                    return;
108                }
109            }
110        }
111
112        let _ = lock.slots.get(index).unwrap().send.send(TaskMsg::Close);
113        lock.slab.remove(index);
114        lock.pk_to_index.retain(|_, i| *i != index);
115        lock.ip_to_index.retain(|_, v| {
116            v.retain(|i| *i != index);
117            !v.is_empty()
118        });
119    }
120
121    // oi clippy, this is super straight forward...
122    #[allow(clippy::type_complexity)]
123    fn insert_and_get_rate_send_list(
124        &self,
125        ip: Arc<std::net::Ipv6Addr>,
126        pk: PubKey,
127        ws: Arc<dyn SbdWebsocket>,
128        maybe_auth: Option<(Option<Arc<str>>, AuthTokenTracker)>,
129    ) -> std::result::Result<
130        Vec<(u64, usize, Weak<dyn SbdWebsocket>)>,
131        Arc<dyn SbdWebsocket>,
132    > {
133        let mut lock = self.0.lock().unwrap();
134
135        if lock.slab.len() >= lock.max_count {
136            return Err(ws);
137        }
138
139        let weak_ws = Arc::downgrade(&ws);
140
141        let uniq = U.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
142
143        let index = lock.slab.insert(SlabEntry {
144            uniq,
145            weak_ws,
146            handshake_complete: false,
147        });
148
149        lock.pk_to_index.insert(pk.clone(), index);
150
151        let rate_send_list = {
152            let list = {
153                // WARN - allocation here!
154                // Also, do we want to limit the max connections from same ip?
155
156                let e = lock
157                    .ip_to_index
158                    .entry(ip.clone())
159                    .or_insert_with(|| Vec::with_capacity(1024));
160
161                e.push(index);
162
163                e.clone()
164            };
165
166            let mut rate_send_list = Vec::with_capacity(list.len());
167
168            for index in list.iter() {
169                if let Some(slab) = lock.slab.get(*index) {
170                    rate_send_list.push((
171                        slab.uniq,
172                        *index,
173                        slab.weak_ws.clone(),
174                    ));
175                }
176            }
177
178            rate_send_list
179        };
180
181        let send = lock.slots.get(index).unwrap().send.clone();
182        let _ = send.send(TaskMsg::NewWs {
183            uniq,
184            index,
185            ws,
186            ip,
187            pk,
188            maybe_auth,
189        });
190
191        Ok(rate_send_list)
192    }
193
194    /// Insert a connection to be managed by this container.
195    pub async fn insert(
196        &self,
197        config: &Config,
198        ip: Arc<std::net::Ipv6Addr>,
199        pk: PubKey,
200        ws: Arc<impl SbdWebsocket>,
201        maybe_auth: Option<(Option<Arc<str>>, AuthTokenTracker)>,
202    ) {
203        let rate_send_list =
204            self.insert_and_get_rate_send_list(ip, pk, ws, maybe_auth);
205
206        match rate_send_list {
207            Ok(rate_send_list) => {
208                let rate = if config.disable_rate_limiting {
209                    1
210                } else {
211                    let mut rate = config.limit_ip_byte_nanos() as u64
212                        * rate_send_list.len() as u64;
213                    if rate > i32::MAX as u64 {
214                        rate = i32::MAX as u64;
215                    }
216                    rate as i32
217                };
218
219                for (uniq, index, weak_ws) in rate_send_list {
220                    if let Some(ws) = weak_ws.upgrade() {
221                        if ws
222                            .send(cmd::SbdCmd::limit_byte_nanos(rate))
223                            .await
224                            .is_err()
225                        {
226                            self.remove(uniq, index);
227                        }
228                    }
229                }
230            }
231            Err(ws) => {
232                ws.close().await;
233                drop(ws);
234            }
235        }
236    }
237
238    fn mark_ready(&self, uniq: u64, index: usize) {
239        let mut lock = self.0.lock().unwrap();
240        if let Some(slab) = lock.slab.get_mut(index) {
241            if slab.uniq == uniq {
242                slab.handshake_complete = true;
243            }
244        }
245    }
246
247    fn get_sender(
248        &self,
249        pk: &PubKey,
250    ) -> Result<(u64, usize, Arc<dyn SbdWebsocket>)> {
251        let lock = self.0.lock().unwrap();
252
253        let index = match lock.pk_to_index.get(pk) {
254            None => return Err(Error::other("no such peer")),
255            Some(index) => *index,
256        };
257
258        let slab = lock.slab.get(index).unwrap();
259
260        if !slab.handshake_complete {
261            return Err(Error::other("no such peer"));
262        }
263
264        let uniq = slab.uniq;
265        let ws = match slab.weak_ws.upgrade() {
266            None => return Err(Error::other("no such peer")),
267            Some(ws) => ws,
268        };
269
270        Ok((uniq, index, ws))
271    }
272
273    async fn send(&self, pk: &PubKey, payload: Payload) -> Result<()> {
274        let (uniq, index, ws) = self.get_sender(pk)?;
275
276        match ws.send(payload).await {
277            Err(err) => {
278                self.remove(uniq, index);
279                Err(err)
280            }
281            Ok(_) => Ok(()),
282        }
283    }
284}
285
286async fn top_task(
287    config: Arc<Config>,
288    ip_rate: Arc<ip_rate::IpRate>,
289    weak: WeakCSlot,
290    mut recv: tokio::sync::mpsc::UnboundedReceiver<TaskMsg>,
291) {
292    let mut item = recv.recv().await;
293    loop {
294        let uitem = match item {
295            None => break,
296            Some(uitem) => uitem,
297        };
298
299        item = if let TaskMsg::NewWs {
300            uniq,
301            index,
302            ws,
303            ip,
304            pk,
305            maybe_auth,
306        } = uitem
307        {
308            let next_i = tokio::select! {
309                i = recv.recv() => Some(i),
310                _ = ws_task(
311                    &config,
312                    &ip_rate,
313                    &weak,
314                    &ws,
315                    ip,
316                    pk,
317                    uniq,
318                    index,
319                    maybe_auth,
320                ) => None,
321            };
322
323            ws.close().await;
324            drop(ws);
325            if let Some(cslot) = weak.upgrade() {
326                cslot.remove(uniq, index);
327            }
328
329            match next_i {
330                Some(i) => i,
331                None => recv.recv().await,
332            }
333        } else {
334            recv.recv().await
335        };
336    }
337}
338
339#[allow(clippy::too_many_arguments)]
340async fn ws_task(
341    config: &Arc<Config>,
342    ip_rate: &ip_rate::IpRate,
343    weak_cslot: &WeakCSlot,
344    ws: &Arc<dyn SbdWebsocket>,
345    ip: Arc<std::net::Ipv6Addr>,
346    pk: PubKey,
347    uniq: u64,
348    index: usize,
349    maybe_auth: Option<(Option<Arc<str>>, AuthTokenTracker)>,
350) {
351    let auth_res = tokio::time::timeout(config.idle_dur(), async {
352        use rand::Rng;
353        let mut nonce = [0xdb; 32];
354        rand::thread_rng().fill(&mut nonce[..]);
355
356        ws.send(cmd::SbdCmd::auth_req(&nonce)).await?;
357
358        loop {
359            let auth_res = ws.recv().await?;
360
361            if !ip_rate.is_ok(&ip, auth_res.as_ref().len()).await {
362                return Err(Error::other("ip rate limited"));
363            }
364
365            if let Some((token, token_tracker)) = &maybe_auth {
366                // we already know they had a valid token
367                // when they opened this connection.
368                // just using this for side-effect marking token use time
369                let _ =
370                    token_tracker.check_is_token_valid(config, token.clone());
371            }
372
373            match cmd::SbdCmd::parse(auth_res)? {
374                cmd::SbdCmd::AuthRes(sig) => {
375                    if !pk.verify(&sig, &nonce) {
376                        return Err(Error::other("invalid sig"));
377                    }
378                    break;
379                }
380                cmd::SbdCmd::Message(_) => {
381                    return Err(Error::other(
382                        "invalid forward before handshake",
383                    ));
384                }
385                _ => continue,
386            }
387        }
388
389        // NOTE: the byte_nanos limit is sent during the cslot insert
390
391        ws.send(cmd::SbdCmd::limit_idle_millis(config.limit_idle_millis))
392            .await?;
393
394        if let Some(cslot) = weak_cslot.upgrade() {
395            cslot.mark_ready(uniq, index);
396        } else {
397            return Err(Error::other("closed"));
398        }
399
400        ws.send(cmd::SbdCmd::ready()).await?;
401
402        Ok(())
403    })
404    .await;
405
406    if auth_res.is_err() {
407        return;
408    }
409
410    while let Ok(Ok(payload)) =
411        tokio::time::timeout(config.idle_dur(), ws.recv()).await
412    {
413        if !ip_rate.is_ok(&ip, payload.len()).await {
414            break;
415        }
416
417        if let Some((token, token_tracker)) = &maybe_auth {
418            // we already know they had a valid token
419            // when they opened this connection.
420            // just using this for side-effect marking token use time
421            let _ = token_tracker.check_is_token_valid(config, token.clone());
422        }
423
424        let cmd = match cmd::SbdCmd::parse(payload) {
425            Err(_) => break,
426            Ok(cmd) => cmd,
427        };
428
429        match cmd {
430            cmd::SbdCmd::Keepalive => (),
431            cmd::SbdCmd::AuthRes(_) => break,
432            cmd::SbdCmd::Unknown => (),
433            cmd::SbdCmd::Message(mut payload) => {
434                let dest = {
435                    let payload = payload.to_mut();
436
437                    let mut dest = [0; 32];
438                    dest.copy_from_slice(&payload[..32]);
439                    let dest = PubKey(Arc::new(dest));
440
441                    payload[..32].copy_from_slice(&pk.0[..]);
442
443                    dest
444                };
445
446                if let Some(cslot) = weak_cslot.upgrade() {
447                    let _ = cslot.send(&dest, payload).await;
448                } else {
449                    break;
450                }
451            }
452        }
453    }
454}