worterbuch_common/
server.rs

1/*
2 *  Worterbuch server messages module
3 *
4 *  Copyright (C) 2024 Michael Bachmann
5 *
6 *  This program is free software: you can redistribute it and/or modify
7 *  it under the terms of the GNU Affero General Public License as published by
8 *  the Free Software Foundation, either version 3 of the License, or
9 *  (at your option) any later version.
10 *
11 *  This program is distributed in the hope that it will be useful,
12 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
13 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 *  GNU Affero General Public License for more details.
15 *
16 *  You should have received a copy of the GNU Affero General Public License
17 *  along with this program.  If not, see <https://www.gnu.org/licenses/>.
18 */
19
20use crate::{
21    ErrorCode, KeyValuePair, KeyValuePairs, MetaData, ProtocolVersion, RequestPattern,
22    TransactionId, TypedKeyValuePair, TypedKeyValuePairs, Value, Version,
23};
24use serde::{de::DeserializeOwned, Deserialize, Serialize};
25use std::fmt;
26
27#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
28#[serde(rename_all = "camelCase")]
29pub enum ServerMessage {
30    Welcome(Welcome),
31    PState(PState),
32    Ack(Ack),
33    State(State),
34    Err(Err),
35    Authorized(Ack),
36    LsState(LsState),
37}
38
39impl ServerMessage {
40    pub fn transaction_id(&self) -> Option<TransactionId> {
41        match self {
42            ServerMessage::Welcome(_) => None,
43            ServerMessage::PState(msg) => Some(msg.transaction_id),
44            ServerMessage::Ack(msg) => Some(msg.transaction_id),
45            ServerMessage::State(msg) => Some(msg.transaction_id),
46            ServerMessage::Err(msg) => Some(msg.transaction_id),
47            ServerMessage::LsState(msg) => Some(msg.transaction_id),
48            ServerMessage::Authorized(_) => Some(0),
49        }
50    }
51}
52
53#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
54#[serde(rename_all = "camelCase")]
55pub struct Welcome {
56    pub info: ServerInfo,
57    pub client_id: String,
58}
59
60#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
61#[serde(rename_all = "camelCase")]
62pub struct PState {
63    pub transaction_id: TransactionId,
64    pub request_pattern: RequestPattern,
65    #[serde(flatten)]
66    pub event: PStateEvent,
67}
68
69#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
70#[serde(rename_all = "camelCase")]
71pub enum PStateEvent {
72    KeyValuePairs(KeyValuePairs),
73    Deleted(KeyValuePairs),
74}
75
76#[derive(Debug, Clone, PartialEq, Eq)]
77pub enum TypedPStateEvent<T: DeserializeOwned> {
78    KeyValuePairs(TypedKeyValuePairs<T>),
79    Deleted(TypedKeyValuePairs<T>),
80}
81
82impl<T: DeserializeOwned> TryFrom<PStateEvent> for TypedPStateEvent<T> {
83    type Error = serde_json::Error;
84
85    fn try_from(value: PStateEvent) -> Result<Self, Self::Error> {
86        match value {
87            PStateEvent::KeyValuePairs(kvps) => Ok(TypedPStateEvent::KeyValuePairs(
88                try_to_typed_key_value_pairs(kvps)?,
89            )),
90            PStateEvent::Deleted(kvps) => Ok(TypedPStateEvent::KeyValuePairs(
91                try_to_typed_key_value_pairs(kvps)?,
92            )),
93        }
94    }
95}
96
97fn try_to_typed_key_value_pairs<T: DeserializeOwned>(
98    kvps: KeyValuePairs,
99) -> Result<TypedKeyValuePairs<T>, serde_json::Error> {
100    let mut out = vec![];
101
102    for kvp in kvps {
103        out.push(kvp.try_into()?);
104    }
105
106    Ok(out)
107}
108
109pub type TypedPStateEvents<T> = Vec<TypedPStateEvent<T>>;
110
111impl From<PStateEvent> for Vec<Option<Value>> {
112    fn from(e: PStateEvent) -> Self {
113        match e {
114            PStateEvent::KeyValuePairs(kvps) => kvps.into_iter().map(KeyValuePair::into).collect(),
115            PStateEvent::Deleted(keys) => keys.into_iter().map(|_| Option::None).collect(),
116        }
117    }
118}
119
120impl From<PState> for Vec<Option<Value>> {
121    fn from(pstate: PState) -> Self {
122        pstate.event.into()
123    }
124}
125
126impl fmt::Display for PState {
127    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
128        match &self.event {
129            PStateEvent::KeyValuePairs(key_value_pairs) => {
130                let kvps: Vec<String> = key_value_pairs
131                    .iter()
132                    .map(|kvp| format!("{}={}", kvp.key, kvp.value))
133                    .collect();
134                let joined = kvps.join("\n");
135                write!(f, "{joined}")
136            }
137            PStateEvent::Deleted(key_value_pairs) => {
138                let kvps: Vec<String> = key_value_pairs
139                    .iter()
140                    .map(|kvp| format!("{}!={}", kvp.key, kvp.value))
141                    .collect();
142                let joined = kvps.join("\n");
143                write!(f, "{joined}")
144            }
145        }
146    }
147}
148
149#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
150#[serde(rename_all = "camelCase")]
151pub struct Ack {
152    pub transaction_id: TransactionId,
153}
154
155impl fmt::Display for Ack {
156    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
157        write!(f, "ack {}", self.transaction_id)
158    }
159}
160
161#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
162#[serde(rename_all = "camelCase")]
163pub struct State {
164    pub transaction_id: TransactionId,
165    #[serde(flatten)]
166    pub event: StateEvent,
167}
168
169#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
170#[serde(rename_all = "camelCase")]
171pub enum StateEvent {
172    Value(Value),
173    Deleted(Value),
174}
175
176impl From<StateEvent> for Option<Value> {
177    fn from(e: StateEvent) -> Self {
178        match e {
179            StateEvent::Value(v) => Some(v),
180            StateEvent::Deleted(_) => None,
181        }
182    }
183}
184
185impl From<State> for Option<Value> {
186    fn from(state: State) -> Self {
187        state.event.into()
188    }
189}
190
191#[derive(Debug, Clone, PartialEq, Eq)]
192pub enum TypedStateEvent<T: DeserializeOwned> {
193    Value(T),
194    Deleted(T),
195}
196
197impl<T: DeserializeOwned> From<TypedStateEvent<T>> for Option<T> {
198    fn from(e: TypedStateEvent<T>) -> Self {
199        match e {
200            TypedStateEvent::Value(v) => Some(v),
201            TypedStateEvent::Deleted(_) => None,
202        }
203    }
204}
205
206impl<T: DeserializeOwned> From<TypedKeyValuePair<T>> for TypedStateEvent<T> {
207    fn from(kvp: TypedKeyValuePair<T>) -> Self {
208        TypedStateEvent::Value(kvp.value)
209    }
210}
211
212impl<T: DeserializeOwned + TryFrom<Value, Error = serde_json::Error>> TryFrom<StateEvent>
213    for TypedStateEvent<T>
214{
215    type Error = serde_json::Error;
216
217    fn try_from(e: StateEvent) -> Result<Self, Self::Error> {
218        match e {
219            StateEvent::Value(v) => Ok(TypedStateEvent::Value(v.try_into()?)),
220            StateEvent::Deleted(v) => Ok(TypedStateEvent::Deleted(v.try_into()?)),
221        }
222    }
223}
224
225pub type TypedStateEvents<T> = Vec<TypedStateEvent<T>>;
226
227impl fmt::Display for State {
228    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
229        match &self.event {
230            StateEvent::Value(v) => write!(f, "{}", v),
231            StateEvent::Deleted(v) => write!(f, "!{}", v),
232        }
233    }
234}
235
236#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
237#[serde(rename_all = "camelCase")]
238pub struct Err {
239    pub transaction_id: TransactionId,
240    pub error_code: ErrorCode,
241    pub metadata: MetaData,
242}
243
244impl fmt::Display for Err {
245    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
246        write!(f, "server error {}: {}", self.error_code, self.metadata)
247    }
248}
249
250impl std::error::Error for Err {}
251
252#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
253#[serde(rename_all = "camelCase")]
254pub struct Handshake {
255    pub protocol_version: ProtocolVersion,
256}
257
258impl fmt::Display for Handshake {
259    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
260        write!(
261            f,
262            "handshake: supported protocol versions: {}",
263            self.protocol_version
264        )
265    }
266}
267
268#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
269#[serde(rename_all = "camelCase")]
270pub struct LsState {
271    pub transaction_id: TransactionId,
272    pub children: Vec<String>,
273}
274
275impl fmt::Display for LsState {
276    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
277        write!(
278            f,
279            "{}",
280            self.children
281                .iter()
282                .map(escape_path_segment)
283                .reduce(|a, b| format!("{a}\t{b}"))
284                .unwrap_or("".to_owned())
285        )
286    }
287}
288
289fn escape_path_segment(str: impl AsRef<str>) -> String {
290    let str = str.as_ref();
291    let white = str.contains(char::is_whitespace);
292    let single_quote = str.contains('\'');
293    let quote = str.contains('"');
294
295    if (quote || white) && !single_quote {
296        format!("'{str}'")
297    } else if single_quote {
298        str.replace('\'', r#"\'"#)
299    } else {
300        str.to_owned()
301    }
302}
303
304#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
305#[serde(rename_all = "camelCase")]
306pub struct ServerInfo {
307    pub version: Version,
308    pub protocol_version: ProtocolVersion,
309    pub authorization_required: bool,
310}
311
312#[cfg(test)]
313mod test {
314    use serde_json::json;
315
316    use super::*;
317
318    #[test]
319    fn state_is_serialized_correctly() {
320        let state = State {
321            transaction_id: 1,
322            event: StateEvent::Value(json!(2)),
323        };
324
325        let json = r#"{"transactionId":1,"value":2}"#;
326
327        assert_eq!(json, &serde_json::to_string(&state).unwrap());
328
329        let state = State {
330            transaction_id: 1,
331            event: StateEvent::Deleted(json!(2)),
332        };
333
334        let json = r#"{"transactionId":1,"deleted":2}"#;
335
336        assert_eq!(json, &serde_json::to_string(&state).unwrap());
337    }
338
339    #[test]
340    fn state_is_deserialized_correctly() {
341        let state = State {
342            transaction_id: 1,
343            event: StateEvent::Value(json!(2)),
344        };
345
346        let json = r#"{"transactionId":1,"value":2}"#;
347
348        assert_eq!(state, serde_json::from_str(&json).unwrap());
349
350        let state = State {
351            transaction_id: 1,
352            event: StateEvent::Deleted(json!(2)),
353        };
354
355        let json = r#"{"transactionId":1,"deleted":2}"#;
356
357        assert_eq!(state, serde_json::from_str(&json).unwrap());
358    }
359
360    #[test]
361    fn pstate_is_serialized_correctly() {
362        let pstate = PState {
363            transaction_id: 1,
364            request_pattern: "$SYS/clients".to_owned(),
365            event: PStateEvent::KeyValuePairs(vec![("$SYS/clients", 2).into()]),
366        };
367
368        let json = r#"{"transactionId":1,"requestPattern":"$SYS/clients","keyValuePairs":[{"key":"$SYS/clients","value":2}]}"#;
369
370        assert_eq!(json, &serde_json::to_string(&pstate).unwrap());
371
372        let pstate = PState {
373            transaction_id: 1,
374            request_pattern: "$SYS/clients".to_owned(),
375            event: PStateEvent::Deleted(vec![("$SYS/clients", 2).into()]),
376        };
377
378        let json = r#"{"transactionId":1,"requestPattern":"$SYS/clients","deleted":[{"key":"$SYS/clients","value":2}]}"#;
379
380        assert_eq!(json, &serde_json::to_string(&pstate).unwrap());
381    }
382
383    #[test]
384    fn pstate_is_deserialized_correctly() {
385        let pstate = PState {
386            transaction_id: 1,
387            request_pattern: "$SYS/clients".to_owned(),
388            event: PStateEvent::KeyValuePairs(vec![("$SYS/clients", 2).into()]),
389        };
390
391        let json = r#"{"transactionId":1,"requestPattern":"$SYS/clients","keyValuePairs":[{"key":"$SYS/clients","value":2}]}"#;
392
393        assert_eq!(pstate, serde_json::from_str(&json).unwrap());
394
395        let pstate = PState {
396            transaction_id: 1,
397            request_pattern: "$SYS/clients".to_owned(),
398            event: PStateEvent::Deleted(vec![("$SYS/clients", 2).into()]),
399        };
400
401        let json = r#"{"transactionId":1,"requestPattern":"$SYS/clients","deleted":[{"key":"$SYS/clients","value":2}]}"#;
402
403        assert_eq!(pstate, serde_json::from_str(&json).unwrap());
404    }
405}