worterbuch_common/
server.rs

1/*
2 *  Worterbuch server messages module
3 *
4 *  Copyright (C) 2024 Michael Bachmann
5 *
6 *  This program is free software: you can redistribute it and/or modify
7 *  it under the terms of the GNU Affero General Public License as published by
8 *  the Free Software Foundation, either version 3 of the License, or
9 *  (at your option) any later version.
10 *
11 *  This program is distributed in the hope that it will be useful,
12 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
13 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 *  GNU Affero General Public License for more details.
15 *
16 *  You should have received a copy of the GNU Affero General Public License
17 *  along with this program.  If not, see <https://www.gnu.org/licenses/>.
18 */
19
20use crate::{
21    CasVersion, ErrorCode, KeyValuePair, KeyValuePairs, MetaData, ProtocolVersion, RequestPattern,
22    TransactionId, TypedKeyValuePair, TypedKeyValuePairs, Value, Version,
23};
24use serde::{de::DeserializeOwned, Deserialize, Serialize};
25use std::fmt;
26
27#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
28#[serde(rename_all = "camelCase")]
29pub enum ServerMessage {
30    Welcome(Welcome),
31    PState(PState),
32    Ack(Ack),
33    State(State),
34    CState(CState),
35    Err(Err),
36    Authorized(Ack),
37    LsState(LsState),
38}
39
40impl ServerMessage {
41    pub fn transaction_id(&self) -> Option<TransactionId> {
42        match self {
43            ServerMessage::Welcome(_) => None,
44            ServerMessage::PState(msg) => Some(msg.transaction_id),
45            ServerMessage::Ack(msg) => Some(msg.transaction_id),
46            ServerMessage::State(msg) => Some(msg.transaction_id),
47            ServerMessage::CState(msg) => Some(msg.transaction_id),
48            ServerMessage::Err(msg) => Some(msg.transaction_id),
49            ServerMessage::LsState(msg) => Some(msg.transaction_id),
50            ServerMessage::Authorized(_) => Some(0),
51        }
52    }
53}
54
55#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
56#[serde(rename_all = "camelCase")]
57pub struct Welcome {
58    pub info: ServerInfo,
59    pub client_id: String,
60}
61
62#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
63#[serde(rename_all = "camelCase")]
64pub struct PState {
65    pub transaction_id: TransactionId,
66    pub request_pattern: RequestPattern,
67    #[serde(flatten)]
68    pub event: PStateEvent,
69}
70
71#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
72#[serde(rename_all = "camelCase")]
73pub enum PStateEvent {
74    KeyValuePairs(KeyValuePairs),
75    Deleted(KeyValuePairs),
76}
77
78#[derive(Debug, Clone, PartialEq, Eq)]
79pub enum TypedPStateEvent<T: DeserializeOwned> {
80    KeyValuePairs(TypedKeyValuePairs<T>),
81    Deleted(TypedKeyValuePairs<T>),
82}
83
84impl<T: DeserializeOwned> TryFrom<PStateEvent> for TypedPStateEvent<T> {
85    type Error = serde_json::Error;
86
87    fn try_from(value: PStateEvent) -> Result<Self, Self::Error> {
88        match value {
89            PStateEvent::KeyValuePairs(kvps) => Ok(TypedPStateEvent::KeyValuePairs(
90                try_to_typed_key_value_pairs(kvps)?,
91            )),
92            PStateEvent::Deleted(kvps) => Ok(TypedPStateEvent::KeyValuePairs(
93                try_to_typed_key_value_pairs(kvps)?,
94            )),
95        }
96    }
97}
98
99fn try_to_typed_key_value_pairs<T: DeserializeOwned>(
100    kvps: KeyValuePairs,
101) -> Result<TypedKeyValuePairs<T>, serde_json::Error> {
102    let mut out = vec![];
103
104    for kvp in kvps {
105        out.push(kvp.try_into()?);
106    }
107
108    Ok(out)
109}
110
111pub type TypedPStateEvents<T> = Vec<TypedPStateEvent<T>>;
112
113impl From<PStateEvent> for Vec<Option<Value>> {
114    fn from(e: PStateEvent) -> Self {
115        match e {
116            PStateEvent::KeyValuePairs(kvps) => kvps.into_iter().map(KeyValuePair::into).collect(),
117            PStateEvent::Deleted(keys) => keys.into_iter().map(|_| Option::None).collect(),
118        }
119    }
120}
121
122impl From<PState> for Vec<Option<Value>> {
123    fn from(pstate: PState) -> Self {
124        pstate.event.into()
125    }
126}
127
128impl fmt::Display for PState {
129    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
130        match &self.event {
131            PStateEvent::KeyValuePairs(key_value_pairs) => {
132                let kvps: Vec<String> = key_value_pairs
133                    .iter()
134                    .map(|kvp| format!("{}={}", kvp.key, kvp.value))
135                    .collect();
136                let joined = kvps.join("\n");
137                write!(f, "{joined}")
138            }
139            PStateEvent::Deleted(key_value_pairs) => {
140                let kvps: Vec<String> = key_value_pairs
141                    .iter()
142                    .map(|kvp| format!("{}!={}", kvp.key, kvp.value))
143                    .collect();
144                let joined = kvps.join("\n");
145                write!(f, "{joined}")
146            }
147        }
148    }
149}
150
151#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
152#[serde(rename_all = "camelCase")]
153pub struct Ack {
154    pub transaction_id: TransactionId,
155}
156
157impl fmt::Display for Ack {
158    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
159        write!(f, "ack {}", self.transaction_id)
160    }
161}
162
163#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
164#[serde(rename_all = "camelCase")]
165pub struct State {
166    pub transaction_id: TransactionId,
167    #[serde(flatten)]
168    pub event: StateEvent,
169}
170
171#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
172#[serde(rename_all = "camelCase")]
173pub enum StateEvent {
174    Value(Value),
175    Deleted(Value),
176}
177
178#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
179#[serde(rename_all = "camelCase")]
180pub struct CState {
181    pub transaction_id: TransactionId,
182    #[serde(flatten)]
183    pub event: CStateEvent,
184}
185
186#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
187#[serde(rename_all = "camelCase")]
188pub struct CStateEvent {
189    pub value: Value,
190    pub version: CasVersion,
191}
192
193impl From<StateEvent> for Option<Value> {
194    fn from(e: StateEvent) -> Self {
195        match e {
196            StateEvent::Value(v) => Some(v),
197            StateEvent::Deleted(_) => None,
198        }
199    }
200}
201
202impl From<State> for Option<Value> {
203    fn from(state: State) -> Self {
204        state.event.into()
205    }
206}
207
208#[derive(Debug, Clone, PartialEq, Eq)]
209pub enum TypedStateEvent<T: DeserializeOwned> {
210    Value(T),
211    Deleted(T),
212}
213
214impl<T: DeserializeOwned> From<TypedStateEvent<T>> for Option<T> {
215    fn from(e: TypedStateEvent<T>) -> Self {
216        match e {
217            TypedStateEvent::Value(v) => Some(v),
218            TypedStateEvent::Deleted(_) => None,
219        }
220    }
221}
222
223impl<T: DeserializeOwned> From<TypedKeyValuePair<T>> for TypedStateEvent<T> {
224    fn from(kvp: TypedKeyValuePair<T>) -> Self {
225        TypedStateEvent::Value(kvp.value)
226    }
227}
228
229impl<T: DeserializeOwned + TryFrom<Value, Error = serde_json::Error>> TryFrom<StateEvent>
230    for TypedStateEvent<T>
231{
232    type Error = serde_json::Error;
233
234    fn try_from(e: StateEvent) -> Result<Self, Self::Error> {
235        match e {
236            StateEvent::Value(v) => Ok(TypedStateEvent::Value(v.try_into()?)),
237            StateEvent::Deleted(v) => Ok(TypedStateEvent::Deleted(v.try_into()?)),
238        }
239    }
240}
241
242pub type TypedStateEvents<T> = Vec<TypedStateEvent<T>>;
243
244impl fmt::Display for State {
245    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
246        match &self.event {
247            StateEvent::Value(v) => write!(f, "{}", v),
248            StateEvent::Deleted(v) => write!(f, "!{}", v),
249        }
250    }
251}
252
253#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
254#[serde(rename_all = "camelCase")]
255pub struct Err {
256    pub transaction_id: TransactionId,
257    pub error_code: ErrorCode,
258    pub metadata: MetaData,
259}
260
261impl fmt::Display for Err {
262    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
263        write!(f, "server error {}: {}", self.error_code, self.metadata)
264    }
265}
266
267impl std::error::Error for Err {}
268
269#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
270#[serde(rename_all = "camelCase")]
271pub struct Handshake {
272    pub protocol_version: ProtocolVersion,
273}
274
275impl fmt::Display for Handshake {
276    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
277        write!(
278            f,
279            "handshake: supported protocol versions: {}",
280            self.protocol_version
281        )
282    }
283}
284
285#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
286#[serde(rename_all = "camelCase")]
287pub struct LsState {
288    pub transaction_id: TransactionId,
289    pub children: Vec<String>,
290}
291
292impl fmt::Display for LsState {
293    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
294        write!(
295            f,
296            "{}",
297            self.children
298                .iter()
299                .map(escape_path_segment)
300                .reduce(|a, b| format!("{a}\t{b}"))
301                .unwrap_or("".to_owned())
302        )
303    }
304}
305
306fn escape_path_segment(str: impl AsRef<str>) -> String {
307    let str = str.as_ref();
308    let white = str.contains(char::is_whitespace);
309    let single_quote = str.contains('\'');
310    let quote = str.contains('"');
311
312    if (quote || white) && !single_quote {
313        format!("'{str}'")
314    } else if single_quote {
315        str.replace('\'', r#"\'"#)
316    } else {
317        str.to_owned()
318    }
319}
320
321#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
322#[serde(rename_all = "camelCase")]
323pub struct ServerInfo {
324    pub version: Version,
325    pub supported_protocol_versions: Box<[ProtocolVersion]>,
326    #[deprecated(since = "1.1.0", note = "replaced by `supported_protocol_versions`")]
327    protocol_version: String,
328    pub authorization_required: bool,
329}
330
331#[allow(deprecated)]
332impl ServerInfo {
333    pub fn new(
334        version: Version,
335        supported_protocol_versions: Box<[ProtocolVersion]>,
336        authorization_required: bool,
337    ) -> Self {
338        Self {
339            version,
340            supported_protocol_versions,
341            protocol_version: "0.11".to_owned(),
342            authorization_required,
343        }
344    }
345}
346
347#[cfg(test)]
348mod test {
349    use serde_json::json;
350
351    use super::*;
352
353    #[test]
354    fn state_is_serialized_correctly() {
355        let state = State {
356            transaction_id: 1,
357            event: StateEvent::Value(json!(2)),
358        };
359
360        let json = r#"{"transactionId":1,"value":2}"#;
361
362        assert_eq!(json, &serde_json::to_string(&state).unwrap());
363
364        let state = State {
365            transaction_id: 1,
366            event: StateEvent::Deleted(json!(2)),
367        };
368
369        let json = r#"{"transactionId":1,"deleted":2}"#;
370
371        assert_eq!(json, &serde_json::to_string(&state).unwrap());
372    }
373
374    #[test]
375    fn state_is_deserialized_correctly() {
376        let state = State {
377            transaction_id: 1,
378            event: StateEvent::Value(json!(2)),
379        };
380
381        let json = r#"{"transactionId":1,"value":2}"#;
382
383        assert_eq!(state, serde_json::from_str(json).unwrap());
384
385        let state = State {
386            transaction_id: 1,
387            event: StateEvent::Deleted(json!(2)),
388        };
389
390        let json = r#"{"transactionId":1,"deleted":2}"#;
391
392        assert_eq!(state, serde_json::from_str(json).unwrap());
393    }
394
395    #[test]
396    fn pstate_is_serialized_correctly() {
397        let pstate = PState {
398            transaction_id: 1,
399            request_pattern: "$SYS/clients".to_owned(),
400            event: PStateEvent::KeyValuePairs(vec![("$SYS/clients", 2).into()]),
401        };
402
403        let json = r#"{"transactionId":1,"requestPattern":"$SYS/clients","keyValuePairs":[{"key":"$SYS/clients","value":2}]}"#;
404
405        assert_eq!(json, &serde_json::to_string(&pstate).unwrap());
406
407        let pstate = PState {
408            transaction_id: 1,
409            request_pattern: "$SYS/clients".to_owned(),
410            event: PStateEvent::Deleted(vec![("$SYS/clients", 2).into()]),
411        };
412
413        let json = r#"{"transactionId":1,"requestPattern":"$SYS/clients","deleted":[{"key":"$SYS/clients","value":2}]}"#;
414
415        assert_eq!(json, &serde_json::to_string(&pstate).unwrap());
416    }
417
418    #[test]
419    fn pstate_is_deserialized_correctly() {
420        let pstate = PState {
421            transaction_id: 1,
422            request_pattern: "$SYS/clients".to_owned(),
423            event: PStateEvent::KeyValuePairs(vec![("$SYS/clients", 2).into()]),
424        };
425
426        let json = r#"{"transactionId":1,"requestPattern":"$SYS/clients","keyValuePairs":[{"key":"$SYS/clients","value":2}]}"#;
427
428        assert_eq!(pstate, serde_json::from_str(json).unwrap());
429
430        let pstate = PState {
431            transaction_id: 1,
432            request_pattern: "$SYS/clients".to_owned(),
433            event: PStateEvent::Deleted(vec![("$SYS/clients", 2).into()]),
434        };
435
436        let json = r#"{"transactionId":1,"requestPattern":"$SYS/clients","deleted":[{"key":"$SYS/clients","value":2}]}"#;
437
438        assert_eq!(pstate, serde_json::from_str(json).unwrap());
439    }
440}