1use crate::{
21 CasVersion, ErrorCode, KeyValuePair, KeyValuePairs, MetaData, ProtocolVersion, RequestPattern,
22 TransactionId, TypedKeyValuePair, TypedKeyValuePairs, Value, Version,
23 error::{ConnectionError, ConnectionResult},
24};
25use serde::{Deserialize, Serialize, de::DeserializeOwned};
26use std::{
27 fmt::{self, Display},
28 io,
29 time::Duration,
30};
31use tokio::{
32 io::{AsyncRead, AsyncWriteExt, BufReader, Lines},
33 time::timeout,
34};
35use tracing::{debug, error, trace};
36
37#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
38#[serde(rename_all = "camelCase")]
39pub enum ServerMessage {
40 Welcome(Welcome),
41 PState(PState),
42 Ack(Ack),
43 State(State),
44 CState(CState),
45 Err(Err),
46 Authorized(Ack),
47 LsState(LsState),
48}
49
50impl ServerMessage {
51 pub fn transaction_id(&self) -> Option<TransactionId> {
52 match self {
53 ServerMessage::Welcome(_) => None,
54 ServerMessage::PState(msg) => Some(msg.transaction_id),
55 ServerMessage::Ack(msg) => Some(msg.transaction_id),
56 ServerMessage::State(msg) => Some(msg.transaction_id),
57 ServerMessage::CState(msg) => Some(msg.transaction_id),
58 ServerMessage::Err(msg) => Some(msg.transaction_id),
59 ServerMessage::LsState(msg) => Some(msg.transaction_id),
60 ServerMessage::Authorized(_) => Some(0),
61 }
62 }
63}
64
65#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
66#[serde(rename_all = "camelCase")]
67pub struct Welcome {
68 pub info: ServerInfo,
69 pub client_id: String,
70}
71
72#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
73#[serde(rename_all = "camelCase")]
74pub struct PState {
75 pub transaction_id: TransactionId,
76 pub request_pattern: RequestPattern,
77 #[serde(flatten)]
78 pub event: PStateEvent,
79}
80
81#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
82#[serde(rename_all = "camelCase")]
83pub enum PStateEvent {
84 KeyValuePairs(KeyValuePairs),
85 Deleted(KeyValuePairs),
86}
87
88#[derive(Debug, Clone, PartialEq, Eq)]
89pub enum TypedPStateEvent<T: DeserializeOwned> {
90 KeyValuePairs(TypedKeyValuePairs<T>),
91 Deleted(TypedKeyValuePairs<T>),
92}
93
94impl<T: DeserializeOwned> TryFrom<PStateEvent> for TypedPStateEvent<T> {
95 type Error = serde_json::Error;
96
97 fn try_from(value: PStateEvent) -> Result<Self, Self::Error> {
98 match value {
99 PStateEvent::KeyValuePairs(kvps) => Ok(TypedPStateEvent::KeyValuePairs(
100 try_to_typed_key_value_pairs(kvps)?,
101 )),
102 PStateEvent::Deleted(kvps) => Ok(TypedPStateEvent::KeyValuePairs(
103 try_to_typed_key_value_pairs(kvps)?,
104 )),
105 }
106 }
107}
108
109fn try_to_typed_key_value_pairs<T: DeserializeOwned>(
110 kvps: KeyValuePairs,
111) -> Result<TypedKeyValuePairs<T>, serde_json::Error> {
112 let mut out = vec![];
113
114 for kvp in kvps {
115 out.push(kvp.try_into()?);
116 }
117
118 Ok(out)
119}
120
121pub type TypedPStateEvents<T> = Vec<TypedPStateEvent<T>>;
122
123impl From<PStateEvent> for Vec<Option<Value>> {
124 fn from(e: PStateEvent) -> Self {
125 match e {
126 PStateEvent::KeyValuePairs(kvps) => kvps.into_iter().map(KeyValuePair::into).collect(),
127 PStateEvent::Deleted(keys) => keys.into_iter().map(|_| Option::None).collect(),
128 }
129 }
130}
131
132impl From<PState> for Vec<Option<Value>> {
133 fn from(pstate: PState) -> Self {
134 pstate.event.into()
135 }
136}
137
138impl fmt::Display for PState {
139 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
140 match &self.event {
141 PStateEvent::KeyValuePairs(key_value_pairs) => {
142 let kvps: Vec<String> = key_value_pairs
143 .iter()
144 .map(|kvp| format!("{}={}", kvp.key, kvp.value))
145 .collect();
146 let joined = kvps.join("\n");
147 write!(f, "{joined}")
148 }
149 PStateEvent::Deleted(key_value_pairs) => {
150 let kvps: Vec<String> = key_value_pairs
151 .iter()
152 .map(|kvp| format!("{}!={}", kvp.key, kvp.value))
153 .collect();
154 let joined = kvps.join("\n");
155 write!(f, "{joined}")
156 }
157 }
158 }
159}
160
161#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
162#[serde(rename_all = "camelCase")]
163pub struct Ack {
164 pub transaction_id: TransactionId,
165}
166
167impl fmt::Display for Ack {
168 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
169 write!(f, "ack {}", self.transaction_id)
170 }
171}
172
173#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
174#[serde(rename_all = "camelCase")]
175pub struct State {
176 pub transaction_id: TransactionId,
177 #[serde(flatten)]
178 pub event: StateEvent,
179}
180
181#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
182#[serde(rename_all = "camelCase")]
183pub enum StateEvent {
184 Value(Value),
185 Deleted(Value),
186}
187
188#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
189#[serde(rename_all = "camelCase")]
190pub struct CState {
191 pub transaction_id: TransactionId,
192 #[serde(flatten)]
193 pub event: CStateEvent,
194}
195
196#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
197#[serde(rename_all = "camelCase")]
198pub struct CStateEvent {
199 pub value: Value,
200 pub version: CasVersion,
201}
202
203impl From<StateEvent> for Option<Value> {
204 fn from(e: StateEvent) -> Self {
205 match e {
206 StateEvent::Value(v) => Some(v),
207 StateEvent::Deleted(_) => None,
208 }
209 }
210}
211
212impl From<State> for Option<Value> {
213 fn from(state: State) -> Self {
214 state.event.into()
215 }
216}
217
218#[derive(Debug, Clone, PartialEq, Eq)]
219pub enum TypedStateEvent<T: DeserializeOwned> {
220 Value(T),
221 Deleted(T),
222}
223
224impl<T: DeserializeOwned> From<TypedStateEvent<T>> for Option<T> {
225 fn from(e: TypedStateEvent<T>) -> Self {
226 match e {
227 TypedStateEvent::Value(v) => Some(v),
228 TypedStateEvent::Deleted(_) => None,
229 }
230 }
231}
232
233impl<T: DeserializeOwned> From<TypedKeyValuePair<T>> for TypedStateEvent<T> {
234 fn from(kvp: TypedKeyValuePair<T>) -> Self {
235 TypedStateEvent::Value(kvp.value)
236 }
237}
238
239impl<T: DeserializeOwned + TryFrom<Value, Error = serde_json::Error>> TryFrom<StateEvent>
240 for TypedStateEvent<T>
241{
242 type Error = serde_json::Error;
243
244 fn try_from(e: StateEvent) -> Result<Self, Self::Error> {
245 match e {
246 StateEvent::Value(v) => Ok(TypedStateEvent::Value(v.try_into()?)),
247 StateEvent::Deleted(v) => Ok(TypedStateEvent::Deleted(v.try_into()?)),
248 }
249 }
250}
251
252pub type TypedStateEvents<T> = Vec<TypedStateEvent<T>>;
253
254impl fmt::Display for State {
255 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
256 match &self.event {
257 StateEvent::Value(v) => write!(f, "{v}"),
258 StateEvent::Deleted(v) => write!(f, "!{v}"),
259 }
260 }
261}
262
263#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
264#[serde(rename_all = "camelCase")]
265pub struct Err {
266 pub transaction_id: TransactionId,
267 pub error_code: ErrorCode,
268 pub metadata: MetaData,
269}
270
271impl fmt::Display for Err {
272 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
273 write!(f, "server error {}: {}", self.error_code, self.metadata)
274 }
275}
276
277impl std::error::Error for Err {}
278
279#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
280#[serde(rename_all = "camelCase")]
281pub struct Handshake {
282 pub protocol_version: ProtocolVersion,
283}
284
285impl fmt::Display for Handshake {
286 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
287 write!(
288 f,
289 "handshake: supported protocol versions: {}",
290 self.protocol_version
291 )
292 }
293}
294
295#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
296#[serde(rename_all = "camelCase")]
297pub struct LsState {
298 pub transaction_id: TransactionId,
299 pub children: Vec<String>,
300}
301
302impl fmt::Display for LsState {
303 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
304 write!(
305 f,
306 "{}",
307 self.children
308 .iter()
309 .map(escape_path_segment)
310 .reduce(|a, b| format!("{a}\t{b}"))
311 .unwrap_or("".to_owned())
312 )
313 }
314}
315
316fn escape_path_segment(str: impl AsRef<str>) -> String {
317 let str = str.as_ref();
318 let white = str.contains(char::is_whitespace);
319 let single_quote = str.contains('\'');
320 let quote = str.contains('"');
321
322 if (quote || white) && !single_quote {
323 format!("'{str}'")
324 } else if single_quote {
325 str.replace('\'', r#"\'"#)
326 } else {
327 str.to_owned()
328 }
329}
330
331#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
332#[serde(rename_all = "camelCase")]
333pub struct ServerInfo {
334 pub version: Version,
335 pub supported_protocol_versions: Box<[ProtocolVersion]>,
336 #[deprecated(since = "1.1.0", note = "replaced by `supported_protocol_versions`")]
337 protocol_version: String,
338 pub authorization_required: bool,
339}
340
341#[allow(deprecated)]
342impl ServerInfo {
343 pub fn new(
344 version: Version,
345 supported_protocol_versions: Box<[ProtocolVersion]>,
346 authorization_required: bool,
347 ) -> Self {
348 Self {
349 version,
350 supported_protocol_versions,
351 protocol_version: "0.11".to_owned(),
352 authorization_required,
353 }
354 }
355}
356
357pub async fn write_line_and_flush(
358 msg: impl Serialize,
359 mut tx: impl AsyncWriteExt + Unpin,
360 send_timeout: Option<Duration>,
361 remote: impl Display,
362) -> ConnectionResult<()> {
363 let mut json = serde_json::to_string(&msg)?;
364 if json.contains('\n') {
365 return Err(ConnectionError::IoError(Box::new(io::Error::new(
366 io::ErrorKind::InvalidData,
367 format!("invalid JSON: '{json}' contains line break"),
368 ))));
369 }
370 if json.trim().is_empty() {
371 return Err(ConnectionError::IoError(Box::new(io::Error::new(
372 io::ErrorKind::InvalidData,
373 format!("invalid JSON: '{json}' is empty"),
374 ))));
375 }
376
377 json.push('\n');
378 let bytes = json.as_bytes();
379
380 debug!("Sending message with timeout {send_timeout:?}: {json}");
381 trace!("Writing line …");
382 for chunk in bytes.chunks(1024) {
383 let mut written = 0;
384 while written < chunk.len() {
385 if let Some(send_timeout) = send_timeout {
386 written += timeout(send_timeout, tx.write(&chunk[written..]))
387 .await
388 .map_err(|_| {
389 ConnectionError::Timeout(Box::new(format!(
390 "timeout while sending tcp message to {remote}"
391 )))
392 })??;
393 } else {
394 written += tx.write(&chunk[written..]).await?;
395 }
396 }
397 }
398 trace!("Writing line done.");
399 trace!("Flushing channel …");
400 if let Some(send_timeout) = send_timeout {
401 timeout(send_timeout, tx.flush()).await.map_err(|_| {
402 ConnectionError::Timeout(Box::new(format!(
403 "timeout while sending tcp message to {remote}"
404 )))
405 })??;
406 } else {
407 tx.flush().await?;
408 }
409 trace!("Flushing channel done.");
410
411 Ok(())
412}
413
414pub async fn receive_msg<T: DeserializeOwned, R: AsyncRead + Unpin>(
415 rx: &mut Lines<BufReader<R>>,
416) -> ConnectionResult<Option<T>> {
417 let read = rx.next_line().await;
418 match read {
419 Ok(None) => Ok(None),
420 Ok(Some(json)) => {
421 debug!("Received message: {json}");
422 let sm = serde_json::from_str(&json);
423 if let Err(e) = &sm {
424 error!("Error deserializing message '{json}': {e}")
425 }
426 Ok(sm?)
427 }
428 Err(e) => Err(e.into()),
429 }
430}
431
432#[cfg(test)]
433mod test {
434 use serde_json::json;
435
436 use super::*;
437
438 #[test]
439 fn state_is_serialized_correctly() {
440 let state = State {
441 transaction_id: 1,
442 event: StateEvent::Value(json!(2)),
443 };
444
445 let json = r#"{"transactionId":1,"value":2}"#;
446
447 assert_eq!(json, &serde_json::to_string(&state).unwrap());
448
449 let state = State {
450 transaction_id: 1,
451 event: StateEvent::Deleted(json!(2)),
452 };
453
454 let json = r#"{"transactionId":1,"deleted":2}"#;
455
456 assert_eq!(json, &serde_json::to_string(&state).unwrap());
457 }
458
459 #[test]
460 fn state_is_deserialized_correctly() {
461 let state = State {
462 transaction_id: 1,
463 event: StateEvent::Value(json!(2)),
464 };
465
466 let json = r#"{"transactionId":1,"value":2}"#;
467
468 assert_eq!(state, serde_json::from_str(json).unwrap());
469
470 let state = State {
471 transaction_id: 1,
472 event: StateEvent::Deleted(json!(2)),
473 };
474
475 let json = r#"{"transactionId":1,"deleted":2}"#;
476
477 assert_eq!(state, serde_json::from_str(json).unwrap());
478 }
479
480 #[test]
481 fn pstate_is_serialized_correctly() {
482 let pstate = PState {
483 transaction_id: 1,
484 request_pattern: "$SYS/clients".to_owned(),
485 event: PStateEvent::KeyValuePairs(vec![KeyValuePair::of("$SYS/clients", 2)]),
486 };
487
488 let json = r#"{"transactionId":1,"requestPattern":"$SYS/clients","keyValuePairs":[{"key":"$SYS/clients","value":2}]}"#;
489
490 assert_eq!(json, &serde_json::to_string(&pstate).unwrap());
491
492 let pstate = PState {
493 transaction_id: 1,
494 request_pattern: "$SYS/clients".to_owned(),
495 event: PStateEvent::Deleted(vec![KeyValuePair::of("$SYS/clients", 2)]),
496 };
497
498 let json = r#"{"transactionId":1,"requestPattern":"$SYS/clients","deleted":[{"key":"$SYS/clients","value":2}]}"#;
499
500 assert_eq!(json, &serde_json::to_string(&pstate).unwrap());
501 }
502
503 #[test]
504 fn pstate_is_deserialized_correctly() {
505 let pstate = PState {
506 transaction_id: 1,
507 request_pattern: "$SYS/clients".to_owned(),
508 event: PStateEvent::KeyValuePairs(vec![KeyValuePair::of("$SYS/clients", 2)]),
509 };
510
511 let json = r#"{"transactionId":1,"requestPattern":"$SYS/clients","keyValuePairs":[{"key":"$SYS/clients","value":2}]}"#;
512
513 assert_eq!(pstate, serde_json::from_str(json).unwrap());
514
515 let pstate = PState {
516 transaction_id: 1,
517 request_pattern: "$SYS/clients".to_owned(),
518 event: PStateEvent::Deleted(vec![KeyValuePair::of("$SYS/clients", 2)]),
519 };
520
521 let json = r#"{"transactionId":1,"requestPattern":"$SYS/clients","deleted":[{"key":"$SYS/clients","value":2}]}"#;
522
523 assert_eq!(pstate, serde_json::from_str(json).unwrap());
524 }
525}