1use lazy_static::lazy_static;
2use redis::{Client, Cmd, Connection, RedisError, RedisResult};
3use std::env;
4use std::path::Path;
5use std::sync::atomic::{AtomicBool, Ordering};
6use std::sync::{Arc, Mutex, Once};
7
8#[derive(Clone)]
13pub struct RedisConfigBuilder {
14 pub host: String,
15 pub port: u16,
16 pub db: i64,
17 pub username: Option<String>,
18 pub password: Option<String>,
19 pub use_tls: bool,
20 pub use_unix_socket: bool,
21 pub socket_path: Option<String>,
22 pub connection_timeout: Option<u64>,
23}
24
25impl Default for RedisConfigBuilder {
26 fn default() -> Self {
27 Self {
28 host: "127.0.0.1".to_string(),
29 port: 6379,
30 db: 0,
31 username: None,
32 password: None,
33 use_tls: false,
34 use_unix_socket: false,
35 socket_path: None,
36 connection_timeout: None,
37 }
38 }
39}
40
41impl RedisConfigBuilder {
42 pub fn new() -> Self {
44 Self::default()
45 }
46
47 pub fn host(mut self, host: &str) -> Self {
49 self.host = host.to_string();
50 self
51 }
52
53 pub fn port(mut self, port: u16) -> Self {
55 self.port = port;
56 self
57 }
58
59 pub fn db(mut self, db: i64) -> Self {
61 self.db = db;
62 self
63 }
64
65 pub fn username(mut self, username: &str) -> Self {
67 self.username = Some(username.to_string());
68 self
69 }
70
71 pub fn password(mut self, password: &str) -> Self {
73 self.password = Some(password.to_string());
74 self
75 }
76
77 pub fn use_tls(mut self, use_tls: bool) -> Self {
79 self.use_tls = use_tls;
80 self
81 }
82
83 pub fn use_unix_socket(mut self, use_unix_socket: bool) -> Self {
85 self.use_unix_socket = use_unix_socket;
86 self
87 }
88
89 pub fn socket_path(mut self, socket_path: &str) -> Self {
91 self.socket_path = Some(socket_path.to_string());
92 self.use_unix_socket = true;
93 self
94 }
95
96 pub fn connection_timeout(mut self, seconds: u64) -> Self {
98 self.connection_timeout = Some(seconds);
99 self
100 }
101
102 pub fn build_connection_url(&self) -> String {
104 if self.use_unix_socket {
105 if let Some(ref socket_path) = self.socket_path {
106 return format!("unix://{}", socket_path);
107 } else {
108 let home_dir = env::var("HOME").unwrap_or_else(|_| String::from("/root"));
110 return format!("unix://{}/hero/var/myredis.sock", home_dir);
111 }
112 }
113
114 let mut url = if self.use_tls {
115 format!("rediss://{}:{}", self.host, self.port)
116 } else {
117 format!("redis://{}:{}", self.host, self.port)
118 };
119
120 if let Some(ref username) = self.username {
122 if let Some(ref password) = self.password {
123 url = format!(
124 "redis://{}:{}@{}:{}",
125 username, password, self.host, self.port
126 );
127 } else {
128 url = format!("redis://{}@{}:{}", username, self.host, self.port);
129 }
130 } else if let Some(ref password) = self.password {
131 url = format!("redis://:{}@{}:{}", password, self.host, self.port);
132 }
133
134 url = format!("{}/{}", url, self.db);
136
137 url
138 }
139
140 pub fn build(&self) -> RedisResult<(Client, i64)> {
142 let url = self.build_connection_url();
143 let client = Client::open(url)?;
144 Ok((client, self.db))
145 }
146}
147
148lazy_static! {
150 static ref REDIS_CLIENT: Mutex<Option<Arc<RedisClientWrapper>>> = Mutex::new(None);
151 static ref INIT: Once = Once::new();
152}
153
154pub struct RedisClientWrapper {
156 client: Client,
157 connection: Mutex<Option<Connection>>,
158 db: i64,
159 initialized: AtomicBool,
160}
161
162impl RedisClientWrapper {
163 fn new(client: Client, db: i64) -> Self {
165 RedisClientWrapper {
166 client,
167 connection: Mutex::new(None),
168 db,
169 initialized: AtomicBool::new(false),
170 }
171 }
172
173 pub fn execute<T: redis::FromRedisValue>(&self, cmd: &mut Cmd) -> RedisResult<T> {
175 let mut conn_guard = self.connection.lock().unwrap();
176
177 if conn_guard.is_none() || {
179 if let Some(ref mut conn) = *conn_guard {
180 let ping_result: RedisResult<String> = redis::cmd("PING").query(conn);
181 ping_result.is_err()
182 } else {
183 true
184 }
185 } {
186 *conn_guard = Some(self.client.get_connection()?);
187 }
188 cmd.query(&mut conn_guard.as_mut().unwrap())
189 }
190
191 fn initialize(&self) -> RedisResult<()> {
193 if self.initialized.load(Ordering::Relaxed) {
194 return Ok(());
195 }
196
197 let mut conn = self.client.get_connection()?;
198
199 let ping_result: String = redis::cmd("PING").query(&mut conn)?;
201 if ping_result != "PONG" {
202 return Err(RedisError::from((
203 redis::ErrorKind::ResponseError,
204 "Failed to ping Redis server",
205 )));
206 }
207
208 let _ = redis::cmd("SELECT").arg(self.db).exec(&mut conn);
210
211 self.initialized.store(true, Ordering::Relaxed);
212
213 let mut conn_guard = self.connection.lock().unwrap();
215 *conn_guard = Some(conn);
216
217 Ok(())
218 }
219}
220
221pub fn get_redis_client() -> RedisResult<Arc<RedisClientWrapper>> {
223 {
225 let guard = REDIS_CLIENT.lock().unwrap();
226 if let Some(ref client) = &*guard {
227 return Ok(Arc::clone(client));
228 }
229 }
230
231 let client = create_redis_client()?;
233
234 {
236 let mut guard = REDIS_CLIENT.lock().unwrap();
237 *guard = Some(Arc::clone(&client));
238 }
239
240 Ok(client)
241}
242
243fn create_redis_client() -> RedisResult<Arc<RedisClientWrapper>> {
245 let db = get_redis_db();
247 let password = env::var("REDIS_PASSWORD").ok();
248 let username = env::var("REDIS_USERNAME").ok();
249 let host = env::var("REDIS_HOST").unwrap_or_else(|_| String::from("127.0.0.1"));
250 let port = env::var("REDIS_PORT")
251 .ok()
252 .and_then(|p| p.parse::<u16>().ok())
253 .unwrap_or(6379);
254
255 let mut builder = RedisConfigBuilder::new().host(&host).port(port).db(db);
257
258 if let Some(user) = username {
259 builder = builder.username(&user);
260 }
261
262 if let Some(pass) = password {
263 builder = builder.password(&pass);
264 }
265
266 let home_dir = env::var("HOME").unwrap_or_else(|_| String::from("/root"));
268 let socket_path = format!("{}/hero/var/myredis.sock", home_dir);
269
270 if Path::new(&socket_path).exists() {
271 let socket_builder = builder.clone().socket_path(&socket_path);
273
274 match socket_builder.build() {
275 Ok((client, db)) => {
276 let wrapper = Arc::new(RedisClientWrapper::new(client, db));
277
278 if let Err(err) = wrapper.initialize() {
280 eprintln!(
281 "Socket exists at {} but connection failed: {}",
282 socket_path, err
283 );
284 } else {
285 return Ok(wrapper);
286 }
287 }
288 Err(err) => {
289 eprintln!(
290 "Socket exists at {} but connection failed: {}",
291 socket_path, err
292 );
293 }
294 }
295 }
296
297 match builder.clone().build() {
299 Ok((client, db)) => {
300 let wrapper = Arc::new(RedisClientWrapper::new(client, db));
301
302 wrapper.initialize()?;
304
305 Ok(wrapper)
306 }
307 Err(err) => Err(RedisError::from((
308 redis::ErrorKind::IoError,
309 "Failed to connect to Redis",
310 format!(
311 "Could not connect via socket at {} or via TCP to {}:{}: {}",
312 socket_path, host, port, err
313 ),
314 ))),
315 }
316}
317
318fn get_redis_db() -> i64 {
320 env::var("REDISDB")
321 .ok()
322 .and_then(|db_str| db_str.parse::<i64>().ok())
323 .unwrap_or(0)
324}
325
326pub fn reset() -> RedisResult<()> {
328 {
330 let mut client_guard = REDIS_CLIENT.lock().unwrap();
331 *client_guard = None;
332 }
333
334 get_redis_client()?;
337 Ok(())
338}
339
340pub fn execute<T>(cmd: &mut Cmd) -> RedisResult<T>
342where
343 T: redis::FromRedisValue,
344{
345 let client = get_redis_client()?;
346 client.execute(cmd)
347}
348
349pub fn with_config(config: RedisConfigBuilder) -> RedisResult<Client> {
359 let (client, _) = config.build()?;
360 Ok(client)
361}