sql_cli/redis_cache_module/
redis_cache.rs

1use redis::{Client, Commands, Connection};
2use sha2::{Digest, Sha256};
3use std::time::Duration;
4use tracing::{debug, info};
5
6pub struct RedisCache {
7    connection: Option<Connection>,
8    enabled: bool,
9}
10
11impl RedisCache {
12    /// Try to create a new Redis cache connection
13    pub fn new() -> Self {
14        // Cache is OPT-IN - must explicitly enable via SQL_CLI_CACHE=true
15        // This ensures sql-cli works exactly as before for users without Redis
16        match std::env::var("SQL_CLI_CACHE") {
17            Ok(val) => {
18                // Only proceed if explicitly enabled
19                if !val.eq_ignore_ascii_case("true")
20                    && !val.eq_ignore_ascii_case("yes")
21                    && val != "1"
22                {
23                    debug!("Cache not enabled (SQL_CLI_CACHE != true)");
24                    return Self {
25                        connection: None,
26                        enabled: false,
27                    };
28                }
29            }
30            Err(_) => {
31                // No SQL_CLI_CACHE variable = cache disabled (default)
32                debug!("Cache disabled by default (set SQL_CLI_CACHE=true to enable)");
33                return Self {
34                    connection: None,
35                    enabled: false,
36                };
37            }
38        }
39
40        debug!("Cache explicitly enabled via SQL_CLI_CACHE=true");
41
42        // Get Redis URL from environment or use default
43        let redis_url = std::env::var("SQL_CLI_REDIS_URL")
44            .unwrap_or_else(|_| "redis://127.0.0.1:6379".to_string());
45
46        // Try to connect with a short timeout
47        match Client::open(redis_url.as_str()) {
48            Ok(client) => {
49                // Try to get connection with timeout
50                match client.get_connection_with_timeout(Duration::from_secs(1)) {
51                    Ok(mut conn) => {
52                        // Test the connection
53                        match redis::cmd("PING").query::<String>(&mut conn) {
54                            Ok(_) => {
55                                debug!("Redis cache connected successfully");
56                                Self {
57                                    connection: Some(conn),
58                                    enabled: true,
59                                }
60                            }
61                            Err(e) => {
62                                debug!("Redis ping failed: {}", e);
63                                Self {
64                                    connection: None,
65                                    enabled: false,
66                                }
67                            }
68                        }
69                    }
70                    Err(e) => {
71                        debug!("Redis connection failed: {}", e);
72                        Self {
73                            connection: None,
74                            enabled: false,
75                        }
76                    }
77                }
78            }
79            Err(e) => {
80                debug!("Redis client creation failed: {}", e);
81                Self {
82                    connection: None,
83                    enabled: false,
84                }
85            }
86        }
87    }
88
89    /// Generate a cache key from Web CTE components (legacy - kept for compatibility)
90    pub fn generate_key(
91        table_name: &str, // CTE name to prevent collisions
92        url: &str,
93        method: Option<&str>,
94        headers: &[(String, String)],
95        body: Option<&str>,
96    ) -> String {
97        // Call the new method with empty context for backward compatibility
98        Self::generate_key_with_context(table_name, url, method, headers, body, "")
99    }
100
101    /// Generate a cache key from Web CTE components with query context
102    pub fn generate_key_with_context(
103        table_name: &str, // CTE name
104        url: &str,
105        method: Option<&str>,
106        headers: &[(String, String)],
107        body: Option<&str>,
108        query_context: &str, // Hash or unique identifier of the full query
109    ) -> String {
110        Self::generate_key_full(
111            table_name,
112            url,
113            method,
114            headers,
115            body,
116            query_context,
117            None, // json_path
118            &[],  // form_files
119            &[],  // form_fields
120        )
121    }
122
123    /// Generate a complete cache key from all Web CTE components
124    pub fn generate_key_full(
125        table_name: &str,
126        url: &str,
127        method: Option<&str>,
128        headers: &[(String, String)],
129        body: Option<&str>,
130        _query_context: &str, // Kept for API compatibility but not used
131        json_path: Option<&str>,
132        form_files: &[(String, String)],
133        form_fields: &[(String, String)],
134    ) -> String {
135        let mut hasher = Sha256::new();
136
137        // NOTE: We do NOT include query_context - each WEB CTE should be
138        // independent and cache based only on its own properties
139
140        // Hash the CTE name first
141        hasher.update(table_name.as_bytes());
142        hasher.update(b":::"); // Separator
143
144        // Hash URL
145        hasher.update(url.as_bytes());
146        hasher.update(b":::");
147
148        // Hash method
149        if let Some(method) = method {
150            hasher.update(method.as_bytes());
151            hasher.update(b":::");
152        }
153
154        // Hash headers (sorted for consistency)
155        let mut sorted_headers = headers.to_vec();
156        sorted_headers.sort_by(|a, b| a.0.cmp(&b.0));
157        for (key, value) in sorted_headers {
158            hasher.update(key.as_bytes());
159            hasher.update(b":");
160            hasher.update(value.as_bytes());
161            hasher.update(b";");
162        }
163
164        // Hash body
165        if let Some(body) = body {
166            hasher.update(b"body:");
167            hasher.update(body.as_bytes());
168            hasher.update(b":::");
169        }
170
171        // Hash json_path
172        if let Some(path) = json_path {
173            hasher.update(b"json_path:");
174            hasher.update(path.as_bytes());
175            hasher.update(b":::");
176        }
177
178        // Hash form_files (sorted for consistency)
179        if !form_files.is_empty() {
180            let mut sorted_files = form_files.to_vec();
181            sorted_files.sort_by(|a, b| a.0.cmp(&b.0));
182            for (field, path) in sorted_files {
183                hasher.update(b"file:");
184                hasher.update(field.as_bytes());
185                hasher.update(b"=");
186                hasher.update(path.as_bytes());
187                hasher.update(b";");
188            }
189        }
190
191        // Hash form_fields (sorted for consistency)
192        if !form_fields.is_empty() {
193            let mut sorted_fields = form_fields.to_vec();
194            sorted_fields.sort_by(|a, b| a.0.cmp(&b.0));
195            for (field, value) in sorted_fields {
196                hasher.update(b"field:");
197                hasher.update(field.as_bytes());
198                hasher.update(b"=");
199                hasher.update(value.as_bytes());
200                hasher.update(b";");
201            }
202        }
203
204        format!("sql-cli:web:{}:{:x}", table_name, hasher.finalize())
205    }
206
207    /// Check if cache is enabled
208    pub fn is_enabled(&self) -> bool {
209        self.enabled
210    }
211
212    /// Get data from cache
213    pub fn get(&mut self, key: &str) -> Option<Vec<u8>> {
214        if !self.enabled {
215            return None;
216        }
217
218        if let Some(ref mut conn) = self.connection {
219            match conn.get::<_, Vec<u8>>(key) {
220                Ok(data) => {
221                    debug!("Cache HIT: {}", &key[0..32.min(key.len())]);
222                    Some(data)
223                }
224                Err(_) => {
225                    debug!("Cache MISS: {}", &key[0..32.min(key.len())]);
226                    None
227                }
228            }
229        } else {
230            None
231        }
232    }
233
234    /// Store data in cache with TTL
235    pub fn set(&mut self, key: &str, data: &[u8], ttl_seconds: u64) -> Result<(), String> {
236        if !self.enabled {
237            return Ok(());
238        }
239
240        if let Some(ref mut conn) = self.connection {
241            match conn.set_ex::<_, _, String>(key, data, ttl_seconds as usize) {
242                Ok(_) => {
243                    info!("Cached {} bytes with TTL {}s", data.len(), ttl_seconds);
244                    Ok(())
245                }
246                Err(e) => {
247                    debug!("Failed to cache: {}", e);
248                    // Don't fail the query just because caching failed
249                    Ok(())
250                }
251            }
252        } else {
253            Ok(())
254        }
255    }
256
257    /// Get TTL of a key (for debugging)
258    pub fn ttl(&mut self, key: &str) -> Option<i64> {
259        if !self.enabled {
260            return None;
261        }
262
263        if let Some(ref mut conn) = self.connection {
264            conn.ttl(key).ok()
265        } else {
266            None
267        }
268    }
269
270    /// Check cache statistics
271    pub fn stats(&mut self) -> Option<String> {
272        if !self.enabled {
273            return None;
274        }
275
276        if let Some(ref mut conn) = self.connection {
277            // Get all sql-cli keys
278            if let Ok(keys) = redis::cmd("KEYS")
279                .arg("sql-cli:*")
280                .query::<Vec<String>>(conn)
281            {
282                let count = keys.len();
283                let mut total_size = 0;
284                let mut expiring_soon = 0;
285
286                for key in &keys {
287                    // Get memory usage
288                    if let Ok(size) = redis::cmd("MEMORY")
289                        .arg("USAGE")
290                        .arg(key)
291                        .query::<Option<usize>>(conn)
292                    {
293                        total_size += size.unwrap_or(0);
294                    }
295
296                    // Check TTL
297                    if let Ok(ttl) = conn.ttl::<_, i64>(key) {
298                        if ttl > 0 && ttl < 300 {
299                            expiring_soon += 1;
300                        }
301                    }
302                }
303
304                return Some(format!(
305                    "Cache stats: {} entries, {:.2} MB, {} expiring soon",
306                    count,
307                    total_size as f64 / 1_048_576.0,
308                    expiring_soon
309                ));
310            }
311        }
312
313        None
314    }
315}
316
317impl Default for RedisCache {
318    fn default() -> Self {
319        Self::new()
320    }
321}