1use crate::{
21 CasVersion, ErrorCode, KeyValuePair, KeyValuePairs, MetaData, ProtocolVersion, RequestPattern,
22 TransactionId, TypedKeyValuePair, TypedKeyValuePairs, Value, Version,
23 error::{ConnectionError, ConnectionResult},
24};
25use serde::{Deserialize, Serialize, de::DeserializeOwned};
26use std::{
27 fmt::{self, Display},
28 io,
29 time::Duration,
30};
31use tokio::{
32 io::{AsyncRead, AsyncWriteExt, BufReader, Lines},
33 time::timeout,
34};
35use tracing::{debug, error, trace};
36
37#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
38#[serde(rename_all = "camelCase")]
39pub enum ServerMessage {
40 Welcome(Welcome),
41 PState(PState),
42 Ack(Ack),
43 State(State),
44 CState(CState),
45 Err(Err),
46 Authorized(Ack),
47 LsState(LsState),
48}
49
50impl ServerMessage {
51 pub fn transaction_id(&self) -> Option<TransactionId> {
52 match self {
53 ServerMessage::Welcome(_) => None,
54 ServerMessage::PState(msg) => Some(msg.transaction_id),
55 ServerMessage::Ack(msg) => Some(msg.transaction_id),
56 ServerMessage::State(msg) => Some(msg.transaction_id),
57 ServerMessage::CState(msg) => Some(msg.transaction_id),
58 ServerMessage::Err(msg) => Some(msg.transaction_id),
59 ServerMessage::LsState(msg) => Some(msg.transaction_id),
60 ServerMessage::Authorized(_) => Some(0),
61 }
62 }
63}
64
65#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
66#[serde(rename_all = "camelCase")]
67pub struct Welcome {
68 pub info: ServerInfo,
69 pub client_id: String,
70}
71
72#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
73#[serde(rename_all = "camelCase")]
74pub struct PState {
75 pub transaction_id: TransactionId,
76 pub request_pattern: RequestPattern,
77 #[serde(flatten)]
78 pub event: PStateEvent,
79}
80
81#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
82#[serde(rename_all = "camelCase")]
83pub enum PStateEvent {
84 KeyValuePairs(KeyValuePairs),
85 Deleted(KeyValuePairs),
86}
87
88#[derive(Debug, Clone, PartialEq, Eq)]
89pub enum TypedPStateEvent<T: DeserializeOwned> {
90 KeyValuePairs(TypedKeyValuePairs<T>),
91 Deleted(TypedKeyValuePairs<T>),
92}
93
94impl<T: DeserializeOwned> TryFrom<PStateEvent> for TypedPStateEvent<T> {
95 type Error = serde_json::Error;
96
97 fn try_from(value: PStateEvent) -> Result<Self, Self::Error> {
98 match value {
99 PStateEvent::KeyValuePairs(kvps) => Ok(TypedPStateEvent::KeyValuePairs(
100 try_to_typed_key_value_pairs(kvps)?,
101 )),
102 PStateEvent::Deleted(kvps) => Ok(TypedPStateEvent::KeyValuePairs(
103 try_to_typed_key_value_pairs(kvps)?,
104 )),
105 }
106 }
107}
108
109fn try_to_typed_key_value_pairs<T: DeserializeOwned>(
110 kvps: KeyValuePairs,
111) -> Result<TypedKeyValuePairs<T>, serde_json::Error> {
112 let mut out = vec![];
113
114 for kvp in kvps {
115 out.push(kvp.try_into()?);
116 }
117
118 Ok(out)
119}
120
121pub type TypedPStateEvents<T> = Vec<TypedPStateEvent<T>>;
122
123impl From<PStateEvent> for Vec<Option<Value>> {
124 fn from(e: PStateEvent) -> Self {
125 match e {
126 PStateEvent::KeyValuePairs(kvps) => kvps.into_iter().map(KeyValuePair::into).collect(),
127 PStateEvent::Deleted(keys) => keys.into_iter().map(|_| Option::None).collect(),
128 }
129 }
130}
131
132impl From<PState> for Vec<Option<Value>> {
133 fn from(pstate: PState) -> Self {
134 pstate.event.into()
135 }
136}
137
138impl fmt::Display for PState {
139 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
140 match &self.event {
141 PStateEvent::KeyValuePairs(key_value_pairs) => {
142 let kvps: Vec<String> = key_value_pairs
143 .iter()
144 .map(|kvp| format!("{}={}", kvp.key, kvp.value))
145 .collect();
146 let joined = kvps.join("\n");
147 write!(f, "{joined}")
148 }
149 PStateEvent::Deleted(key_value_pairs) => {
150 let kvps: Vec<String> = key_value_pairs
151 .iter()
152 .map(|kvp| format!("{}!={}", kvp.key, kvp.value))
153 .collect();
154 let joined = kvps.join("\n");
155 write!(f, "{joined}")
156 }
157 }
158 }
159}
160
161#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
162#[serde(rename_all = "camelCase")]
163pub struct Ack {
164 pub transaction_id: TransactionId,
165}
166
167impl fmt::Display for Ack {
168 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
169 write!(f, "ack {}", self.transaction_id)
170 }
171}
172
173#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
174#[serde(rename_all = "camelCase")]
175pub struct State {
176 pub transaction_id: TransactionId,
177 #[serde(flatten)]
178 pub event: StateEvent,
179}
180
181#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
182#[serde(rename_all = "camelCase")]
183pub enum StateEvent {
184 Value(Value),
185 Deleted(Value),
186}
187
188#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
189#[serde(rename_all = "camelCase")]
190pub struct CState {
191 pub transaction_id: TransactionId,
192 #[serde(flatten)]
193 pub event: CStateEvent,
194}
195
196#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
197#[serde(rename_all = "camelCase")]
198pub struct CStateEvent {
199 pub value: Value,
200 pub version: CasVersion,
201}
202
203impl From<StateEvent> for Option<Value> {
204 fn from(e: StateEvent) -> Self {
205 match e {
206 StateEvent::Value(v) => Some(v),
207 StateEvent::Deleted(_) => None,
208 }
209 }
210}
211
212impl From<State> for Option<Value> {
213 fn from(state: State) -> Self {
214 state.event.into()
215 }
216}
217
218#[derive(Debug, Clone, PartialEq, Eq)]
219pub enum TypedStateEvent<T: DeserializeOwned> {
220 Value(T),
221 Deleted(T),
222}
223
224impl<T: DeserializeOwned> From<TypedStateEvent<T>> for Option<T> {
225 fn from(e: TypedStateEvent<T>) -> Self {
226 match e {
227 TypedStateEvent::Value(v) => Some(v),
228 TypedStateEvent::Deleted(_) => None,
229 }
230 }
231}
232
233impl<T: DeserializeOwned> From<TypedKeyValuePair<T>> for TypedStateEvent<T> {
234 fn from(kvp: TypedKeyValuePair<T>) -> Self {
235 TypedStateEvent::Value(kvp.value)
236 }
237}
238
239impl<T: DeserializeOwned + TryFrom<Value, Error = serde_json::Error>> TryFrom<StateEvent>
240 for TypedStateEvent<T>
241{
242 type Error = serde_json::Error;
243
244 fn try_from(e: StateEvent) -> Result<Self, Self::Error> {
245 match e {
246 StateEvent::Value(v) => Ok(TypedStateEvent::Value(v.try_into()?)),
247 StateEvent::Deleted(v) => Ok(TypedStateEvent::Deleted(v.try_into()?)),
248 }
249 }
250}
251
252pub type TypedStateEvents<T> = Vec<TypedStateEvent<T>>;
253
254impl fmt::Display for State {
255 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
256 match &self.event {
257 StateEvent::Value(v) => write!(f, "{}", v),
258 StateEvent::Deleted(v) => write!(f, "!{}", v),
259 }
260 }
261}
262
263#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
264#[serde(rename_all = "camelCase")]
265pub struct Err {
266 pub transaction_id: TransactionId,
267 pub error_code: ErrorCode,
268 pub metadata: MetaData,
269}
270
271impl fmt::Display for Err {
272 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
273 write!(f, "server error {}: {}", self.error_code, self.metadata)
274 }
275}
276
277impl std::error::Error for Err {}
278
279#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
280#[serde(rename_all = "camelCase")]
281pub struct Handshake {
282 pub protocol_version: ProtocolVersion,
283}
284
285impl fmt::Display for Handshake {
286 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
287 write!(
288 f,
289 "handshake: supported protocol versions: {}",
290 self.protocol_version
291 )
292 }
293}
294
295#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
296#[serde(rename_all = "camelCase")]
297pub struct LsState {
298 pub transaction_id: TransactionId,
299 pub children: Vec<String>,
300}
301
302impl fmt::Display for LsState {
303 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
304 write!(
305 f,
306 "{}",
307 self.children
308 .iter()
309 .map(escape_path_segment)
310 .reduce(|a, b| format!("{a}\t{b}"))
311 .unwrap_or("".to_owned())
312 )
313 }
314}
315
316fn escape_path_segment(str: impl AsRef<str>) -> String {
317 let str = str.as_ref();
318 let white = str.contains(char::is_whitespace);
319 let single_quote = str.contains('\'');
320 let quote = str.contains('"');
321
322 if (quote || white) && !single_quote {
323 format!("'{str}'")
324 } else if single_quote {
325 str.replace('\'', r#"\'"#)
326 } else {
327 str.to_owned()
328 }
329}
330
331#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
332#[serde(rename_all = "camelCase")]
333pub struct ServerInfo {
334 pub version: Version,
335 pub supported_protocol_versions: Box<[ProtocolVersion]>,
336 #[deprecated(since = "1.1.0", note = "replaced by `supported_protocol_versions`")]
337 protocol_version: String,
338 pub authorization_required: bool,
339}
340
341#[allow(deprecated)]
342impl ServerInfo {
343 pub fn new(
344 version: Version,
345 supported_protocol_versions: Box<[ProtocolVersion]>,
346 authorization_required: bool,
347 ) -> Self {
348 Self {
349 version,
350 supported_protocol_versions,
351 protocol_version: "0.11".to_owned(),
352 authorization_required,
353 }
354 }
355}
356
357pub async fn write_line_and_flush(
358 msg: impl Serialize,
359 mut tx: impl AsyncWriteExt + Unpin,
360 send_timeout: Duration,
361 remote: impl Display,
362) -> ConnectionResult<()> {
363 let mut json = serde_json::to_string(&msg)?;
364 if json.contains('\n') {
365 return Err(ConnectionError::IoError(io::Error::new(
366 io::ErrorKind::InvalidData,
367 format!("invalid JSON: '{json}' contains line break"),
368 )));
369 }
370 if json.trim().is_empty() {
371 return Err(ConnectionError::IoError(io::Error::new(
372 io::ErrorKind::InvalidData,
373 format!("invalid JSON: '{json}' is empty"),
374 )));
375 }
376
377 json.push('\n');
378 let bytes = json.as_bytes();
379
380 debug!("Sending message: {json}");
381 trace!("Writing line …");
382 for chunk in bytes.chunks(1024) {
383 let mut written = 0;
384 while written < chunk.len() {
385 written += timeout(send_timeout, tx.write(&chunk[written..]))
386 .await
387 .map_err(|_| {
388 ConnectionError::Timeout(format!(
389 "timeout while sending tcp message to {remote}"
390 ))
391 })??;
392 }
393 }
394 trace!("Writing line done.");
395 trace!("Flushing channel …");
396 tx.flush().await?;
397 trace!("Flushing channel done.");
398
399 Ok(())
400}
401
402pub async fn receive_msg<T: DeserializeOwned, R: AsyncRead + Unpin>(
403 rx: &mut Lines<BufReader<R>>,
404) -> ConnectionResult<Option<T>> {
405 let read = rx.next_line().await;
406 match read {
407 Ok(None) => Ok(None),
408 Ok(Some(json)) => {
409 debug!("Received message: {json}");
410 let sm = serde_json::from_str(&json);
411 if let Err(e) = &sm {
412 error!("Error deserializing message '{json}': {e}")
413 }
414 Ok(sm?)
415 }
416 Err(e) => Err(e.into()),
417 }
418}
419
420#[cfg(test)]
421mod test {
422 use serde_json::json;
423
424 use super::*;
425
426 #[test]
427 fn state_is_serialized_correctly() {
428 let state = State {
429 transaction_id: 1,
430 event: StateEvent::Value(json!(2)),
431 };
432
433 let json = r#"{"transactionId":1,"value":2}"#;
434
435 assert_eq!(json, &serde_json::to_string(&state).unwrap());
436
437 let state = State {
438 transaction_id: 1,
439 event: StateEvent::Deleted(json!(2)),
440 };
441
442 let json = r#"{"transactionId":1,"deleted":2}"#;
443
444 assert_eq!(json, &serde_json::to_string(&state).unwrap());
445 }
446
447 #[test]
448 fn state_is_deserialized_correctly() {
449 let state = State {
450 transaction_id: 1,
451 event: StateEvent::Value(json!(2)),
452 };
453
454 let json = r#"{"transactionId":1,"value":2}"#;
455
456 assert_eq!(state, serde_json::from_str(json).unwrap());
457
458 let state = State {
459 transaction_id: 1,
460 event: StateEvent::Deleted(json!(2)),
461 };
462
463 let json = r#"{"transactionId":1,"deleted":2}"#;
464
465 assert_eq!(state, serde_json::from_str(json).unwrap());
466 }
467
468 #[test]
469 fn pstate_is_serialized_correctly() {
470 let pstate = PState {
471 transaction_id: 1,
472 request_pattern: "$SYS/clients".to_owned(),
473 event: PStateEvent::KeyValuePairs(vec![KeyValuePair::of("$SYS/clients", 2)]),
474 };
475
476 let json = r#"{"transactionId":1,"requestPattern":"$SYS/clients","keyValuePairs":[{"key":"$SYS/clients","value":2}]}"#;
477
478 assert_eq!(json, &serde_json::to_string(&pstate).unwrap());
479
480 let pstate = PState {
481 transaction_id: 1,
482 request_pattern: "$SYS/clients".to_owned(),
483 event: PStateEvent::Deleted(vec![KeyValuePair::of("$SYS/clients", 2)]),
484 };
485
486 let json = r#"{"transactionId":1,"requestPattern":"$SYS/clients","deleted":[{"key":"$SYS/clients","value":2}]}"#;
487
488 assert_eq!(json, &serde_json::to_string(&pstate).unwrap());
489 }
490
491 #[test]
492 fn pstate_is_deserialized_correctly() {
493 let pstate = PState {
494 transaction_id: 1,
495 request_pattern: "$SYS/clients".to_owned(),
496 event: PStateEvent::KeyValuePairs(vec![KeyValuePair::of("$SYS/clients", 2)]),
497 };
498
499 let json = r#"{"transactionId":1,"requestPattern":"$SYS/clients","keyValuePairs":[{"key":"$SYS/clients","value":2}]}"#;
500
501 assert_eq!(pstate, serde_json::from_str(json).unwrap());
502
503 let pstate = PState {
504 transaction_id: 1,
505 request_pattern: "$SYS/clients".to_owned(),
506 event: PStateEvent::Deleted(vec![KeyValuePair::of("$SYS/clients", 2)]),
507 };
508
509 let json = r#"{"transactionId":1,"requestPattern":"$SYS/clients","deleted":[{"key":"$SYS/clients","value":2}]}"#;
510
511 assert_eq!(pstate, serde_json::from_str(json).unwrap());
512 }
513}