vibesql_server/subscription/manager/lifecycle.rs
1//! Subscription lifecycle management: subscribe and unsubscribe operations.
2
3use std::collections::HashSet;
4use std::sync::atomic::Ordering;
5
6use tokio::sync::mpsc;
7use tracing::debug;
8
9use super::SubscriptionManager;
10use crate::subscription::{
11 extract_table_refs, Subscription, SubscriptionError, SubscriptionId, SubscriptionUpdate,
12};
13
14impl SubscriptionManager {
15 /// Create a new subscription for a query
16 ///
17 /// Parses the query to extract table dependencies and registers the
18 /// subscription for notifications.
19 ///
20 /// # Arguments
21 ///
22 /// * `query` - SQL query to monitor
23 /// * `notify_tx` - Channel to send updates to the subscriber
24 ///
25 /// # Returns
26 ///
27 /// The subscription ID on success, or an error if parsing fails or limits exceeded
28 ///
29 /// # Errors
30 ///
31 /// - `ParseError` if the query cannot be parsed or references no tables
32 /// - `GlobalLimitExceeded` if the global subscription limit is reached
33 ///
34 /// # Example
35 ///
36 /// ```text
37 /// let manager = SubscriptionManager::new();
38 /// let (tx, mut rx) = mpsc::channel(16);
39 ///
40 /// let id = manager.subscribe("SELECT * FROM users".to_string(), tx)?;
41 /// println!("Subscribed with ID: {}", id);
42 /// ```
43 pub fn subscribe(
44 &self,
45 query: String,
46 notify_tx: mpsc::Sender<SubscriptionUpdate>,
47 ) -> Result<SubscriptionId, SubscriptionError> {
48 // Atomically reserve a slot to prevent TOCTOU race condition
49 // Use compare-and-swap loop to safely increment the counter
50 loop {
51 let current_count = self.subscription_count_atomic.load(Ordering::Acquire);
52 if current_count >= self.config.max_global {
53 self.limit_exceeded_count.fetch_add(1, Ordering::Relaxed);
54 return Err(SubscriptionError::GlobalLimitExceeded {
55 current: current_count,
56 max: self.config.max_global,
57 });
58 }
59
60 // Try to atomically increment the count
61 match self.subscription_count_atomic.compare_exchange(
62 current_count,
63 current_count + 1,
64 Ordering::AcqRel,
65 Ordering::Acquire,
66 ) {
67 Ok(_) => break, // Successfully reserved a slot
68 Err(_) => continue, // Another thread changed the count, retry
69 }
70 }
71
72 // Parse query and extract table dependencies
73 let tables = match self.extract_tables(&query) {
74 Ok(tables) => tables,
75 Err(e) => {
76 // Release the reserved slot on parse error
77 self.subscription_count_atomic.fetch_sub(1, Ordering::Release);
78 return Err(e);
79 }
80 };
81
82 if tables.is_empty() {
83 // Release the reserved slot
84 self.subscription_count_atomic.fetch_sub(1, Ordering::Release);
85 return Err(SubscriptionError::ParseError(
86 "Query must reference at least one table".to_string(),
87 ));
88 }
89
90 // Create subscription with default channel buffer size
91 let subscription = Subscription::new(query.clone(), tables.clone(), notify_tx);
92 let id = subscription.id;
93
94 debug!(
95 subscription_id = %id,
96 tables = ?tables,
97 "Creating new subscription"
98 );
99
100 // Register subscription
101 self.subscriptions.insert(id, subscription);
102
103 // Index by tables
104 for table in tables {
105 self.table_index.entry(table).or_default().insert(id);
106 }
107
108 Ok(id)
109 }
110
111 /// Create a new subscription for a specific connection (wire protocol)
112 ///
113 /// This is the primary method for wire protocol subscriptions. It:
114 /// - Checks both global and per-connection limits
115 /// - Associates the subscription with a connection ID for cleanup
116 /// - Stores the wire protocol UUID for lookup
117 ///
118 /// # Arguments
119 ///
120 /// * `query` - SQL query to monitor
121 /// * `notify_tx` - Channel to send updates to the subscriber
122 /// * `connection_id` - The connection/session ID that owns this subscription
123 /// * `wire_subscription_id` - The wire protocol UUID for this subscription
124 /// * `table_dependencies` - Pre-extracted table dependencies (from AST parsing)
125 ///
126 /// # Returns
127 ///
128 /// The subscription ID on success, or an error if limits exceeded
129 ///
130 /// # Errors
131 ///
132 /// - `GlobalLimitExceeded` if the global subscription limit is reached
133 /// - `ConnectionLimitExceeded` if the per-connection limit is reached
134 pub fn subscribe_for_connection(
135 &self,
136 query: String,
137 notify_tx: mpsc::Sender<SubscriptionUpdate>,
138 connection_id: String,
139 wire_subscription_id: [u8; 16],
140 table_dependencies: HashSet<String>,
141 filter: Option<String>,
142 ) -> Result<SubscriptionId, SubscriptionError> {
143 // Check per-connection limit first
144 let conn_count = self
145 .connection_subscription_counts
146 .entry(connection_id.clone())
147 .or_insert_with(|| std::sync::atomic::AtomicUsize::new(0));
148
149 // Use CAS loop for per-connection limit check
150 loop {
151 let current_conn_count = conn_count.load(Ordering::Acquire);
152 if current_conn_count >= self.config.max_per_connection {
153 return Err(SubscriptionError::ConnectionLimitExceeded {
154 current: current_conn_count,
155 max: self.config.max_per_connection,
156 });
157 }
158
159 // Try to atomically increment the per-connection count
160 match conn_count.compare_exchange(
161 current_conn_count,
162 current_conn_count + 1,
163 Ordering::AcqRel,
164 Ordering::Acquire,
165 ) {
166 Ok(_) => break,
167 Err(_) => continue,
168 }
169 }
170
171 // Atomically reserve a global slot to prevent TOCTOU race condition
172 loop {
173 let current_count = self.subscription_count_atomic.load(Ordering::Acquire);
174 if current_count >= self.config.max_global {
175 // Release the per-connection slot we reserved
176 conn_count.fetch_sub(1, Ordering::Release);
177 self.limit_exceeded_count.fetch_add(1, Ordering::Relaxed);
178 return Err(SubscriptionError::GlobalLimitExceeded {
179 current: current_count,
180 max: self.config.max_global,
181 });
182 }
183
184 // Try to atomically increment the count
185 match self.subscription_count_atomic.compare_exchange(
186 current_count,
187 current_count + 1,
188 Ordering::AcqRel,
189 Ordering::Acquire,
190 ) {
191 Ok(_) => break,
192 Err(_) => continue,
193 }
194 }
195
196 // Create subscription with connection tracking
197 let subscription = Subscription::for_connection(
198 query.clone(),
199 table_dependencies.clone(),
200 notify_tx,
201 connection_id.clone(),
202 wire_subscription_id,
203 filter,
204 &self.config,
205 );
206 let id = subscription.id;
207
208 debug!(
209 subscription_id = %id,
210 connection_id = %connection_id,
211 tables = ?table_dependencies,
212 "Creating new subscription for connection"
213 );
214
215 // Register subscription
216 self.subscriptions.insert(id, subscription);
217
218 // Index by tables (lowercase for case-insensitive matching)
219 for table in table_dependencies {
220 self.table_index.entry(table.to_lowercase()).or_default().insert(id);
221 }
222
223 // Index by connection
224 self.connection_index.entry(connection_id).or_default().insert(id);
225
226 // Index by wire ID
227 self.wire_id_index.insert(wire_subscription_id, id);
228
229 Ok(id)
230 }
231
232 /// Remove a subscription
233 ///
234 /// Unregisters the subscription and removes it from all indexes.
235 ///
236 /// # Arguments
237 ///
238 /// * `id` - The subscription ID to remove
239 ///
240 /// # Returns
241 ///
242 /// `true` if the removed subscription was selective-eligible, `false` otherwise
243 pub fn unsubscribe(&self, id: SubscriptionId) -> bool {
244 if let Some((_, subscription)) = self.subscriptions.remove(&id) {
245 debug!(subscription_id = %id, "Removing subscription");
246
247 let was_selective_eligible = subscription.selective_eligible;
248
249 // Decrement the atomic counter
250 self.subscription_count_atomic.fetch_sub(1, Ordering::Release);
251
252 // Remove from table index
253 for table in &subscription.tables {
254 if let Some(mut ids) = self.table_index.get_mut(table) {
255 ids.remove(&id);
256 }
257 }
258
259 // Remove from connection index if present
260 if let Some(ref conn_id) = subscription.connection_id {
261 if let Some(mut ids) = self.connection_index.get_mut(conn_id) {
262 ids.remove(&id);
263 }
264 // Decrement per-connection count
265 if let Some(count) = self.connection_subscription_counts.get(conn_id) {
266 count.fetch_sub(1, Ordering::Release);
267 }
268 }
269
270 // Remove from wire ID index if present
271 if let Some(wire_id) = subscription.wire_subscription_id {
272 self.wire_id_index.remove(&wire_id);
273 }
274
275 return was_selective_eligible;
276 }
277 false
278 }
279
280 /// Remove a subscription by its wire protocol ID
281 ///
282 /// This is used by wire protocol clients that use UUID-based subscription IDs.
283 ///
284 /// # Arguments
285 ///
286 /// * `wire_id` - The wire protocol subscription ID (UUID bytes)
287 ///
288 /// # Returns
289 ///
290 /// `true` if the removed subscription was selective-eligible, `false` otherwise.
291 /// Returns `false` if the subscription was not found.
292 pub fn unsubscribe_by_wire_id(&self, wire_id: &[u8; 16]) -> bool {
293 if let Some((_, id)) = self.wire_id_index.remove(wire_id) {
294 self.unsubscribe(id)
295 } else {
296 false
297 }
298 }
299
300 /// Remove all subscriptions for a connection
301 ///
302 /// This should be called when a connection closes to clean up all its
303 /// subscriptions. This is important for wire protocol connections.
304 ///
305 /// # Arguments
306 ///
307 /// * `connection_id` - The connection ID to clean up
308 ///
309 /// # Returns
310 ///
311 /// A tuple of (total_removed, selective_eligible_removed)
312 pub fn unsubscribe_all_for_connection(&self, connection_id: &str) -> (usize, usize) {
313 let subscription_ids: Vec<SubscriptionId> = if let Some((_, ids)) =
314 self.connection_index.remove(connection_id)
315 {
316 ids.into_iter().collect()
317 } else {
318 return (0, 0);
319 };
320
321 let count = subscription_ids.len();
322 debug!(
323 connection_id = %connection_id,
324 subscription_count = count,
325 "Removing all subscriptions for connection"
326 );
327
328 let mut selective_eligible_count = 0;
329 for id in subscription_ids {
330 // Note: unsubscribe will try to remove from connection_index again,
331 // but it will be a no-op since we already removed it
332 if self.unsubscribe(id) {
333 selective_eligible_count += 1;
334 }
335 }
336
337 // Clean up the per-connection count entry
338 self.connection_subscription_counts.remove(connection_id);
339
340 (count, selective_eligible_count)
341 }
342
343 /// Extract table references from a query
344 pub(crate) fn extract_tables(&self, query: &str) -> Result<HashSet<String>, SubscriptionError> {
345 let stmt = vibesql_parser::Parser::parse_sql(query)
346 .map_err(|e| SubscriptionError::ParseError(e.to_string()))?;
347 Ok(extract_table_refs(&stmt))
348 }
349}