typedb_driver/database/
database_manager.rs1#[cfg(not(feature = "sync"))]
21use std::future::Future;
22use std::{
23 collections::HashMap,
24 sync::{Arc, RwLock},
25};
26
27use super::Database;
28use crate::{
29 common::{address::Address, error::ConnectionError, Result},
30 connection::server_connection::ServerConnection,
31 info::DatabaseInfo,
32 Error,
33};
34
35#[derive(Debug)]
37pub struct DatabaseManager {
38 server_connections: HashMap<Address, ServerConnection>,
39 databases_cache: RwLock<HashMap<String, Arc<Database>>>,
40}
41
42impl DatabaseManager {
44 pub(crate) fn new(
45 server_connections: HashMap<Address, ServerConnection>,
46 database_info: Vec<DatabaseInfo>,
47 ) -> Result<Self> {
48 let mut databases = HashMap::new();
49 for info in database_info {
50 let database = Database::new(info, server_connections.clone())?;
51 databases.insert(database.name().to_owned(), Arc::new(database));
52 }
53 Ok(Self { server_connections, databases_cache: RwLock::new(databases) })
54 }
55
56 #[cfg_attr(feature = "sync", doc = "driver.databases().all();")]
62 #[cfg_attr(not(feature = "sync"), doc = "driver.databases().all().await;")]
63 #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
65 pub async fn all(&self) -> Result<Vec<Arc<Database>>> {
66 let mut error_buffer = Vec::with_capacity(self.server_connections.len());
67 for (server_id, server_connection) in self.server_connections.iter() {
68 match server_connection.all_databases().await {
69 Ok(list) => {
70 let mut new_databases: Vec<Arc<Database>> = Vec::new();
71 for db_info in list {
72 new_databases.push(Arc::new(Database::new(db_info, self.server_connections.clone())?));
73 }
74 let mut databases = self.databases_cache.write().unwrap();
75 databases.clear();
76 databases
77 .extend(new_databases.iter().map(|database| (database.name().to_owned(), database.clone())));
78 return Ok(new_databases);
79 }
80 Err(err) => error_buffer.push(format!("- {}: {}", server_id, err)),
81 }
82 }
83 Err(ConnectionError::ServerConnectionFailedWithError { error: error_buffer.join("\n") })?
84 }
85
86 #[cfg_attr(feature = "sync", doc = "driver.databases().get(name);")]
96 #[cfg_attr(not(feature = "sync"), doc = "driver.databases().get(name).await;")]
97 #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
99 pub async fn get(&self, name: impl AsRef<str>) -> Result<Arc<Database>> {
100 let name = name.as_ref();
101
102 if !self.contains(name.to_owned()).await? {
103 self.databases_cache.write().unwrap().remove(name);
104 return Err(ConnectionError::DatabaseNotFound { name: name.to_owned() }.into());
105 }
106
107 if let Some(cached_database) = self.try_get_cached(name) {
108 return Ok(cached_database);
109 }
110
111 self.cache_insert(Database::get(name.to_owned(), self.server_connections.clone()).await?);
112 Ok(self.try_get_cached(name).unwrap())
113 }
114
115 #[cfg_attr(feature = "sync", doc = "driver.databases().contains(name);")]
125 #[cfg_attr(not(feature = "sync"), doc = "driver.databases().contains(name).await;")]
126 #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
128 pub async fn contains(&self, name: impl Into<String>) -> Result<bool> {
129 let name = name.into();
130 self.run_failsafe(
131 name,
132 |server_connection, name| async move { server_connection.contains_database(name).await },
133 )
134 .await
135 }
136
137 #[cfg_attr(feature = "sync", doc = "driver.databases().create(name);")]
147 #[cfg_attr(not(feature = "sync"), doc = "driver.databases().create(name).await;")]
148 #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
150 pub async fn create(&self, name: impl Into<String>) -> Result {
151 let name = name.into();
152 let database_info = self
153 .run_failsafe(name, |server_connection, name| async move { server_connection.create_database(name).await }) .await?;
155 self.cache_insert(Database::new(database_info, self.server_connections.clone())?);
156 Ok(())
157 }
158
159 #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
160 pub(crate) async fn get_cached_or_fetch(&self, name: &str) -> Result<Arc<Database>> {
161 match self.try_get_cached(name) {
162 Some(cached_database) => Ok(cached_database),
163 None => self.get(name).await,
164 }
165 }
166
167 fn try_get_cached(&self, name: &str) -> Option<Arc<Database>> {
168 self.databases_cache.read().unwrap().get(name).cloned()
169 }
170
171 fn cache_insert(&self, database: Database) {
172 self.databases_cache.write().unwrap().insert(database.name().to_owned(), Arc::new(database));
173 }
174
175 #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
176 async fn run_failsafe<F, P, R>(&self, name: String, task: F) -> Result<R>
177 where
178 F: Fn(ServerConnection, String) -> P,
179 P: Future<Output = Result<R>>,
180 {
181 let mut error_buffer = Vec::with_capacity(self.server_connections.len());
182 for (server_id, server_connection) in self.server_connections.iter() {
183 match task(server_connection.clone(), name.clone()).await {
184 Ok(res) => return Ok(res),
185 err @ Err(Error::Connection(ConnectionError::ServerConnectionIsClosed)) => return err,
197 Err(err) => error_buffer.push(format!("- {}: {}", server_id, err)),
198 }
199 }
200 Err(ConnectionError::ServerConnectionFailedWithError { error: error_buffer.join("\n") })?
201 }
202}