rustis/network/
standalone_connection.rs

1use crate::{
2    Error, Future, Result, RetryReason, TcpStreamReader, TcpStreamWriter,
3    client::{Config, PreparedCommand},
4    commands::{
5        ClusterCommands, ConnectionCommands, HelloOptions, SentinelCommands, ServerCommands,
6    },
7    resp::{BufferDecoder, Command, CommandEncoder, RespBuf},
8    tcp_connect,
9};
10#[cfg(any(feature = "native-tls", feature = "rustls"))]
11use crate::{TcpTlsStreamReader, TcpTlsStreamWriter, tcp_tls_connect};
12use bytes::BytesMut;
13use futures_util::{SinkExt, StreamExt};
14use log::{Level, debug, log_enabled};
15use serde::de::DeserializeOwned;
16use smallvec::SmallVec;
17use std::future::IntoFuture;
18use tokio::io::AsyncWriteExt;
19use tokio_util::codec::{Encoder, FramedRead, FramedWrite};
20
21pub(crate) enum Streams {
22    Tcp(
23        FramedRead<TcpStreamReader, BufferDecoder>,
24        FramedWrite<TcpStreamWriter, CommandEncoder>,
25    ),
26    #[cfg(any(feature = "native-tls", feature = "rustls"))]
27    TcpTls(
28        FramedRead<TcpTlsStreamReader, BufferDecoder>,
29        FramedWrite<TcpTlsStreamWriter, CommandEncoder>,
30    ),
31}
32
33impl Streams {
34    pub async fn connect(host: &str, port: u16, config: &Config) -> Result<Self> {
35        #[cfg(any(feature = "native-tls", feature = "rustls"))]
36        if let Some(tls_config) = &config.tls_config {
37            let (reader, writer) =
38                tcp_tls_connect(host, port, tls_config, config.connect_timeout).await?;
39            let framed_read = FramedRead::new(reader, BufferDecoder);
40            let framed_write = FramedWrite::new(writer, CommandEncoder);
41            Ok(Streams::TcpTls(framed_read, framed_write))
42        } else {
43            Self::connect_non_secure(host, port, config).await
44        }
45
46        #[cfg(not(any(feature = "native-tls", feature = "rustls")))]
47        Self::connect_non_secure(host, port, config).await
48    }
49
50    pub async fn connect_non_secure(host: &str, port: u16, config: &Config) -> Result<Self> {
51        let (reader, writer) = tcp_connect(host, port, config).await?;
52        let framed_read = FramedRead::new(reader, BufferDecoder);
53        let framed_write = FramedWrite::new(writer, CommandEncoder);
54        Ok(Streams::Tcp(framed_read, framed_write))
55    }
56}
57
58pub struct StandaloneConnection {
59    host: String,
60    port: u16,
61    config: Config,
62    streams: Streams,
63    buffer: BytesMut,
64    version: String,
65    tag: String,
66}
67
68impl StandaloneConnection {
69    pub async fn connect(host: &str, port: u16, config: &Config) -> Result<Self> {
70        let streams = Streams::connect(host, port, config).await?;
71
72        let mut connection = Self {
73            host: host.to_owned(),
74            port,
75            config: config.clone(),
76            streams,
77            buffer: BytesMut::new(),
78            version: String::new(),
79            tag: if config.connection_name.is_empty() {
80                format!("{host}:{port}")
81            } else {
82                format!("{}:{}:{}", config.connection_name, host, port)
83            },
84        };
85
86        connection.post_connect().await?;
87
88        Ok(connection)
89    }
90
91    pub async fn write(&mut self, command: &Command) -> Result<()> {
92        if log_enabled!(Level::Debug) {
93            debug!("[{}] Sending command: {command}", self.tag);
94        }
95        match &mut self.streams {
96            Streams::Tcp(_, framed_write) => framed_write.send(command).await,
97            #[cfg(any(feature = "native-tls", feature = "rustls"))]
98            Streams::TcpTls(_, framed_write) => framed_write.send(command).await,
99        }
100    }
101
102    pub async fn write_batch(
103        &mut self,
104        commands: SmallVec<[&mut Command; 10]>,
105        _retry_reasons: &[RetryReason],
106    ) -> Result<()> {
107        self.buffer.clear();
108
109        let command_encoder = match &mut self.streams {
110            Streams::Tcp(_, framed_write) => framed_write.encoder_mut(),
111            #[cfg(any(feature = "native-tls", feature = "rustls"))]
112            Streams::TcpTls(_, framed_write) => framed_write.encoder_mut(),
113        };
114
115        #[cfg(debug_assertions)]
116        let mut kill_connection = false;
117
118        for command in commands {
119            if log_enabled!(Level::Debug) {
120                debug!("[{}] Sending command: {command}", self.tag);
121            }
122
123            #[cfg(debug_assertions)]
124            if command.kill_connection_on_write > 0 {
125                kill_connection = true;
126                command.kill_connection_on_write -= 1;
127            }
128
129            command_encoder.encode(command, &mut self.buffer)?;
130        }
131
132        #[cfg(debug_assertions)]
133        if kill_connection {
134            let client_id = self.client_id().await?;
135            let mut config = self.config.clone();
136            "killer".clone_into(&mut config.connection_name);
137            let mut connection =
138                StandaloneConnection::connect(&self.host, self.port, &config).await?;
139            connection
140                .client_kill(crate::commands::ClientKillOptions::default().id(client_id))
141                .await?;
142        }
143
144        match &mut self.streams {
145            Streams::Tcp(_, framed_write) => framed_write.get_mut().write_all(&self.buffer).await?,
146            #[cfg(any(feature = "native-tls", feature = "rustls"))]
147            Streams::TcpTls(_, framed_write) => {
148                framed_write.get_mut().write_all(&self.buffer).await?
149            }
150        }
151
152        Ok(())
153    }
154
155    pub async fn read(&mut self) -> Option<Result<RespBuf>> {
156        if let Some(result) = match &mut self.streams {
157            Streams::Tcp(framed_read, _) => framed_read.next().await,
158            #[cfg(any(feature = "native-tls", feature = "rustls"))]
159            Streams::TcpTls(framed_read, _) => framed_read.next().await,
160        } {
161            if log_enabled!(Level::Debug) {
162                match &result {
163                    Ok(bytes) => debug!("[{}] Received result {bytes}", self.tag),
164                    Err(err) => debug!("[{}] Received result {err:?}", self.tag),
165                }
166            }
167            Some(result)
168        } else {
169            debug!("[{}] Socked is closed", self.tag);
170            None
171        }
172    }
173
174    pub async fn reconnect(&mut self) -> Result<()> {
175        self.streams = Streams::connect(&self.host, self.port, &self.config).await?;
176        self.post_connect().await?;
177
178        Ok(())
179
180        // TODO improve reconnection strategy with multiple retries
181    }
182
183    async fn post_connect(&mut self) -> Result<()> {
184        // RESP3
185        let mut hello_options = HelloOptions::new(3);
186
187        let config_username = self.config.username.clone();
188        let config_password = self.config.password.clone();
189        let config_connection_name = self.config.connection_name.clone();
190
191        // authentication
192        if let Some(password) = &config_password {
193            hello_options = hello_options.auth(
194                match &config_username {
195                    Some(username) => username,
196                    None => "default",
197                },
198                password,
199            );
200        }
201
202        // connection name
203        if !config_connection_name.is_empty() {
204            hello_options = hello_options.set_name(&config_connection_name);
205        }
206
207        let hello_result = self.hello(hello_options).await?;
208        self.version = hello_result.version;
209
210        // select database
211        if self.config.database != 0 {
212            self.select(self.config.database).await?;
213        }
214
215        Ok(())
216    }
217
218    pub fn get_version(&self) -> &str {
219        &self.version
220    }
221
222    pub(crate) fn tag(&self) -> &str {
223        &self.tag
224    }
225}
226
227impl<'a, R> IntoFuture for PreparedCommand<'a, &'a mut StandaloneConnection, R>
228where
229    R: DeserializeOwned + Send + 'a,
230{
231    type Output = Result<R>;
232    type IntoFuture = Future<'a, R>;
233
234    fn into_future(self) -> Self::IntoFuture {
235        Box::pin(async move {
236            self.executor.write(&self.command).await?;
237
238            let resp_buf = self.executor.read().await.ok_or_else(|| {
239                Error::Client(format!("[{}] disconnected by peer", self.executor.tag()))
240            })??;
241
242            resp_buf.to()
243        })
244    }
245}
246
247impl<'a> ClusterCommands<'a> for &'a mut StandaloneConnection {}
248impl<'a> ConnectionCommands<'a> for &'a mut StandaloneConnection {}
249impl<'a> SentinelCommands<'a> for &'a mut StandaloneConnection {}
250impl<'a> ServerCommands<'a> for &'a mut StandaloneConnection {}