vibesql_server/
registry.rs1use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9use vibesql_storage::Database;
10
11pub type SharedDatabase = Arc<RwLock<Database>>;
13
14#[derive(Clone)]
20pub struct DatabaseRegistry {
21 databases: Arc<RwLock<HashMap<String, SharedDatabase>>>,
22}
23
24impl DatabaseRegistry {
25 pub fn new() -> Self {
27 Self { databases: Arc::new(RwLock::new(HashMap::new())) }
28 }
29
30 pub async fn get_or_create(&self, name: &str) -> SharedDatabase {
35 {
37 let databases = self.databases.read().await;
38 if let Some(db) = databases.get(name) {
39 return Arc::clone(db);
40 }
41 }
42
43 let mut databases = self.databases.write().await;
45
46 if let Some(db) = databases.get(name) {
48 return Arc::clone(db);
49 }
50
51 let db = Arc::new(RwLock::new(Database::new()));
53 databases.insert(name.to_string(), Arc::clone(&db));
54 db
55 }
56
57 #[allow(dead_code)]
61 pub async fn get(&self, name: &str) -> Option<SharedDatabase> {
62 let databases = self.databases.read().await;
63 databases.get(name).cloned()
64 }
65
66 #[allow(dead_code)]
68 pub async fn list_databases(&self) -> Vec<String> {
69 let databases = self.databases.read().await;
70 databases.keys().cloned().collect()
71 }
72
73 #[allow(dead_code)]
75 pub async fn database_count(&self) -> usize {
76 let databases = self.databases.read().await;
77 databases.len()
78 }
79
80 pub async fn register_database(&self, name: &str, db: Database) {
85 let mut databases = self.databases.write().await;
86 databases.insert(name.to_string(), Arc::new(RwLock::new(db)));
87 }
88}
89
90impl Default for DatabaseRegistry {
91 fn default() -> Self {
92 Self::new()
93 }
94}
95
96#[cfg(test)]
97mod tests {
98 use super::*;
99
100 #[tokio::test]
101 async fn test_get_or_create_new_database() {
102 let registry = DatabaseRegistry::new();
103
104 let db1 = registry.get_or_create("testdb").await;
105 assert_eq!(registry.database_count().await, 1);
106
107 let db2 = registry.get_or_create("testdb").await;
109 assert_eq!(registry.database_count().await, 1);
110
111 assert!(Arc::ptr_eq(&db1, &db2));
113 }
114
115 #[tokio::test]
116 async fn test_different_databases() {
117 let registry = DatabaseRegistry::new();
118
119 let db1 = registry.get_or_create("db1").await;
120 let db2 = registry.get_or_create("db2").await;
121
122 assert_eq!(registry.database_count().await, 2);
123 assert!(!Arc::ptr_eq(&db1, &db2));
124 }
125
126 #[tokio::test]
127 async fn test_shared_data_across_connections() {
128 let registry = DatabaseRegistry::new();
129
130 let db1 = registry.get_or_create("shared").await;
132 let db2 = registry.get_or_create("shared").await;
133
134 {
136 let mut db = db1.write().await;
137 let schema = vibesql_catalog::TableSchema::new(
138 "users".to_string(),
139 vec![vibesql_catalog::ColumnSchema::new(
140 "id".to_string(),
141 vibesql_types::DataType::Integer,
142 true,
143 )],
144 );
145 db.create_table(schema).unwrap();
146 }
147
148 {
150 let db = db2.read().await;
151 assert!(db.get_table("users").is_some());
152 }
153 }
154
155 #[tokio::test]
156 async fn test_list_databases() {
157 let registry = DatabaseRegistry::new();
158
159 registry.get_or_create("alpha").await;
160 registry.get_or_create("beta").await;
161 registry.get_or_create("gamma").await;
162
163 let names = registry.list_databases().await;
164 assert_eq!(names.len(), 3);
165 assert!(names.contains(&"alpha".to_string()));
166 assert!(names.contains(&"beta".to_string()));
167 assert!(names.contains(&"gamma".to_string()));
168 }
169}