use std::{
fmt,
num::NonZeroUsize,
sync::{
atomic::{AtomicU32, Ordering},
Arc,
},
time::Duration,
};
use async_trait::async_trait;
use futures::TryFutureExt;
use lru::LruCache;
use parking_lot::Mutex;
use rmpv::Value;
use tokio::time::timeout;
use tracing::{debug, trace};
use crate::{
builder::ConnectionBuilder,
client::{Executor, Stream, Transaction, TransactionBuilder},
codec::{
consts::TransactionIsolationLevel,
request::{EncodedRequest, Id, Request},
response::ResponseBody,
},
transport::DispatcherSender,
ExecutorExt, Result,
};
#[derive(Clone)]
pub struct Connection {
inner: Arc<ConnectionInner>,
}
struct ConnectionInner {
dispatcher_sender: DispatcherSender,
next_stream_id: AtomicU32,
timeout: Option<Duration>,
transaction_timeout_secs: Option<f64>,
transaction_isolation_level: TransactionIsolationLevel,
async_rt_handle: tokio::runtime::Handle,
sql_statement_cache: Option<Mutex<LruCache<String, u64>>>,
sql_statement_cache_update_lock: Mutex<()>,
}
impl Connection {
pub fn builder() -> ConnectionBuilder {
ConnectionBuilder::default()
}
pub(crate) fn new(
dispatcher_sender: DispatcherSender,
timeout: Option<Duration>,
transaction_timeout: Option<Duration>,
transaction_isolation_level: TransactionIsolationLevel,
sql_statement_cache_capacity: usize,
) -> Self {
Self {
inner: Arc::new(ConnectionInner {
dispatcher_sender,
next_stream_id: AtomicU32::new(1),
timeout,
transaction_timeout_secs: transaction_timeout.as_ref().map(Duration::as_secs_f64),
transaction_isolation_level,
async_rt_handle: tokio::runtime::Handle::current(),
sql_statement_cache: NonZeroUsize::new(sql_statement_cache_capacity)
.map(|x| Mutex::new(LruCache::new(x))),
sql_statement_cache_update_lock: Mutex::new(()),
}),
}
}
#[allow(clippy::let_underscore_future)]
pub(crate) fn send_request_sync_and_forget(&self, body: impl Request, stream_id: Option<u32>) {
let this = self.clone();
let req = EncodedRequest::new(body, stream_id);
let _ = self.inner.async_rt_handle.spawn(async move {
let res = futures::future::ready(req)
.err_into()
.and_then(|x| this.send_encoded_request(x))
.await;
debug!("Response for background request: {:?}", res);
});
}
pub(crate) fn next_stream_id(&self) -> u32 {
let next = self.inner.next_stream_id.fetch_add(1, Ordering::Relaxed);
if next != 0 {
next
} else {
self.inner.next_stream_id.fetch_add(1, Ordering::Relaxed)
}
}
pub(crate) async fn id(&self, features: Id) -> Result<()> {
self.send_request(features).await.map(drop)
}
pub(crate) fn stream(&self) -> Stream {
Stream::new(self.clone())
}
pub(crate) fn transaction_builder(&self) -> TransactionBuilder {
TransactionBuilder::new(
self.clone(),
self.inner.transaction_timeout_secs,
self.inner.transaction_isolation_level,
)
}
pub(crate) async fn transaction(&self) -> Result<Transaction> {
self.transaction_builder().begin().await
}
async fn get_cached_sql_statement_id_inner(&self, statement: &str) -> Option<u64> {
let cache = self.inner.sql_statement_cache.as_ref()?;
if let Some(stmt_id) = cache.lock().get(statement) {
return Some(*stmt_id);
}
let update_lock = self.inner.sql_statement_cache_update_lock.try_lock();
let stmt_id = {
let stmt_id = match self.prepare_sql(statement).await {
Ok(x) => {
let stmt_id = x.stmt_id();
trace!(statement, "Statement prepared with id {stmt_id}");
stmt_id
}
Err(err) => {
debug!("Failed to prepare statement for cache: {:#}", err);
return None;
}
};
let _ = cache.lock().put(statement.into(), stmt_id);
stmt_id
};
drop(update_lock);
Some(stmt_id)
}
}
#[async_trait]
impl Executor for Connection {
async fn send_encoded_request(&self, request: EncodedRequest) -> Result<Value> {
let fut = self.inner.dispatcher_sender.send(request);
let resp = match self.inner.timeout {
Some(x) => timeout(x, fut).await??,
None => fut.await?,
};
match resp.body {
ResponseBody::Ok(x) => Ok(x),
ResponseBody::Error(x) => Err(x.into()),
}
}
fn stream(&self) -> Stream {
self.stream()
}
fn transaction_builder(&self) -> TransactionBuilder {
self.transaction_builder()
}
async fn transaction(&self) -> Result<Transaction> {
self.transaction().await
}
async fn get_cached_sql_statement_id(&self, statement: &str) -> Option<u64> {
self.get_cached_sql_statement_id_inner(statement).await
}
}
impl fmt::Debug for Connection {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Connection")
}
}