Skip to main content

sentinel_driver/cache/
mod.rs

1use std::collections::HashMap;
2use std::sync::atomic::{AtomicU64, Ordering};
3
4use lru::LruCache;
5use std::num::NonZeroUsize;
6
7use crate::statement::Statement;
8
9/// Default LRU cache capacity for ad-hoc queries.
10const DEFAULT_LRU_CAPACITY: usize = 256;
11
12/// Two-tier prepared statement cache.
13///
14/// - **Tier 1** (HashMap): Pre-registered queries. Never evicted. O(1) lookup.
15/// - **Tier 2** (LRU): Ad-hoc queries. Auto-evicted when full. O(1) amortized.
16///
17/// Statements are keyed by SQL text. Each cached statement has a unique
18/// server-side name for the PG prepared statement protocol.
19pub struct StatementCache {
20    /// Tier 1: registered (permanent) statements, keyed by user-given name.
21    registered: HashMap<String, CachedStatement>,
22    /// Tier 2: ad-hoc statements, keyed by SQL text.
23    adhoc: LruCache<String, CachedStatement>,
24    /// Counter for generating unique statement names.
25    name_counter: AtomicU64,
26    /// Metrics.
27    metrics: CacheMetrics,
28}
29
30/// A cached prepared statement entry.
31#[derive(Debug, Clone)]
32#[non_exhaustive]
33pub struct CachedStatement {
34    /// The server-side statement name.
35    pub name: String,
36    /// The full statement metadata.
37    pub statement: Statement,
38}
39
40/// Cache hit/miss metrics.
41#[derive(Debug, Clone, Default)]
42#[non_exhaustive]
43pub struct CacheMetrics {
44    pub tier1_hits: u64,
45    pub tier2_hits: u64,
46    pub misses: u64,
47    pub evictions: u64,
48}
49
50impl CacheMetrics {
51    /// Total cache hits (tier 1 + tier 2).
52    pub fn total_hits(&self) -> u64 {
53        self.tier1_hits + self.tier2_hits
54    }
55
56    /// Hit rate as a fraction (0.0 to 1.0).
57    pub fn hit_rate(&self) -> f64 {
58        let total = self.total_hits() + self.misses;
59        if total == 0 {
60            0.0
61        } else {
62            self.total_hits() as f64 / total as f64
63        }
64    }
65}
66
67impl StatementCache {
68    /// Create a new statement cache with the default LRU capacity (256).
69    pub fn new() -> Self {
70        Self::with_capacity(DEFAULT_LRU_CAPACITY)
71    }
72
73    /// Create a new statement cache with a custom LRU capacity.
74    pub fn with_capacity(lru_capacity: usize) -> Self {
75        Self {
76            registered: HashMap::new(),
77            adhoc: LruCache::new(NonZeroUsize::new(lru_capacity).unwrap_or(NonZeroUsize::MIN)),
78            name_counter: AtomicU64::new(0),
79            metrics: CacheMetrics::default(),
80        }
81    }
82
83    /// Register a statement in Tier 1 (permanent, never evicted).
84    ///
85    /// The `name` is the user-defined name (also used as the server-side name).
86    pub fn register(&mut self, name: &str, statement: Statement) {
87        self.registered.insert(
88            name.to_string(),
89            CachedStatement {
90                name: name.to_string(),
91                statement,
92            },
93        );
94    }
95
96    /// Look up a registered statement by name (Tier 1).
97    pub fn get_registered(&mut self, name: &str) -> Option<&CachedStatement> {
98        let result = self.registered.get(name);
99        if result.is_some() {
100            self.metrics.tier1_hits += 1;
101        }
102        result
103    }
104
105    /// Look up an ad-hoc statement by SQL text (Tier 2).
106    pub fn get_adhoc(&mut self, sql: &str) -> Option<&CachedStatement> {
107        let result = self.adhoc.get(sql);
108        if result.is_some() {
109            self.metrics.tier2_hits += 1;
110        }
111        result
112    }
113
114    /// Insert an ad-hoc statement into Tier 2.
115    ///
116    /// Returns the evicted statement's server-side name if the cache was full,
117    /// so the caller can send a Close message to the server.
118    pub fn insert_adhoc(&mut self, sql: String, statement: Statement) -> Option<String> {
119        let name = self.generate_name();
120
121        // Check if inserting will evict
122        let evicted = if self.adhoc.len() == self.adhoc.cap().get() {
123            // Peek at the LRU entry that will be evicted
124            self.adhoc.peek_lru().map(|(_, cached)| cached.name.clone())
125        } else {
126            None
127        };
128
129        if evicted.is_some() {
130            self.metrics.evictions += 1;
131        }
132
133        self.adhoc.put(sql, CachedStatement { name, statement });
134
135        evicted
136    }
137
138    /// Record a cache miss.
139    pub fn record_miss(&mut self) {
140        self.metrics.misses += 1;
141    }
142
143    /// Get the server-side name for an ad-hoc query, or generate one.
144    ///
145    /// Checks Tier 2 first. If not found, records a miss and returns `None`.
146    pub fn lookup_or_miss(&mut self, sql: &str) -> Option<&CachedStatement> {
147        if self.adhoc.get(sql).is_some() {
148            self.metrics.tier2_hits += 1;
149            self.adhoc.get(sql)
150        } else {
151            self.metrics.misses += 1;
152            None
153        }
154    }
155
156    /// Get cache metrics.
157    pub fn metrics(&self) -> &CacheMetrics {
158        &self.metrics
159    }
160
161    /// Number of registered (Tier 1) statements.
162    pub fn registered_count(&self) -> usize {
163        self.registered.len()
164    }
165
166    /// Number of cached ad-hoc (Tier 2) statements.
167    pub fn adhoc_count(&self) -> usize {
168        self.adhoc.len()
169    }
170
171    /// Generate a unique server-side statement name.
172    pub fn generate_name(&self) -> String {
173        let id = self.name_counter.fetch_add(1, Ordering::Relaxed);
174        format!("_sentinel_s{id}")
175    }
176}
177
178impl Default for StatementCache {
179    fn default() -> Self {
180        Self::new()
181    }
182}