Skip to main content

thruster_rate_limit/stores/
map.rs

1use async_trait::async_trait;
2use std::{
3    collections::HashMap,
4    sync::Arc,
5    time::{SystemTime, SystemTimeError, UNIX_EPOCH},
6};
7use tokio::sync::Mutex;
8
9use crate::Store;
10
11#[derive(Clone, Debug)]
12struct MapValue {
13    value: usize,
14    expiry_s: usize,
15    unix: u64,
16}
17
18#[derive(Clone)]
19pub struct MapStore {
20    hash_map: Arc<Mutex<HashMap<String, MapValue>>>,
21}
22
23impl MapStore {
24    pub fn new() -> Self {
25        return Self {
26            hash_map: Arc::new(Mutex::new(HashMap::new())),
27        };
28    }
29}
30
31impl Default for MapStore {
32    fn default() -> Self {
33        return Self::new();
34    }
35}
36
37#[async_trait]
38impl Store for MapStore {
39    type Error = SystemTimeError;
40
41    async fn get(&mut self, key: &str) -> Result<Option<usize>, Self::Error> {
42        let mut hash_map = self.hash_map.lock().await;
43
44        let MapValue {
45            value,
46            expiry_s,
47            unix,
48        } = match hash_map.get(key).cloned() {
49            Some(x) => x,
50            None => return Ok(None),
51        };
52
53        let remove_at = unix + expiry_s as u64;
54        let now_unix = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
55
56        if remove_at <= now_unix {
57            hash_map.remove(key);
58            return Ok(None);
59        }
60
61        return Ok(Some(value));
62    }
63
64    async fn set(&mut self, key: &str, value: usize, expiry_s: usize) -> Result<(), Self::Error> {
65        let mut hash_map = self.hash_map.lock().await;
66
67        if let Some(already) = hash_map.get(key) {
68            let already = already.clone();
69            hash_map.insert(key.to_string(), MapValue { value, ..already });
70
71            return Ok(());
72        }
73
74        hash_map.insert(
75            key.to_string(),
76            MapValue {
77                expiry_s,
78                value,
79                unix: SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs(),
80            },
81        );
82
83        return Ok(());
84    }
85}