Skip to main content

vibesql_server/
registry.rs

1//! Database registry for shared database instances across connections.
2//!
3//! This module provides a registry that maps database names to shared database
4//! instances, allowing multiple connections to the same database to share data.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9use vibesql_storage::Database;
10
11/// Shared database handle that can be cloned across connections.
12pub type SharedDatabase = Arc<RwLock<Database>>;
13
14/// Registry managing shared database instances.
15///
16/// When a connection requests a database by name, the registry either returns
17/// an existing shared instance or creates a new one. This ensures all connections
18/// to the same database name share the same data.
19#[derive(Clone)]
20pub struct DatabaseRegistry {
21    databases: Arc<RwLock<HashMap<String, SharedDatabase>>>,
22}
23
24impl DatabaseRegistry {
25    /// Create a new empty database registry.
26    pub fn new() -> Self {
27        Self { databases: Arc::new(RwLock::new(HashMap::new())) }
28    }
29
30    /// Get or create a shared database instance for the given name.
31    ///
32    /// If a database with the given name already exists, returns a clone of
33    /// its shared handle. Otherwise, creates a new database and returns it.
34    pub async fn get_or_create(&self, name: &str) -> SharedDatabase {
35        // First try read lock to check if database exists
36        {
37            let databases = self.databases.read().await;
38            if let Some(db) = databases.get(name) {
39                return Arc::clone(db);
40            }
41        }
42
43        // Need to create - acquire write lock
44        let mut databases = self.databases.write().await;
45
46        // Double-check after acquiring write lock (another task may have created it)
47        if let Some(db) = databases.get(name) {
48            return Arc::clone(db);
49        }
50
51        // Create new database
52        let db = Arc::new(RwLock::new(Database::new()));
53        databases.insert(name.to_string(), Arc::clone(&db));
54        db
55    }
56
57    /// Get a shared database instance if it exists.
58    ///
59    /// Returns None if no database with the given name exists.
60    #[allow(dead_code)]
61    pub async fn get(&self, name: &str) -> Option<SharedDatabase> {
62        let databases = self.databases.read().await;
63        databases.get(name).cloned()
64    }
65
66    /// List all database names in the registry.
67    #[allow(dead_code)]
68    pub async fn list_databases(&self) -> Vec<String> {
69        let databases = self.databases.read().await;
70        databases.keys().cloned().collect()
71    }
72
73    /// Get the number of databases in the registry.
74    #[allow(dead_code)]
75    pub async fn database_count(&self) -> usize {
76        let databases = self.databases.read().await;
77        databases.len()
78    }
79
80    /// Register a pre-built database instance.
81    ///
82    /// This is useful for benchmarks where the database is pre-loaded with data
83    /// before starting the server.
84    pub async fn register_database(&self, name: &str, db: Database) {
85        let mut databases = self.databases.write().await;
86        databases.insert(name.to_string(), Arc::new(RwLock::new(db)));
87    }
88}
89
90impl Default for DatabaseRegistry {
91    fn default() -> Self {
92        Self::new()
93    }
94}
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99
100    #[tokio::test]
101    async fn test_get_or_create_new_database() {
102        let registry = DatabaseRegistry::new();
103
104        let db1 = registry.get_or_create("testdb").await;
105        assert_eq!(registry.database_count().await, 1);
106
107        // Same name should return the same database
108        let db2 = registry.get_or_create("testdb").await;
109        assert_eq!(registry.database_count().await, 1);
110
111        // Verify they point to the same database
112        assert!(Arc::ptr_eq(&db1, &db2));
113    }
114
115    #[tokio::test]
116    async fn test_different_databases() {
117        let registry = DatabaseRegistry::new();
118
119        let db1 = registry.get_or_create("db1").await;
120        let db2 = registry.get_or_create("db2").await;
121
122        assert_eq!(registry.database_count().await, 2);
123        assert!(!Arc::ptr_eq(&db1, &db2));
124    }
125
126    #[tokio::test]
127    async fn test_shared_data_across_connections() {
128        let registry = DatabaseRegistry::new();
129
130        // Simulate two connections to the same database
131        let db1 = registry.get_or_create("shared").await;
132        let db2 = registry.get_or_create("shared").await;
133
134        // Create a table through first "connection"
135        {
136            let mut db = db1.write().await;
137            let schema = vibesql_catalog::TableSchema::new(
138                "users".to_string(),
139                vec![vibesql_catalog::ColumnSchema::new(
140                    "id".to_string(),
141                    vibesql_types::DataType::Integer,
142                    true,
143                )],
144            );
145            db.create_table(schema).unwrap();
146        }
147
148        // Should be visible through second "connection"
149        {
150            let db = db2.read().await;
151            assert!(db.get_table("users").is_some());
152        }
153    }
154
155    #[tokio::test]
156    async fn test_list_databases() {
157        let registry = DatabaseRegistry::new();
158
159        registry.get_or_create("alpha").await;
160        registry.get_or_create("beta").await;
161        registry.get_or_create("gamma").await;
162
163        let names = registry.list_databases().await;
164        assert_eq!(names.len(), 3);
165        assert!(names.contains(&"alpha".to_string()));
166        assert!(names.contains(&"beta".to_string()));
167        assert!(names.contains(&"gamma".to_string()));
168    }
169}