1use std::{
21 collections::{HashMap, HashSet},
22 fmt,
23 sync::{Arc, Mutex},
24 time::Duration,
25};
26
27use crossbeam::channel::Sender;
28use futures::future::join_all;
29use itertools::Itertools;
30use tokio::{
31 select,
32 sync::{
33 mpsc::{unbounded_channel as unbounded_async, UnboundedReceiver, UnboundedSender},
34 oneshot::{channel as oneshot_async, Sender as AsyncOneshotSender},
35 },
36 time::{sleep_until, Instant},
37};
38
39use super::{
40 network::transmitter::{RPCTransmitter, TransactionTransmitter},
41 runtime::BackgroundRuntime,
42 TransactionStream,
43};
44use crate::{
45 common::{
46 address::Address,
47 error::{ConnectionError, Error},
48 info::{DatabaseInfo, SessionInfo},
49 Callback, Result, SessionID, SessionType, TransactionType,
50 },
51 connection::message::{Request, Response, TransactionRequest},
52 error::InternalError,
53 user::User,
54 Credential, Options,
55};
56
57#[derive(Clone)]
59pub struct Connection {
60 server_connections: HashMap<Address, ServerConnection>,
61 background_runtime: Arc<BackgroundRuntime>,
62 username: Option<String>,
63 is_cloud: bool,
64}
65
66impl Connection {
67 pub fn new_core(address: impl AsRef<str>) -> Result<Self> {
79 let id = address.as_ref().to_string();
80 let address: Address = id.parse()?;
81 let background_runtime = Arc::new(BackgroundRuntime::new()?);
82 let server_connection = ServerConnection::new_core(background_runtime.clone(), address)?;
83
84 let advertised_address = server_connection
85 .servers_all()?
86 .into_iter()
87 .exactly_one()
88 .map_err(|e| ConnectionError::ServerConnectionFailedStatusError { error: e.to_string() })?;
89
90 match server_connection.validate() {
91 Ok(()) => Ok(Self {
92 server_connections: [(advertised_address, server_connection)].into(),
93 background_runtime,
94 username: None,
95 is_cloud: false,
96 }),
97 Err(err) => Err(err),
98 }
99 }
100
101 pub fn new_cloud<T: AsRef<str> + Sync>(init_addresses: &[T], credential: Credential) -> Result<Self> {
124 let background_runtime = Arc::new(BackgroundRuntime::new()?);
125 let servers = Self::fetch_server_list(background_runtime.clone(), init_addresses, credential.clone())?;
126 let server_to_address = servers.into_iter().map(|address| (address.clone(), address)).collect();
127 Self::new_cloud_impl(server_to_address, background_runtime, credential)
128 }
129
130 pub fn new_cloud_with_translation<T, U>(address_translation: HashMap<T, U>, credential: Credential) -> Result<Self>
151 where
152 T: AsRef<str> + Sync,
153 U: AsRef<str> + Sync,
154 {
155 let background_runtime = Arc::new(BackgroundRuntime::new()?);
156
157 let fetched =
158 Self::fetch_server_list(background_runtime.clone(), address_translation.keys(), credential.clone())?;
159
160 let address_to_server: HashMap<Address, Address> = address_translation
161 .into_iter()
162 .map(|(public, private)| -> Result<_> { Ok((public.as_ref().parse()?, private.as_ref().parse()?)) })
163 .try_collect()?;
164
165 let provided: HashSet<Address> = address_to_server.values().cloned().collect();
166 let unknown = &provided - &fetched;
167 let unmapped = &fetched - &provided;
168 if !unknown.is_empty() || !unmapped.is_empty() {
169 return Err(ConnectionError::AddressTranslationMismatch { unknown, unmapped }.into());
170 }
171
172 debug_assert_eq!(fetched, provided);
173
174 Self::new_cloud_impl(address_to_server, background_runtime, credential)
175 }
176
177 fn new_cloud_impl(
178 address_to_server: HashMap<Address, Address>,
179 background_runtime: Arc<BackgroundRuntime>,
180 credential: Credential,
181 ) -> Result<Connection> {
182 let server_connections: HashMap<Address, ServerConnection> = address_to_server
183 .into_iter()
184 .map(|(public, private)| {
185 ServerConnection::new_cloud(background_runtime.clone(), public, credential.clone())
186 .map(|server_connection| (private, server_connection))
187 })
188 .try_collect()?;
189
190 let errors = server_connections.values().map(|conn| conn.validate()).filter_map(Result::err).collect_vec();
191 if errors.len() == server_connections.len() {
192 Err(ConnectionError::CloudAllNodesFailed {
193 errors: errors.into_iter().map(|err| err.to_string()).join("\n"),
194 })?
195 } else {
196 Ok(Connection {
197 server_connections,
198 background_runtime,
199 username: Some(credential.username().to_owned()),
200 is_cloud: true,
201 })
202 }
203 }
204
205 fn fetch_server_list(
206 background_runtime: Arc<BackgroundRuntime>,
207 addresses: impl IntoIterator<Item = impl AsRef<str>> + Clone,
208 credential: Credential,
209 ) -> Result<HashSet<Address>> {
210 let addresses: Vec<Address> = addresses.into_iter().map(|addr| addr.as_ref().parse()).try_collect()?;
211 for address in &addresses {
212 let server_connection =
213 ServerConnection::new_cloud(background_runtime.clone(), address.clone(), credential.clone());
214 match server_connection {
215 Ok(server_connection) => match server_connection.servers_all() {
216 Ok(servers) => return Ok(servers.into_iter().collect()),
217 Err(Error::Connection(
218 ConnectionError::ServerConnectionFailedStatusError { .. } | ConnectionError::ConnectionFailed,
219 )) => (),
220 Err(err) => Err(err)?,
221 },
222 Err(Error::Connection(
223 ConnectionError::ServerConnectionFailedStatusError { .. } | ConnectionError::ConnectionFailed,
224 )) => (),
225 Err(err) => Err(err)?,
226 }
227 }
228 Err(ConnectionError::ServerConnectionFailed { addresses }.into())
229 }
230
231 pub fn is_open(&self) -> bool {
239 self.background_runtime.is_open()
240 }
241
242 pub fn is_cloud(&self) -> bool {
250 self.is_cloud
251 }
252
253 pub fn force_close(&self) -> Result {
261 let result =
262 self.server_connections.values().map(ServerConnection::force_close).try_collect().map_err(Into::into);
263 self.background_runtime.force_close().and(result)
264 }
265
266 pub(crate) fn server_count(&self) -> usize {
267 self.server_connections.len()
268 }
269
270 pub(crate) fn servers(&self) -> impl Iterator<Item = &Address> {
271 self.server_connections.keys()
272 }
273
274 pub(crate) fn connection(&self, id: &Address) -> Option<&ServerConnection> {
275 self.server_connections.get(id)
276 }
277
278 pub(crate) fn connections(&self) -> impl Iterator<Item = (&Address, &ServerConnection)> + '_ {
279 self.server_connections.iter()
280 }
281
282 pub(crate) fn username(&self) -> Option<&str> {
283 self.username.as_deref()
284 }
285
286 pub(crate) fn unable_to_connect_error(&self) -> Error {
287 Error::Connection(ConnectionError::ServerConnectionFailed {
288 addresses: self.servers().map(Address::clone).collect_vec(),
289 })
290 }
291}
292
293impl fmt::Debug for Connection {
294 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
295 f.debug_struct("Connection").field("server_connections", &self.server_connections).finish()
296 }
297}
298
299#[derive(Clone)]
300pub(crate) struct ServerConnection {
301 background_runtime: Arc<BackgroundRuntime>,
302 open_sessions: Arc<Mutex<HashMap<SessionID, UnboundedSender<()>>>>,
303 request_transmitter: Arc<RPCTransmitter>,
304}
305
306impl ServerConnection {
307 fn new_core(background_runtime: Arc<BackgroundRuntime>, address: Address) -> Result<Self> {
308 let request_transmitter = Arc::new(RPCTransmitter::start_core(address, &background_runtime)?);
309 Ok(Self { background_runtime, open_sessions: Default::default(), request_transmitter })
310 }
311
312 fn new_cloud(background_runtime: Arc<BackgroundRuntime>, address: Address, credential: Credential) -> Result<Self> {
313 let request_transmitter = Arc::new(RPCTransmitter::start_cloud(address, credential, &background_runtime)?);
314 Ok(Self { background_runtime, open_sessions: Default::default(), request_transmitter })
315 }
316
317 pub(crate) fn validate(&self) -> Result {
318 match self.request_blocking(Request::ConnectionOpen)? {
319 Response::ConnectionOpen => Ok(()),
320 other => Err(ConnectionError::UnexpectedResponse { response: format!("{other:?}") }.into()),
321 }
322 }
323
324 #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
325 async fn request(&self, request: Request) -> Result<Response> {
326 if !self.background_runtime.is_open() {
327 return Err(ConnectionError::ConnectionIsClosed.into());
328 }
329 self.request_transmitter.request(request).await
330 }
331
332 fn request_blocking(&self, request: Request) -> Result<Response> {
333 if !self.background_runtime.is_open() {
334 return Err(ConnectionError::ConnectionIsClosed.into());
335 }
336 self.request_transmitter.request_blocking(request)
337 }
338
339 pub(crate) fn force_close(&self) -> Result {
340 let session_ids: Vec<SessionID> = self.open_sessions.lock().unwrap().keys().cloned().collect();
341 for session_id in session_ids {
342 self.close_session(session_id).ok();
343 }
344 self.request_transmitter.force_close()
345 }
346
347 pub(crate) fn servers_all(&self) -> Result<Vec<Address>> {
348 match self.request_blocking(Request::ServersAll)? {
349 Response::ServersAll { servers } => Ok(servers),
350 other => Err(InternalError::UnexpectedResponseType { response_type: format!("{other:?}") }.into()),
351 }
352 }
353
354 #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
355 pub(crate) async fn database_exists(&self, database_name: String) -> Result<bool> {
356 match self.request(Request::DatabasesContains { database_name }).await? {
357 Response::DatabasesContains { contains } => Ok(contains),
358 other => Err(InternalError::UnexpectedResponseType { response_type: format!("{other:?}") }.into()),
359 }
360 }
361
362 #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
363 pub(crate) async fn create_database(&self, database_name: String) -> Result {
364 self.request(Request::DatabaseCreate { database_name }).await?;
365 Ok(())
366 }
367
368 #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
369 pub(crate) async fn get_database_replicas(&self, database_name: String) -> Result<DatabaseInfo> {
370 match self.request(Request::DatabaseGet { database_name }).await? {
371 Response::DatabaseGet { database } => Ok(database),
372 other => Err(InternalError::UnexpectedResponseType { response_type: format!("{other:?}") }.into()),
373 }
374 }
375
376 #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
377 pub(crate) async fn all_databases(&self) -> Result<Vec<DatabaseInfo>> {
378 match self.request(Request::DatabasesAll).await? {
379 Response::DatabasesAll { databases } => Ok(databases),
380 other => Err(InternalError::UnexpectedResponseType { response_type: format!("{other:?}") }.into()),
381 }
382 }
383
384 #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
385 pub(crate) async fn database_schema(&self, database_name: String) -> Result<String> {
386 match self.request(Request::DatabaseSchema { database_name }).await? {
387 Response::DatabaseSchema { schema } => Ok(schema),
388 other => Err(InternalError::UnexpectedResponseType { response_type: format!("{other:?}") }.into()),
389 }
390 }
391
392 #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
393 pub(crate) async fn database_type_schema(&self, database_name: String) -> Result<String> {
394 match self.request(Request::DatabaseTypeSchema { database_name }).await? {
395 Response::DatabaseTypeSchema { schema } => Ok(schema),
396 other => Err(InternalError::UnexpectedResponseType { response_type: format!("{other:?}") }.into()),
397 }
398 }
399
400 #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
401 pub(crate) async fn database_rule_schema(&self, database_name: String) -> Result<String> {
402 match self.request(Request::DatabaseRuleSchema { database_name }).await? {
403 Response::DatabaseRuleSchema { schema } => Ok(schema),
404 other => Err(InternalError::UnexpectedResponseType { response_type: format!("{other:?}") }.into()),
405 }
406 }
407
408 #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
409 pub(crate) async fn delete_database(&self, database_name: String) -> Result {
410 self.request(Request::DatabaseDelete { database_name }).await?;
411 Ok(())
412 }
413
414 #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
415 pub(crate) async fn open_session(
416 &self,
417 database_name: String,
418 session_type: SessionType,
419 options: Options,
420 ) -> Result<SessionInfo> {
421 let start = Instant::now();
422 match self.request(Request::SessionOpen { database_name, session_type, options }).await? {
423 Response::SessionOpen { session_id, server_duration } => {
424 let (on_close_register_sink, on_close_register_source) = unbounded_async();
425 let (pulse_shutdown_sink, pulse_shutdown_source) = unbounded_async();
426 self.open_sessions.lock().unwrap().insert(session_id.clone(), pulse_shutdown_sink);
427 self.background_runtime.spawn(session_pulse(
428 session_id.clone(),
429 self.request_transmitter.clone(),
430 on_close_register_source,
431 self.background_runtime.callback_handler_sink(),
432 pulse_shutdown_source,
433 ));
434 Ok(SessionInfo {
435 session_id,
436 network_latency: start.elapsed().saturating_sub(server_duration),
437 on_close_register_sink,
438 })
439 }
440 other => Err(InternalError::UnexpectedResponseType { response_type: format!("{other:?}") }.into()),
441 }
442 }
443
444 pub(crate) fn close_session(&self, session_id: SessionID) -> Result {
445 if let Some(sink) = self.open_sessions.lock().unwrap().remove(&session_id) {
446 sink.send(()).ok();
447 }
448 self.request_blocking(Request::SessionClose { session_id })?;
449 Ok(())
450 }
451
452 #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
453 pub(crate) async fn open_transaction(
454 &self,
455 session_id: SessionID,
456 transaction_type: TransactionType,
457 options: Options,
458 network_latency: Duration,
459 ) -> Result<(TransactionStream, UnboundedSender<()>)> {
460 match self
461 .request(Request::Transaction(TransactionRequest::Open {
462 session_id,
463 transaction_type,
464 options,
465 network_latency,
466 }))
467 .await?
468 {
469 Response::TransactionOpen { request_sink, response_source } => {
470 let transmitter = TransactionTransmitter::new(
471 &self.background_runtime,
472 request_sink,
473 response_source,
474 self.background_runtime.callback_handler_sink(),
475 );
476 let transmitter_shutdown_sink = transmitter.shutdown_sink().clone();
477 let transaction_stream = TransactionStream::new(transaction_type, options, transmitter);
478 Ok((transaction_stream, transmitter_shutdown_sink))
479 }
480 other => Err(InternalError::UnexpectedResponseType { response_type: format!("{other:?}") }.into()),
481 }
482 }
483
484 #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
485 pub(crate) async fn all_users(&self) -> Result<Vec<User>> {
486 match self.request(Request::UsersAll).await? {
487 Response::UsersAll { users } => Ok(users),
488 other => Err(InternalError::UnexpectedResponseType { response_type: format!("{other:?}") }.into()),
489 }
490 }
491
492 #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
493 pub(crate) async fn contains_user(&self, username: String) -> Result<bool> {
494 match self.request(Request::UsersContain { username }).await? {
495 Response::UsersContain { contains } => Ok(contains),
496 other => Err(InternalError::UnexpectedResponseType { response_type: format!("{other:?}") }.into()),
497 }
498 }
499
500 #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
501 pub(crate) async fn create_user(&self, username: String, password: String) -> Result {
502 match self.request(Request::UsersCreate { username, password }).await? {
503 Response::UsersCreate => Ok(()),
504 other => Err(InternalError::UnexpectedResponseType { response_type: format!("{other:?}") }.into()),
505 }
506 }
507
508 #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
509 pub(crate) async fn delete_user(&self, username: String) -> Result {
510 match self.request(Request::UsersDelete { username }).await? {
511 Response::UsersDelete => Ok(()),
512 other => Err(InternalError::UnexpectedResponseType { response_type: format!("{other:?}") }.into()),
513 }
514 }
515
516 #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
517 pub(crate) async fn get_user(&self, username: String) -> Result<Option<User>> {
518 match self.request(Request::UsersGet { username }).await? {
519 Response::UsersGet { user } => Ok(user),
520 other => Err(InternalError::UnexpectedResponseType { response_type: format!("{other:?}") }.into()),
521 }
522 }
523
524 #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
525 pub(crate) async fn set_user_password(&self, username: String, password: String) -> Result {
526 match self.request(Request::UsersPasswordSet { username, password }).await? {
527 Response::UsersPasswordSet => Ok(()),
528 other => Err(InternalError::UnexpectedResponseType { response_type: format!("{other:?}") }.into()),
529 }
530 }
531
532 #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
533 pub(crate) async fn update_user_password(
534 &self,
535 username: String,
536 password_old: String,
537 password_new: String,
538 ) -> Result {
539 match self.request(Request::UserPasswordUpdate { username, password_old, password_new }).await? {
540 Response::UserPasswordUpdate => Ok(()),
541 other => Err(InternalError::UnexpectedResponseType { response_type: format!("{other:?}") }.into()),
542 }
543 }
544}
545
546impl fmt::Debug for ServerConnection {
547 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
548 f.debug_struct("ServerConnection").field("open_sessions", &self.open_sessions).finish()
549 }
550}
551
552async fn session_pulse(
553 session_id: SessionID,
554 request_transmitter: Arc<RPCTransmitter>,
555 mut on_close_callback_source: UnboundedReceiver<Callback>,
556 callback_handler_sink: Sender<(Callback, AsyncOneshotSender<()>)>,
557 mut shutdown_source: UnboundedReceiver<()>,
558) {
559 const PULSE_INTERVAL: Duration = Duration::from_secs(5);
560 let mut next_pulse = Instant::now();
561 let mut on_close = Vec::new();
562 loop {
563 select! {
564 _ = sleep_until(next_pulse) => {
565 let session_id = session_id.clone();
566 match request_transmitter.request_async(Request::SessionPulse { session_id }).await {
567 Ok(Response::SessionPulse { is_alive: true }) => {
568 next_pulse = (next_pulse + PULSE_INTERVAL).max(Instant::now())
569 }
570 _ => break,
571 }
572 }
573 callback = on_close_callback_source.recv() => {
574 if let Some(callback) = callback {
575 on_close.push(callback)
576 }
577 }
578 _ = shutdown_source.recv() => break,
579 }
580 }
581
582 join_all(on_close.into_iter().map(|callback| {
583 let (response_sink, response) = oneshot_async();
584 callback_handler_sink.send((Box::new(callback), response_sink)).unwrap();
585 response
586 }))
587 .await;
588}