1use crate::{
21 ErrorCode, KeyValuePair, KeyValuePairs, MetaData, ProtocolVersion, RequestPattern,
22 TransactionId, TypedKeyValuePair, TypedKeyValuePairs, Value, Version,
23};
24use serde::{de::DeserializeOwned, Deserialize, Serialize};
25use std::fmt;
26
27#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
28#[serde(rename_all = "camelCase")]
29pub enum ServerMessage {
30 Welcome(Welcome),
31 PState(PState),
32 Ack(Ack),
33 State(State),
34 Err(Err),
35 Authorized(Ack),
36 LsState(LsState),
37}
38
39impl ServerMessage {
40 pub fn transaction_id(&self) -> Option<TransactionId> {
41 match self {
42 ServerMessage::Welcome(_) => None,
43 ServerMessage::PState(msg) => Some(msg.transaction_id),
44 ServerMessage::Ack(msg) => Some(msg.transaction_id),
45 ServerMessage::State(msg) => Some(msg.transaction_id),
46 ServerMessage::Err(msg) => Some(msg.transaction_id),
47 ServerMessage::LsState(msg) => Some(msg.transaction_id),
48 ServerMessage::Authorized(_) => Some(0),
49 }
50 }
51}
52
53#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
54#[serde(rename_all = "camelCase")]
55pub struct Welcome {
56 pub info: ServerInfo,
57 pub client_id: String,
58}
59
60#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
61#[serde(rename_all = "camelCase")]
62pub struct PState {
63 pub transaction_id: TransactionId,
64 pub request_pattern: RequestPattern,
65 #[serde(flatten)]
66 pub event: PStateEvent,
67}
68
69#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
70#[serde(rename_all = "camelCase")]
71pub enum PStateEvent {
72 KeyValuePairs(KeyValuePairs),
73 Deleted(KeyValuePairs),
74}
75
76#[derive(Debug, Clone, PartialEq, Eq)]
77pub enum TypedPStateEvent<T: DeserializeOwned> {
78 KeyValuePairs(TypedKeyValuePairs<T>),
79 Deleted(TypedKeyValuePairs<T>),
80}
81
82impl<T: DeserializeOwned> TryFrom<PStateEvent> for TypedPStateEvent<T> {
83 type Error = serde_json::Error;
84
85 fn try_from(value: PStateEvent) -> Result<Self, Self::Error> {
86 match value {
87 PStateEvent::KeyValuePairs(kvps) => Ok(TypedPStateEvent::KeyValuePairs(
88 try_to_typed_key_value_pairs(kvps)?,
89 )),
90 PStateEvent::Deleted(kvps) => Ok(TypedPStateEvent::KeyValuePairs(
91 try_to_typed_key_value_pairs(kvps)?,
92 )),
93 }
94 }
95}
96
97fn try_to_typed_key_value_pairs<T: DeserializeOwned>(
98 kvps: KeyValuePairs,
99) -> Result<TypedKeyValuePairs<T>, serde_json::Error> {
100 let mut out = vec![];
101
102 for kvp in kvps {
103 out.push(kvp.try_into()?);
104 }
105
106 Ok(out)
107}
108
109pub type TypedPStateEvents<T> = Vec<TypedPStateEvent<T>>;
110
111impl From<PStateEvent> for Vec<Option<Value>> {
112 fn from(e: PStateEvent) -> Self {
113 match e {
114 PStateEvent::KeyValuePairs(kvps) => kvps.into_iter().map(KeyValuePair::into).collect(),
115 PStateEvent::Deleted(keys) => keys.into_iter().map(|_| Option::None).collect(),
116 }
117 }
118}
119
120impl From<PState> for Vec<Option<Value>> {
121 fn from(pstate: PState) -> Self {
122 pstate.event.into()
123 }
124}
125
126impl fmt::Display for PState {
127 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
128 match &self.event {
129 PStateEvent::KeyValuePairs(key_value_pairs) => {
130 let kvps: Vec<String> = key_value_pairs
131 .iter()
132 .map(|kvp| format!("{}={}", kvp.key, kvp.value))
133 .collect();
134 let joined = kvps.join("\n");
135 write!(f, "{joined}")
136 }
137 PStateEvent::Deleted(key_value_pairs) => {
138 let kvps: Vec<String> = key_value_pairs
139 .iter()
140 .map(|kvp| format!("{}!={}", kvp.key, kvp.value))
141 .collect();
142 let joined = kvps.join("\n");
143 write!(f, "{joined}")
144 }
145 }
146 }
147}
148
149#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
150#[serde(rename_all = "camelCase")]
151pub struct Ack {
152 pub transaction_id: TransactionId,
153}
154
155impl fmt::Display for Ack {
156 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
157 write!(f, "ack {}", self.transaction_id)
158 }
159}
160
161#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
162#[serde(rename_all = "camelCase")]
163pub struct State {
164 pub transaction_id: TransactionId,
165 #[serde(flatten)]
166 pub event: StateEvent,
167}
168
169#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
170#[serde(rename_all = "camelCase")]
171pub enum StateEvent {
172 Value(Value),
173 Deleted(Value),
174}
175
176impl From<StateEvent> for Option<Value> {
177 fn from(e: StateEvent) -> Self {
178 match e {
179 StateEvent::Value(v) => Some(v),
180 StateEvent::Deleted(_) => None,
181 }
182 }
183}
184
185impl From<State> for Option<Value> {
186 fn from(state: State) -> Self {
187 state.event.into()
188 }
189}
190
191#[derive(Debug, Clone, PartialEq, Eq)]
192pub enum TypedStateEvent<T: DeserializeOwned> {
193 Value(T),
194 Deleted(T),
195}
196
197impl<T: DeserializeOwned> From<TypedStateEvent<T>> for Option<T> {
198 fn from(e: TypedStateEvent<T>) -> Self {
199 match e {
200 TypedStateEvent::Value(v) => Some(v),
201 TypedStateEvent::Deleted(_) => None,
202 }
203 }
204}
205
206impl<T: DeserializeOwned> From<TypedKeyValuePair<T>> for TypedStateEvent<T> {
207 fn from(kvp: TypedKeyValuePair<T>) -> Self {
208 TypedStateEvent::Value(kvp.value)
209 }
210}
211
212impl<T: DeserializeOwned + TryFrom<Value, Error = serde_json::Error>> TryFrom<StateEvent>
213 for TypedStateEvent<T>
214{
215 type Error = serde_json::Error;
216
217 fn try_from(e: StateEvent) -> Result<Self, Self::Error> {
218 match e {
219 StateEvent::Value(v) => Ok(TypedStateEvent::Value(v.try_into()?)),
220 StateEvent::Deleted(v) => Ok(TypedStateEvent::Deleted(v.try_into()?)),
221 }
222 }
223}
224
225pub type TypedStateEvents<T> = Vec<TypedStateEvent<T>>;
226
227impl fmt::Display for State {
228 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
229 match &self.event {
230 StateEvent::Value(v) => write!(f, "{}", v),
231 StateEvent::Deleted(v) => write!(f, "!{}", v),
232 }
233 }
234}
235
236#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
237#[serde(rename_all = "camelCase")]
238pub struct Err {
239 pub transaction_id: TransactionId,
240 pub error_code: ErrorCode,
241 pub metadata: MetaData,
242}
243
244impl fmt::Display for Err {
245 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
246 write!(f, "server error {}: {}", self.error_code, self.metadata)
247 }
248}
249
250impl std::error::Error for Err {}
251
252#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
253#[serde(rename_all = "camelCase")]
254pub struct Handshake {
255 pub protocol_version: ProtocolVersion,
256}
257
258impl fmt::Display for Handshake {
259 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
260 write!(
261 f,
262 "handshake: supported protocol versions: {}",
263 self.protocol_version
264 )
265 }
266}
267
268#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
269#[serde(rename_all = "camelCase")]
270pub struct LsState {
271 pub transaction_id: TransactionId,
272 pub children: Vec<String>,
273}
274
275impl fmt::Display for LsState {
276 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
277 write!(
278 f,
279 "{}",
280 self.children
281 .iter()
282 .map(escape_path_segment)
283 .reduce(|a, b| format!("{a}\t{b}"))
284 .unwrap_or("".to_owned())
285 )
286 }
287}
288
289fn escape_path_segment(str: impl AsRef<str>) -> String {
290 let str = str.as_ref();
291 let white = str.contains(char::is_whitespace);
292 let single_quote = str.contains('\'');
293 let quote = str.contains('"');
294
295 if (quote || white) && !single_quote {
296 format!("'{str}'")
297 } else if single_quote {
298 str.replace('\'', r#"\'"#)
299 } else {
300 str.to_owned()
301 }
302}
303
304#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
305#[serde(rename_all = "camelCase")]
306pub struct ServerInfo {
307 pub version: Version,
308 pub protocol_version: ProtocolVersion,
309 pub authorization_required: bool,
310}
311
312#[cfg(test)]
313mod test {
314 use serde_json::json;
315
316 use super::*;
317
318 #[test]
319 fn state_is_serialized_correctly() {
320 let state = State {
321 transaction_id: 1,
322 event: StateEvent::Value(json!(2)),
323 };
324
325 let json = r#"{"transactionId":1,"value":2}"#;
326
327 assert_eq!(json, &serde_json::to_string(&state).unwrap());
328
329 let state = State {
330 transaction_id: 1,
331 event: StateEvent::Deleted(json!(2)),
332 };
333
334 let json = r#"{"transactionId":1,"deleted":2}"#;
335
336 assert_eq!(json, &serde_json::to_string(&state).unwrap());
337 }
338
339 #[test]
340 fn state_is_deserialized_correctly() {
341 let state = State {
342 transaction_id: 1,
343 event: StateEvent::Value(json!(2)),
344 };
345
346 let json = r#"{"transactionId":1,"value":2}"#;
347
348 assert_eq!(state, serde_json::from_str(&json).unwrap());
349
350 let state = State {
351 transaction_id: 1,
352 event: StateEvent::Deleted(json!(2)),
353 };
354
355 let json = r#"{"transactionId":1,"deleted":2}"#;
356
357 assert_eq!(state, serde_json::from_str(&json).unwrap());
358 }
359
360 #[test]
361 fn pstate_is_serialized_correctly() {
362 let pstate = PState {
363 transaction_id: 1,
364 request_pattern: "$SYS/clients".to_owned(),
365 event: PStateEvent::KeyValuePairs(vec![("$SYS/clients", 2).into()]),
366 };
367
368 let json = r#"{"transactionId":1,"requestPattern":"$SYS/clients","keyValuePairs":[{"key":"$SYS/clients","value":2}]}"#;
369
370 assert_eq!(json, &serde_json::to_string(&pstate).unwrap());
371
372 let pstate = PState {
373 transaction_id: 1,
374 request_pattern: "$SYS/clients".to_owned(),
375 event: PStateEvent::Deleted(vec![("$SYS/clients", 2).into()]),
376 };
377
378 let json = r#"{"transactionId":1,"requestPattern":"$SYS/clients","deleted":[{"key":"$SYS/clients","value":2}]}"#;
379
380 assert_eq!(json, &serde_json::to_string(&pstate).unwrap());
381 }
382
383 #[test]
384 fn pstate_is_deserialized_correctly() {
385 let pstate = PState {
386 transaction_id: 1,
387 request_pattern: "$SYS/clients".to_owned(),
388 event: PStateEvent::KeyValuePairs(vec![("$SYS/clients", 2).into()]),
389 };
390
391 let json = r#"{"transactionId":1,"requestPattern":"$SYS/clients","keyValuePairs":[{"key":"$SYS/clients","value":2}]}"#;
392
393 assert_eq!(pstate, serde_json::from_str(&json).unwrap());
394
395 let pstate = PState {
396 transaction_id: 1,
397 request_pattern: "$SYS/clients".to_owned(),
398 event: PStateEvent::Deleted(vec![("$SYS/clients", 2).into()]),
399 };
400
401 let json = r#"{"transactionId":1,"requestPattern":"$SYS/clients","deleted":[{"key":"$SYS/clients","value":2}]}"#;
402
403 assert_eq!(pstate, serde_json::from_str(&json).unwrap());
404 }
405}