simple_redis_wrapper/client/
redis_async_client.rs

1use crate::client::types::{EvictionPolicy, Key, Namespace, Prefix};
2use anyhow::anyhow;
3use futures::stream::StreamExt;
4use redis::aio::ConnectionManager;
5use redis::{AsyncCommands, AsyncIter, ScanOptions, cmd};
6use serde::Serialize;
7use serde::de::DeserializeOwned;
8use std::env;
9
10pub struct RedisAsyncClient {
11    pub url: String,
12    pub connection: ConnectionManager,
13    pub namespace: Namespace,
14}
15
16impl Clone for RedisAsyncClient {
17    fn clone(&self) -> Self {
18        Self {
19            url: self.url.clone(),
20            connection: self.connection.clone(),
21            namespace: self.namespace.clone(),
22        }
23    }
24}
25
26impl RedisAsyncClient {
27    pub async fn new(url: Option<String>, namespace: Namespace) -> anyhow::Result<Self> {
28        let url = url.unwrap_or(env::var("REDIS_URL")?);
29        let url = if url.ends_with("#insecure") {
30            url
31        } else {
32            format!("{}#insecure", url)
33        };
34        let client = redis::Client::open(url.clone())?;
35        let connection = ConnectionManager::new(client).await?;
36        // let connection = client.get_multiplexed_async_connection().await?;
37        Ok(Self {
38            url,
39            connection,
40            namespace,
41        })
42    }
43
44    pub async fn set_eviction_policy(
45        &self,
46        eviction_policy: EvictionPolicy,
47    ) -> anyhow::Result<String> {
48        let _: () = cmd("CONFIG")
49            .arg("SET")
50            .arg("maxmemory-policy")
51            .arg(eviction_policy.to_string())
52            .query_async(&mut self.connection())
53            .await?;
54        self.get_eviction_policy().await
55    }
56
57    pub async fn get_eviction_policy(&self) -> anyhow::Result<String> {
58        let current_policy: Vec<String> = cmd("CONFIG")
59            .arg("GET")
60            .arg("maxmemory-policy")
61            .query_async(&mut self.connection())
62            .await?;
63        Ok(current_policy.join(""))
64    }
65
66    pub fn key(&self, prefix: &Prefix, key: &Key) -> String {
67        format!("{}:{}:{}", self.namespace.0, prefix.0, key.0)
68    }
69
70    pub fn connection(&self) -> ConnectionManager {
71        self.connection.clone()
72    }
73
74    pub async fn get(&self, key: &str) -> anyhow::Result<Option<String>> {
75        let redis_str: Option<String> = AsyncCommands::get(&mut self.connection(), key).await?;
76        Ok(redis_str)
77    }
78
79    pub async fn set_ex(&self, key: &str, value: &str, expiry: Option<u64>) -> anyhow::Result<()> {
80        match expiry {
81            Some(expiry) => {
82                let _: () =
83                    AsyncCommands::set_ex(&mut self.connection(), key, value, expiry).await?;
84            }
85            None => {
86                let _: () = AsyncCommands::set(&mut self.connection(), key, value).await?;
87            }
88        }
89        Ok(())
90    }
91
92    pub async fn get_all(&self) -> anyhow::Result<Vec<(String, String)>> {
93        let mut output: Vec<(String, String)> = Vec::new();
94        let keys: Vec<String> = AsyncCommands::keys(&mut self.connection(), "*").await?;
95        for key in keys {
96            if let Some(value) = self.get(&key).await? {
97                output.push((key, value))
98            }
99        }
100        Ok(output)
101    }
102
103    pub async fn remove(&self, key: &str) -> anyhow::Result<()> {
104        let _: () = AsyncCommands::del(&mut self.connection(), key).await?;
105        Ok(())
106    }
107
108    pub async fn get_entity<T>(&self, prefix: &Prefix, key: &Key) -> anyhow::Result<T>
109    where
110        T: DeserializeOwned + Serialize,
111    {
112        let redis_str: Option<String> =
113            AsyncCommands::get(&mut self.connection(), self.key(prefix, key)).await?;
114        match redis_str {
115            Some(string) => {
116                let redis_entity: T = serde_json::from_str(&string)
117                    .map_err(|e| anyhow!("get_entity serde_json error: {}", e))?;
118                Ok(redis_entity)
119            }
120            None => Err(anyhow!("Didn't find entity")),
121        }
122    }
123
124    pub async fn save_entity<T>(
125        &self,
126        prefix: &Prefix,
127        key: &Key,
128        value: &T,
129        expiry: Option<u64>,
130    ) -> anyhow::Result<()>
131    where
132        T: DeserializeOwned + Serialize,
133    {
134        let value_str = serde_json::to_string(&value)
135            .map_err(|e| anyhow!("save_entity serde_json error: {}", e))?;
136        match expiry {
137            Some(expiry) => {
138                let _: () = AsyncCommands::set_ex(
139                    &mut self.connection(),
140                    self.key(prefix, key),
141                    value_str,
142                    expiry,
143                )
144                .await?;
145            }
146            None => {
147                let _: () =
148                    AsyncCommands::set(&mut self.connection(), self.key(prefix, key), value_str)
149                        .await?;
150            }
151        }
152        Ok(())
153    }
154
155    pub async fn remove_entity<T>(&self, prefix: &Prefix, key: &Key) -> anyhow::Result<()> {
156        let _: () = AsyncCommands::del(&mut self.connection(), self.key(prefix, key)).await?;
157        Ok(())
158    }
159
160    pub async fn scan<T>(&self, pattern: &str, chunk_size: usize) -> anyhow::Result<Vec<T>>
161    where
162        T: DeserializeOwned + Serialize,
163    {
164        let opts = ScanOptions::default().with_pattern(pattern);
165        let mut con = self.connection();
166        let iter: AsyncIter<Option<String>> = AsyncCommands::scan_options(&mut con, opts).await?;
167        let keys: Vec<Option<String>> = iter.map(Result::unwrap_or_default).collect().await;
168        let keys: Vec<String> = keys.into_iter().filter_map(|i| i).collect();
169        let mut output: Vec<T> = Vec::with_capacity(keys.len());
170        for chunk in keys.chunks(chunk_size) {
171            let values: Vec<Option<String>> = AsyncCommands::mget(&mut con, chunk).await?;
172            let values: Vec<String> = values.into_iter().filter_map(|i| i).collect();
173            for value in values {
174                match serde_json::from_str::<T>(&value) {
175                    Ok(v) => output.push(v),
176                    Err(_) => {}
177                }
178            }
179        }
180        Ok(output)
181    }
182}