worterbuch_common/
server.rs

1/*
2 *  Worterbuch server messages module
3 *
4 *  Copyright (C) 2024 Michael Bachmann
5 *
6 *  This program is free software: you can redistribute it and/or modify
7 *  it under the terms of the GNU Affero General Public License as published by
8 *  the Free Software Foundation, either version 3 of the License, or
9 *  (at your option) any later version.
10 *
11 *  This program is distributed in the hope that it will be useful,
12 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
13 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 *  GNU Affero General Public License for more details.
15 *
16 *  You should have received a copy of the GNU Affero General Public License
17 *  along with this program.  If not, see <https://www.gnu.org/licenses/>.
18 */
19
20use crate::{
21    CasVersion, ErrorCode, KeyValuePair, KeyValuePairs, MetaData, ProtocolVersion, RequestPattern,
22    TransactionId, TypedKeyValuePair, TypedKeyValuePairs, Value, Version,
23    error::{ConnectionError, ConnectionResult},
24};
25use serde::{Deserialize, Serialize, de::DeserializeOwned};
26use std::{
27    fmt::{self, Display},
28    io,
29    time::Duration,
30};
31use tokio::{
32    io::{AsyncRead, AsyncWriteExt, BufReader, Lines},
33    time::timeout,
34};
35use tracing::{debug, error, trace};
36
37#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
38#[serde(rename_all = "camelCase")]
39pub enum ServerMessage {
40    Welcome(Welcome),
41    PState(PState),
42    Ack(Ack),
43    State(State),
44    CState(CState),
45    Err(Err),
46    Authorized(Ack),
47    LsState(LsState),
48}
49
50impl ServerMessage {
51    pub fn transaction_id(&self) -> Option<TransactionId> {
52        match self {
53            ServerMessage::Welcome(_) => None,
54            ServerMessage::PState(msg) => Some(msg.transaction_id),
55            ServerMessage::Ack(msg) => Some(msg.transaction_id),
56            ServerMessage::State(msg) => Some(msg.transaction_id),
57            ServerMessage::CState(msg) => Some(msg.transaction_id),
58            ServerMessage::Err(msg) => Some(msg.transaction_id),
59            ServerMessage::LsState(msg) => Some(msg.transaction_id),
60            ServerMessage::Authorized(_) => Some(0),
61        }
62    }
63}
64
65#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
66#[serde(rename_all = "camelCase")]
67pub struct Welcome {
68    pub info: ServerInfo,
69    pub client_id: String,
70}
71
72#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
73#[serde(rename_all = "camelCase")]
74pub struct PState {
75    pub transaction_id: TransactionId,
76    pub request_pattern: RequestPattern,
77    #[serde(flatten)]
78    pub event: PStateEvent,
79}
80
81#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
82#[serde(rename_all = "camelCase")]
83pub enum PStateEvent {
84    KeyValuePairs(KeyValuePairs),
85    Deleted(KeyValuePairs),
86}
87
88#[derive(Debug, Clone, PartialEq, Eq)]
89pub enum TypedPStateEvent<T: DeserializeOwned> {
90    KeyValuePairs(TypedKeyValuePairs<T>),
91    Deleted(TypedKeyValuePairs<T>),
92}
93
94impl<T: DeserializeOwned> TryFrom<PStateEvent> for TypedPStateEvent<T> {
95    type Error = serde_json::Error;
96
97    fn try_from(value: PStateEvent) -> Result<Self, Self::Error> {
98        match value {
99            PStateEvent::KeyValuePairs(kvps) => Ok(TypedPStateEvent::KeyValuePairs(
100                try_to_typed_key_value_pairs(kvps)?,
101            )),
102            PStateEvent::Deleted(kvps) => Ok(TypedPStateEvent::KeyValuePairs(
103                try_to_typed_key_value_pairs(kvps)?,
104            )),
105        }
106    }
107}
108
109fn try_to_typed_key_value_pairs<T: DeserializeOwned>(
110    kvps: KeyValuePairs,
111) -> Result<TypedKeyValuePairs<T>, serde_json::Error> {
112    let mut out = vec![];
113
114    for kvp in kvps {
115        out.push(kvp.try_into()?);
116    }
117
118    Ok(out)
119}
120
121pub type TypedPStateEvents<T> = Vec<TypedPStateEvent<T>>;
122
123impl From<PStateEvent> for Vec<Option<Value>> {
124    fn from(e: PStateEvent) -> Self {
125        match e {
126            PStateEvent::KeyValuePairs(kvps) => kvps.into_iter().map(KeyValuePair::into).collect(),
127            PStateEvent::Deleted(keys) => keys.into_iter().map(|_| Option::None).collect(),
128        }
129    }
130}
131
132impl From<PState> for Vec<Option<Value>> {
133    fn from(pstate: PState) -> Self {
134        pstate.event.into()
135    }
136}
137
138impl fmt::Display for PState {
139    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
140        match &self.event {
141            PStateEvent::KeyValuePairs(key_value_pairs) => {
142                let kvps: Vec<String> = key_value_pairs
143                    .iter()
144                    .map(|kvp| format!("{}={}", kvp.key, kvp.value))
145                    .collect();
146                let joined = kvps.join("\n");
147                write!(f, "{joined}")
148            }
149            PStateEvent::Deleted(key_value_pairs) => {
150                let kvps: Vec<String> = key_value_pairs
151                    .iter()
152                    .map(|kvp| format!("{}!={}", kvp.key, kvp.value))
153                    .collect();
154                let joined = kvps.join("\n");
155                write!(f, "{joined}")
156            }
157        }
158    }
159}
160
161#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
162#[serde(rename_all = "camelCase")]
163pub struct Ack {
164    pub transaction_id: TransactionId,
165}
166
167impl fmt::Display for Ack {
168    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
169        write!(f, "ack {}", self.transaction_id)
170    }
171}
172
173#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
174#[serde(rename_all = "camelCase")]
175pub struct State {
176    pub transaction_id: TransactionId,
177    #[serde(flatten)]
178    pub event: StateEvent,
179}
180
181#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
182#[serde(rename_all = "camelCase")]
183pub enum StateEvent {
184    Value(Value),
185    Deleted(Value),
186}
187
188#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
189#[serde(rename_all = "camelCase")]
190pub struct CState {
191    pub transaction_id: TransactionId,
192    #[serde(flatten)]
193    pub event: CStateEvent,
194}
195
196#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
197#[serde(rename_all = "camelCase")]
198pub struct CStateEvent {
199    pub value: Value,
200    pub version: CasVersion,
201}
202
203impl From<StateEvent> for Option<Value> {
204    fn from(e: StateEvent) -> Self {
205        match e {
206            StateEvent::Value(v) => Some(v),
207            StateEvent::Deleted(_) => None,
208        }
209    }
210}
211
212impl From<State> for Option<Value> {
213    fn from(state: State) -> Self {
214        state.event.into()
215    }
216}
217
218#[derive(Debug, Clone, PartialEq, Eq)]
219pub enum TypedStateEvent<T: DeserializeOwned> {
220    Value(T),
221    Deleted(T),
222}
223
224impl<T: DeserializeOwned> From<TypedStateEvent<T>> for Option<T> {
225    fn from(e: TypedStateEvent<T>) -> Self {
226        match e {
227            TypedStateEvent::Value(v) => Some(v),
228            TypedStateEvent::Deleted(_) => None,
229        }
230    }
231}
232
233impl<T: DeserializeOwned> From<TypedKeyValuePair<T>> for TypedStateEvent<T> {
234    fn from(kvp: TypedKeyValuePair<T>) -> Self {
235        TypedStateEvent::Value(kvp.value)
236    }
237}
238
239impl<T: DeserializeOwned + TryFrom<Value, Error = serde_json::Error>> TryFrom<StateEvent>
240    for TypedStateEvent<T>
241{
242    type Error = serde_json::Error;
243
244    fn try_from(e: StateEvent) -> Result<Self, Self::Error> {
245        match e {
246            StateEvent::Value(v) => Ok(TypedStateEvent::Value(v.try_into()?)),
247            StateEvent::Deleted(v) => Ok(TypedStateEvent::Deleted(v.try_into()?)),
248        }
249    }
250}
251
252pub type TypedStateEvents<T> = Vec<TypedStateEvent<T>>;
253
254impl fmt::Display for State {
255    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
256        match &self.event {
257            StateEvent::Value(v) => write!(f, "{v}"),
258            StateEvent::Deleted(v) => write!(f, "!{v}"),
259        }
260    }
261}
262
263#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
264#[serde(rename_all = "camelCase")]
265pub struct Err {
266    pub transaction_id: TransactionId,
267    pub error_code: ErrorCode,
268    pub metadata: MetaData,
269}
270
271impl fmt::Display for Err {
272    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
273        write!(f, "server error {}: {}", self.error_code, self.metadata)
274    }
275}
276
277impl std::error::Error for Err {}
278
279#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
280#[serde(rename_all = "camelCase")]
281pub struct Handshake {
282    pub protocol_version: ProtocolVersion,
283}
284
285impl fmt::Display for Handshake {
286    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
287        write!(
288            f,
289            "handshake: supported protocol versions: {}",
290            self.protocol_version
291        )
292    }
293}
294
295#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
296#[serde(rename_all = "camelCase")]
297pub struct LsState {
298    pub transaction_id: TransactionId,
299    pub children: Vec<String>,
300}
301
302impl fmt::Display for LsState {
303    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
304        write!(
305            f,
306            "{}",
307            self.children
308                .iter()
309                .map(escape_path_segment)
310                .reduce(|a, b| format!("{a}\t{b}"))
311                .unwrap_or("".to_owned())
312        )
313    }
314}
315
316fn escape_path_segment(str: impl AsRef<str>) -> String {
317    let str = str.as_ref();
318    let white = str.contains(char::is_whitespace);
319    let single_quote = str.contains('\'');
320    let quote = str.contains('"');
321
322    if (quote || white) && !single_quote {
323        format!("'{str}'")
324    } else if single_quote {
325        str.replace('\'', r#"\'"#)
326    } else {
327        str.to_owned()
328    }
329}
330
331#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
332#[serde(rename_all = "camelCase")]
333pub struct ServerInfo {
334    pub version: Version,
335    pub supported_protocol_versions: Box<[ProtocolVersion]>,
336    #[deprecated(since = "1.1.0", note = "replaced by `supported_protocol_versions`")]
337    protocol_version: String,
338    pub authorization_required: bool,
339}
340
341#[allow(deprecated)]
342impl ServerInfo {
343    pub fn new(
344        version: Version,
345        supported_protocol_versions: Box<[ProtocolVersion]>,
346        authorization_required: bool,
347    ) -> Self {
348        Self {
349            version,
350            supported_protocol_versions,
351            protocol_version: "0.11".to_owned(),
352            authorization_required,
353        }
354    }
355}
356
357pub async fn write_line_and_flush(
358    msg: impl Serialize,
359    mut tx: impl AsyncWriteExt + Unpin,
360    send_timeout: Option<Duration>,
361    remote: impl Display,
362) -> ConnectionResult<()> {
363    let mut json = serde_json::to_string(&msg)?;
364    if json.contains('\n') {
365        return Err(ConnectionError::IoError(Box::new(io::Error::new(
366            io::ErrorKind::InvalidData,
367            format!("invalid JSON: '{json}' contains line break"),
368        ))));
369    }
370    if json.trim().is_empty() {
371        return Err(ConnectionError::IoError(Box::new(io::Error::new(
372            io::ErrorKind::InvalidData,
373            format!("invalid JSON: '{json}' is empty"),
374        ))));
375    }
376
377    json.push('\n');
378    let bytes = json.as_bytes();
379
380    debug!("Sending message with timeout {send_timeout:?}: {json}");
381    trace!("Writing line …");
382    for chunk in bytes.chunks(1024) {
383        let mut written = 0;
384        while written < chunk.len() {
385            if let Some(send_timeout) = send_timeout {
386                written += timeout(send_timeout, tx.write(&chunk[written..]))
387                    .await
388                    .map_err(|_| {
389                        ConnectionError::Timeout(Box::new(format!(
390                            "timeout while sending tcp message to {remote}"
391                        )))
392                    })??;
393            } else {
394                written += tx.write(&chunk[written..]).await?;
395            }
396        }
397    }
398    trace!("Writing line done.");
399    trace!("Flushing channel …");
400    if let Some(send_timeout) = send_timeout {
401        timeout(send_timeout, tx.flush()).await.map_err(|_| {
402            ConnectionError::Timeout(Box::new(format!(
403                "timeout while sending tcp message to {remote}"
404            )))
405        })??;
406    } else {
407        tx.flush().await?;
408    }
409    trace!("Flushing channel done.");
410
411    Ok(())
412}
413
414pub async fn receive_msg<T: DeserializeOwned, R: AsyncRead + Unpin>(
415    rx: &mut Lines<BufReader<R>>,
416) -> ConnectionResult<Option<T>> {
417    let read = rx.next_line().await;
418    match read {
419        Ok(None) => Ok(None),
420        Ok(Some(json)) => {
421            debug!("Received message: {json}");
422            let sm = serde_json::from_str(&json);
423            if let Err(e) = &sm {
424                error!("Error deserializing message '{json}': {e}")
425            }
426            Ok(sm?)
427        }
428        Err(e) => Err(e.into()),
429    }
430}
431
432#[cfg(test)]
433mod test {
434    use serde_json::json;
435
436    use super::*;
437
438    #[test]
439    fn state_is_serialized_correctly() {
440        let state = State {
441            transaction_id: 1,
442            event: StateEvent::Value(json!(2)),
443        };
444
445        let json = r#"{"transactionId":1,"value":2}"#;
446
447        assert_eq!(json, &serde_json::to_string(&state).unwrap());
448
449        let state = State {
450            transaction_id: 1,
451            event: StateEvent::Deleted(json!(2)),
452        };
453
454        let json = r#"{"transactionId":1,"deleted":2}"#;
455
456        assert_eq!(json, &serde_json::to_string(&state).unwrap());
457    }
458
459    #[test]
460    fn state_is_deserialized_correctly() {
461        let state = State {
462            transaction_id: 1,
463            event: StateEvent::Value(json!(2)),
464        };
465
466        let json = r#"{"transactionId":1,"value":2}"#;
467
468        assert_eq!(state, serde_json::from_str(json).unwrap());
469
470        let state = State {
471            transaction_id: 1,
472            event: StateEvent::Deleted(json!(2)),
473        };
474
475        let json = r#"{"transactionId":1,"deleted":2}"#;
476
477        assert_eq!(state, serde_json::from_str(json).unwrap());
478    }
479
480    #[test]
481    fn pstate_is_serialized_correctly() {
482        let pstate = PState {
483            transaction_id: 1,
484            request_pattern: "$SYS/clients".to_owned(),
485            event: PStateEvent::KeyValuePairs(vec![KeyValuePair::of("$SYS/clients", 2)]),
486        };
487
488        let json = r#"{"transactionId":1,"requestPattern":"$SYS/clients","keyValuePairs":[{"key":"$SYS/clients","value":2}]}"#;
489
490        assert_eq!(json, &serde_json::to_string(&pstate).unwrap());
491
492        let pstate = PState {
493            transaction_id: 1,
494            request_pattern: "$SYS/clients".to_owned(),
495            event: PStateEvent::Deleted(vec![KeyValuePair::of("$SYS/clients", 2)]),
496        };
497
498        let json = r#"{"transactionId":1,"requestPattern":"$SYS/clients","deleted":[{"key":"$SYS/clients","value":2}]}"#;
499
500        assert_eq!(json, &serde_json::to_string(&pstate).unwrap());
501    }
502
503    #[test]
504    fn pstate_is_deserialized_correctly() {
505        let pstate = PState {
506            transaction_id: 1,
507            request_pattern: "$SYS/clients".to_owned(),
508            event: PStateEvent::KeyValuePairs(vec![KeyValuePair::of("$SYS/clients", 2)]),
509        };
510
511        let json = r#"{"transactionId":1,"requestPattern":"$SYS/clients","keyValuePairs":[{"key":"$SYS/clients","value":2}]}"#;
512
513        assert_eq!(pstate, serde_json::from_str(json).unwrap());
514
515        let pstate = PState {
516            transaction_id: 1,
517            request_pattern: "$SYS/clients".to_owned(),
518            event: PStateEvent::Deleted(vec![KeyValuePair::of("$SYS/clients", 2)]),
519        };
520
521        let json = r#"{"transactionId":1,"requestPattern":"$SYS/clients","deleted":[{"key":"$SYS/clients","value":2}]}"#;
522
523        assert_eq!(pstate, serde_json::from_str(json).unwrap());
524    }
525}