typedb_driver/
driver.rs

1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements.  See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership.  The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License.  You may obtain a copy of the License at
9 *
10 *   http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied.  See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20use std::{
21    collections::{HashMap, HashSet},
22    fmt,
23    sync::Arc,
24};
25
26use itertools::Itertools;
27
28use crate::{
29    common::{
30        address::Address,
31        error::{ConnectionError, Error},
32        Result,
33    },
34    connection::{runtime::BackgroundRuntime, server_connection::ServerConnection},
35    Credentials, DatabaseManager, DriverOptions, Transaction, TransactionOptions, TransactionType, UserManager,
36};
37
38/// A connection to a TypeDB server which serves as the starting point for all interaction.
39pub struct TypeDBDriver {
40    server_connections: HashMap<Address, ServerConnection>,
41    database_manager: DatabaseManager,
42    user_manager: UserManager,
43    background_runtime: Arc<BackgroundRuntime>,
44}
45
46impl TypeDBDriver {
47    const DRIVER_LANG: &'static str = "rust";
48    const VERSION: &'static str = match option_env!("CARGO_PKG_VERSION") {
49        None => "0.0.0",
50        Some(version) => version,
51    };
52
53    pub const DEFAULT_ADDRESS: &'static str = "localhost:1729";
54
55    /// Creates a new TypeDB Server connection.
56    ///
57    /// # Arguments
58    ///
59    /// * `address` — The address (host:port) on which the TypeDB Server is running
60    /// * `credentials` — The Credentials to connect with
61    /// * `driver_options` — The DriverOptions to connect with
62    ///
63    /// # Examples
64    ///
65    /// ```rust
66    #[cfg_attr(
67        feature = "sync",
68        doc = "TypeDBDriver::new(\"127.0.0.1:1729\", Credentials::new(\"username\", \"password\"), DriverOptions::new(true, None))"
69    )]
70    #[cfg_attr(
71        not(feature = "sync"),
72        doc = "TypeDBDriver::new(\"127.0.0.1:1729\", Credentials::new(\"username\", \"password\"), DriverOptions::new(true, None)).await"
73    )]
74    /// ```
75    #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
76    pub async fn new(
77        address: impl AsRef<str>,
78        credentials: Credentials,
79        driver_options: DriverOptions,
80    ) -> Result<Self> {
81        Self::new_with_description(address, credentials, driver_options, Self::DRIVER_LANG).await
82    }
83
84    /// Creates a new TypeDB Server connection with a description.
85    /// This method is generally used by TypeDB drivers built on top of the Rust driver.
86    /// In other cases, use [`Self::new`] instead.
87    ///
88    /// # Arguments
89    ///
90    /// * `address` — The address (host:port) on which the TypeDB Server is running
91    /// * `credentials` — The Credentials to connect with
92    /// * `driver_options` — The DriverOptions to connect with
93    /// * `driver_lang` — The language of the driver connecting to the server
94    ///
95    /// # Examples
96    ///
97    /// ```rust
98    #[cfg_attr(
99        feature = "sync",
100        doc = "TypeDBDriver::new_with_description(\"127.0.0.1:1729\", Credentials::new(\"username\", \"password\"), DriverOptions::new(true, None), \"rust\")"
101    )]
102    #[cfg_attr(
103        not(feature = "sync"),
104        doc = "TypeDBDriver::new_with_description(\"127.0.0.1:1729\", Credentials::new(\"username\", \"password\"), DriverOptions::new(true, None), \"rust\").await"
105    )]
106    /// ```
107    #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
108    pub async fn new_with_description(
109        address: impl AsRef<str>,
110        credentials: Credentials,
111        driver_options: DriverOptions,
112        driver_lang: impl AsRef<str>,
113    ) -> Result<Self> {
114        let id = address.as_ref().to_string();
115        let address: Address = id.parse()?;
116        let background_runtime = Arc::new(BackgroundRuntime::new()?);
117
118        let (server_connection, database_info) = ServerConnection::new(
119            background_runtime.clone(),
120            address.clone(),
121            credentials,
122            driver_options,
123            driver_lang.as_ref(),
124            Self::VERSION,
125        )
126        .await?;
127
128        // // validate
129        // let advertised_address = server_connection
130        //     .servers_all()?
131        //     .into_iter()
132        //     .exactly_one()
133        //     .map_err(|e| ConnectionError::ServerConnectionFailedStatusError { error: e.to_string() })?;
134
135        // TODO: this solidifies the assumption that servers don't change
136        let server_connections: HashMap<Address, ServerConnection> = [(address, server_connection)].into();
137        let database_manager = DatabaseManager::new(server_connections.clone(), database_info)?;
138        let user_manager = UserManager::new(server_connections.clone());
139
140        Ok(Self { server_connections, database_manager, user_manager, background_runtime })
141    }
142
143    #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
144    async fn fetch_server_list(
145        background_runtime: Arc<BackgroundRuntime>,
146        addresses: impl IntoIterator<Item = impl AsRef<str>> + Clone,
147        credentials: Credentials,
148        driver_options: DriverOptions,
149    ) -> Result<HashSet<Address>> {
150        let addresses: Vec<Address> = addresses.into_iter().map(|addr| addr.as_ref().parse()).try_collect()?;
151        for address in &addresses {
152            let server_connection = ServerConnection::new(
153                background_runtime.clone(),
154                address.clone(),
155                credentials.clone(),
156                driver_options.clone(),
157                Self::DRIVER_LANG,
158                Self::VERSION,
159            )
160            .await;
161            match server_connection {
162                Ok((server_connection, _)) => match server_connection.servers_all() {
163                    Ok(servers) => return Ok(servers.into_iter().collect()),
164                    Err(Error::Connection(
165                        ConnectionError::ServerConnectionFailedStatusError { .. } | ConnectionError::ConnectionFailed,
166                    )) => (),
167                    Err(err) => Err(err)?,
168                },
169                Err(Error::Connection(
170                    ConnectionError::ServerConnectionFailedStatusError { .. } | ConnectionError::ConnectionFailed,
171                )) => (),
172                Err(err) => Err(err)?,
173            }
174        }
175        Err(ConnectionError::ServerConnectionFailed { addresses }.into())
176    }
177
178    /// Checks it this connection is opened.
179    //
180    /// # Examples
181    ///
182    /// ```rust
183    /// driver.is_open()
184    /// ```
185    pub fn is_open(&self) -> bool {
186        self.background_runtime.is_open()
187    }
188
189    pub fn databases(&self) -> &DatabaseManager {
190        &self.database_manager
191    }
192
193    pub fn users(&self) -> &UserManager {
194        &self.user_manager
195    }
196
197    /// Opens a transaction with default options.
198    /// See [`TypeDBDriver::transaction_with_options`]
199    #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
200    pub async fn transaction(
201        &self,
202        database_name: impl AsRef<str>,
203        transaction_type: TransactionType,
204    ) -> Result<Transaction> {
205        self.transaction_with_options(database_name, transaction_type, TransactionOptions::new()).await
206    }
207
208    /// Performs a TypeQL query in this transaction.
209    ///
210    /// # Arguments
211    ///
212    /// * `database_name` — The name of the database to connect to
213    /// * `transaction_type` — The TransactionType to open the transaction with
214    /// * `options` — The TransactionOptions to open the transaction with
215    ///
216    /// # Examples
217    ///
218    /// ```rust
219    /// transaction.transaction_with_options(database_name, transaction_type, options)
220    /// ```
221    #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
222    pub async fn transaction_with_options(
223        &self,
224        database_name: impl AsRef<str>,
225        transaction_type: TransactionType,
226        options: TransactionOptions,
227    ) -> Result<Transaction> {
228        let database_name = database_name.as_ref();
229        let database = self.database_manager.get_cached_or_fetch(database_name).await?;
230        let transaction_stream = database
231            .run_failsafe(|database| async move {
232                database.connection().open_transaction(database.name(), transaction_type, options).await
233            })
234            .await?;
235        Ok(Transaction::new(transaction_stream))
236    }
237
238    /// Closes this connection if it is open.
239    ///
240    /// # Examples
241    ///
242    /// ```rust
243    /// driver.force_close()
244    /// ```
245    pub fn force_close(&self) -> Result {
246        if !self.is_open() {
247            return Ok(());
248        }
249
250        let result =
251            self.server_connections.values().map(ServerConnection::force_close).try_collect().map_err(Into::into);
252        self.background_runtime.force_close().and(result)
253    }
254
255    pub(crate) fn server_count(&self) -> usize {
256        self.server_connections.len()
257    }
258
259    pub(crate) fn servers(&self) -> impl Iterator<Item = &Address> {
260        self.server_connections.keys()
261    }
262
263    pub(crate) fn connection(&self, id: &Address) -> Option<&ServerConnection> {
264        self.server_connections.get(id)
265    }
266
267    pub(crate) fn connections(&self) -> impl Iterator<Item = (&Address, &ServerConnection)> + '_ {
268        self.server_connections.iter()
269    }
270
271    pub(crate) fn unable_to_connect_error(&self) -> Error {
272        Error::Connection(ConnectionError::ServerConnectionFailed {
273            addresses: self.servers().map(Address::clone).collect_vec(),
274        })
275    }
276}
277
278impl fmt::Debug for TypeDBDriver {
279    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
280        f.debug_struct("Connection").field("server_connections", &self.server_connections).finish()
281    }
282}