use async_trait::async_trait;
use std::collections::VecDeque;
use std::io;
use std::mem;
use std::net::SocketAddr;
use std::net::ToSocketAddrs;
#[cfg(unix)]
use std::path::Path;
use std::pin::Pin;
use std::task::{self, Poll};
use combine::{parser::combinator::AnySendPartialState, stream::PointerOffset};
#[cfg(all(unix, feature = "tokio-comp"))]
use tokio::net::UnixStream as UnixStreamTokio;
use tokio::{
io::{AsyncRead, AsyncWrite, AsyncWriteExt},
sync::{mpsc, oneshot},
};
#[cfg(feature = "tokio-comp")]
use tokio::net::TcpStream as TcpStreamTokio;
#[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))]
use tokio_util::codec::Decoder;
use futures_util::{
future::{Future, FutureExt, TryFutureExt},
ready,
sink::Sink,
stream::{Stream, StreamExt, TryStreamExt as _},
};
use pin_project_lite::pin_project;
use crate::cmd::{cmd, Cmd};
use crate::connection::Msg;
use crate::connection::{ConnectionAddr, ConnectionInfo};
#[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))]
use crate::parser::ValueCodec;
use crate::types::{ErrorKind, RedisError, RedisFuture, RedisResult, Value};
use crate::{from_redis_value, ToRedisArgs};
#[cfg(feature = "async-std-comp")]
use crate::aio_async_std;
#[async_trait]
pub(crate) trait Connect {
async fn connect_tcp(socket_addr: SocketAddr) -> RedisResult<ActualConnection>;
#[cfg(unix)]
async fn connect_unix(path: &Path) -> RedisResult<ActualConnection>;
}
#[cfg(feature = "tokio-comp")]
mod tokio_aio {
use super::{async_trait, ActualConnection, Connect, RedisResult, SocketAddr, TcpStreamTokio};
#[cfg(unix)]
use super::{Path, UnixStreamTokio};
pub struct Tokio;
#[async_trait]
impl Connect for Tokio {
async fn connect_tcp(socket_addr: SocketAddr) -> RedisResult<ActualConnection> {
Ok(TcpStreamTokio::connect(&socket_addr)
.await
.map(ActualConnection::TcpTokio)?)
}
#[cfg(unix)]
async fn connect_unix(path: &Path) -> RedisResult<ActualConnection> {
Ok(UnixStreamTokio::connect(path)
.await
.map(ActualConnection::UnixTokio)?)
}
}
}
pub(crate) enum ActualConnection {
#[cfg(feature = "tokio-comp")]
TcpTokio(TcpStreamTokio),
#[cfg(unix)]
#[cfg(feature = "tokio-comp")]
UnixTokio(UnixStreamTokio),
#[cfg(feature = "async-std-comp")]
TcpAsyncStd(aio_async_std::TcpStreamAsyncStdWrapped),
#[cfg(feature = "async-std-comp")]
#[cfg(unix)]
UnixAsyncStd(aio_async_std::UnixStreamAsyncStdWrapped),
}
impl AsyncWrite for ActualConnection {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut task::Context,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match &mut *self {
#[cfg(feature = "tokio-comp")]
ActualConnection::TcpTokio(r) => Pin::new(r).poll_write(cx, buf),
#[cfg(unix)]
#[cfg(feature = "tokio-comp")]
ActualConnection::UnixTokio(r) => Pin::new(r).poll_write(cx, buf),
#[cfg(feature = "async-std-comp")]
ActualConnection::TcpAsyncStd(r) => Pin::new(r).poll_write(cx, buf),
#[cfg(feature = "async-std-comp")]
#[cfg(unix)]
ActualConnection::UnixAsyncStd(r) => Pin::new(r).poll_write(cx, buf),
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<io::Result<()>> {
match &mut *self {
#[cfg(feature = "tokio-comp")]
ActualConnection::TcpTokio(r) => Pin::new(r).poll_flush(cx),
#[cfg(unix)]
#[cfg(feature = "tokio-comp")]
ActualConnection::UnixTokio(r) => Pin::new(r).poll_flush(cx),
#[cfg(feature = "async-std-comp")]
ActualConnection::TcpAsyncStd(r) => Pin::new(r).poll_flush(cx),
#[cfg(feature = "async-std-comp")]
#[cfg(unix)]
ActualConnection::UnixAsyncStd(r) => Pin::new(r).poll_flush(cx),
}
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<io::Result<()>> {
match &mut *self {
#[cfg(feature = "tokio-comp")]
ActualConnection::TcpTokio(r) => Pin::new(r).poll_shutdown(cx),
#[cfg(unix)]
#[cfg(feature = "tokio-comp")]
ActualConnection::UnixTokio(r) => Pin::new(r).poll_shutdown(cx),
#[cfg(feature = "async-std-comp")]
ActualConnection::TcpAsyncStd(r) => Pin::new(r).poll_shutdown(cx),
#[cfg(feature = "async-std-comp")]
#[cfg(unix)]
ActualConnection::UnixAsyncStd(r) => Pin::new(r).poll_shutdown(cx),
}
}
}
impl AsyncRead for ActualConnection {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut task::Context,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
match &mut *self {
#[cfg(feature = "tokio-comp")]
ActualConnection::TcpTokio(r) => Pin::new(r).poll_read(cx, buf),
#[cfg(unix)]
#[cfg(feature = "tokio-comp")]
ActualConnection::UnixTokio(r) => Pin::new(r).poll_read(cx, buf),
#[cfg(feature = "async-std-comp")]
ActualConnection::TcpAsyncStd(r) => Pin::new(r).poll_read(cx, buf),
#[cfg(feature = "async-std-comp")]
#[cfg(unix)]
ActualConnection::UnixAsyncStd(r) => Pin::new(r).poll_read(cx, buf),
}
}
}
pub struct PubSub(Connection);
impl PubSub {
fn new(con: Connection) -> Self {
Self(con)
}
pub async fn subscribe<T: ToRedisArgs>(&mut self, channel: T) -> RedisResult<()> {
Ok(cmd("SUBSCRIBE")
.arg(channel)
.query_async(&mut self.0)
.await?)
}
pub async fn psubscribe<T: ToRedisArgs>(&mut self, pchannel: T) -> RedisResult<()> {
Ok(cmd("PSUBSCRIBE")
.arg(pchannel)
.query_async(&mut self.0)
.await?)
}
pub async fn unsubscribe<T: ToRedisArgs>(&mut self, channel: T) -> RedisResult<()> {
Ok(cmd("UNSUBSCRIBE")
.arg(channel)
.query_async(&mut self.0)
.await?)
}
pub async fn punsubscribe<T: ToRedisArgs>(&mut self, pchannel: T) -> RedisResult<()> {
Ok(cmd("PUNSUBSCRIBE")
.arg(pchannel)
.query_async(&mut self.0)
.await?)
}
pub fn on_message<'a>(&'a mut self) -> impl Stream<Item = Msg> + 'a {
ValueCodec::default()
.framed(&mut self.0.con)
.into_stream()
.filter_map(|msg| Box::pin(async move { Msg::from_value(&msg.ok()?) }))
}
pub async fn into_connection(mut self) -> Connection {
self.0.exit_pubsub().await.ok();
self.0
}
}
pub struct Connection {
con: ActualConnection,
buf: Vec<u8>,
decoder: combine::stream::Decoder<AnySendPartialState, PointerOffset<[u8]>>,
db: i64,
pubsub: bool,
}
impl Connection {
pub fn into_pubsub(self) -> PubSub {
PubSub::new(self)
}
async fn read_response(&mut self) -> RedisResult<Value> {
crate::parser::parse_redis_value_async(&mut self.decoder, &mut self.con).await
}
async fn exit_pubsub(&mut self) -> RedisResult<()> {
let res = self.clear_active_subscriptions().await;
if res.is_ok() {
self.pubsub = false;
} else {
self.pubsub = true;
}
res
}
async fn clear_active_subscriptions(&mut self) -> RedisResult<()> {
{
let unsubscribe = crate::Pipeline::new()
.add_command(cmd("UNSUBSCRIBE"))
.add_command(cmd("PUNSUBSCRIBE"))
.get_packed_pipeline();
self.con.write_all(&unsubscribe).await?;
}
let mut received_unsub = false;
let mut received_punsub = false;
loop {
let res: (Vec<u8>, (), isize) = from_redis_value(&self.read_response().await?)?;
match res.0.first() {
Some(&b'u') => received_unsub = true,
Some(&b'p') => received_punsub = true,
_ => (),
}
if received_unsub && received_punsub && res.2 == 0 {
break;
}
}
Ok(())
}
}
#[cfg(feature = "tokio-comp")]
pub async fn connect_tokio(connection_info: &ConnectionInfo) -> RedisResult<Connection> {
let con = connect_simple::<tokio_aio::Tokio>(connection_info).await?;
prepare_connection(con, connection_info).await
}
#[cfg(feature = "async-std-comp")]
pub async fn connect_async_std(connection_info: &ConnectionInfo) -> RedisResult<Connection> {
let con = connect_simple::<aio_async_std::AsyncStd>(connection_info).await?;
prepare_connection(con, connection_info).await
}
async fn prepare_connection(
con: ActualConnection,
connection_info: &ConnectionInfo,
) -> RedisResult<Connection> {
let mut rv = Connection {
con,
buf: Vec::new(),
decoder: combine::stream::Decoder::new(),
db: connection_info.db,
pubsub: false,
};
authenticate(connection_info, &mut rv).await?;
Ok(rv)
}
async fn authenticate<C>(connection_info: &ConnectionInfo, con: &mut C) -> RedisResult<()>
where
C: ConnectionLike,
{
if let Some(passwd) = &connection_info.passwd {
match cmd("AUTH").arg(passwd).query_async(con).await {
Ok(Value::Okay) => (),
_ => {
fail!((
ErrorKind::AuthenticationFailed,
"Password authentication failed"
));
}
}
}
if connection_info.db != 0 {
match cmd("SELECT").arg(connection_info.db).query_async(con).await {
Ok(Value::Okay) => (),
_ => fail!((
ErrorKind::ResponseError,
"Redis server refused to switch database"
)),
}
}
Ok(())
}
async fn connect_simple<T: Connect>(
connection_info: &ConnectionInfo,
) -> RedisResult<ActualConnection> {
Ok(match *connection_info.addr {
ConnectionAddr::Tcp(ref host, port) => {
let socket_addr = get_socket_addrs(host, port)?;
<T>::connect_tcp(socket_addr).await?
}
#[cfg(unix)]
ConnectionAddr::Unix(ref path) => <T>::connect_unix(path).await?,
#[cfg(not(unix))]
ConnectionAddr::Unix(_) => {
return Err(RedisError::from((
ErrorKind::InvalidClientConfig,
"Cannot connect to unix sockets \
on this platform",
)))
}
})
}
fn get_socket_addrs(host: &str, port: u16) -> RedisResult<SocketAddr> {
let mut socket_addrs = (&host[..], port).to_socket_addrs()?;
match socket_addrs.next() {
Some(socket_addr) => Ok(socket_addr),
None => Err(RedisError::from((
ErrorKind::InvalidClientConfig,
"No address found for host",
))),
}
}
pub trait ConnectionLike: Sized {
fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value>;
fn req_packed_commands<'a>(
&'a mut self,
cmd: &'a crate::Pipeline,
offset: usize,
count: usize,
) -> RedisFuture<'a, Vec<Value>>;
fn get_db(&self) -> i64;
}
impl ConnectionLike for Connection {
fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> {
(async move {
if self.pubsub {
self.exit_pubsub().await?;
}
self.buf.clear();
cmd.write_packed_command(&mut self.buf);
self.con.write_all(&self.buf).await?;
self.read_response().await
})
.boxed()
}
fn req_packed_commands<'a>(
&'a mut self,
cmd: &'a crate::Pipeline,
offset: usize,
count: usize,
) -> RedisFuture<'a, Vec<Value>> {
(async move {
if self.pubsub {
self.exit_pubsub().await?;
}
self.buf.clear();
cmd.write_packed_pipeline(&mut self.buf);
self.con.write_all(&self.buf).await?;
for _ in 0..offset {
self.read_response().await?;
}
let mut rv = Vec::with_capacity(count);
for _ in 0..count {
rv.push(self.read_response().await?);
}
Ok(rv)
})
.boxed()
}
fn get_db(&self) -> i64 {
self.db
}
}
type PipelineOutput<O, E> = oneshot::Sender<Result<Vec<O>, E>>;
struct InFlight<O, E> {
output: PipelineOutput<O, E>,
response_count: usize,
buffer: Vec<O>,
}
struct PipelineMessage<S, I, E> {
input: S,
output: PipelineOutput<I, E>,
response_count: usize,
}
struct Pipeline<SinkItem, I, E>(mpsc::Sender<PipelineMessage<SinkItem, I, E>>);
impl<SinkItem, I, E> Clone for Pipeline<SinkItem, I, E> {
fn clone(&self) -> Self {
Pipeline(self.0.clone())
}
}
pin_project! {
struct PipelineSink<T, I, E> {
#[pin]
sink_stream: T,
in_flight: VecDeque<InFlight<I, E>>,
error: Option<E>,
}
}
impl<T, I, E> PipelineSink<T, I, E>
where
T: Stream<Item = Result<I, E>> + 'static,
{
fn new<SinkItem>(sink_stream: T) -> Self
where
T: Sink<SinkItem, Error = E> + Stream<Item = Result<I, E>> + 'static,
{
PipelineSink {
sink_stream,
in_flight: VecDeque::new(),
error: None,
}
}
fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Result<(), ()>> {
loop {
let item = match ready!(self.as_mut().project().sink_stream.poll_next(cx)) {
Some(Ok(item)) => Ok(item),
Some(Err(err)) => Err(err),
None => return Poll::Ready(Err(())),
};
self.as_mut().send_result(item);
}
}
fn send_result(self: Pin<&mut Self>, result: Result<I, E>) {
let self_ = self.project();
let response = {
let entry = match self_.in_flight.front_mut() {
Some(entry) => entry,
None => return,
};
match result {
Ok(item) => {
entry.buffer.push(item);
if entry.response_count > entry.buffer.len() {
return;
}
Ok(mem::replace(&mut entry.buffer, Vec::new()))
}
Err(err) => Err(err),
}
};
let entry = self_.in_flight.pop_front().unwrap();
entry.output.send(response).ok();
}
}
impl<SinkItem, T, I, E> Sink<PipelineMessage<SinkItem, I, E>> for PipelineSink<T, I, E>
where
T: Sink<SinkItem, Error = E> + Stream<Item = Result<I, E>> + 'static,
{
type Error = ();
fn poll_ready(
mut self: Pin<&mut Self>,
cx: &mut task::Context,
) -> Poll<Result<(), Self::Error>> {
match ready!(self.as_mut().project().sink_stream.poll_ready(cx)) {
Ok(()) => Ok(()).into(),
Err(err) => {
*self.project().error = Some(err);
Ok(()).into()
}
}
}
fn start_send(
mut self: Pin<&mut Self>,
PipelineMessage {
input,
output,
response_count,
}: PipelineMessage<SinkItem, I, E>,
) -> Result<(), Self::Error> {
let self_ = self.as_mut().project();
if let Some(err) = self_.error.take() {
let _ = output.send(Err(err));
return Err(());
}
match self_.sink_stream.start_send(input) {
Ok(()) => {
self_.in_flight.push_back(InFlight {
output,
response_count,
buffer: Vec::new(),
});
Ok(())
}
Err(err) => {
let _ = output.send(Err(err));
Err(())
}
}
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut task::Context,
) -> Poll<Result<(), Self::Error>> {
ready!(self
.as_mut()
.project()
.sink_stream
.poll_flush(cx)
.map_err(|err| {
self.as_mut().send_result(Err(err));
}))?;
self.poll_read(cx)
}
fn poll_close(
mut self: Pin<&mut Self>,
cx: &mut task::Context,
) -> Poll<Result<(), Self::Error>> {
if !self.in_flight.is_empty() {
ready!(self.as_mut().poll_flush(cx))?;
}
let this = self.as_mut().project();
this.sink_stream.poll_close(cx).map_err(|err| {
self.send_result(Err(err));
})
}
}
impl<SinkItem, I, E> Pipeline<SinkItem, I, E>
where
SinkItem: Send + 'static,
I: Send + 'static,
E: Send + 'static,
{
fn new<T>(sink_stream: T) -> (Self, impl Future<Output = ()>)
where
T: Sink<SinkItem, Error = E> + Stream<Item = Result<I, E>> + 'static,
T: Send + 'static,
T::Item: Send,
T::Error: Send,
T::Error: ::std::fmt::Debug,
{
const BUFFER_SIZE: usize = 50;
let (sender, receiver) = mpsc::channel(BUFFER_SIZE);
let f = receiver
.map(Ok)
.forward(PipelineSink::new::<SinkItem>(sink_stream))
.map(|_| ());
(Pipeline(sender), f)
}
async fn send(&mut self, item: SinkItem) -> Result<I, Option<E>> {
self.send_recv_multiple(item, 1)
.map_ok(|mut item| item.pop().unwrap())
.await
}
async fn send_recv_multiple(
&mut self,
input: SinkItem,
count: usize,
) -> Result<Vec<I>, Option<E>> {
let (sender, receiver) = oneshot::channel();
self.0
.send(PipelineMessage {
input,
response_count: count,
output: sender,
})
.map_err(|_| None)
.and_then(|_| {
receiver.map(|result| {
match result {
Ok(result) => result.map_err(Some),
Err(_) => {
Err(None)
}
}
})
})
.await
}
}
#[derive(Clone)]
pub struct MultiplexedConnection {
pipeline: Pipeline<Vec<u8>, Value, RedisError>,
db: i64,
}
impl MultiplexedConnection {
#[cfg(feature = "tokio-comp")]
pub(crate) async fn new_tokio(
connection_info: &ConnectionInfo,
) -> RedisResult<(Self, impl Future<Output = ()>)> {
let con = connect_simple::<tokio_aio::Tokio>(connection_info).await?;
Ok(MultiplexedConnection::create_connection(connection_info, con).await?)
}
#[cfg(feature = "async-std-comp")]
pub(crate) async fn new_async_std(
connection_info: &ConnectionInfo,
) -> RedisResult<(Self, impl Future<Output = ()>)> {
let con = connect_simple::<aio_async_std::AsyncStd>(connection_info).await?;
MultiplexedConnection::create_connection(connection_info, con).await
}
async fn create_connection(
connection_info: &ConnectionInfo,
con: ActualConnection,
) -> RedisResult<(Self, impl Future<Output = ()>)> {
fn boxed(
f: impl Future<Output = ()> + Send + 'static,
) -> Pin<Box<dyn Future<Output = ()> + Send>> {
Box::pin(f)
}
#[cfg(all(not(feature = "tokio-comp"), not(feature = "async-std-comp")))]
compile_error!("tokio-comp or async-std-comp features required for aio feature");
let (pipeline, driver) = match con {
#[cfg(feature = "tokio-comp")]
ActualConnection::TcpTokio(tcp) => {
let codec = ValueCodec::default().framed(tcp);
let (pipeline, driver) = Pipeline::new(codec);
(pipeline, boxed(driver))
}
#[cfg(feature = "async-std-comp")]
ActualConnection::TcpAsyncStd(tcp) => {
let codec = ValueCodec::default().framed(tcp);
let (pipeline, driver) = Pipeline::new(codec);
(pipeline, boxed(driver))
}
#[cfg(unix)]
#[cfg(feature = "tokio-comp")]
ActualConnection::UnixTokio(unix) => {
let codec = ValueCodec::default().framed(unix);
let (pipeline, driver) = Pipeline::new(codec);
(pipeline, boxed(driver))
}
#[cfg(unix)]
#[cfg(feature = "async-std-comp")]
ActualConnection::UnixAsyncStd(unix) => {
let codec = ValueCodec::default().framed(unix);
let (pipeline, driver) = Pipeline::new(codec);
(pipeline, boxed(driver))
}
};
let mut con = MultiplexedConnection {
pipeline,
db: connection_info.db,
};
authenticate(connection_info, &mut con).await?;
Ok((con, driver))
}
}
impl ConnectionLike for MultiplexedConnection {
fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> {
(async move {
let value = self
.pipeline
.send(cmd.get_packed_command())
.await
.map_err(|err| {
err.unwrap_or_else(|| {
RedisError::from(io::Error::from(io::ErrorKind::BrokenPipe))
})
})?;
Ok(value)
})
.boxed()
}
fn req_packed_commands<'a>(
&'a mut self,
cmd: &'a crate::Pipeline,
offset: usize,
count: usize,
) -> RedisFuture<'a, Vec<Value>> {
(async move {
let mut value = self
.pipeline
.send_recv_multiple(cmd.get_packed_pipeline(), offset + count)
.await
.map_err(|err| {
err.unwrap_or_else(|| {
RedisError::from(io::Error::from(io::ErrorKind::BrokenPipe))
})
})?;
value.drain(..offset);
Ok(value)
})
.boxed()
}
fn get_db(&self) -> i64 {
self.db
}
}
#[cfg(feature = "connection-manager")]
mod connection_manager {
use super::*;
use std::sync::Arc;
use arc_swap::{self, ArcSwap};
use futures::future::{self, Shared};
use futures_util::future::BoxFuture;
#[derive(Clone)]
pub struct ConnectionManager {
connection_info: ConnectionInfo,
connection: Arc<ArcSwap<SharedRedisFuture<MultiplexedConnection>>>,
}
type CloneableRedisResult<T> = Result<T, Arc<RedisError>>;
type SharedRedisFuture<T> = Shared<BoxFuture<'static, CloneableRedisResult<T>>>;
impl ConnectionManager {
pub async fn new(connection_info: ConnectionInfo) -> RedisResult<Self> {
#[cfg(all(feature = "tokio-comp", not(feature = "async-std-comp")))]
let con = connect_simple::<tokio_aio::Tokio>(&connection_info).await?;
#[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))]
let con = connect_simple::<aio_async_std::AsyncStd>(&connection_info).await?;
#[cfg(all(feature = "tokio-comp", feature = "async-std-comp"))]
let con = if tokio::runtime::Handle::try_current().is_ok() {
connect_simple::<tokio_aio::Tokio>(&connection_info).await?
} else {
connect_simple::<aio_async_std::AsyncStd>(&connection_info).await?
};
#[cfg(all(not(feature = "tokio-comp"), not(feature = "async-std-comp")))]
compile_error!("tokio-comp or async-std-comp features required for aio feature");
let (connection, driver) =
MultiplexedConnection::create_connection(&connection_info, con).await?;
tokio::spawn(driver);
Ok(Self {
connection_info,
connection: Arc::new(ArcSwap::from_pointee(
future::ok(connection).boxed().shared(),
)),
})
}
fn reconnect(
&self,
current: arc_swap::Guard<'_, Arc<SharedRedisFuture<MultiplexedConnection>>>,
) {
let connection_info = self.connection_info.clone();
let new_connection: SharedRedisFuture<MultiplexedConnection> = async move {
#[cfg(all(feature = "tokio-comp", not(feature = "async-std-comp")))]
let con = connect_simple::<tokio_aio::Tokio>(&connection_info).await?;
#[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))]
let con = connect_simple::<aio_async_std::AsyncStd>(&connection_info).await?;
#[cfg(all(feature = "tokio-comp", feature = "async-std-comp"))]
let con = if tokio::runtime::Handle::try_current().is_ok() {
connect_simple::<tokio_aio::Tokio>(&connection_info).await?
} else {
connect_simple::<aio_async_std::AsyncStd>(&connection_info).await?
};
#[cfg(all(not(feature = "tokio-comp"), not(feature = "async-std-comp")))]
compile_error!("tokio-comp or async-std-comp features required for aio feature");
let (new_connection, driver) =
MultiplexedConnection::create_connection(&connection_info, con).await?;
tokio::spawn(driver);
Ok(new_connection)
}
.boxed()
.shared();
let new_connection_arc = Arc::new(new_connection.clone());
let prev = self
.connection
.compare_and_swap(¤t, new_connection_arc);
if Arc::ptr_eq(&prev, ¤t) {
tokio::spawn(new_connection);
}
}
}
macro_rules! reconnect_if_dropped {
($self:expr, $result:expr, $current:expr) => {
if let Err(ref e) = $result {
if e.is_connection_dropped() {
$self.reconnect($current);
}
}
};
}
macro_rules! reconnect_if_io_error {
($self:expr, $result:expr, $current:expr) => {
if let Err(e) = $result {
if e.is_io_error() {
$self.reconnect($current);
}
return Err(e);
}
};
}
impl ConnectionLike for ConnectionManager {
fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> {
(async move {
let guard = self.connection.load();
let connection_result = (**guard)
.clone()
.await
.map_err(|e| e.clone_mostly("Reconnecting failed"));
reconnect_if_io_error!(self, connection_result, guard);
let result = connection_result?.req_packed_command(cmd).await;
reconnect_if_dropped!(self, &result, guard);
result
})
.boxed()
}
fn req_packed_commands<'a>(
&'a mut self,
cmd: &'a crate::Pipeline,
offset: usize,
count: usize,
) -> RedisFuture<'a, Vec<Value>> {
(async move {
let guard = self.connection.load();
let connection_result = (**guard)
.clone()
.await
.map_err(|e| e.clone_mostly("Reconnecting failed"));
reconnect_if_io_error!(self, connection_result, guard);
let result = connection_result?
.req_packed_commands(cmd, offset, count)
.await;
reconnect_if_dropped!(self, &result, guard);
result
})
.boxed()
}
fn get_db(&self) -> i64 {
self.connection_info.db
}
}
}
#[cfg(feature = "connection-manager")]
pub use connection_manager::ConnectionManager;