typedb_driver/database/
database_manager.rs

1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements.  See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership.  The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License.  You may obtain a copy of the License at
9 *
10 *   http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied.  See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20#[cfg(not(feature = "sync"))]
21use std::future::Future;
22use std::{
23    collections::HashMap,
24    sync::{Arc, RwLock},
25};
26
27use super::Database;
28use crate::{
29    common::{address::Address, error::ConnectionError, Result},
30    connection::server_connection::ServerConnection,
31    info::DatabaseInfo,
32    Error,
33};
34
35/// Provides access to all database management methods.
36#[derive(Debug)]
37pub struct DatabaseManager {
38    server_connections: HashMap<Address, ServerConnection>,
39    databases_cache: RwLock<HashMap<String, Arc<Database>>>,
40}
41
42/// Provides access to all database management methods.
43impl DatabaseManager {
44    pub(crate) fn new(
45        server_connections: HashMap<Address, ServerConnection>,
46        database_info: Vec<DatabaseInfo>,
47    ) -> Result<Self> {
48        let mut databases = HashMap::new();
49        for info in database_info {
50            let database = Database::new(info, server_connections.clone())?;
51            databases.insert(database.name().to_owned(), Arc::new(database));
52        }
53        Ok(Self { server_connections, databases_cache: RwLock::new(databases) })
54    }
55
56    /// Retrieves all databases present on the TypeDB server
57    ///
58    /// # Examples
59    ///
60    /// ```rust
61    #[cfg_attr(feature = "sync", doc = "driver.databases().all();")]
62    #[cfg_attr(not(feature = "sync"), doc = "driver.databases().all().await;")]
63    /// ```
64    #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
65    pub async fn all(&self) -> Result<Vec<Arc<Database>>> {
66        let mut error_buffer = Vec::with_capacity(self.server_connections.len());
67        for (server_id, server_connection) in self.server_connections.iter() {
68            match server_connection.all_databases().await {
69                Ok(list) => {
70                    let mut new_databases: Vec<Arc<Database>> = Vec::new();
71                    for db_info in list {
72                        new_databases.push(Arc::new(Database::new(db_info, self.server_connections.clone())?));
73                    }
74                    let mut databases = self.databases_cache.write().unwrap();
75                    databases.clear();
76                    databases
77                        .extend(new_databases.iter().map(|database| (database.name().to_owned(), database.clone())));
78                    return Ok(new_databases);
79                }
80                Err(err) => error_buffer.push(format!("- {}: {}", server_id, err)),
81            }
82        }
83        Err(ConnectionError::ServerConnectionFailedWithError { error: error_buffer.join("\n") })?
84    }
85
86    /// Retrieve the database with the given name.
87    ///
88    /// # Arguments
89    ///
90    /// * `name` — The name of the database to retrieve
91    ///
92    /// # Examples
93    ///
94    /// ```rust
95    #[cfg_attr(feature = "sync", doc = "driver.databases().get(name);")]
96    #[cfg_attr(not(feature = "sync"), doc = "driver.databases().get(name).await;")]
97    /// ```
98    #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
99    pub async fn get(&self, name: impl AsRef<str>) -> Result<Arc<Database>> {
100        let name = name.as_ref();
101
102        if !self.contains(name.to_owned()).await? {
103            self.databases_cache.write().unwrap().remove(name);
104            return Err(ConnectionError::DatabaseNotFound { name: name.to_owned() }.into());
105        }
106
107        if let Some(cached_database) = self.try_get_cached(name) {
108            return Ok(cached_database);
109        }
110
111        self.cache_insert(Database::get(name.to_owned(), self.server_connections.clone()).await?);
112        Ok(self.try_get_cached(name).unwrap())
113    }
114
115    /// Checks if a database with the given name exists
116    ///
117    /// # Arguments
118    ///
119    /// * `name` — The database name to be checked
120    ///
121    /// # Examples
122    ///
123    /// ```rust
124    #[cfg_attr(feature = "sync", doc = "driver.databases().contains(name);")]
125    #[cfg_attr(not(feature = "sync"), doc = "driver.databases().contains(name).await;")]
126    /// ```
127    #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
128    pub async fn contains(&self, name: impl Into<String>) -> Result<bool> {
129        let name = name.into();
130        self.run_failsafe(
131            name,
132            |server_connection, name| async move { server_connection.contains_database(name).await },
133        )
134        .await
135    }
136
137    /// Create a database with the given name
138    ///
139    /// # Arguments
140    ///
141    /// * `name` — The name of the database to be created
142    ///
143    /// # Examples
144    ///
145    /// ```rust
146    #[cfg_attr(feature = "sync", doc = "driver.databases().create(name);")]
147    #[cfg_attr(not(feature = "sync"), doc = "driver.databases().create(name).await;")]
148    /// ```
149    #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
150    pub async fn create(&self, name: impl Into<String>) -> Result {
151        let name = name.into();
152        let database_info = self
153            .run_failsafe(name, |server_connection, name| async move { server_connection.create_database(name).await }) // TODO: run_failsafe produces additiona Connection error if the database name is incorrect. Is it ok?
154            .await?;
155        self.cache_insert(Database::new(database_info, self.server_connections.clone())?);
156        Ok(())
157    }
158
159    #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
160    pub(crate) async fn get_cached_or_fetch(&self, name: &str) -> Result<Arc<Database>> {
161        match self.try_get_cached(name) {
162            Some(cached_database) => Ok(cached_database),
163            None => self.get(name).await,
164        }
165    }
166
167    fn try_get_cached(&self, name: &str) -> Option<Arc<Database>> {
168        self.databases_cache.read().unwrap().get(name).cloned()
169    }
170
171    fn cache_insert(&self, database: Database) {
172        self.databases_cache.write().unwrap().insert(database.name().to_owned(), Arc::new(database));
173    }
174
175    #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
176    async fn run_failsafe<F, P, R>(&self, name: String, task: F) -> Result<R>
177    where
178        F: Fn(ServerConnection, String) -> P,
179        P: Future<Output = Result<R>>,
180    {
181        let mut error_buffer = Vec::with_capacity(self.server_connections.len());
182        for (server_id, server_connection) in self.server_connections.iter() {
183            match task(server_connection.clone(), name.clone()).await {
184                Ok(res) => return Ok(res),
185                // TODO: database manager should never encounter NOT PRIMARY errors since we are failing over server connections, not replicas
186
187                // Err(Error::Connection(ConnectionError::CloudReplicaNotPrimary)) => {
188                //     return Database::get(name, self.connection.clone())
189                //         .await?
190                //         .run_on_primary_replica(|database| {
191                //             let task = &task;
192                //             async move { task(database.connection().clone(), database.name().to_owned()).await }
193                //         })
194                //         .await
195                // }
196                err @ Err(Error::Connection(ConnectionError::ServerConnectionIsClosed)) => return err,
197                Err(err) => error_buffer.push(format!("- {}: {}", server_id, err)),
198            }
199        }
200        Err(ConnectionError::ServerConnectionFailedWithError { error: error_buffer.join("\n") })?
201    }
202}