typedb_driver/database/
database_manager.rs

1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements.  See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership.  The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License.  You may obtain a copy of the License at
9 *
10 *   http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied.  See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20#[cfg(not(feature = "sync"))]
21use std::future::Future;
22use std::{
23    collections::HashMap,
24    io::{BufReader, BufWriter, Cursor, Read},
25    path::Path,
26    sync::{Arc, RwLock},
27};
28
29use prost::{decode_length_delimiter, Message};
30use typedb_protocol::migration::Item;
31
32use super::Database;
33use crate::{
34    common::{address::Address, error::ConnectionError, Result},
35    connection::server_connection::ServerConnection,
36    database::migration::{try_open_import_file, ProtoMessageIterator},
37    info::DatabaseInfo,
38    resolve, Error,
39};
40
41/// Provides access to all database management methods.
42#[derive(Debug)]
43pub struct DatabaseManager {
44    server_connections: HashMap<Address, ServerConnection>,
45    databases_cache: RwLock<HashMap<String, Arc<Database>>>,
46}
47
48/// Provides access to all database management methods.
49impl DatabaseManager {
50    pub(crate) fn new(
51        server_connections: HashMap<Address, ServerConnection>,
52        database_info: Vec<DatabaseInfo>,
53    ) -> Result<Self> {
54        let mut databases = HashMap::new();
55        for info in database_info {
56            let database = Database::new(info, server_connections.clone())?;
57            databases.insert(database.name().to_owned(), Arc::new(database));
58        }
59        Ok(Self { server_connections, databases_cache: RwLock::new(databases) })
60    }
61
62    /// Retrieves all databases present on the TypeDB server
63    ///
64    /// # Examples
65    ///
66    /// ```rust
67    #[cfg_attr(feature = "sync", doc = "driver.databases().all();")]
68    #[cfg_attr(not(feature = "sync"), doc = "driver.databases().all().await;")]
69    /// ```
70    #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
71    pub async fn all(&self) -> Result<Vec<Arc<Database>>> {
72        let mut error_buffer = Vec::with_capacity(self.server_connections.len());
73        for (server_id, server_connection) in self.server_connections.iter() {
74            match server_connection.all_databases().await {
75                Ok(list) => {
76                    let mut new_databases: Vec<Arc<Database>> = Vec::new();
77                    for db_info in list {
78                        new_databases.push(Arc::new(Database::new(db_info, self.server_connections.clone())?));
79                    }
80                    let mut databases = self.databases_cache.write().unwrap();
81                    databases.clear();
82                    databases
83                        .extend(new_databases.iter().map(|database| (database.name().to_owned(), database.clone())));
84                    return Ok(new_databases);
85                }
86                Err(err) => error_buffer.push(format!("- {}: {}", server_id, err)),
87            }
88        }
89        Err(ConnectionError::ServerConnectionFailedWithError { error: error_buffer.join("\n") })?
90    }
91
92    /// Retrieve the database with the given name.
93    ///
94    /// # Arguments
95    ///
96    /// * `name` — The name of the database to retrieve
97    ///
98    /// # Examples
99    ///
100    /// ```rust
101    #[cfg_attr(feature = "sync", doc = "driver.databases().get(name);")]
102    #[cfg_attr(not(feature = "sync"), doc = "driver.databases().get(name).await;")]
103    /// ```
104    #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
105    pub async fn get(&self, name: impl AsRef<str>) -> Result<Arc<Database>> {
106        let name = name.as_ref();
107
108        if !self.contains(name.to_owned()).await? {
109            self.databases_cache.write().unwrap().remove(name);
110            return Err(ConnectionError::DatabaseNotFound { name: name.to_owned() }.into());
111        }
112
113        if let Some(cached_database) = self.try_get_cached(name) {
114            return Ok(cached_database);
115        }
116
117        self.cache_insert(Database::get(name.to_owned(), self.server_connections.clone()).await?);
118        Ok(self.try_get_cached(name).unwrap())
119    }
120
121    /// Checks if a database with the given name exists
122    ///
123    /// # Arguments
124    ///
125    /// * `name` — The database name to be checked
126    ///
127    /// # Examples
128    ///
129    /// ```rust
130    #[cfg_attr(feature = "sync", doc = "driver.databases().contains(name);")]
131    #[cfg_attr(not(feature = "sync"), doc = "driver.databases().contains(name).await;")]
132    /// ```
133    #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
134    pub async fn contains(&self, name: impl Into<String>) -> Result<bool> {
135        let name = name.into();
136        self.run_failsafe(
137            name,
138            |server_connection, name| async move { server_connection.contains_database(name).await },
139        )
140        .await
141    }
142
143    /// Create a database with the given name
144    ///
145    /// # Arguments
146    ///
147    /// * `name` — The name of the database to be created
148    ///
149    /// # Examples
150    ///
151    /// ```rust
152    #[cfg_attr(feature = "sync", doc = "driver.databases().create(name);")]
153    #[cfg_attr(not(feature = "sync"), doc = "driver.databases().create(name).await;")]
154    /// ```
155    #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
156    pub async fn create(&self, name: impl Into<String>) -> Result {
157        let name = name.into();
158        let database_info = self
159            .run_failsafe(name, |server_connection, name| async move { server_connection.create_database(name).await }) // TODO: run_failsafe produces additiona Connection error if the database name is incorrect. Is it ok?
160            .await?;
161        self.cache_insert(Database::new(database_info, self.server_connections.clone())?);
162        Ok(())
163    }
164
165    /// Create a database with the given name based on previously exported another database's data
166    /// loaded from a file.
167    /// This is a blocking operation and may take a significant amount of time depending on the
168    /// database size.
169    ///
170    /// # Arguments
171    ///
172    /// * `name` — The name of the database to be created
173    /// * `schema` — The schema definition query string for the database
174    /// * `data_file_path` — The exported database file to import the data from
175    ///
176    /// # Examples
177    ///
178    /// ```rust
179    #[cfg_attr(feature = "sync", doc = "driver.databases().import_from_file(name, schema, data_path);")]
180    #[cfg_attr(not(feature = "sync"), doc = "driver.databases().import_from_file(name, schema, data_path).await;")]
181    /// ```
182    #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
183    pub async fn import_from_file(
184        &self,
185        name: impl Into<String>,
186        schema: impl Into<String>,
187        data_file_path: impl AsRef<Path>,
188    ) -> Result {
189        const ITEM_BATCH_SIZE: usize = 250;
190
191        let name = name.into();
192        let schema: String = schema.into();
193        let schema_ref: &str = schema.as_ref();
194        let data_file_path = data_file_path.as_ref();
195
196        self.run_failsafe(name, |server_connection, name| async move {
197            let file = try_open_import_file(data_file_path)?;
198            let mut import_stream = server_connection.import_database(name, schema_ref.to_string()).await?;
199
200            let mut item_buffer = Vec::with_capacity(ITEM_BATCH_SIZE);
201            let mut read_item_iterator = ProtoMessageIterator::<Item, _>::new(BufReader::new(file));
202
203            while let Some(item) = read_item_iterator.next() {
204                let item = item?;
205                item_buffer.push(item);
206                if item_buffer.len() >= ITEM_BATCH_SIZE {
207                    import_stream.send_items(item_buffer.split_off(0))?;
208                }
209            }
210
211            if !item_buffer.is_empty() {
212                import_stream.send_items(item_buffer)?;
213            }
214
215            resolve!(import_stream.done())
216        })
217        .await
218    }
219
220    #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
221    pub(crate) async fn get_cached_or_fetch(&self, name: &str) -> Result<Arc<Database>> {
222        match self.try_get_cached(name) {
223            Some(cached_database) => Ok(cached_database),
224            None => self.get(name).await,
225        }
226    }
227
228    fn try_get_cached(&self, name: &str) -> Option<Arc<Database>> {
229        self.databases_cache.read().unwrap().get(name).cloned()
230    }
231
232    fn cache_insert(&self, database: Database) {
233        self.databases_cache.write().unwrap().insert(database.name().to_owned(), Arc::new(database));
234    }
235
236    #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
237    async fn run_failsafe<F, P, R>(&self, name: String, task: F) -> Result<R>
238    where
239        F: Fn(ServerConnection, String) -> P,
240        P: Future<Output = Result<R>>,
241    {
242        let mut error_buffer = Vec::with_capacity(self.server_connections.len());
243        for (server_id, server_connection) in self.server_connections.iter() {
244            match task(server_connection.clone(), name.clone()).await {
245                Ok(res) => return Ok(res),
246                // TODO: database manager should never encounter NOT PRIMARY errors since we are failing over server connections, not replicas
247
248                // Err(Error::Connection(ConnectionError::ClusterReplicaNotPrimary)) => {
249                //     return Database::get(name, self.connection.clone())
250                //         .await?
251                //         .run_on_primary_replica(|database| {
252                //             let task = &task;
253                //             async move { task(database.connection().clone(), database.name().to_owned()).await }
254                //         })
255                //         .await
256                // }
257                err @ Err(Error::Connection(ConnectionError::ServerConnectionIsClosed)) => return err,
258                Err(err) => error_buffer.push(format!("- {}: {}", server_id, err)),
259            }
260        }
261        // TODO: With this, every operation fails with
262        // [CXN03] Connection Error: Unable to connect to TypeDB server(s), received errors: .... <stacktrace>
263        // Which is quite confusing as it's not really connected to connection.
264        Err(ConnectionError::ServerConnectionFailedWithError { error: error_buffer.join("\n") })?
265    }
266}