Skip to main content

vibesql_server/subscription/manager/
lifecycle.rs

1//! Subscription lifecycle management: subscribe and unsubscribe operations.
2
3use std::collections::HashSet;
4use std::sync::atomic::Ordering;
5
6use tokio::sync::mpsc;
7use tracing::debug;
8
9use super::SubscriptionManager;
10use crate::subscription::{
11    extract_table_refs, Subscription, SubscriptionError, SubscriptionId, SubscriptionUpdate,
12};
13
14impl SubscriptionManager {
15    /// Create a new subscription for a query
16    ///
17    /// Parses the query to extract table dependencies and registers the
18    /// subscription for notifications.
19    ///
20    /// # Arguments
21    ///
22    /// * `query` - SQL query to monitor
23    /// * `notify_tx` - Channel to send updates to the subscriber
24    ///
25    /// # Returns
26    ///
27    /// The subscription ID on success, or an error if parsing fails or limits exceeded
28    ///
29    /// # Errors
30    ///
31    /// - `ParseError` if the query cannot be parsed or references no tables
32    /// - `GlobalLimitExceeded` if the global subscription limit is reached
33    ///
34    /// # Example
35    ///
36    /// ```text
37    /// let manager = SubscriptionManager::new();
38    /// let (tx, mut rx) = mpsc::channel(16);
39    ///
40    /// let id = manager.subscribe("SELECT * FROM users".to_string(), tx)?;
41    /// println!("Subscribed with ID: {}", id);
42    /// ```
43    pub fn subscribe(
44        &self,
45        query: String,
46        notify_tx: mpsc::Sender<SubscriptionUpdate>,
47    ) -> Result<SubscriptionId, SubscriptionError> {
48        // Atomically reserve a slot to prevent TOCTOU race condition
49        // Use compare-and-swap loop to safely increment the counter
50        loop {
51            let current_count = self.subscription_count_atomic.load(Ordering::Acquire);
52            if current_count >= self.config.max_global {
53                self.limit_exceeded_count.fetch_add(1, Ordering::Relaxed);
54                return Err(SubscriptionError::GlobalLimitExceeded {
55                    current: current_count,
56                    max: self.config.max_global,
57                });
58            }
59
60            // Try to atomically increment the count
61            match self.subscription_count_atomic.compare_exchange(
62                current_count,
63                current_count + 1,
64                Ordering::AcqRel,
65                Ordering::Acquire,
66            ) {
67                Ok(_) => break,     // Successfully reserved a slot
68                Err(_) => continue, // Another thread changed the count, retry
69            }
70        }
71
72        // Parse query and extract table dependencies
73        let tables = match self.extract_tables(&query) {
74            Ok(tables) => tables,
75            Err(e) => {
76                // Release the reserved slot on parse error
77                self.subscription_count_atomic.fetch_sub(1, Ordering::Release);
78                return Err(e);
79            }
80        };
81
82        if tables.is_empty() {
83            // Release the reserved slot
84            self.subscription_count_atomic.fetch_sub(1, Ordering::Release);
85            return Err(SubscriptionError::ParseError(
86                "Query must reference at least one table".to_string(),
87            ));
88        }
89
90        // Create subscription with default channel buffer size
91        let subscription = Subscription::new(query.clone(), tables.clone(), notify_tx);
92        let id = subscription.id;
93
94        debug!(
95            subscription_id = %id,
96            tables = ?tables,
97            "Creating new subscription"
98        );
99
100        // Register subscription
101        self.subscriptions.insert(id, subscription);
102
103        // Index by tables
104        for table in tables {
105            self.table_index.entry(table).or_default().insert(id);
106        }
107
108        Ok(id)
109    }
110
111    /// Create a new subscription for a specific connection (wire protocol)
112    ///
113    /// This is the primary method for wire protocol subscriptions. It:
114    /// - Checks both global and per-connection limits
115    /// - Associates the subscription with a connection ID for cleanup
116    /// - Stores the wire protocol UUID for lookup
117    ///
118    /// # Arguments
119    ///
120    /// * `query` - SQL query to monitor
121    /// * `notify_tx` - Channel to send updates to the subscriber
122    /// * `connection_id` - The connection/session ID that owns this subscription
123    /// * `wire_subscription_id` - The wire protocol UUID for this subscription
124    /// * `table_dependencies` - Pre-extracted table dependencies (from AST parsing)
125    ///
126    /// # Returns
127    ///
128    /// The subscription ID on success, or an error if limits exceeded
129    ///
130    /// # Errors
131    ///
132    /// - `GlobalLimitExceeded` if the global subscription limit is reached
133    /// - `ConnectionLimitExceeded` if the per-connection limit is reached
134    pub fn subscribe_for_connection(
135        &self,
136        query: String,
137        notify_tx: mpsc::Sender<SubscriptionUpdate>,
138        connection_id: String,
139        wire_subscription_id: [u8; 16],
140        table_dependencies: HashSet<String>,
141        filter: Option<String>,
142    ) -> Result<SubscriptionId, SubscriptionError> {
143        // Check per-connection limit first
144        let conn_count = self
145            .connection_subscription_counts
146            .entry(connection_id.clone())
147            .or_insert_with(|| std::sync::atomic::AtomicUsize::new(0));
148
149        // Use CAS loop for per-connection limit check
150        loop {
151            let current_conn_count = conn_count.load(Ordering::Acquire);
152            if current_conn_count >= self.config.max_per_connection {
153                return Err(SubscriptionError::ConnectionLimitExceeded {
154                    current: current_conn_count,
155                    max: self.config.max_per_connection,
156                });
157            }
158
159            // Try to atomically increment the per-connection count
160            match conn_count.compare_exchange(
161                current_conn_count,
162                current_conn_count + 1,
163                Ordering::AcqRel,
164                Ordering::Acquire,
165            ) {
166                Ok(_) => break,
167                Err(_) => continue,
168            }
169        }
170
171        // Atomically reserve a global slot to prevent TOCTOU race condition
172        loop {
173            let current_count = self.subscription_count_atomic.load(Ordering::Acquire);
174            if current_count >= self.config.max_global {
175                // Release the per-connection slot we reserved
176                conn_count.fetch_sub(1, Ordering::Release);
177                self.limit_exceeded_count.fetch_add(1, Ordering::Relaxed);
178                return Err(SubscriptionError::GlobalLimitExceeded {
179                    current: current_count,
180                    max: self.config.max_global,
181                });
182            }
183
184            // Try to atomically increment the count
185            match self.subscription_count_atomic.compare_exchange(
186                current_count,
187                current_count + 1,
188                Ordering::AcqRel,
189                Ordering::Acquire,
190            ) {
191                Ok(_) => break,
192                Err(_) => continue,
193            }
194        }
195
196        // Create subscription with connection tracking
197        let subscription = Subscription::for_connection(
198            query.clone(),
199            table_dependencies.clone(),
200            notify_tx,
201            connection_id.clone(),
202            wire_subscription_id,
203            filter,
204            &self.config,
205        );
206        let id = subscription.id;
207
208        debug!(
209            subscription_id = %id,
210            connection_id = %connection_id,
211            tables = ?table_dependencies,
212            "Creating new subscription for connection"
213        );
214
215        // Register subscription
216        self.subscriptions.insert(id, subscription);
217
218        // Index by tables (lowercase for case-insensitive matching)
219        for table in table_dependencies {
220            self.table_index.entry(table.to_lowercase()).or_default().insert(id);
221        }
222
223        // Index by connection
224        self.connection_index.entry(connection_id).or_default().insert(id);
225
226        // Index by wire ID
227        self.wire_id_index.insert(wire_subscription_id, id);
228
229        Ok(id)
230    }
231
232    /// Remove a subscription
233    ///
234    /// Unregisters the subscription and removes it from all indexes.
235    ///
236    /// # Arguments
237    ///
238    /// * `id` - The subscription ID to remove
239    ///
240    /// # Returns
241    ///
242    /// `true` if the removed subscription was selective-eligible, `false` otherwise
243    pub fn unsubscribe(&self, id: SubscriptionId) -> bool {
244        if let Some((_, subscription)) = self.subscriptions.remove(&id) {
245            debug!(subscription_id = %id, "Removing subscription");
246
247            let was_selective_eligible = subscription.selective_eligible;
248
249            // Decrement the atomic counter
250            self.subscription_count_atomic.fetch_sub(1, Ordering::Release);
251
252            // Remove from table index
253            for table in &subscription.tables {
254                if let Some(mut ids) = self.table_index.get_mut(table) {
255                    ids.remove(&id);
256                }
257            }
258
259            // Remove from connection index if present
260            if let Some(ref conn_id) = subscription.connection_id {
261                if let Some(mut ids) = self.connection_index.get_mut(conn_id) {
262                    ids.remove(&id);
263                }
264                // Decrement per-connection count
265                if let Some(count) = self.connection_subscription_counts.get(conn_id) {
266                    count.fetch_sub(1, Ordering::Release);
267                }
268            }
269
270            // Remove from wire ID index if present
271            if let Some(wire_id) = subscription.wire_subscription_id {
272                self.wire_id_index.remove(&wire_id);
273            }
274
275            return was_selective_eligible;
276        }
277        false
278    }
279
280    /// Remove a subscription by its wire protocol ID
281    ///
282    /// This is used by wire protocol clients that use UUID-based subscription IDs.
283    ///
284    /// # Arguments
285    ///
286    /// * `wire_id` - The wire protocol subscription ID (UUID bytes)
287    ///
288    /// # Returns
289    ///
290    /// `true` if the removed subscription was selective-eligible, `false` otherwise.
291    /// Returns `false` if the subscription was not found.
292    pub fn unsubscribe_by_wire_id(&self, wire_id: &[u8; 16]) -> bool {
293        if let Some((_, id)) = self.wire_id_index.remove(wire_id) {
294            self.unsubscribe(id)
295        } else {
296            false
297        }
298    }
299
300    /// Remove all subscriptions for a connection
301    ///
302    /// This should be called when a connection closes to clean up all its
303    /// subscriptions. This is important for wire protocol connections.
304    ///
305    /// # Arguments
306    ///
307    /// * `connection_id` - The connection ID to clean up
308    ///
309    /// # Returns
310    ///
311    /// A tuple of (total_removed, selective_eligible_removed)
312    pub fn unsubscribe_all_for_connection(&self, connection_id: &str) -> (usize, usize) {
313        let subscription_ids: Vec<SubscriptionId> = if let Some((_, ids)) =
314            self.connection_index.remove(connection_id)
315        {
316            ids.into_iter().collect()
317        } else {
318            return (0, 0);
319        };
320
321        let count = subscription_ids.len();
322        debug!(
323            connection_id = %connection_id,
324            subscription_count = count,
325            "Removing all subscriptions for connection"
326        );
327
328        let mut selective_eligible_count = 0;
329        for id in subscription_ids {
330            // Note: unsubscribe will try to remove from connection_index again,
331            // but it will be a no-op since we already removed it
332            if self.unsubscribe(id) {
333                selective_eligible_count += 1;
334            }
335        }
336
337        // Clean up the per-connection count entry
338        self.connection_subscription_counts.remove(connection_id);
339
340        (count, selective_eligible_count)
341    }
342
343    /// Extract table references from a query
344    pub(crate) fn extract_tables(&self, query: &str) -> Result<HashSet<String>, SubscriptionError> {
345        let stmt = vibesql_parser::Parser::parse_sql(query)
346            .map_err(|e| SubscriptionError::ParseError(e.to_string()))?;
347        Ok(extract_table_refs(&stmt))
348    }
349}