1use crate::*;
19use ::std::net::{SocketAddr, ToSocketAddrs};
20use ::std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
21use snafu::{OptionExt, ResultExt};
22
23const ERR_NAME_TAKEN: usize = 513;
24pub const MAX_LEN_NAME: usize = 20;
26
27pub const DEFAULT_TIMEOUT_CONN: Duration = Duration::from_millis(1500);
29pub const DEFAULT_TIMEOUT_CMD: Duration = Duration::from_millis(1500);
31const KEY_CLIENT_ID_SELF: &str = "client_id";
33
34#[derive(Clone)]
47pub struct ManagedConfig {
48 addr: SocketAddr,
49 user: String,
50 password: String,
51 server_port: u16,
52 conn_timeout: Duration,
53 cmd_timeout: Duration,
54 name: Option<String>,
55}
56
57impl ManagedConfig {
58 pub fn new<A: ToSocketAddrs>(
60 addr: A,
61 server_port: u16,
62 user: String,
63 password: String,
64 ) -> Result<Self> {
65 Ok(Self {
66 addr: addr
67 .to_socket_addrs()
68 .context(Io {
69 context: "invalid socket address",
70 })?
71 .next()
72 .context(InvalidSocketAddress {})?,
73 user,
74 password,
75 server_port,
76 name: Default::default(),
77 conn_timeout: DEFAULT_TIMEOUT_CONN,
78 cmd_timeout: DEFAULT_TIMEOUT_CMD,
79 })
80 }
81
82 pub fn name(mut self, name: String) -> Self {
89 self.name = Some(name);
90 self
91 }
92
93 pub fn connection_timeout(mut self, timeout: Duration) -> Self {
95 self.conn_timeout = timeout;
96 self
97 }
98
99 pub fn timeout(mut self, timeout: Duration) -> Self {
101 self.cmd_timeout = timeout;
102 self
103 }
104}
105
106pub struct ManagedConnection {
108 cfg: ManagedConfig,
109 conn: QueryClient,
110 last_ping: Instant,
111 conn_id: Option<ClientId>,
112}
113
114impl ManagedConnection {
115 fn connect(cfg: &ManagedConfig) -> Result<QueryClient> {
116 let mut conn =
118 QueryClient::with_timeout(&cfg.addr, Some(cfg.conn_timeout), Some(cfg.cmd_timeout))?;
119 conn.login(&cfg.user, &cfg.password)?;
120 conn.select_server_by_port(cfg.server_port)?;
121 if let Some(n) = cfg.name.as_ref() {
122 if n.len() > MAX_LEN_NAME {
124 return InvalidNameLength {
125 length: n.len(),
126 expected: MAX_LEN_NAME,
127 }
128 .fail();
129 }
130 Self::set_name_fallback(&mut conn, n)?;
131 }
132 Ok(conn)
133 }
134
135 fn set_name_fallback(conn: &mut QueryClient, name: &str) -> Result<()> {
137 if let Err(e) = conn.rename(name) {
138 if e.error_response().map_or(true, |r| r.id != ERR_NAME_TAKEN) {
139 return Err(e);
140 } else {
141 conn.rename(&Self::calc_name_retry(name))?;
142 }
143 }
144 Ok(())
145 }
146
147 fn calc_name_retry(name: &str) -> String {
149 let name = if name.len() >= MAX_LEN_NAME - 2 {
151 &name[0..MAX_LEN_NAME / 2]
152 } else {
153 name
154 };
155 let time = SystemTime::now()
156 .duration_since(UNIX_EPOCH)
157 .unwrap()
158 .as_millis()
159 .to_string();
160 let reamining = MAX_LEN_NAME - name.len();
161 let time = if reamining > time.len() {
162 &time
163 } else {
164 &time.as_str()[time.len() - reamining..]
165 };
166
167 format!("{}{}", name, time)
168 }
169
170 pub fn conn_id(&mut self) -> Result<ClientId> {
172 Ok(match self.conn_id {
173 Some(v) => v,
174 None => {
175 let mut res = self.get()?.whoami(false)?;
176 let clid = crate::raw::int_val_parser(&mut res, KEY_CLIENT_ID_SELF)?;
177 self.conn_id = Some(clid);
178 clid
179 }
180 })
181 }
182
183 pub fn clone(&self, new_name: Option<String>) -> Result<Self> {
186 let mut cfg = self.cfg.clone();
187 if new_name.is_some() {
188 cfg.name = new_name;
189 }
190 Self::new(self.cfg.clone())
191 }
192
193 pub fn new(config: ManagedConfig) -> Result<ManagedConnection> {
195 let conn = Self::connect(&config)?;
196 Ok(Self {
197 conn,
198 cfg: config,
199 last_ping: Instant::now(),
200 conn_id: None,
201 })
202 }
203
204 pub fn force_reconnect(&mut self) -> Result<()> {
207 self.conn = Self::connect(&self.cfg)?;
208 self.conn_id = None;
209 Ok(())
210 }
211
212 pub fn get(&mut self) -> Result<&mut QueryClient> {
215 if self.last_ping.elapsed() < Duration::from_secs(0) {
216 return Ok(&mut self.conn);
217 }
218 let conn = match self.conn.ping() {
219 Ok(_) => &mut self.conn,
220 Err(_) => {
221 self.force_reconnect()?;
222 &mut self.conn
223 }
224 };
225 self.last_ping = Instant::now();
226 Ok(conn)
227 }
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233
234 #[test]
235 fn test_name_empty() {
236 let name = ManagedConnection::calc_name_retry("");
237 assert!(name.len() <= MAX_LEN_NAME);
238 assert!(name.len() > 0);
239 dbg!(name);
240 }
241
242 #[test]
243 fn test_name_fallback_normal() {
244 let name = ManagedConnection::calc_name_retry("ct_bot-fallback");
246
247 assert!(name.starts_with("ct_bot-fallback"));
248 assert!(name.len() <= MAX_LEN_NAME);
249 assert!(name.len() > "ct_bot-fallback".len());
250 dbg!(name);
251 }
252
253 #[test]
254 fn test_name_fallback_underflow() {
255 let name = ManagedConnection::calc_name_retry("ct_bot");
257
258 assert!(name.starts_with("ct_bot"));
259 assert!(name.len() <= MAX_LEN_NAME);
260 assert!(name.len() > "ct_bot".len());
261 dbg!(name);
262 }
263
264 #[test]
265 fn test_name_fallback_fit() {
266 {
267 let name_input = "1234567890123456789D";
269 let name = ManagedConnection::calc_name_retry(name_input);
270 dbg!(&name);
271 assert!(name.starts_with(&name_input[..MAX_LEN_NAME / 2]));
272 assert!(name.len() <= MAX_LEN_NAME);
273 }
274
275 assert!(MAX_LEN_NAME > 3);
277 {
278 let name_input = "123456789012345678";
280 let name = ManagedConnection::calc_name_retry(name_input);
281 dbg!(&name);
282 assert!(name.starts_with(&name_input[..MAX_LEN_NAME / 2]));
283 assert!(name.len() <= MAX_LEN_NAME);
284 }
285 }
286
287 #[test]
288 fn test_name_fallback_overflow() {
289 let name_input = "1234567890123456789012345678901234567890";
291 assert!(name_input.len() > MAX_LEN_NAME);
292 let name = ManagedConnection::calc_name_retry(name_input);
293 dbg!(&name);
294 assert!(name.starts_with(&name_input[..MAX_LEN_NAME / 2]));
295 assert!(name.len() <= MAX_LEN_NAME);
296 }
297}