1use crate::{
2 Error, Future, Result, RetryReason, TcpStreamReader, TcpStreamWriter,
3 client::{Config, PreparedCommand},
4 commands::{
5 ClusterCommands, ConnectionCommands, HelloOptions, SentinelCommands, ServerCommands,
6 },
7 resp::{BufferDecoder, Command, CommandEncoder, RespBuf},
8 tcp_connect,
9};
10#[cfg(any(feature = "native-tls", feature = "rustls"))]
11use crate::{TcpTlsStreamReader, TcpTlsStreamWriter, tcp_tls_connect};
12use bytes::BytesMut;
13use futures_util::{SinkExt, StreamExt};
14use log::{Level, debug, log_enabled};
15use serde::de::DeserializeOwned;
16use smallvec::SmallVec;
17use std::future::IntoFuture;
18use tokio::io::AsyncWriteExt;
19use tokio_util::codec::{Encoder, FramedRead, FramedWrite};
20
21pub(crate) enum Streams {
22 Tcp(
23 FramedRead<TcpStreamReader, BufferDecoder>,
24 FramedWrite<TcpStreamWriter, CommandEncoder>,
25 ),
26 #[cfg(any(feature = "native-tls", feature = "rustls"))]
27 TcpTls(
28 FramedRead<TcpTlsStreamReader, BufferDecoder>,
29 FramedWrite<TcpTlsStreamWriter, CommandEncoder>,
30 ),
31}
32
33impl Streams {
34 pub async fn connect(host: &str, port: u16, config: &Config) -> Result<Self> {
35 #[cfg(any(feature = "native-tls", feature = "rustls"))]
36 if let Some(tls_config) = &config.tls_config {
37 let (reader, writer) =
38 tcp_tls_connect(host, port, tls_config, config.connect_timeout).await?;
39 let framed_read = FramedRead::new(reader, BufferDecoder);
40 let framed_write = FramedWrite::new(writer, CommandEncoder);
41 Ok(Streams::TcpTls(framed_read, framed_write))
42 } else {
43 Self::connect_non_secure(host, port, config).await
44 }
45
46 #[cfg(not(any(feature = "native-tls", feature = "rustls")))]
47 Self::connect_non_secure(host, port, config).await
48 }
49
50 pub async fn connect_non_secure(host: &str, port: u16, config: &Config) -> Result<Self> {
51 let (reader, writer) = tcp_connect(host, port, config).await?;
52 let framed_read = FramedRead::new(reader, BufferDecoder);
53 let framed_write = FramedWrite::new(writer, CommandEncoder);
54 Ok(Streams::Tcp(framed_read, framed_write))
55 }
56}
57
58pub struct StandaloneConnection {
59 host: String,
60 port: u16,
61 config: Config,
62 streams: Streams,
63 buffer: BytesMut,
64 version: String,
65 tag: String,
66}
67
68impl StandaloneConnection {
69 pub async fn connect(host: &str, port: u16, config: &Config) -> Result<Self> {
70 let streams = Streams::connect(host, port, config).await?;
71
72 let mut connection = Self {
73 host: host.to_owned(),
74 port,
75 config: config.clone(),
76 streams,
77 buffer: BytesMut::new(),
78 version: String::new(),
79 tag: if config.connection_name.is_empty() {
80 format!("{host}:{port}")
81 } else {
82 format!("{}:{}:{}", config.connection_name, host, port)
83 },
84 };
85
86 connection.post_connect().await?;
87
88 Ok(connection)
89 }
90
91 pub async fn write(&mut self, command: &Command) -> Result<()> {
92 if log_enabled!(Level::Debug) {
93 debug!("[{}] Sending command: {command}", self.tag);
94 }
95 match &mut self.streams {
96 Streams::Tcp(_, framed_write) => framed_write.send(command).await,
97 #[cfg(any(feature = "native-tls", feature = "rustls"))]
98 Streams::TcpTls(_, framed_write) => framed_write.send(command).await,
99 }
100 }
101
102 pub async fn write_batch(
103 &mut self,
104 commands: SmallVec<[&mut Command; 10]>,
105 _retry_reasons: &[RetryReason],
106 ) -> Result<()> {
107 self.buffer.clear();
108
109 let command_encoder = match &mut self.streams {
110 Streams::Tcp(_, framed_write) => framed_write.encoder_mut(),
111 #[cfg(any(feature = "native-tls", feature = "rustls"))]
112 Streams::TcpTls(_, framed_write) => framed_write.encoder_mut(),
113 };
114
115 #[cfg(debug_assertions)]
116 let mut kill_connection = false;
117
118 for command in commands {
119 if log_enabled!(Level::Debug) {
120 debug!("[{}] Sending command: {command}", self.tag);
121 }
122
123 #[cfg(debug_assertions)]
124 if command.kill_connection_on_write > 0 {
125 kill_connection = true;
126 command.kill_connection_on_write -= 1;
127 }
128
129 command_encoder.encode(command, &mut self.buffer)?;
130 }
131
132 #[cfg(debug_assertions)]
133 if kill_connection {
134 let client_id = self.client_id().await?;
135 let mut config = self.config.clone();
136 "killer".clone_into(&mut config.connection_name);
137 let mut connection =
138 StandaloneConnection::connect(&self.host, self.port, &config).await?;
139 connection
140 .client_kill(crate::commands::ClientKillOptions::default().id(client_id))
141 .await?;
142 }
143
144 match &mut self.streams {
145 Streams::Tcp(_, framed_write) => framed_write.get_mut().write_all(&self.buffer).await?,
146 #[cfg(any(feature = "native-tls", feature = "rustls"))]
147 Streams::TcpTls(_, framed_write) => {
148 framed_write.get_mut().write_all(&self.buffer).await?
149 }
150 }
151
152 Ok(())
153 }
154
155 pub async fn read(&mut self) -> Option<Result<RespBuf>> {
156 if let Some(result) = match &mut self.streams {
157 Streams::Tcp(framed_read, _) => framed_read.next().await,
158 #[cfg(any(feature = "native-tls", feature = "rustls"))]
159 Streams::TcpTls(framed_read, _) => framed_read.next().await,
160 } {
161 if log_enabled!(Level::Debug) {
162 match &result {
163 Ok(bytes) => debug!("[{}] Received result {bytes}", self.tag),
164 Err(err) => debug!("[{}] Received result {err:?}", self.tag),
165 }
166 }
167 Some(result)
168 } else {
169 debug!("[{}] Socked is closed", self.tag);
170 None
171 }
172 }
173
174 pub async fn reconnect(&mut self) -> Result<()> {
175 self.streams = Streams::connect(&self.host, self.port, &self.config).await?;
176 self.post_connect().await?;
177
178 Ok(())
179
180 }
182
183 async fn post_connect(&mut self) -> Result<()> {
184 let mut hello_options = HelloOptions::new(3);
186
187 let config_username = self.config.username.clone();
188 let config_password = self.config.password.clone();
189 let config_connection_name = self.config.connection_name.clone();
190
191 if let Some(password) = &config_password {
193 hello_options = hello_options.auth(
194 match &config_username {
195 Some(username) => username,
196 None => "default",
197 },
198 password,
199 );
200 }
201
202 if !config_connection_name.is_empty() {
204 hello_options = hello_options.set_name(&config_connection_name);
205 }
206
207 let hello_result = self.hello(hello_options).await?;
208 self.version = hello_result.version;
209
210 if self.config.database != 0 {
212 self.select(self.config.database).await?;
213 }
214
215 Ok(())
216 }
217
218 pub fn get_version(&self) -> &str {
219 &self.version
220 }
221
222 pub(crate) fn tag(&self) -> &str {
223 &self.tag
224 }
225}
226
227impl<'a, R> IntoFuture for PreparedCommand<'a, &'a mut StandaloneConnection, R>
228where
229 R: DeserializeOwned + Send + 'a,
230{
231 type Output = Result<R>;
232 type IntoFuture = Future<'a, R>;
233
234 fn into_future(self) -> Self::IntoFuture {
235 Box::pin(async move {
236 self.executor.write(&self.command).await?;
237
238 let resp_buf = self.executor.read().await.ok_or_else(|| {
239 Error::Client(format!("[{}] disconnected by peer", self.executor.tag()))
240 })??;
241
242 resp_buf.to()
243 })
244 }
245}
246
247impl<'a> ClusterCommands<'a> for &'a mut StandaloneConnection {}
248impl<'a> ConnectionCommands<'a> for &'a mut StandaloneConnection {}
249impl<'a> SentinelCommands<'a> for &'a mut StandaloneConnection {}
250impl<'a> ServerCommands<'a> for &'a mut StandaloneConnection {}