1use std::{
21 collections::{HashMap, HashSet},
22 fmt,
23 sync::Arc,
24};
25
26use itertools::Itertools;
27
28use crate::{
29 common::{
30 address::Address,
31 error::{ConnectionError, Error},
32 Result,
33 },
34 connection::{runtime::BackgroundRuntime, server_connection::ServerConnection},
35 Credentials, DatabaseManager, DriverOptions, Transaction, TransactionOptions, TransactionType, UserManager,
36};
37
38pub struct TypeDBDriver {
40 server_connections: HashMap<Address, ServerConnection>,
41 database_manager: DatabaseManager,
42 user_manager: UserManager,
43 background_runtime: Arc<BackgroundRuntime>,
44}
45
46impl TypeDBDriver {
47 const DRIVER_LANG: &'static str = "rust";
48 const VERSION: &'static str = match option_env!("CARGO_PKG_VERSION") {
49 None => "0.0.0",
50 Some(version) => version,
51 };
52
53 pub const DEFAULT_ADDRESS: &'static str = "localhost:1729";
54
55 #[cfg_attr(
67 feature = "sync",
68 doc = "TypeDBDriver::new(\"127.0.0.1:1729\", Credentials::new(\"username\", \"password\"), DriverOptions::new(true, None))"
69 )]
70 #[cfg_attr(
71 not(feature = "sync"),
72 doc = "TypeDBDriver::new(\"127.0.0.1:1729\", Credentials::new(\"username\", \"password\"), DriverOptions::new(true, None)).await"
73 )]
74 #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
76 pub async fn new(
77 address: impl AsRef<str>,
78 credentials: Credentials,
79 driver_options: DriverOptions,
80 ) -> Result<Self> {
81 Self::new_with_description(address, credentials, driver_options, Self::DRIVER_LANG).await
82 }
83
84 #[cfg_attr(
99 feature = "sync",
100 doc = "TypeDBDriver::new_with_description(\"127.0.0.1:1729\", Credentials::new(\"username\", \"password\"), DriverOptions::new(true, None), \"rust\")"
101 )]
102 #[cfg_attr(
103 not(feature = "sync"),
104 doc = "TypeDBDriver::new_with_description(\"127.0.0.1:1729\", Credentials::new(\"username\", \"password\"), DriverOptions::new(true, None), \"rust\").await"
105 )]
106 #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
108 pub async fn new_with_description(
109 address: impl AsRef<str>,
110 credentials: Credentials,
111 driver_options: DriverOptions,
112 driver_lang: impl AsRef<str>,
113 ) -> Result<Self> {
114 let id = address.as_ref().to_string();
115 let address: Address = id.parse()?;
116 let background_runtime = Arc::new(BackgroundRuntime::new()?);
117
118 let (server_connection, database_info) = ServerConnection::new(
119 background_runtime.clone(),
120 address.clone(),
121 credentials,
122 driver_options,
123 driver_lang.as_ref(),
124 Self::VERSION,
125 )
126 .await?;
127
128 let server_connections: HashMap<Address, ServerConnection> = [(address, server_connection)].into();
137 let database_manager = DatabaseManager::new(server_connections.clone(), database_info)?;
138 let user_manager = UserManager::new(server_connections.clone());
139
140 Ok(Self { server_connections, database_manager, user_manager, background_runtime })
141 }
142
143 #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
144 async fn fetch_server_list(
145 background_runtime: Arc<BackgroundRuntime>,
146 addresses: impl IntoIterator<Item = impl AsRef<str>> + Clone,
147 credentials: Credentials,
148 driver_options: DriverOptions,
149 ) -> Result<HashSet<Address>> {
150 let addresses: Vec<Address> = addresses.into_iter().map(|addr| addr.as_ref().parse()).try_collect()?;
151 for address in &addresses {
152 let server_connection = ServerConnection::new(
153 background_runtime.clone(),
154 address.clone(),
155 credentials.clone(),
156 driver_options.clone(),
157 Self::DRIVER_LANG,
158 Self::VERSION,
159 )
160 .await;
161 match server_connection {
162 Ok((server_connection, _)) => match server_connection.servers_all() {
163 Ok(servers) => return Ok(servers.into_iter().collect()),
164 Err(Error::Connection(
165 ConnectionError::ServerConnectionFailedStatusError { .. } | ConnectionError::ConnectionFailed,
166 )) => (),
167 Err(err) => Err(err)?,
168 },
169 Err(Error::Connection(
170 ConnectionError::ServerConnectionFailedStatusError { .. } | ConnectionError::ConnectionFailed,
171 )) => (),
172 Err(err) => Err(err)?,
173 }
174 }
175 Err(ConnectionError::ServerConnectionFailed { addresses }.into())
176 }
177
178 pub fn is_open(&self) -> bool {
186 self.background_runtime.is_open()
187 }
188
189 pub fn databases(&self) -> &DatabaseManager {
190 &self.database_manager
191 }
192
193 pub fn users(&self) -> &UserManager {
194 &self.user_manager
195 }
196
197 #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
200 pub async fn transaction(
201 &self,
202 database_name: impl AsRef<str>,
203 transaction_type: TransactionType,
204 ) -> Result<Transaction> {
205 self.transaction_with_options(database_name, transaction_type, TransactionOptions::new()).await
206 }
207
208 #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
222 pub async fn transaction_with_options(
223 &self,
224 database_name: impl AsRef<str>,
225 transaction_type: TransactionType,
226 options: TransactionOptions,
227 ) -> Result<Transaction> {
228 let database_name = database_name.as_ref();
229 let database = self.database_manager.get_cached_or_fetch(database_name).await?;
230 let transaction_stream = database
231 .run_failsafe(|database| async move {
232 database.connection().open_transaction(database.name(), transaction_type, options).await
233 })
234 .await?;
235 Ok(Transaction::new(transaction_stream))
236 }
237
238 pub fn force_close(&self) -> Result {
246 if !self.is_open() {
247 return Ok(());
248 }
249
250 let result =
251 self.server_connections.values().map(ServerConnection::force_close).try_collect().map_err(Into::into);
252 self.background_runtime.force_close().and(result)
253 }
254
255 pub(crate) fn server_count(&self) -> usize {
256 self.server_connections.len()
257 }
258
259 pub(crate) fn servers(&self) -> impl Iterator<Item = &Address> {
260 self.server_connections.keys()
261 }
262
263 pub(crate) fn connection(&self, id: &Address) -> Option<&ServerConnection> {
264 self.server_connections.get(id)
265 }
266
267 pub(crate) fn connections(&self) -> impl Iterator<Item = (&Address, &ServerConnection)> + '_ {
268 self.server_connections.iter()
269 }
270
271 pub(crate) fn unable_to_connect_error(&self) -> Error {
272 Error::Connection(ConnectionError::ServerConnectionFailed {
273 addresses: self.servers().map(Address::clone).collect_vec(),
274 })
275 }
276}
277
278impl fmt::Debug for TypeDBDriver {
279 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
280 f.debug_struct("Connection").field("server_connections", &self.server_connections).finish()
281 }
282}