worterbuch_common/
lib.rs

1/*
2 *  Worterbuch common modules library
3 *
4 *  Copyright (C) 2024 Michael Bachmann
5 *
6 *  This program is free software: you can redistribute it and/or modify
7 *  it under the terms of the GNU Affero General Public License as published by
8 *  the Free Software Foundation, either version 3 of the License, or
9 *  (at your option) any later version.
10 *
11 *  This program is distributed in the hope that it will be useful,
12 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
13 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 *  GNU Affero General Public License for more details.
15 *
16 *  You should have received a copy of the GNU Affero General Public License
17 *  along with this program.  If not, see <https://www.gnu.org/licenses/>.
18 */
19
20#[cfg(feature = "benchmark")]
21pub mod benchmark;
22mod client;
23pub mod error;
24mod server;
25
26pub use client::*;
27use serde_json::json;
28pub use server::*;
29
30use error::WorterbuchResult;
31use serde::{Deserialize, Serialize, de::DeserializeOwned};
32use serde_repr::*;
33use sha2::{Digest, Sha256};
34use std::{fmt, ops::Deref};
35
36pub const SYSTEM_TOPIC_ROOT: &str = "$SYS";
37pub const SYSTEM_TOPIC_ROOT_PREFIX: &str = "$SYS/";
38pub const SYSTEM_TOPIC_CLIENTS: &str = "clients";
39pub const SYSTEM_TOPIC_VERSION: &str = "version";
40pub const SYSTEM_TOPIC_LICENSE: &str = "license";
41pub const SYSTEM_TOPIC_SOURCES: &str = "source-code";
42pub const SYSTEM_TOPIC_SUBSCRIPTIONS: &str = "subscriptions";
43pub const SYSTEM_TOPIC_CLIENTS_PROTOCOL: &str = "protocol";
44pub const SYSTEM_TOPIC_CLIENTS_ADDRESS: &str = "address";
45pub const SYSTEM_TOPIC_LAST_WILL: &str = "lastWill";
46pub const SYSTEM_TOPIC_GRAVE_GOODS: &str = "graveGoods";
47pub const SYSTEM_TOPIC_CLIENT_NAME: &str = "clientName";
48pub const SYSTEM_TOPIC_SUPPORTED_PROTOCOL_VERSION: &str = "protocolVersion";
49pub const SYSTEM_TOPIC_MODE: &str = "mode";
50
51pub type TransactionId = u64;
52pub type RequestPattern = String;
53pub type RequestPatterns = Vec<RequestPattern>;
54pub type Key = String;
55pub type Value = serde_json::Value;
56pub type KeyValuePairs = Vec<KeyValuePair>;
57pub type TypedKeyValuePairs<T> = Vec<TypedKeyValuePair<T>>;
58pub type MetaData = String;
59pub type Path = String;
60pub type ProtocolVersionSegment = u32;
61pub type ProtocolVersions = Vec<ProtocolVersion>;
62pub type LastWill = KeyValuePairs;
63pub type GraveGoods = RequestPatterns;
64pub type UniqueFlag = bool;
65pub type LiveOnlyFlag = bool;
66pub type AuthToken = String;
67pub type CasVersion = u64;
68
69#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
70#[serde(rename_all = "camelCase")]
71pub enum Privilege {
72    Read,
73    Write,
74    Delete,
75    Profile,
76}
77
78impl fmt::Display for Privilege {
79    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80        match self {
81            Privilege::Read => "read".fmt(f),
82            Privilege::Write => "write".fmt(f),
83            Privilege::Delete => "delete".fmt(f),
84            Privilege::Profile => "profile".fmt(f),
85        }
86    }
87}
88
89#[derive(Debug, Serialize, Deserialize)]
90#[serde(rename_all = "camelCase")]
91pub enum AuthCheck<'a> {
92    Pattern(&'a str),
93    Flag,
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
97#[serde(rename_all = "camelCase")]
98pub enum AuthCheckOwned {
99    Pattern(String),
100    Flag,
101}
102
103impl<'a> From<AuthCheck<'a>> for AuthCheckOwned {
104    fn from(value: AuthCheck<'a>) -> Self {
105        match value {
106            AuthCheck::Pattern(p) => AuthCheckOwned::Pattern(p.to_owned()),
107            AuthCheck::Flag => AuthCheckOwned::Flag,
108        }
109    }
110}
111
112impl fmt::Display for AuthCheckOwned {
113    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
114        match self {
115            AuthCheckOwned::Pattern(p) => p.fmt(f),
116            AuthCheckOwned::Flag => true.fmt(f),
117        }
118    }
119}
120
121#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize_repr, Deserialize_repr)]
122#[repr(u8)]
123pub enum ErrorCode {
124    IllegalWildcard = 0b00000000,
125    IllegalMultiWildcard = 0b00000001,
126    MultiWildcardAtIllegalPosition = 0b00000010,
127    IoError = 0b00000011,
128    SerdeError = 0b00000100,
129    NoSuchValue = 0b00000101,
130    NotSubscribed = 0b00000110,
131    ProtocolNegotiationFailed = 0b00000111,
132    InvalidServerResponse = 0b00001000,
133    ReadOnlyKey = 0b00001001,
134    AuthorizationFailed = 0b00001010,
135    AuthorizationRequired = 0b00001011,
136    AlreadyAuthorized = 0b00001100,
137    MissingValue = 0b00001101,
138    Unauthorized = 0b00001110,
139    NoPubStream = 0b00001111,
140    NotLeader = 0b00010000,
141    Cas = 0b00010001,
142    CasVersionMismatch = 0b00010010,
143    NotImplemented = 0b00010011,
144    KeyIsLocked = 0b00010100,
145    KeyIsNotLocked = 0b00010101,
146    LockAcquisitionCancelled = 0b00010110,
147    Other = 0b11111111,
148}
149
150impl fmt::Display for ErrorCode {
151    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
152        (self.to_owned() as u8).fmt(f)
153    }
154}
155
156#[macro_export]
157macro_rules! topic {
158    ($( $x:expr ),+ ) => {
159        {
160            let mut segments = Vec::new();
161            $(
162                segments.push($x.to_string());
163            )+
164            segments.join("/")
165        }
166    };
167}
168
169pub type Version = String;
170
171#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
172pub struct ProtocolVersion(ProtocolVersionSegment, ProtocolVersionSegment);
173
174impl ProtocolVersion {
175    pub const fn new(major: ProtocolVersionSegment, minor: ProtocolVersionSegment) -> Self {
176        Self(major, minor)
177    }
178
179    pub const fn major(&self) -> ProtocolVersionSegment {
180        self.0
181    }
182
183    pub const fn minor(&self) -> ProtocolVersionSegment {
184        self.1
185    }
186
187    pub fn is_compatible_with_server(&self, server_version: &ProtocolVersion) -> bool {
188        self.major() == server_version.major() && self.minor() <= server_version.minor()
189    }
190
191    pub fn is_compatible_with_client_version(&self, client_version: &ProtocolVersion) -> bool {
192        self.major() == client_version.major() && self.minor() >= client_version.minor()
193    }
194}
195
196impl fmt::Display for ProtocolVersion {
197    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
198        write!(f, "{}.{}", self.0, self.1)
199    }
200}
201
202#[derive(Debug, Clone, PartialEq, Eq, Serialize, Hash, Deserialize)]
203pub enum Protocol {
204    TCP,
205    WS,
206    HTTP,
207    UNIX,
208}
209
210#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
211#[serde(rename_all = "camelCase")]
212pub struct KeyValuePair {
213    pub key: Key,
214    pub value: Value,
215}
216
217impl fmt::Display for KeyValuePair {
218    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
219        write!(f, "{}={}", self.key, self.value)
220    }
221}
222
223impl From<KeyValuePair> for Option<Value> {
224    fn from(kvp: KeyValuePair) -> Self {
225        Some(kvp.value)
226    }
227}
228
229impl From<KeyValuePair> for Value {
230    fn from(kvp: KeyValuePair) -> Self {
231        kvp.value
232    }
233}
234
235impl KeyValuePair {
236    pub fn new(key: String, value: Value) -> Self {
237        KeyValuePair { key, value }
238    }
239
240    pub fn of<S: Serialize>(key: impl Into<String>, value: S) -> Self {
241        KeyValuePair::new(key.into(), json!(value))
242    }
243}
244
245#[derive(Debug, Clone, PartialEq, Eq)]
246pub struct TypedKeyValuePair<T: DeserializeOwned> {
247    pub key: Key,
248    pub value: T,
249}
250
251impl<T: DeserializeOwned> TryFrom<KeyValuePair> for TypedKeyValuePair<T> {
252    type Error = serde_json::Error;
253
254    fn try_from(kvp: KeyValuePair) -> Result<Self, Self::Error> {
255        let deserialized = serde_json::from_value(kvp.value)?;
256        Ok(TypedKeyValuePair {
257            key: kvp.key,
258            value: deserialized,
259        })
260    }
261}
262
263impl<S: Serialize> From<(String, S)> for KeyValuePair {
264    fn from((key, value): (String, S)) -> Self {
265        let value = json!(value);
266        KeyValuePair { key, value }
267    }
268}
269
270impl<S: Serialize> From<(&str, S)> for KeyValuePair {
271    fn from((key, value): (&str, S)) -> Self {
272        let value = json!(value);
273        KeyValuePair {
274            key: key.to_owned(),
275            value,
276        }
277    }
278}
279
280// #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord, Tags)]
281pub type RegularKeySegment = String;
282
283pub fn parse_segments(pattern: &str) -> WorterbuchResult<Vec<RegularKeySegment>> {
284    let mut segments = Vec::new();
285    for segment in pattern.split('/') {
286        let ks: KeySegment = segment.into();
287        match ks {
288            KeySegment::Regular(reg) => segments.push(reg),
289            KeySegment::Wildcard => {
290                return Err(error::WorterbuchError::IllegalWildcard(pattern.to_owned()));
291            }
292            KeySegment::MultiWildcard => {
293                return Err(error::WorterbuchError::IllegalMultiWildcard(
294                    pattern.to_owned(),
295                ));
296            }
297        }
298    }
299    Ok(segments)
300}
301
302#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
303pub enum KeySegment {
304    Regular(RegularKeySegment),
305    Wildcard,
306    MultiWildcard,
307    // RegexWildcard(String),
308}
309
310pub fn format_path(path: &[KeySegment]) -> String {
311    path.iter()
312        .map(|seg| format!("{seg}"))
313        .collect::<Vec<String>>()
314        .join("/")
315}
316
317impl From<RegularKeySegment> for KeySegment {
318    fn from(reg: RegularKeySegment) -> Self {
319        Self::Regular(reg)
320    }
321}
322
323impl Deref for KeySegment {
324    type Target = str;
325
326    fn deref(&self) -> &Self::Target {
327        match self {
328            KeySegment::Regular(reg) => reg,
329            KeySegment::Wildcard => "?",
330            KeySegment::MultiWildcard => "#",
331        }
332    }
333}
334
335impl fmt::Display for KeySegment {
336    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
337        match self {
338            KeySegment::Regular(segment) => segment.fmt(f),
339            KeySegment::Wildcard => write!(f, "?"),
340            KeySegment::MultiWildcard => write!(f, "#"),
341            // PathSegment::RegexWildcard(regex) => write!(f, "?{regex}?"),
342        }
343    }
344}
345
346impl From<&str> for KeySegment {
347    fn from(str: &str) -> Self {
348        match str {
349            "?" => KeySegment::Wildcard,
350            "#" => KeySegment::MultiWildcard,
351            other => KeySegment::Regular(other.to_owned()),
352        }
353    }
354}
355
356impl KeySegment {
357    pub fn parse(pattern: impl AsRef<str>) -> Vec<KeySegment> {
358        let segments = pattern.as_ref().split('/');
359        segments.map(KeySegment::from).collect()
360    }
361}
362
363pub fn quote(str: impl AsRef<str>) -> String {
364    let str_ref = str.as_ref();
365    if str_ref.starts_with('\"') && str_ref.ends_with('\"') {
366        str_ref.to_owned()
367    } else {
368        format!("\"{str_ref}\"")
369    }
370}
371
372pub fn digest_token(auth_token: &Option<String>, client_id: String) -> Option<String> {
373    auth_token.as_deref().map(|token| {
374        let salted = client_id + token;
375        let mut hasher = Sha256::new();
376        hasher.update(salted.as_bytes());
377        format!("{:x}", hasher.finalize())
378    })
379}
380
381#[cfg(test)]
382mod test {
383    use crate::{ErrorCode, ProtocolVersion};
384    use serde_json::json;
385
386    #[test]
387    fn protocol_versions_are_sorted_correctly() {
388        assert!(ProtocolVersion::new(1, 2) < ProtocolVersion::new(3, 2));
389        assert!(ProtocolVersion::new(1, 2) == ProtocolVersion::new(1, 2));
390        assert!(ProtocolVersion::new(2, 1) > ProtocolVersion::new(1, 9));
391
392        let mut versions = vec![
393            ProtocolVersion::new(1, 2),
394            ProtocolVersion::new(0, 456),
395            ProtocolVersion::new(9, 0),
396            ProtocolVersion::new(3, 15),
397        ];
398        versions.sort();
399        assert_eq!(
400            vec![
401                ProtocolVersion::new(0, 456),
402                ProtocolVersion::new(1, 2),
403                ProtocolVersion::new(3, 15),
404                ProtocolVersion::new(9, 0)
405            ],
406            versions
407        );
408    }
409
410    #[test]
411    fn topic_macro_generates_topic_correctly() {
412        assert_eq!(
413            "hello/world/foo/bar",
414            topic!("hello", "world", "foo", "bar")
415        );
416    }
417
418    #[test]
419    fn error_codes_are_serialized_as_numbers() {
420        assert_eq!(
421            "1",
422            serde_json::to_string(&ErrorCode::IllegalMultiWildcard).unwrap()
423        )
424    }
425
426    #[test]
427    fn error_codes_are_deserialized_from_numbers() {
428        assert_eq!(
429            ErrorCode::ProtocolNegotiationFailed,
430            serde_json::from_str("7").unwrap()
431        )
432    }
433
434    #[test]
435    fn protocol_version_get_serialized_correctly() {
436        assert_eq!(&json!(ProtocolVersion::new(2, 1)).to_string(), "[2,1]")
437    }
438
439    #[test]
440    fn protocol_version_get_formatted_correctly() {
441        assert_eq!(&ProtocolVersion::new(2, 1).to_string(), "2.1")
442    }
443
444    #[test]
445    fn compatible_version_is_selected_correctly() {
446        let client_version = ProtocolVersion::new(1, 2);
447        let server_versions = [
448            ProtocolVersion::new(0, 11),
449            ProtocolVersion::new(1, 6),
450            ProtocolVersion::new(2, 0),
451        ];
452        let compatible_version = server_versions
453            .iter()
454            .find(|v| client_version.is_compatible_with_server(v));
455        assert_eq!(compatible_version, Some(&server_versions[1]))
456    }
457}