1use crate::{
21 CasVersion, ErrorCode, KeyValuePair, KeyValuePairs, MetaData, ProtocolVersion, RequestPattern,
22 TransactionId, TypedKeyValuePair, TypedKeyValuePairs, Value, Version,
23};
24use serde::{de::DeserializeOwned, Deserialize, Serialize};
25use std::fmt;
26
27#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
28#[serde(rename_all = "camelCase")]
29pub enum ServerMessage {
30 Welcome(Welcome),
31 PState(PState),
32 Ack(Ack),
33 State(State),
34 CState(CState),
35 Err(Err),
36 Authorized(Ack),
37 LsState(LsState),
38}
39
40impl ServerMessage {
41 pub fn transaction_id(&self) -> Option<TransactionId> {
42 match self {
43 ServerMessage::Welcome(_) => None,
44 ServerMessage::PState(msg) => Some(msg.transaction_id),
45 ServerMessage::Ack(msg) => Some(msg.transaction_id),
46 ServerMessage::State(msg) => Some(msg.transaction_id),
47 ServerMessage::CState(msg) => Some(msg.transaction_id),
48 ServerMessage::Err(msg) => Some(msg.transaction_id),
49 ServerMessage::LsState(msg) => Some(msg.transaction_id),
50 ServerMessage::Authorized(_) => Some(0),
51 }
52 }
53}
54
55#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
56#[serde(rename_all = "camelCase")]
57pub struct Welcome {
58 pub info: ServerInfo,
59 pub client_id: String,
60}
61
62#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
63#[serde(rename_all = "camelCase")]
64pub struct PState {
65 pub transaction_id: TransactionId,
66 pub request_pattern: RequestPattern,
67 #[serde(flatten)]
68 pub event: PStateEvent,
69}
70
71#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
72#[serde(rename_all = "camelCase")]
73pub enum PStateEvent {
74 KeyValuePairs(KeyValuePairs),
75 Deleted(KeyValuePairs),
76}
77
78#[derive(Debug, Clone, PartialEq, Eq)]
79pub enum TypedPStateEvent<T: DeserializeOwned> {
80 KeyValuePairs(TypedKeyValuePairs<T>),
81 Deleted(TypedKeyValuePairs<T>),
82}
83
84impl<T: DeserializeOwned> TryFrom<PStateEvent> for TypedPStateEvent<T> {
85 type Error = serde_json::Error;
86
87 fn try_from(value: PStateEvent) -> Result<Self, Self::Error> {
88 match value {
89 PStateEvent::KeyValuePairs(kvps) => Ok(TypedPStateEvent::KeyValuePairs(
90 try_to_typed_key_value_pairs(kvps)?,
91 )),
92 PStateEvent::Deleted(kvps) => Ok(TypedPStateEvent::KeyValuePairs(
93 try_to_typed_key_value_pairs(kvps)?,
94 )),
95 }
96 }
97}
98
99fn try_to_typed_key_value_pairs<T: DeserializeOwned>(
100 kvps: KeyValuePairs,
101) -> Result<TypedKeyValuePairs<T>, serde_json::Error> {
102 let mut out = vec![];
103
104 for kvp in kvps {
105 out.push(kvp.try_into()?);
106 }
107
108 Ok(out)
109}
110
111pub type TypedPStateEvents<T> = Vec<TypedPStateEvent<T>>;
112
113impl From<PStateEvent> for Vec<Option<Value>> {
114 fn from(e: PStateEvent) -> Self {
115 match e {
116 PStateEvent::KeyValuePairs(kvps) => kvps.into_iter().map(KeyValuePair::into).collect(),
117 PStateEvent::Deleted(keys) => keys.into_iter().map(|_| Option::None).collect(),
118 }
119 }
120}
121
122impl From<PState> for Vec<Option<Value>> {
123 fn from(pstate: PState) -> Self {
124 pstate.event.into()
125 }
126}
127
128impl fmt::Display for PState {
129 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
130 match &self.event {
131 PStateEvent::KeyValuePairs(key_value_pairs) => {
132 let kvps: Vec<String> = key_value_pairs
133 .iter()
134 .map(|kvp| format!("{}={}", kvp.key, kvp.value))
135 .collect();
136 let joined = kvps.join("\n");
137 write!(f, "{joined}")
138 }
139 PStateEvent::Deleted(key_value_pairs) => {
140 let kvps: Vec<String> = key_value_pairs
141 .iter()
142 .map(|kvp| format!("{}!={}", kvp.key, kvp.value))
143 .collect();
144 let joined = kvps.join("\n");
145 write!(f, "{joined}")
146 }
147 }
148 }
149}
150
151#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
152#[serde(rename_all = "camelCase")]
153pub struct Ack {
154 pub transaction_id: TransactionId,
155}
156
157impl fmt::Display for Ack {
158 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
159 write!(f, "ack {}", self.transaction_id)
160 }
161}
162
163#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
164#[serde(rename_all = "camelCase")]
165pub struct State {
166 pub transaction_id: TransactionId,
167 #[serde(flatten)]
168 pub event: StateEvent,
169}
170
171#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
172#[serde(rename_all = "camelCase")]
173pub enum StateEvent {
174 Value(Value),
175 Deleted(Value),
176}
177
178#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
179#[serde(rename_all = "camelCase")]
180pub struct CState {
181 pub transaction_id: TransactionId,
182 #[serde(flatten)]
183 pub event: CStateEvent,
184}
185
186#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
187#[serde(rename_all = "camelCase")]
188pub struct CStateEvent {
189 pub value: Value,
190 pub version: CasVersion,
191}
192
193impl From<StateEvent> for Option<Value> {
194 fn from(e: StateEvent) -> Self {
195 match e {
196 StateEvent::Value(v) => Some(v),
197 StateEvent::Deleted(_) => None,
198 }
199 }
200}
201
202impl From<State> for Option<Value> {
203 fn from(state: State) -> Self {
204 state.event.into()
205 }
206}
207
208#[derive(Debug, Clone, PartialEq, Eq)]
209pub enum TypedStateEvent<T: DeserializeOwned> {
210 Value(T),
211 Deleted(T),
212}
213
214impl<T: DeserializeOwned> From<TypedStateEvent<T>> for Option<T> {
215 fn from(e: TypedStateEvent<T>) -> Self {
216 match e {
217 TypedStateEvent::Value(v) => Some(v),
218 TypedStateEvent::Deleted(_) => None,
219 }
220 }
221}
222
223impl<T: DeserializeOwned> From<TypedKeyValuePair<T>> for TypedStateEvent<T> {
224 fn from(kvp: TypedKeyValuePair<T>) -> Self {
225 TypedStateEvent::Value(kvp.value)
226 }
227}
228
229impl<T: DeserializeOwned + TryFrom<Value, Error = serde_json::Error>> TryFrom<StateEvent>
230 for TypedStateEvent<T>
231{
232 type Error = serde_json::Error;
233
234 fn try_from(e: StateEvent) -> Result<Self, Self::Error> {
235 match e {
236 StateEvent::Value(v) => Ok(TypedStateEvent::Value(v.try_into()?)),
237 StateEvent::Deleted(v) => Ok(TypedStateEvent::Deleted(v.try_into()?)),
238 }
239 }
240}
241
242pub type TypedStateEvents<T> = Vec<TypedStateEvent<T>>;
243
244impl fmt::Display for State {
245 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
246 match &self.event {
247 StateEvent::Value(v) => write!(f, "{}", v),
248 StateEvent::Deleted(v) => write!(f, "!{}", v),
249 }
250 }
251}
252
253#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
254#[serde(rename_all = "camelCase")]
255pub struct Err {
256 pub transaction_id: TransactionId,
257 pub error_code: ErrorCode,
258 pub metadata: MetaData,
259}
260
261impl fmt::Display for Err {
262 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
263 write!(f, "server error {}: {}", self.error_code, self.metadata)
264 }
265}
266
267impl std::error::Error for Err {}
268
269#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
270#[serde(rename_all = "camelCase")]
271pub struct Handshake {
272 pub protocol_version: ProtocolVersion,
273}
274
275impl fmt::Display for Handshake {
276 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
277 write!(
278 f,
279 "handshake: supported protocol versions: {}",
280 self.protocol_version
281 )
282 }
283}
284
285#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
286#[serde(rename_all = "camelCase")]
287pub struct LsState {
288 pub transaction_id: TransactionId,
289 pub children: Vec<String>,
290}
291
292impl fmt::Display for LsState {
293 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
294 write!(
295 f,
296 "{}",
297 self.children
298 .iter()
299 .map(escape_path_segment)
300 .reduce(|a, b| format!("{a}\t{b}"))
301 .unwrap_or("".to_owned())
302 )
303 }
304}
305
306fn escape_path_segment(str: impl AsRef<str>) -> String {
307 let str = str.as_ref();
308 let white = str.contains(char::is_whitespace);
309 let single_quote = str.contains('\'');
310 let quote = str.contains('"');
311
312 if (quote || white) && !single_quote {
313 format!("'{str}'")
314 } else if single_quote {
315 str.replace('\'', r#"\'"#)
316 } else {
317 str.to_owned()
318 }
319}
320
321#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
322#[serde(rename_all = "camelCase")]
323pub struct ServerInfo {
324 pub version: Version,
325 pub supported_protocol_versions: Box<[ProtocolVersion]>,
326 #[deprecated(since = "1.1.0", note = "replaced by `supported_protocol_versions`")]
327 protocol_version: String,
328 pub authorization_required: bool,
329}
330
331#[allow(deprecated)]
332impl ServerInfo {
333 pub fn new(
334 version: Version,
335 supported_protocol_versions: Box<[ProtocolVersion]>,
336 authorization_required: bool,
337 ) -> Self {
338 Self {
339 version,
340 supported_protocol_versions,
341 protocol_version: "0.11".to_owned(),
342 authorization_required,
343 }
344 }
345}
346
347#[cfg(test)]
348mod test {
349 use serde_json::json;
350
351 use super::*;
352
353 #[test]
354 fn state_is_serialized_correctly() {
355 let state = State {
356 transaction_id: 1,
357 event: StateEvent::Value(json!(2)),
358 };
359
360 let json = r#"{"transactionId":1,"value":2}"#;
361
362 assert_eq!(json, &serde_json::to_string(&state).unwrap());
363
364 let state = State {
365 transaction_id: 1,
366 event: StateEvent::Deleted(json!(2)),
367 };
368
369 let json = r#"{"transactionId":1,"deleted":2}"#;
370
371 assert_eq!(json, &serde_json::to_string(&state).unwrap());
372 }
373
374 #[test]
375 fn state_is_deserialized_correctly() {
376 let state = State {
377 transaction_id: 1,
378 event: StateEvent::Value(json!(2)),
379 };
380
381 let json = r#"{"transactionId":1,"value":2}"#;
382
383 assert_eq!(state, serde_json::from_str(json).unwrap());
384
385 let state = State {
386 transaction_id: 1,
387 event: StateEvent::Deleted(json!(2)),
388 };
389
390 let json = r#"{"transactionId":1,"deleted":2}"#;
391
392 assert_eq!(state, serde_json::from_str(json).unwrap());
393 }
394
395 #[test]
396 fn pstate_is_serialized_correctly() {
397 let pstate = PState {
398 transaction_id: 1,
399 request_pattern: "$SYS/clients".to_owned(),
400 event: PStateEvent::KeyValuePairs(vec![("$SYS/clients", 2).into()]),
401 };
402
403 let json = r#"{"transactionId":1,"requestPattern":"$SYS/clients","keyValuePairs":[{"key":"$SYS/clients","value":2}]}"#;
404
405 assert_eq!(json, &serde_json::to_string(&pstate).unwrap());
406
407 let pstate = PState {
408 transaction_id: 1,
409 request_pattern: "$SYS/clients".to_owned(),
410 event: PStateEvent::Deleted(vec![("$SYS/clients", 2).into()]),
411 };
412
413 let json = r#"{"transactionId":1,"requestPattern":"$SYS/clients","deleted":[{"key":"$SYS/clients","value":2}]}"#;
414
415 assert_eq!(json, &serde_json::to_string(&pstate).unwrap());
416 }
417
418 #[test]
419 fn pstate_is_deserialized_correctly() {
420 let pstate = PState {
421 transaction_id: 1,
422 request_pattern: "$SYS/clients".to_owned(),
423 event: PStateEvent::KeyValuePairs(vec![("$SYS/clients", 2).into()]),
424 };
425
426 let json = r#"{"transactionId":1,"requestPattern":"$SYS/clients","keyValuePairs":[{"key":"$SYS/clients","value":2}]}"#;
427
428 assert_eq!(pstate, serde_json::from_str(json).unwrap());
429
430 let pstate = PState {
431 transaction_id: 1,
432 request_pattern: "$SYS/clients".to_owned(),
433 event: PStateEvent::Deleted(vec![("$SYS/clients", 2).into()]),
434 };
435
436 let json = r#"{"transactionId":1,"requestPattern":"$SYS/clients","deleted":[{"key":"$SYS/clients","value":2}]}"#;
437
438 assert_eq!(pstate, serde_json::from_str(json).unwrap());
439 }
440}