worterbuch_common/
lib.rs

1/*
2 *  Worterbuch common modules library
3 *
4 *  Copyright (C) 2024 Michael Bachmann
5 *
6 *  This program is free software: you can redistribute it and/or modify
7 *  it under the terms of the GNU Affero General Public License as published by
8 *  the Free Software Foundation, either version 3 of the License, or
9 *  (at your option) any later version.
10 *
11 *  This program is distributed in the hope that it will be useful,
12 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
13 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 *  GNU Affero General Public License for more details.
15 *
16 *  You should have received a copy of the GNU Affero General Public License
17 *  along with this program.  If not, see <https://www.gnu.org/licenses/>.
18 */
19
20pub mod benchmark;
21mod client;
22pub mod error;
23mod server;
24pub mod tcp;
25
26pub use client::*;
27use serde_json::json;
28pub use server::*;
29
30use error::WorterbuchResult;
31use serde::{Deserialize, Serialize, de::DeserializeOwned};
32use serde_repr::*;
33use sha2::{Digest, Sha256};
34use std::{fmt, ops::Deref};
35
36pub const SYSTEM_TOPIC_ROOT: &str = "$SYS";
37pub const SYSTEM_TOPIC_ROOT_PREFIX: &str = "$SYS/";
38pub const SYSTEM_TOPIC_CLIENTS: &str = "clients";
39pub const SYSTEM_TOPIC_VERSION: &str = "version";
40pub const SYSTEM_TOPIC_LICENSE: &str = "license";
41pub const SYSTEM_TOPIC_SOURCES: &str = "source-code";
42pub const SYSTEM_TOPIC_SUBSCRIPTIONS: &str = "subscriptions";
43pub const SYSTEM_TOPIC_CLIENTS_PROTOCOL: &str = "protocol";
44pub const SYSTEM_TOPIC_CLIENTS_ADDRESS: &str = "address";
45pub const SYSTEM_TOPIC_LAST_WILL: &str = "lastWill";
46pub const SYSTEM_TOPIC_GRAVE_GOODS: &str = "graveGoods";
47pub const SYSTEM_TOPIC_CLIENT_NAME: &str = "clientName";
48pub const SYSTEM_TOPIC_SUPPORTED_PROTOCOL_VERSION: &str = "protocolVersion";
49pub const SYSTEM_TOPIC_MODE: &str = "mode";
50
51pub type TransactionId = u64;
52pub type RequestPattern = String;
53pub type RequestPatterns = Vec<RequestPattern>;
54pub type Key = String;
55pub type Value = serde_json::Value;
56pub type KeyValuePairs = Vec<KeyValuePair>;
57pub type TypedKeyValuePairs<T> = Vec<TypedKeyValuePair<T>>;
58pub type MetaData = String;
59pub type Path = String;
60pub type ProtocolVersionSegment = u32;
61pub type ProtocolVersions = Vec<ProtocolVersion>;
62pub type LastWill = KeyValuePairs;
63pub type GraveGoods = RequestPatterns;
64pub type UniqueFlag = bool;
65pub type LiveOnlyFlag = bool;
66pub type AuthToken = String;
67pub type CasVersion = u64;
68
69#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
70#[serde(rename_all = "camelCase")]
71pub enum Privilege {
72    Read,
73    Write,
74    Delete,
75}
76
77impl fmt::Display for Privilege {
78    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79        match self {
80            Privilege::Read => "read".fmt(f),
81            Privilege::Write => "write".fmt(f),
82            Privilege::Delete => "delete".fmt(f),
83        }
84    }
85}
86
87#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize_repr, Deserialize_repr)]
88#[repr(u8)]
89pub enum ErrorCode {
90    IllegalWildcard = 0b00000000,
91    IllegalMultiWildcard = 0b00000001,
92    MultiWildcardAtIllegalPosition = 0b00000010,
93    IoError = 0b00000011,
94    SerdeError = 0b00000100,
95    NoSuchValue = 0b00000101,
96    NotSubscribed = 0b00000110,
97    ProtocolNegotiationFailed = 0b00000111,
98    InvalidServerResponse = 0b00001000,
99    ReadOnlyKey = 0b00001001,
100    AuthorizationFailed = 0b00001010,
101    AuthorizationRequired = 0b00001011,
102    AlreadyAuthorized = 0b00001100,
103    MissingValue = 0b00001101,
104    Unauthorized = 0b00001110,
105    NoPubStream = 0b00001111,
106    NotLeader = 0b00010000,
107    Cas = 0b00010001,
108    CasVersionMismatch = 0b00010010,
109    NotImplemented = 0b00010011,
110    KeyIsLocked = 0b00010100,
111    KeyIsNotLocked = 0b00010101,
112    Other = 0b11111111,
113}
114
115impl fmt::Display for ErrorCode {
116    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
117        (self.to_owned() as u8).fmt(f)
118    }
119}
120
121#[macro_export]
122macro_rules! topic {
123    ($( $x:expr ),+ ) => {
124        {
125            let mut segments = Vec::new();
126            $(
127                segments.push($x.to_string());
128            )+
129            segments.join("/")
130        }
131    };
132}
133
134pub type Version = String;
135
136#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
137pub struct ProtocolVersion(ProtocolVersionSegment, ProtocolVersionSegment);
138
139impl ProtocolVersion {
140    pub const fn new(major: ProtocolVersionSegment, minor: ProtocolVersionSegment) -> Self {
141        Self(major, minor)
142    }
143
144    pub const fn major(&self) -> ProtocolVersionSegment {
145        self.0
146    }
147
148    pub const fn minor(&self) -> ProtocolVersionSegment {
149        self.1
150    }
151
152    pub fn is_compatible_with_server(&self, server_version: &ProtocolVersion) -> bool {
153        self.major() == server_version.major() && self.minor() <= server_version.minor()
154    }
155
156    pub fn is_compatible_with_client_version(&self, client_version: &ProtocolVersion) -> bool {
157        self.major() == client_version.major() && self.minor() >= client_version.minor()
158    }
159}
160
161impl fmt::Display for ProtocolVersion {
162    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
163        write!(f, "{}.{}", self.0, self.1)
164    }
165}
166
167#[derive(Debug, Clone, PartialEq, Eq, Serialize, Hash, Deserialize)]
168pub enum Protocol {
169    TCP,
170    WS,
171    HTTP,
172    UNIX,
173}
174
175#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
176#[serde(rename_all = "camelCase")]
177pub struct KeyValuePair {
178    pub key: Key,
179    pub value: Value,
180}
181
182impl fmt::Display for KeyValuePair {
183    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
184        write!(f, "{}={}", self.key, self.value)
185    }
186}
187
188impl From<KeyValuePair> for Option<Value> {
189    fn from(kvp: KeyValuePair) -> Self {
190        Some(kvp.value)
191    }
192}
193
194impl From<KeyValuePair> for Value {
195    fn from(kvp: KeyValuePair) -> Self {
196        kvp.value
197    }
198}
199
200impl KeyValuePair {
201    pub fn new<S: Serialize>(key: String, value: S) -> Self {
202        (key, value).into()
203    }
204
205    pub fn of<S: Serialize>(key: String, value: S) -> Self {
206        KeyValuePair::new(key, value)
207    }
208}
209
210#[derive(Debug, Clone, PartialEq, Eq)]
211pub struct TypedKeyValuePair<T: DeserializeOwned> {
212    pub key: Key,
213    pub value: T,
214}
215
216impl<T: DeserializeOwned> TryFrom<KeyValuePair> for TypedKeyValuePair<T> {
217    type Error = serde_json::Error;
218
219    fn try_from(kvp: KeyValuePair) -> Result<Self, Self::Error> {
220        let deserialized = serde_json::from_value(kvp.value)?;
221        Ok(TypedKeyValuePair {
222            key: kvp.key,
223            value: deserialized,
224        })
225    }
226}
227
228impl<S: Serialize> From<(String, S)> for KeyValuePair {
229    fn from((key, value): (String, S)) -> Self {
230        let value = json!(value);
231        KeyValuePair { key, value }
232    }
233}
234
235impl<S: Serialize> From<(&str, S)> for KeyValuePair {
236    fn from((key, value): (&str, S)) -> Self {
237        let value = json!(value);
238        KeyValuePair {
239            key: key.to_owned(),
240            value,
241        }
242    }
243}
244
245// #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord, Tags)]
246pub type RegularKeySegment = String;
247
248pub fn parse_segments(pattern: &str) -> WorterbuchResult<Vec<RegularKeySegment>> {
249    let mut segments = Vec::new();
250    for segment in pattern.split('/') {
251        let ks: KeySegment = segment.into();
252        match ks {
253            KeySegment::Regular(reg) => segments.push(reg),
254            KeySegment::Wildcard => {
255                return Err(error::WorterbuchError::IllegalWildcard(pattern.to_owned()));
256            }
257            KeySegment::MultiWildcard => {
258                return Err(error::WorterbuchError::IllegalMultiWildcard(
259                    pattern.to_owned(),
260                ));
261            }
262        }
263    }
264    Ok(segments)
265}
266
267#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
268pub enum KeySegment {
269    Regular(RegularKeySegment),
270    Wildcard,
271    MultiWildcard,
272    // RegexWildcard(String),
273}
274
275pub fn format_path(path: &[KeySegment]) -> String {
276    path.iter()
277        .map(|seg| format!("{seg}"))
278        .collect::<Vec<String>>()
279        .join("/")
280}
281
282impl From<RegularKeySegment> for KeySegment {
283    fn from(reg: RegularKeySegment) -> Self {
284        Self::Regular(reg)
285    }
286}
287
288impl Deref for KeySegment {
289    type Target = str;
290
291    fn deref(&self) -> &Self::Target {
292        match self {
293            KeySegment::Regular(reg) => reg,
294            KeySegment::Wildcard => "?",
295            KeySegment::MultiWildcard => "#",
296        }
297    }
298}
299
300impl fmt::Display for KeySegment {
301    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
302        match self {
303            KeySegment::Regular(segment) => segment.fmt(f),
304            KeySegment::Wildcard => write!(f, "?"),
305            KeySegment::MultiWildcard => write!(f, "#"),
306            // PathSegment::RegexWildcard(regex) => write!(f, "?{regex}?"),
307        }
308    }
309}
310
311impl From<&str> for KeySegment {
312    fn from(str: &str) -> Self {
313        match str {
314            "?" => KeySegment::Wildcard,
315            "#" => KeySegment::MultiWildcard,
316            other => KeySegment::Regular(other.to_owned()),
317        }
318    }
319}
320
321impl KeySegment {
322    pub fn parse(pattern: impl AsRef<str>) -> Vec<KeySegment> {
323        let segments = pattern.as_ref().split('/');
324        segments.map(KeySegment::from).collect()
325    }
326}
327
328pub fn quote(str: impl AsRef<str>) -> String {
329    let str_ref = str.as_ref();
330    if str_ref.starts_with('\"') && str_ref.ends_with('\"') {
331        str_ref.to_owned()
332    } else {
333        format!("\"{str_ref}\"")
334    }
335}
336
337pub fn digest_token(auth_token: &Option<String>, client_id: String) -> Option<String> {
338    auth_token.as_deref().map(|token| {
339        let salted = client_id + token;
340        let mut hasher = Sha256::new();
341        hasher.update(salted.as_bytes());
342        format!("{:x}", hasher.finalize())
343    })
344}
345
346#[cfg(test)]
347mod test {
348    use crate::{ErrorCode, ProtocolVersion};
349    use serde_json::json;
350
351    #[test]
352    fn protocol_versions_are_sorted_correctly() {
353        assert!(ProtocolVersion::new(1, 2) < ProtocolVersion::new(3, 2));
354        assert!(ProtocolVersion::new(1, 2) == ProtocolVersion::new(1, 2));
355        assert!(ProtocolVersion::new(2, 1) > ProtocolVersion::new(1, 9));
356
357        let mut versions = vec![
358            ProtocolVersion::new(1, 2),
359            ProtocolVersion::new(0, 456),
360            ProtocolVersion::new(9, 0),
361            ProtocolVersion::new(3, 15),
362        ];
363        versions.sort();
364        assert_eq!(
365            vec![
366                ProtocolVersion::new(0, 456),
367                ProtocolVersion::new(1, 2),
368                ProtocolVersion::new(3, 15),
369                ProtocolVersion::new(9, 0)
370            ],
371            versions
372        );
373    }
374
375    #[test]
376    fn topic_macro_generates_topic_correctly() {
377        assert_eq!(
378            "hello/world/foo/bar",
379            topic!("hello", "world", "foo", "bar")
380        );
381    }
382
383    #[test]
384    fn error_codes_are_serialized_as_numbers() {
385        assert_eq!(
386            "1",
387            serde_json::to_string(&ErrorCode::IllegalMultiWildcard).unwrap()
388        )
389    }
390
391    #[test]
392    fn error_codes_are_deserialized_from_numbers() {
393        assert_eq!(
394            ErrorCode::ProtocolNegotiationFailed,
395            serde_json::from_str("7").unwrap()
396        )
397    }
398
399    #[test]
400    fn protocol_version_get_serialized_correctly() {
401        assert_eq!(&json!(ProtocolVersion::new(2, 1)).to_string(), "[2,1]")
402    }
403
404    #[test]
405    fn protocol_version_get_formatted_correctly() {
406        assert_eq!(&ProtocolVersion::new(2, 1).to_string(), "2.1")
407    }
408
409    #[test]
410    fn compatible_version_is_selected_correctly() {
411        let client_version = ProtocolVersion::new(1, 2);
412        let server_versions = [
413            ProtocolVersion::new(0, 11),
414            ProtocolVersion::new(1, 6),
415            ProtocolVersion::new(2, 0),
416        ];
417        let compatible_version = server_versions
418            .iter()
419            .find(|v| client_version.is_compatible_with_server(v));
420        assert_eq!(compatible_version, Some(&server_versions[1]))
421    }
422}