worterbuch_common/
server.rs

1/*
2 *  Worterbuch server messages module
3 *
4 *  Copyright (C) 2024 Michael Bachmann
5 *
6 *  This program is free software: you can redistribute it and/or modify
7 *  it under the terms of the GNU Affero General Public License as published by
8 *  the Free Software Foundation, either version 3 of the License, or
9 *  (at your option) any later version.
10 *
11 *  This program is distributed in the hope that it will be useful,
12 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
13 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 *  GNU Affero General Public License for more details.
15 *
16 *  You should have received a copy of the GNU Affero General Public License
17 *  along with this program.  If not, see <https://www.gnu.org/licenses/>.
18 */
19
20use crate::{
21    CasVersion, ErrorCode, KeyValuePair, KeyValuePairs, MetaData, ProtocolVersion, RequestPattern,
22    TransactionId, TypedKeyValuePair, TypedKeyValuePairs, Value, Version,
23    error::{ConnectionError, ConnectionResult},
24};
25use serde::{Deserialize, Serialize, de::DeserializeOwned};
26use std::{
27    fmt::{self, Display},
28    io,
29    time::Duration,
30};
31use tokio::{
32    io::{AsyncRead, AsyncWriteExt, BufReader, Lines},
33    time::timeout,
34};
35use tracing::{debug, error, trace};
36
37#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
38#[serde(rename_all = "camelCase")]
39pub enum ServerMessage {
40    Welcome(Welcome),
41    PState(PState),
42    Ack(Ack),
43    State(State),
44    CState(CState),
45    Err(Err),
46    Authorized(Ack),
47    LsState(LsState),
48}
49
50impl ServerMessage {
51    pub fn transaction_id(&self) -> Option<TransactionId> {
52        match self {
53            ServerMessage::Welcome(_) => None,
54            ServerMessage::PState(msg) => Some(msg.transaction_id),
55            ServerMessage::Ack(msg) => Some(msg.transaction_id),
56            ServerMessage::State(msg) => Some(msg.transaction_id),
57            ServerMessage::CState(msg) => Some(msg.transaction_id),
58            ServerMessage::Err(msg) => Some(msg.transaction_id),
59            ServerMessage::LsState(msg) => Some(msg.transaction_id),
60            ServerMessage::Authorized(_) => Some(0),
61        }
62    }
63}
64
65#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
66#[serde(rename_all = "camelCase")]
67pub struct Welcome {
68    pub info: ServerInfo,
69    pub client_id: String,
70}
71
72#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
73#[serde(rename_all = "camelCase")]
74pub struct PState {
75    pub transaction_id: TransactionId,
76    pub request_pattern: RequestPattern,
77    #[serde(flatten)]
78    pub event: PStateEvent,
79}
80
81#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
82#[serde(rename_all = "camelCase")]
83pub enum PStateEvent {
84    KeyValuePairs(KeyValuePairs),
85    Deleted(KeyValuePairs),
86}
87
88#[derive(Debug, Clone, PartialEq, Eq)]
89pub enum TypedPStateEvent<T: DeserializeOwned> {
90    KeyValuePairs(TypedKeyValuePairs<T>),
91    Deleted(TypedKeyValuePairs<T>),
92}
93
94impl<T: DeserializeOwned> TryFrom<PStateEvent> for TypedPStateEvent<T> {
95    type Error = serde_json::Error;
96
97    fn try_from(value: PStateEvent) -> Result<Self, Self::Error> {
98        match value {
99            PStateEvent::KeyValuePairs(kvps) => Ok(TypedPStateEvent::KeyValuePairs(
100                try_to_typed_key_value_pairs(kvps)?,
101            )),
102            PStateEvent::Deleted(kvps) => Ok(TypedPStateEvent::KeyValuePairs(
103                try_to_typed_key_value_pairs(kvps)?,
104            )),
105        }
106    }
107}
108
109fn try_to_typed_key_value_pairs<T: DeserializeOwned>(
110    kvps: KeyValuePairs,
111) -> Result<TypedKeyValuePairs<T>, serde_json::Error> {
112    let mut out = vec![];
113
114    for kvp in kvps {
115        out.push(kvp.try_into()?);
116    }
117
118    Ok(out)
119}
120
121pub type TypedPStateEvents<T> = Vec<TypedPStateEvent<T>>;
122
123impl From<PStateEvent> for Vec<Option<Value>> {
124    fn from(e: PStateEvent) -> Self {
125        match e {
126            PStateEvent::KeyValuePairs(kvps) => kvps.into_iter().map(KeyValuePair::into).collect(),
127            PStateEvent::Deleted(keys) => keys.into_iter().map(|_| Option::None).collect(),
128        }
129    }
130}
131
132impl From<PState> for Vec<Option<Value>> {
133    fn from(pstate: PState) -> Self {
134        pstate.event.into()
135    }
136}
137
138impl fmt::Display for PState {
139    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
140        match &self.event {
141            PStateEvent::KeyValuePairs(key_value_pairs) => {
142                let kvps: Vec<String> = key_value_pairs
143                    .iter()
144                    .map(|kvp| format!("{}={}", kvp.key, kvp.value))
145                    .collect();
146                let joined = kvps.join("\n");
147                write!(f, "{joined}")
148            }
149            PStateEvent::Deleted(key_value_pairs) => {
150                let kvps: Vec<String> = key_value_pairs
151                    .iter()
152                    .map(|kvp| format!("{}!={}", kvp.key, kvp.value))
153                    .collect();
154                let joined = kvps.join("\n");
155                write!(f, "{joined}")
156            }
157        }
158    }
159}
160
161#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
162#[serde(rename_all = "camelCase")]
163pub struct Ack {
164    pub transaction_id: TransactionId,
165}
166
167impl fmt::Display for Ack {
168    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
169        write!(f, "ack {}", self.transaction_id)
170    }
171}
172
173#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
174#[serde(rename_all = "camelCase")]
175pub struct State {
176    pub transaction_id: TransactionId,
177    #[serde(flatten)]
178    pub event: StateEvent,
179}
180
181#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
182#[serde(rename_all = "camelCase")]
183pub enum StateEvent {
184    Value(Value),
185    Deleted(Value),
186}
187
188#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
189#[serde(rename_all = "camelCase")]
190pub struct CState {
191    pub transaction_id: TransactionId,
192    #[serde(flatten)]
193    pub event: CStateEvent,
194}
195
196#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
197#[serde(rename_all = "camelCase")]
198pub struct CStateEvent {
199    pub value: Value,
200    pub version: CasVersion,
201}
202
203impl From<StateEvent> for Option<Value> {
204    fn from(e: StateEvent) -> Self {
205        match e {
206            StateEvent::Value(v) => Some(v),
207            StateEvent::Deleted(_) => None,
208        }
209    }
210}
211
212impl From<State> for Option<Value> {
213    fn from(state: State) -> Self {
214        state.event.into()
215    }
216}
217
218#[derive(Debug, Clone, PartialEq, Eq)]
219pub enum TypedStateEvent<T: DeserializeOwned> {
220    Value(T),
221    Deleted(T),
222}
223
224impl<T: DeserializeOwned> From<TypedStateEvent<T>> for Option<T> {
225    fn from(e: TypedStateEvent<T>) -> Self {
226        match e {
227            TypedStateEvent::Value(v) => Some(v),
228            TypedStateEvent::Deleted(_) => None,
229        }
230    }
231}
232
233impl<T: DeserializeOwned> From<TypedKeyValuePair<T>> for TypedStateEvent<T> {
234    fn from(kvp: TypedKeyValuePair<T>) -> Self {
235        TypedStateEvent::Value(kvp.value)
236    }
237}
238
239impl<T: DeserializeOwned + TryFrom<Value, Error = serde_json::Error>> TryFrom<StateEvent>
240    for TypedStateEvent<T>
241{
242    type Error = serde_json::Error;
243
244    fn try_from(e: StateEvent) -> Result<Self, Self::Error> {
245        match e {
246            StateEvent::Value(v) => Ok(TypedStateEvent::Value(v.try_into()?)),
247            StateEvent::Deleted(v) => Ok(TypedStateEvent::Deleted(v.try_into()?)),
248        }
249    }
250}
251
252pub type TypedStateEvents<T> = Vec<TypedStateEvent<T>>;
253
254impl fmt::Display for State {
255    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
256        match &self.event {
257            StateEvent::Value(v) => write!(f, "{}", v),
258            StateEvent::Deleted(v) => write!(f, "!{}", v),
259        }
260    }
261}
262
263#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
264#[serde(rename_all = "camelCase")]
265pub struct Err {
266    pub transaction_id: TransactionId,
267    pub error_code: ErrorCode,
268    pub metadata: MetaData,
269}
270
271impl fmt::Display for Err {
272    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
273        write!(f, "server error {}: {}", self.error_code, self.metadata)
274    }
275}
276
277impl std::error::Error for Err {}
278
279#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
280#[serde(rename_all = "camelCase")]
281pub struct Handshake {
282    pub protocol_version: ProtocolVersion,
283}
284
285impl fmt::Display for Handshake {
286    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
287        write!(
288            f,
289            "handshake: supported protocol versions: {}",
290            self.protocol_version
291        )
292    }
293}
294
295#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
296#[serde(rename_all = "camelCase")]
297pub struct LsState {
298    pub transaction_id: TransactionId,
299    pub children: Vec<String>,
300}
301
302impl fmt::Display for LsState {
303    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
304        write!(
305            f,
306            "{}",
307            self.children
308                .iter()
309                .map(escape_path_segment)
310                .reduce(|a, b| format!("{a}\t{b}"))
311                .unwrap_or("".to_owned())
312        )
313    }
314}
315
316fn escape_path_segment(str: impl AsRef<str>) -> String {
317    let str = str.as_ref();
318    let white = str.contains(char::is_whitespace);
319    let single_quote = str.contains('\'');
320    let quote = str.contains('"');
321
322    if (quote || white) && !single_quote {
323        format!("'{str}'")
324    } else if single_quote {
325        str.replace('\'', r#"\'"#)
326    } else {
327        str.to_owned()
328    }
329}
330
331#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
332#[serde(rename_all = "camelCase")]
333pub struct ServerInfo {
334    pub version: Version,
335    pub supported_protocol_versions: Box<[ProtocolVersion]>,
336    #[deprecated(since = "1.1.0", note = "replaced by `supported_protocol_versions`")]
337    protocol_version: String,
338    pub authorization_required: bool,
339}
340
341#[allow(deprecated)]
342impl ServerInfo {
343    pub fn new(
344        version: Version,
345        supported_protocol_versions: Box<[ProtocolVersion]>,
346        authorization_required: bool,
347    ) -> Self {
348        Self {
349            version,
350            supported_protocol_versions,
351            protocol_version: "0.11".to_owned(),
352            authorization_required,
353        }
354    }
355}
356
357pub async fn write_line_and_flush(
358    msg: impl Serialize,
359    mut tx: impl AsyncWriteExt + Unpin,
360    send_timeout: Duration,
361    remote: impl Display,
362) -> ConnectionResult<()> {
363    let mut json = serde_json::to_string(&msg)?;
364    if json.contains('\n') {
365        return Err(ConnectionError::IoError(io::Error::new(
366            io::ErrorKind::InvalidData,
367            format!("invalid JSON: '{json}' contains line break"),
368        )));
369    }
370    if json.trim().is_empty() {
371        return Err(ConnectionError::IoError(io::Error::new(
372            io::ErrorKind::InvalidData,
373            format!("invalid JSON: '{json}' is empty"),
374        )));
375    }
376
377    json.push('\n');
378    let bytes = json.as_bytes();
379
380    debug!("Sending message: {json}");
381    trace!("Writing line …");
382    for chunk in bytes.chunks(1024) {
383        let mut written = 0;
384        while written < chunk.len() {
385            written += timeout(send_timeout, tx.write(&chunk[written..]))
386                .await
387                .map_err(|_| {
388                    ConnectionError::Timeout(format!(
389                        "timeout while sending tcp message to {remote}"
390                    ))
391                })??;
392        }
393    }
394    trace!("Writing line done.");
395    trace!("Flushing channel …");
396    tx.flush().await?;
397    trace!("Flushing channel done.");
398
399    Ok(())
400}
401
402pub async fn receive_msg<T: DeserializeOwned, R: AsyncRead + Unpin>(
403    rx: &mut Lines<BufReader<R>>,
404) -> ConnectionResult<Option<T>> {
405    let read = rx.next_line().await;
406    match read {
407        Ok(None) => Ok(None),
408        Ok(Some(json)) => {
409            debug!("Received message: {json}");
410            let sm = serde_json::from_str(&json);
411            if let Err(e) = &sm {
412                error!("Error deserializing message '{json}': {e}")
413            }
414            Ok(sm?)
415        }
416        Err(e) => Err(e.into()),
417    }
418}
419
420#[cfg(test)]
421mod test {
422    use serde_json::json;
423
424    use super::*;
425
426    #[test]
427    fn state_is_serialized_correctly() {
428        let state = State {
429            transaction_id: 1,
430            event: StateEvent::Value(json!(2)),
431        };
432
433        let json = r#"{"transactionId":1,"value":2}"#;
434
435        assert_eq!(json, &serde_json::to_string(&state).unwrap());
436
437        let state = State {
438            transaction_id: 1,
439            event: StateEvent::Deleted(json!(2)),
440        };
441
442        let json = r#"{"transactionId":1,"deleted":2}"#;
443
444        assert_eq!(json, &serde_json::to_string(&state).unwrap());
445    }
446
447    #[test]
448    fn state_is_deserialized_correctly() {
449        let state = State {
450            transaction_id: 1,
451            event: StateEvent::Value(json!(2)),
452        };
453
454        let json = r#"{"transactionId":1,"value":2}"#;
455
456        assert_eq!(state, serde_json::from_str(json).unwrap());
457
458        let state = State {
459            transaction_id: 1,
460            event: StateEvent::Deleted(json!(2)),
461        };
462
463        let json = r#"{"transactionId":1,"deleted":2}"#;
464
465        assert_eq!(state, serde_json::from_str(json).unwrap());
466    }
467
468    #[test]
469    fn pstate_is_serialized_correctly() {
470        let pstate = PState {
471            transaction_id: 1,
472            request_pattern: "$SYS/clients".to_owned(),
473            event: PStateEvent::KeyValuePairs(vec![KeyValuePair::of("$SYS/clients", 2)]),
474        };
475
476        let json = r#"{"transactionId":1,"requestPattern":"$SYS/clients","keyValuePairs":[{"key":"$SYS/clients","value":2}]}"#;
477
478        assert_eq!(json, &serde_json::to_string(&pstate).unwrap());
479
480        let pstate = PState {
481            transaction_id: 1,
482            request_pattern: "$SYS/clients".to_owned(),
483            event: PStateEvent::Deleted(vec![KeyValuePair::of("$SYS/clients", 2)]),
484        };
485
486        let json = r#"{"transactionId":1,"requestPattern":"$SYS/clients","deleted":[{"key":"$SYS/clients","value":2}]}"#;
487
488        assert_eq!(json, &serde_json::to_string(&pstate).unwrap());
489    }
490
491    #[test]
492    fn pstate_is_deserialized_correctly() {
493        let pstate = PState {
494            transaction_id: 1,
495            request_pattern: "$SYS/clients".to_owned(),
496            event: PStateEvent::KeyValuePairs(vec![KeyValuePair::of("$SYS/clients", 2)]),
497        };
498
499        let json = r#"{"transactionId":1,"requestPattern":"$SYS/clients","keyValuePairs":[{"key":"$SYS/clients","value":2}]}"#;
500
501        assert_eq!(pstate, serde_json::from_str(json).unwrap());
502
503        let pstate = PState {
504            transaction_id: 1,
505            request_pattern: "$SYS/clients".to_owned(),
506            event: PStateEvent::Deleted(vec![KeyValuePair::of("$SYS/clients", 2)]),
507        };
508
509        let json = r#"{"transactionId":1,"requestPattern":"$SYS/clients","deleted":[{"key":"$SYS/clients","value":2}]}"#;
510
511        assert_eq!(pstate, serde_json::from_str(json).unwrap());
512    }
513}