Skip to main content

typedb_driver/connection/
connection.rs

1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements.  See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership.  The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License.  You may obtain a copy of the License at
9 *
10 *   http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied.  See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20use std::{
21    collections::{HashMap, HashSet},
22    fmt,
23    sync::{Arc, Mutex},
24    time::Duration,
25};
26
27use crossbeam::channel::Sender;
28use futures::future::join_all;
29use itertools::Itertools;
30use tokio::{
31    select,
32    sync::{
33        mpsc::{unbounded_channel as unbounded_async, UnboundedReceiver, UnboundedSender},
34        oneshot::{channel as oneshot_async, Sender as AsyncOneshotSender},
35    },
36    time::{sleep_until, Instant},
37};
38
39use super::{
40    network::transmitter::{RPCTransmitter, TransactionTransmitter},
41    runtime::BackgroundRuntime,
42    TransactionStream,
43};
44use crate::{
45    common::{
46        address::Address,
47        error::{ConnectionError, Error},
48        info::{DatabaseInfo, SessionInfo},
49        Callback, Result, SessionID, SessionType, TransactionType,
50    },
51    connection::message::{Request, Response, TransactionRequest},
52    error::InternalError,
53    user::User,
54    Credential, Options,
55};
56
57/// A connection to a TypeDB server which serves as the starting point for all interaction.
58#[derive(Clone)]
59pub struct Connection {
60    server_connections: HashMap<Address, ServerConnection>,
61    background_runtime: Arc<BackgroundRuntime>,
62    username: Option<String>,
63    is_cloud: bool,
64}
65
66impl Connection {
67    /// Creates a new TypeDB Server connection.
68    ///
69    /// # Arguments
70    ///
71    /// * `address` -- The address (host:port) on which the TypeDB Server is running
72    ///
73    /// # Examples
74    ///
75    /// ```rust
76    /// Connection::new_core("127.0.0.1:1729")
77    /// ```
78    pub fn new_core(address: impl AsRef<str>) -> Result<Self> {
79        let id = address.as_ref().to_string();
80        let address: Address = id.parse()?;
81        let background_runtime = Arc::new(BackgroundRuntime::new()?);
82        let server_connection = ServerConnection::new_core(background_runtime.clone(), address)?;
83
84        let advertised_address = server_connection
85            .servers_all()?
86            .into_iter()
87            .exactly_one()
88            .map_err(|e| ConnectionError::ServerConnectionFailedStatusError { error: e.to_string() })?;
89
90        match server_connection.validate() {
91            Ok(()) => Ok(Self {
92                server_connections: [(advertised_address, server_connection)].into(),
93                background_runtime,
94                username: None,
95                is_cloud: false,
96            }),
97            Err(err) => Err(err),
98        }
99    }
100
101    /// Creates a new TypeDB Cloud connection.
102    ///
103    /// # Arguments
104    ///
105    /// * `init_addresses` -- Addresses (host:port) on which TypeDB Cloud nodes are running
106    /// * `credential` -- User credential and TLS encryption setting
107    ///
108    /// # Examples
109    ///
110    /// ```rust
111    /// Connection::new_cloud(
112    ///     &["localhost:11729", "localhost:21729", "localhost:31729"],
113    ///     Credential::with_tls(
114    ///         "admin",
115    ///         "password",
116    ///         Some(&PathBuf::from(
117    ///             std::env::var("ROOT_CA")
118    ///                 .expect("ROOT_CA environment variable needs to be set for cloud tests to run"),
119    ///         )),
120    ///     )?,
121    /// )
122    /// ```
123    pub fn new_cloud<T: AsRef<str> + Sync>(init_addresses: &[T], credential: Credential) -> Result<Self> {
124        let background_runtime = Arc::new(BackgroundRuntime::new()?);
125        let servers = Self::fetch_server_list(background_runtime.clone(), init_addresses, credential.clone())?;
126        let server_to_address = servers.into_iter().map(|address| (address.clone(), address)).collect();
127        Self::new_cloud_impl(server_to_address, background_runtime, credential)
128    }
129
130    /// Creates a new TypeDB Cloud connection.
131    ///
132    /// # Arguments
133    ///
134    /// * `address_translation` -- Translation map from addresses to be used by the driver for connection
135    ///    to addresses received from the TypeDB server(s)
136    /// * `credential` -- User credential and TLS encryption setting
137    ///
138    /// # Examples
139    ///
140    /// ```rust
141    /// Connection::new_cloud_with_translation(
142    ///     [
143    ///         ("typedb-cloud.ext:11729", "localhost:11729"),
144    ///         ("typedb-cloud.ext:21729", "localhost:21729"),
145    ///         ("typedb-cloud.ext:31729", "localhost:31729"),
146    ///     ].into(),
147    ///     credential,
148    /// )
149    /// ```
150    pub fn new_cloud_with_translation<T, U>(address_translation: HashMap<T, U>, credential: Credential) -> Result<Self>
151    where
152        T: AsRef<str> + Sync,
153        U: AsRef<str> + Sync,
154    {
155        let background_runtime = Arc::new(BackgroundRuntime::new()?);
156
157        let fetched =
158            Self::fetch_server_list(background_runtime.clone(), address_translation.keys(), credential.clone())?;
159
160        let address_to_server: HashMap<Address, Address> = address_translation
161            .into_iter()
162            .map(|(public, private)| -> Result<_> { Ok((public.as_ref().parse()?, private.as_ref().parse()?)) })
163            .try_collect()?;
164
165        let provided: HashSet<Address> = address_to_server.values().cloned().collect();
166        let unknown = &provided - &fetched;
167        let unmapped = &fetched - &provided;
168        if !unknown.is_empty() || !unmapped.is_empty() {
169            return Err(ConnectionError::AddressTranslationMismatch { unknown, unmapped }.into());
170        }
171
172        debug_assert_eq!(fetched, provided);
173
174        Self::new_cloud_impl(address_to_server, background_runtime, credential)
175    }
176
177    fn new_cloud_impl(
178        address_to_server: HashMap<Address, Address>,
179        background_runtime: Arc<BackgroundRuntime>,
180        credential: Credential,
181    ) -> Result<Connection> {
182        let server_connections: HashMap<Address, ServerConnection> = address_to_server
183            .into_iter()
184            .map(|(public, private)| {
185                ServerConnection::new_cloud(background_runtime.clone(), public, credential.clone())
186                    .map(|server_connection| (private, server_connection))
187            })
188            .try_collect()?;
189
190        let errors = server_connections.values().map(|conn| conn.validate()).filter_map(Result::err).collect_vec();
191        if errors.len() == server_connections.len() {
192            Err(ConnectionError::CloudAllNodesFailed {
193                errors: errors.into_iter().map(|err| err.to_string()).join("\n"),
194            })?
195        } else {
196            Ok(Connection {
197                server_connections,
198                background_runtime,
199                username: Some(credential.username().to_owned()),
200                is_cloud: true,
201            })
202        }
203    }
204
205    fn fetch_server_list(
206        background_runtime: Arc<BackgroundRuntime>,
207        addresses: impl IntoIterator<Item = impl AsRef<str>> + Clone,
208        credential: Credential,
209    ) -> Result<HashSet<Address>> {
210        let addresses: Vec<Address> = addresses.into_iter().map(|addr| addr.as_ref().parse()).try_collect()?;
211        for address in &addresses {
212            let server_connection =
213                ServerConnection::new_cloud(background_runtime.clone(), address.clone(), credential.clone());
214            match server_connection {
215                Ok(server_connection) => match server_connection.servers_all() {
216                    Ok(servers) => return Ok(servers.into_iter().collect()),
217                    Err(Error::Connection(
218                        ConnectionError::ServerConnectionFailedStatusError { .. } | ConnectionError::ConnectionFailed,
219                    )) => (),
220                    Err(err) => Err(err)?,
221                },
222                Err(Error::Connection(
223                    ConnectionError::ServerConnectionFailedStatusError { .. } | ConnectionError::ConnectionFailed,
224                )) => (),
225                Err(err) => Err(err)?,
226            }
227        }
228        Err(ConnectionError::ServerConnectionFailed { addresses }.into())
229    }
230
231    /// Checks it this connection is opened.
232    //
233    /// # Examples
234    ///
235    /// ```rust
236    /// connection.is_open()
237    /// ```
238    pub fn is_open(&self) -> bool {
239        self.background_runtime.is_open()
240    }
241
242    /// Check if the connection is to an Cloud server.
243    ///
244    /// # Examples
245    ///
246    /// ```rust
247    /// connection.is_cloud()
248    /// ```
249    pub fn is_cloud(&self) -> bool {
250        self.is_cloud
251    }
252
253    /// Closes this connection.
254    ///
255    /// # Examples
256    ///
257    /// ```rust
258    /// connection.force_close()
259    /// ```
260    pub fn force_close(&self) -> Result {
261        let result =
262            self.server_connections.values().map(ServerConnection::force_close).try_collect().map_err(Into::into);
263        self.background_runtime.force_close().and(result)
264    }
265
266    pub(crate) fn server_count(&self) -> usize {
267        self.server_connections.len()
268    }
269
270    pub(crate) fn servers(&self) -> impl Iterator<Item = &Address> {
271        self.server_connections.keys()
272    }
273
274    pub(crate) fn connection(&self, id: &Address) -> Option<&ServerConnection> {
275        self.server_connections.get(id)
276    }
277
278    pub(crate) fn connections(&self) -> impl Iterator<Item = (&Address, &ServerConnection)> + '_ {
279        self.server_connections.iter()
280    }
281
282    pub(crate) fn username(&self) -> Option<&str> {
283        self.username.as_deref()
284    }
285
286    pub(crate) fn unable_to_connect_error(&self) -> Error {
287        Error::Connection(ConnectionError::ServerConnectionFailed {
288            addresses: self.servers().map(Address::clone).collect_vec(),
289        })
290    }
291}
292
293impl fmt::Debug for Connection {
294    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
295        f.debug_struct("Connection").field("server_connections", &self.server_connections).finish()
296    }
297}
298
299#[derive(Clone)]
300pub(crate) struct ServerConnection {
301    background_runtime: Arc<BackgroundRuntime>,
302    open_sessions: Arc<Mutex<HashMap<SessionID, UnboundedSender<()>>>>,
303    request_transmitter: Arc<RPCTransmitter>,
304}
305
306impl ServerConnection {
307    fn new_core(background_runtime: Arc<BackgroundRuntime>, address: Address) -> Result<Self> {
308        let request_transmitter = Arc::new(RPCTransmitter::start_core(address, &background_runtime)?);
309        Ok(Self { background_runtime, open_sessions: Default::default(), request_transmitter })
310    }
311
312    fn new_cloud(background_runtime: Arc<BackgroundRuntime>, address: Address, credential: Credential) -> Result<Self> {
313        let request_transmitter = Arc::new(RPCTransmitter::start_cloud(address, credential, &background_runtime)?);
314        Ok(Self { background_runtime, open_sessions: Default::default(), request_transmitter })
315    }
316
317    pub(crate) fn validate(&self) -> Result {
318        match self.request_blocking(Request::ConnectionOpen)? {
319            Response::ConnectionOpen => Ok(()),
320            other => Err(ConnectionError::UnexpectedResponse { response: format!("{other:?}") }.into()),
321        }
322    }
323
324    #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
325    async fn request(&self, request: Request) -> Result<Response> {
326        if !self.background_runtime.is_open() {
327            return Err(ConnectionError::ConnectionIsClosed.into());
328        }
329        self.request_transmitter.request(request).await
330    }
331
332    fn request_blocking(&self, request: Request) -> Result<Response> {
333        if !self.background_runtime.is_open() {
334            return Err(ConnectionError::ConnectionIsClosed.into());
335        }
336        self.request_transmitter.request_blocking(request)
337    }
338
339    pub(crate) fn force_close(&self) -> Result {
340        let session_ids: Vec<SessionID> = self.open_sessions.lock().unwrap().keys().cloned().collect();
341        for session_id in session_ids {
342            self.close_session(session_id).ok();
343        }
344        self.request_transmitter.force_close()
345    }
346
347    pub(crate) fn servers_all(&self) -> Result<Vec<Address>> {
348        match self.request_blocking(Request::ServersAll)? {
349            Response::ServersAll { servers } => Ok(servers),
350            other => Err(InternalError::UnexpectedResponseType { response_type: format!("{other:?}") }.into()),
351        }
352    }
353
354    #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
355    pub(crate) async fn database_exists(&self, database_name: String) -> Result<bool> {
356        match self.request(Request::DatabasesContains { database_name }).await? {
357            Response::DatabasesContains { contains } => Ok(contains),
358            other => Err(InternalError::UnexpectedResponseType { response_type: format!("{other:?}") }.into()),
359        }
360    }
361
362    #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
363    pub(crate) async fn create_database(&self, database_name: String) -> Result {
364        self.request(Request::DatabaseCreate { database_name }).await?;
365        Ok(())
366    }
367
368    #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
369    pub(crate) async fn get_database_replicas(&self, database_name: String) -> Result<DatabaseInfo> {
370        match self.request(Request::DatabaseGet { database_name }).await? {
371            Response::DatabaseGet { database } => Ok(database),
372            other => Err(InternalError::UnexpectedResponseType { response_type: format!("{other:?}") }.into()),
373        }
374    }
375
376    #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
377    pub(crate) async fn all_databases(&self) -> Result<Vec<DatabaseInfo>> {
378        match self.request(Request::DatabasesAll).await? {
379            Response::DatabasesAll { databases } => Ok(databases),
380            other => Err(InternalError::UnexpectedResponseType { response_type: format!("{other:?}") }.into()),
381        }
382    }
383
384    #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
385    pub(crate) async fn database_schema(&self, database_name: String) -> Result<String> {
386        match self.request(Request::DatabaseSchema { database_name }).await? {
387            Response::DatabaseSchema { schema } => Ok(schema),
388            other => Err(InternalError::UnexpectedResponseType { response_type: format!("{other:?}") }.into()),
389        }
390    }
391
392    #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
393    pub(crate) async fn database_type_schema(&self, database_name: String) -> Result<String> {
394        match self.request(Request::DatabaseTypeSchema { database_name }).await? {
395            Response::DatabaseTypeSchema { schema } => Ok(schema),
396            other => Err(InternalError::UnexpectedResponseType { response_type: format!("{other:?}") }.into()),
397        }
398    }
399
400    #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
401    pub(crate) async fn database_rule_schema(&self, database_name: String) -> Result<String> {
402        match self.request(Request::DatabaseRuleSchema { database_name }).await? {
403            Response::DatabaseRuleSchema { schema } => Ok(schema),
404            other => Err(InternalError::UnexpectedResponseType { response_type: format!("{other:?}") }.into()),
405        }
406    }
407
408    #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
409    pub(crate) async fn delete_database(&self, database_name: String) -> Result {
410        self.request(Request::DatabaseDelete { database_name }).await?;
411        Ok(())
412    }
413
414    #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
415    pub(crate) async fn open_session(
416        &self,
417        database_name: String,
418        session_type: SessionType,
419        options: Options,
420    ) -> Result<SessionInfo> {
421        let start = Instant::now();
422        match self.request(Request::SessionOpen { database_name, session_type, options }).await? {
423            Response::SessionOpen { session_id, server_duration } => {
424                let (on_close_register_sink, on_close_register_source) = unbounded_async();
425                let (pulse_shutdown_sink, pulse_shutdown_source) = unbounded_async();
426                self.open_sessions.lock().unwrap().insert(session_id.clone(), pulse_shutdown_sink);
427                self.background_runtime.spawn(session_pulse(
428                    session_id.clone(),
429                    self.request_transmitter.clone(),
430                    on_close_register_source,
431                    self.background_runtime.callback_handler_sink(),
432                    pulse_shutdown_source,
433                ));
434                Ok(SessionInfo {
435                    session_id,
436                    network_latency: start.elapsed().saturating_sub(server_duration),
437                    on_close_register_sink,
438                })
439            }
440            other => Err(InternalError::UnexpectedResponseType { response_type: format!("{other:?}") }.into()),
441        }
442    }
443
444    pub(crate) fn close_session(&self, session_id: SessionID) -> Result {
445        if let Some(sink) = self.open_sessions.lock().unwrap().remove(&session_id) {
446            sink.send(()).ok();
447        }
448        self.request_blocking(Request::SessionClose { session_id })?;
449        Ok(())
450    }
451
452    #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
453    pub(crate) async fn open_transaction(
454        &self,
455        session_id: SessionID,
456        transaction_type: TransactionType,
457        options: Options,
458        network_latency: Duration,
459    ) -> Result<(TransactionStream, UnboundedSender<()>)> {
460        match self
461            .request(Request::Transaction(TransactionRequest::Open {
462                session_id,
463                transaction_type,
464                options,
465                network_latency,
466            }))
467            .await?
468        {
469            Response::TransactionOpen { request_sink, response_source } => {
470                let transmitter = TransactionTransmitter::new(
471                    &self.background_runtime,
472                    request_sink,
473                    response_source,
474                    self.background_runtime.callback_handler_sink(),
475                );
476                let transmitter_shutdown_sink = transmitter.shutdown_sink().clone();
477                let transaction_stream = TransactionStream::new(transaction_type, options, transmitter);
478                Ok((transaction_stream, transmitter_shutdown_sink))
479            }
480            other => Err(InternalError::UnexpectedResponseType { response_type: format!("{other:?}") }.into()),
481        }
482    }
483
484    #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
485    pub(crate) async fn all_users(&self) -> Result<Vec<User>> {
486        match self.request(Request::UsersAll).await? {
487            Response::UsersAll { users } => Ok(users),
488            other => Err(InternalError::UnexpectedResponseType { response_type: format!("{other:?}") }.into()),
489        }
490    }
491
492    #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
493    pub(crate) async fn contains_user(&self, username: String) -> Result<bool> {
494        match self.request(Request::UsersContain { username }).await? {
495            Response::UsersContain { contains } => Ok(contains),
496            other => Err(InternalError::UnexpectedResponseType { response_type: format!("{other:?}") }.into()),
497        }
498    }
499
500    #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
501    pub(crate) async fn create_user(&self, username: String, password: String) -> Result {
502        match self.request(Request::UsersCreate { username, password }).await? {
503            Response::UsersCreate => Ok(()),
504            other => Err(InternalError::UnexpectedResponseType { response_type: format!("{other:?}") }.into()),
505        }
506    }
507
508    #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
509    pub(crate) async fn delete_user(&self, username: String) -> Result {
510        match self.request(Request::UsersDelete { username }).await? {
511            Response::UsersDelete => Ok(()),
512            other => Err(InternalError::UnexpectedResponseType { response_type: format!("{other:?}") }.into()),
513        }
514    }
515
516    #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
517    pub(crate) async fn get_user(&self, username: String) -> Result<Option<User>> {
518        match self.request(Request::UsersGet { username }).await? {
519            Response::UsersGet { user } => Ok(user),
520            other => Err(InternalError::UnexpectedResponseType { response_type: format!("{other:?}") }.into()),
521        }
522    }
523
524    #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
525    pub(crate) async fn set_user_password(&self, username: String, password: String) -> Result {
526        match self.request(Request::UsersPasswordSet { username, password }).await? {
527            Response::UsersPasswordSet => Ok(()),
528            other => Err(InternalError::UnexpectedResponseType { response_type: format!("{other:?}") }.into()),
529        }
530    }
531
532    #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
533    pub(crate) async fn update_user_password(
534        &self,
535        username: String,
536        password_old: String,
537        password_new: String,
538    ) -> Result {
539        match self.request(Request::UserPasswordUpdate { username, password_old, password_new }).await? {
540            Response::UserPasswordUpdate => Ok(()),
541            other => Err(InternalError::UnexpectedResponseType { response_type: format!("{other:?}") }.into()),
542        }
543    }
544}
545
546impl fmt::Debug for ServerConnection {
547    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
548        f.debug_struct("ServerConnection").field("open_sessions", &self.open_sessions).finish()
549    }
550}
551
552async fn session_pulse(
553    session_id: SessionID,
554    request_transmitter: Arc<RPCTransmitter>,
555    mut on_close_callback_source: UnboundedReceiver<Callback>,
556    callback_handler_sink: Sender<(Callback, AsyncOneshotSender<()>)>,
557    mut shutdown_source: UnboundedReceiver<()>,
558) {
559    const PULSE_INTERVAL: Duration = Duration::from_secs(5);
560    let mut next_pulse = Instant::now();
561    let mut on_close = Vec::new();
562    loop {
563        select! {
564            _ = sleep_until(next_pulse) => {
565                let session_id = session_id.clone();
566                match request_transmitter.request_async(Request::SessionPulse { session_id }).await {
567                    Ok(Response::SessionPulse { is_alive: true }) => {
568                        next_pulse = (next_pulse + PULSE_INTERVAL).max(Instant::now())
569                    }
570                    _ => break,
571                }
572            }
573            callback = on_close_callback_source.recv() => {
574                if let Some(callback) = callback {
575                    on_close.push(callback)
576                }
577            }
578            _ = shutdown_source.recv() => break,
579        }
580    }
581
582    join_all(on_close.into_iter().map(|callback| {
583        let (response_sink, response) = oneshot_async();
584        callback_handler_sink.send((Box::new(callback), response_sink)).unwrap();
585        response
586    }))
587    .await;
588}