1#[cfg(not(feature = "sync"))]
21use std::future::Future;
22use std::{
23 collections::HashMap,
24 io::{BufReader, BufWriter, Cursor, Read},
25 path::Path,
26 sync::{Arc, RwLock},
27};
28
29use prost::{decode_length_delimiter, Message};
30use typedb_protocol::migration::Item;
31
32use super::Database;
33use crate::{
34 common::{address::Address, error::ConnectionError, Result},
35 connection::server_connection::ServerConnection,
36 database::migration::{try_open_import_file, ProtoMessageIterator},
37 info::DatabaseInfo,
38 resolve, Error,
39};
40
41#[derive(Debug)]
43pub struct DatabaseManager {
44 server_connections: HashMap<Address, ServerConnection>,
45 databases_cache: RwLock<HashMap<String, Arc<Database>>>,
46}
47
48impl DatabaseManager {
50 pub(crate) fn new(
51 server_connections: HashMap<Address, ServerConnection>,
52 database_info: Vec<DatabaseInfo>,
53 ) -> Result<Self> {
54 let mut databases = HashMap::new();
55 for info in database_info {
56 let database = Database::new(info, server_connections.clone())?;
57 databases.insert(database.name().to_owned(), Arc::new(database));
58 }
59 Ok(Self { server_connections, databases_cache: RwLock::new(databases) })
60 }
61
62 #[cfg_attr(feature = "sync", doc = "driver.databases().all();")]
68 #[cfg_attr(not(feature = "sync"), doc = "driver.databases().all().await;")]
69 #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
71 pub async fn all(&self) -> Result<Vec<Arc<Database>>> {
72 let mut error_buffer = Vec::with_capacity(self.server_connections.len());
73 for (server_id, server_connection) in self.server_connections.iter() {
74 match server_connection.all_databases().await {
75 Ok(list) => {
76 let mut new_databases: Vec<Arc<Database>> = Vec::new();
77 for db_info in list {
78 new_databases.push(Arc::new(Database::new(db_info, self.server_connections.clone())?));
79 }
80 let mut databases = self.databases_cache.write().unwrap();
81 databases.clear();
82 databases
83 .extend(new_databases.iter().map(|database| (database.name().to_owned(), database.clone())));
84 return Ok(new_databases);
85 }
86 Err(err) => error_buffer.push(format!("- {}: {}", server_id, err)),
87 }
88 }
89 Err(ConnectionError::ServerConnectionFailedWithError { error: error_buffer.join("\n") })?
90 }
91
92 #[cfg_attr(feature = "sync", doc = "driver.databases().get(name);")]
102 #[cfg_attr(not(feature = "sync"), doc = "driver.databases().get(name).await;")]
103 #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
105 pub async fn get(&self, name: impl AsRef<str>) -> Result<Arc<Database>> {
106 let name = name.as_ref();
107
108 if !self.contains(name.to_owned()).await? {
109 self.databases_cache.write().unwrap().remove(name);
110 return Err(ConnectionError::DatabaseNotFound { name: name.to_owned() }.into());
111 }
112
113 if let Some(cached_database) = self.try_get_cached(name) {
114 return Ok(cached_database);
115 }
116
117 self.cache_insert(Database::get(name.to_owned(), self.server_connections.clone()).await?);
118 Ok(self.try_get_cached(name).unwrap())
119 }
120
121 #[cfg_attr(feature = "sync", doc = "driver.databases().contains(name);")]
131 #[cfg_attr(not(feature = "sync"), doc = "driver.databases().contains(name).await;")]
132 #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
134 pub async fn contains(&self, name: impl Into<String>) -> Result<bool> {
135 let name = name.into();
136 self.run_failsafe(
137 name,
138 |server_connection, name| async move { server_connection.contains_database(name).await },
139 )
140 .await
141 }
142
143 #[cfg_attr(feature = "sync", doc = "driver.databases().create(name);")]
153 #[cfg_attr(not(feature = "sync"), doc = "driver.databases().create(name).await;")]
154 #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
156 pub async fn create(&self, name: impl Into<String>) -> Result {
157 let name = name.into();
158 let database_info = self
159 .run_failsafe(name, |server_connection, name| async move { server_connection.create_database(name).await }) .await?;
161 self.cache_insert(Database::new(database_info, self.server_connections.clone())?);
162 Ok(())
163 }
164
165 #[cfg_attr(feature = "sync", doc = "driver.databases().import_from_file(name, schema, data_path);")]
180 #[cfg_attr(not(feature = "sync"), doc = "driver.databases().import_from_file(name, schema, data_path).await;")]
181 #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
183 pub async fn import_from_file(
184 &self,
185 name: impl Into<String>,
186 schema: impl Into<String>,
187 data_file_path: impl AsRef<Path>,
188 ) -> Result {
189 const ITEM_BATCH_SIZE: usize = 250;
190
191 let name = name.into();
192 let schema: String = schema.into();
193 let schema_ref: &str = schema.as_ref();
194 let data_file_path = data_file_path.as_ref();
195
196 self.run_failsafe(name, |server_connection, name| async move {
197 let file = try_open_import_file(data_file_path)?;
198 let mut import_stream = server_connection.import_database(name, schema_ref.to_string()).await?;
199
200 let mut item_buffer = Vec::with_capacity(ITEM_BATCH_SIZE);
201 let mut read_item_iterator = ProtoMessageIterator::<Item, _>::new(BufReader::new(file));
202
203 while let Some(item) = read_item_iterator.next() {
204 let item = item?;
205 item_buffer.push(item);
206 if item_buffer.len() >= ITEM_BATCH_SIZE {
207 import_stream.send_items(item_buffer.split_off(0))?;
208 }
209 }
210
211 if !item_buffer.is_empty() {
212 import_stream.send_items(item_buffer)?;
213 }
214
215 resolve!(import_stream.done())
216 })
217 .await
218 }
219
220 #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
221 pub(crate) async fn get_cached_or_fetch(&self, name: &str) -> Result<Arc<Database>> {
222 match self.try_get_cached(name) {
223 Some(cached_database) => Ok(cached_database),
224 None => self.get(name).await,
225 }
226 }
227
228 fn try_get_cached(&self, name: &str) -> Option<Arc<Database>> {
229 self.databases_cache.read().unwrap().get(name).cloned()
230 }
231
232 fn cache_insert(&self, database: Database) {
233 self.databases_cache.write().unwrap().insert(database.name().to_owned(), Arc::new(database));
234 }
235
236 #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
237 async fn run_failsafe<F, P, R>(&self, name: String, task: F) -> Result<R>
238 where
239 F: Fn(ServerConnection, String) -> P,
240 P: Future<Output = Result<R>>,
241 {
242 let mut error_buffer = Vec::with_capacity(self.server_connections.len());
243 for (server_id, server_connection) in self.server_connections.iter() {
244 match task(server_connection.clone(), name.clone()).await {
245 Ok(res) => return Ok(res),
246 err @ Err(Error::Connection(ConnectionError::ServerConnectionIsClosed)) => return err,
258 Err(err) => error_buffer.push(format!("- {}: {}", server_id, err)),
259 }
260 }
261 Err(ConnectionError::ServerConnectionFailedWithError { error: error_buffer.join("\n") })?
265 }
266}