worterbuch_common/
lib.rs

1/*
2 *  Worterbuch common modules library
3 *
4 *  Copyright (C) 2024 Michael Bachmann
5 *
6 *  This program is free software: you can redistribute it and/or modify
7 *  it under the terms of the GNU Affero General Public License as published by
8 *  the Free Software Foundation, either version 3 of the License, or
9 *  (at your option) any later version.
10 *
11 *  This program is distributed in the hope that it will be useful,
12 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
13 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 *  GNU Affero General Public License for more details.
15 *
16 *  You should have received a copy of the GNU Affero General Public License
17 *  along with this program.  If not, see <https://www.gnu.org/licenses/>.
18 */
19
20#[cfg(feature = "benchmark")]
21pub mod benchmark;
22mod client;
23pub mod error;
24mod server;
25
26pub use client::*;
27use serde_json::json;
28pub use server::*;
29
30use error::WorterbuchResult;
31use serde::{Deserialize, Serialize, de::DeserializeOwned};
32use serde_repr::*;
33use sha2::{Digest, Sha256};
34use std::{fmt, ops::Deref};
35#[cfg(feature = "jemalloc")]
36mod jemalloc;
37#[cfg(feature = "jemalloc")]
38pub mod profiling;
39
40pub const SYSTEM_TOPIC_ROOT: &str = "$SYS";
41pub const SYSTEM_TOPIC_ROOT_PREFIX: &str = "$SYS/";
42pub const SYSTEM_TOPIC_CLIENTS: &str = "clients";
43pub const SYSTEM_TOPIC_VERSION: &str = "version";
44pub const SYSTEM_TOPIC_LICENSE: &str = "license";
45pub const SYSTEM_TOPIC_SOURCES: &str = "source-code";
46pub const SYSTEM_TOPIC_SUBSCRIPTIONS: &str = "subscriptions";
47pub const SYSTEM_TOPIC_CLIENTS_PROTOCOL: &str = "protocol";
48pub const SYSTEM_TOPIC_CLIENTS_ADDRESS: &str = "address";
49pub const SYSTEM_TOPIC_LAST_WILL: &str = "lastWill";
50pub const SYSTEM_TOPIC_GRAVE_GOODS: &str = "graveGoods";
51pub const SYSTEM_TOPIC_CLIENT_NAME: &str = "clientName";
52pub const SYSTEM_TOPIC_SUPPORTED_PROTOCOL_VERSION: &str = "protocolVersion";
53pub const SYSTEM_TOPIC_MODE: &str = "mode";
54
55pub type TransactionId = u64;
56pub type RequestPattern = String;
57pub type RequestPatterns = Vec<RequestPattern>;
58pub type Key = String;
59pub type Value = serde_json::Value;
60pub type KeyValuePairs = Vec<KeyValuePair>;
61pub type TypedKeyValuePairs<T> = Vec<TypedKeyValuePair<T>>;
62pub type MetaData = String;
63pub type Path = String;
64pub type ProtocolVersionSegment = u32;
65pub type ProtocolVersions = Vec<ProtocolVersion>;
66pub type LastWill = KeyValuePairs;
67pub type GraveGoods = RequestPatterns;
68pub type UniqueFlag = bool;
69pub type LiveOnlyFlag = bool;
70pub type AuthToken = String;
71pub type CasVersion = u64;
72
73#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
74#[serde(rename_all = "camelCase")]
75pub enum Privilege {
76    Read,
77    Write,
78    Delete,
79    Profile,
80    WebLogin,
81}
82
83impl fmt::Display for Privilege {
84    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85        match self {
86            Privilege::Read => "read".fmt(f),
87            Privilege::Write => "write".fmt(f),
88            Privilege::Delete => "delete".fmt(f),
89            Privilege::Profile => "profile".fmt(f),
90            Privilege::WebLogin => "web-login".fmt(f),
91        }
92    }
93}
94
95#[derive(Debug, Serialize, Deserialize)]
96#[serde(rename_all = "camelCase")]
97pub enum AuthCheck<'a> {
98    Pattern(&'a str),
99    Flag,
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
103#[serde(rename_all = "camelCase")]
104pub enum AuthCheckOwned {
105    Pattern(String),
106    Flag,
107}
108
109impl<'a> From<AuthCheck<'a>> for AuthCheckOwned {
110    fn from(value: AuthCheck<'a>) -> Self {
111        match value {
112            AuthCheck::Pattern(p) => AuthCheckOwned::Pattern(p.to_owned()),
113            AuthCheck::Flag => AuthCheckOwned::Flag,
114        }
115    }
116}
117
118impl fmt::Display for AuthCheckOwned {
119    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
120        match self {
121            AuthCheckOwned::Pattern(p) => p.fmt(f),
122            AuthCheckOwned::Flag => true.fmt(f),
123        }
124    }
125}
126
127#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize_repr, Deserialize_repr)]
128#[repr(u8)]
129pub enum ErrorCode {
130    IllegalWildcard = 0b00000000,
131    IllegalMultiWildcard = 0b00000001,
132    MultiWildcardAtIllegalPosition = 0b00000010,
133    IoError = 0b00000011,
134    SerdeError = 0b00000100,
135    NoSuchValue = 0b00000101,
136    NotSubscribed = 0b00000110,
137    ProtocolNegotiationFailed = 0b00000111,
138    InvalidServerResponse = 0b00001000,
139    ReadOnlyKey = 0b00001001,
140    AuthorizationFailed = 0b00001010,
141    AuthorizationRequired = 0b00001011,
142    AlreadyAuthorized = 0b00001100,
143    MissingValue = 0b00001101,
144    Unauthorized = 0b00001110,
145    NoPubStream = 0b00001111,
146    NotLeader = 0b00010000,
147    Cas = 0b00010001,
148    CasVersionMismatch = 0b00010010,
149    NotImplemented = 0b00010011,
150    KeyIsLocked = 0b00010100,
151    KeyIsNotLocked = 0b00010101,
152    LockAcquisitionCancelled = 0b00010110,
153    FeatureDisabled = 0b00010111,
154    Other = 0b11111111,
155}
156
157impl fmt::Display for ErrorCode {
158    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
159        (self.to_owned() as u8).fmt(f)
160    }
161}
162
163#[macro_export]
164macro_rules! topic {
165    ($( $x:expr ),+ ) => {
166        {
167            let mut segments = Vec::new();
168            $(
169                segments.push($x.to_string());
170            )+
171            segments.join("/")
172        }
173    };
174}
175
176pub type Version = String;
177
178#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
179pub struct ProtocolVersion(ProtocolVersionSegment, ProtocolVersionSegment);
180
181impl ProtocolVersion {
182    pub const fn new(major: ProtocolVersionSegment, minor: ProtocolVersionSegment) -> Self {
183        Self(major, minor)
184    }
185
186    pub const fn major(&self) -> ProtocolVersionSegment {
187        self.0
188    }
189
190    pub const fn minor(&self) -> ProtocolVersionSegment {
191        self.1
192    }
193
194    pub fn is_compatible_with_server(&self, server_version: &ProtocolVersion) -> bool {
195        self.major() == server_version.major() && self.minor() <= server_version.minor()
196    }
197
198    pub fn is_compatible_with_client_version(&self, client_version: &ProtocolVersion) -> bool {
199        self.major() == client_version.major() && self.minor() >= client_version.minor()
200    }
201}
202
203impl fmt::Display for ProtocolVersion {
204    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
205        write!(f, "{}.{}", self.0, self.1)
206    }
207}
208
209#[derive(Debug, Clone, PartialEq, Eq, Serialize, Hash, Deserialize)]
210pub enum Protocol {
211    TCP,
212    WS,
213    HTTP,
214    UNIX,
215}
216
217#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
218#[serde(rename_all = "camelCase")]
219pub struct KeyValuePair {
220    pub key: Key,
221    pub value: Value,
222}
223
224impl fmt::Display for KeyValuePair {
225    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
226        write!(f, "{}={}", self.key, self.value)
227    }
228}
229
230impl From<KeyValuePair> for Option<Value> {
231    fn from(kvp: KeyValuePair) -> Self {
232        Some(kvp.value)
233    }
234}
235
236impl From<KeyValuePair> for Value {
237    fn from(kvp: KeyValuePair) -> Self {
238        kvp.value
239    }
240}
241
242impl KeyValuePair {
243    pub fn new(key: String, value: Value) -> Self {
244        KeyValuePair { key, value }
245    }
246
247    pub fn of<S: Serialize>(key: impl Into<String>, value: S) -> Self {
248        KeyValuePair::new(key.into(), json!(value))
249    }
250}
251
252#[derive(Debug, Clone, PartialEq, Eq)]
253pub struct TypedKeyValuePair<T: DeserializeOwned> {
254    pub key: Key,
255    pub value: T,
256}
257
258impl<T: DeserializeOwned> TryFrom<KeyValuePair> for TypedKeyValuePair<T> {
259    type Error = serde_json::Error;
260
261    fn try_from(kvp: KeyValuePair) -> Result<Self, Self::Error> {
262        let deserialized = serde_json::from_value(kvp.value)?;
263        Ok(TypedKeyValuePair {
264            key: kvp.key,
265            value: deserialized,
266        })
267    }
268}
269
270impl<S: Serialize> From<(String, S)> for KeyValuePair {
271    fn from((key, value): (String, S)) -> Self {
272        let value = json!(value);
273        KeyValuePair { key, value }
274    }
275}
276
277impl<S: Serialize> From<(&str, S)> for KeyValuePair {
278    fn from((key, value): (&str, S)) -> Self {
279        let value = json!(value);
280        KeyValuePair {
281            key: key.to_owned(),
282            value,
283        }
284    }
285}
286
287// #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord, Tags)]
288pub type RegularKeySegment = String;
289
290pub fn parse_segments(pattern: &str) -> WorterbuchResult<Vec<RegularKeySegment>> {
291    let mut segments = Vec::new();
292    for segment in pattern.split('/') {
293        let ks: KeySegment = segment.into();
294        match ks {
295            KeySegment::Regular(reg) => segments.push(reg),
296            KeySegment::Wildcard => {
297                return Err(error::WorterbuchError::IllegalWildcard(pattern.to_owned()));
298            }
299            KeySegment::MultiWildcard => {
300                return Err(error::WorterbuchError::IllegalMultiWildcard(
301                    pattern.to_owned(),
302                ));
303            }
304        }
305    }
306    Ok(segments)
307}
308
309#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
310pub enum KeySegment {
311    Regular(RegularKeySegment),
312    Wildcard,
313    MultiWildcard,
314    // RegexWildcard(String),
315}
316
317pub fn format_path(path: &[KeySegment]) -> String {
318    path.iter()
319        .map(|seg| format!("{seg}"))
320        .collect::<Vec<String>>()
321        .join("/")
322}
323
324impl From<RegularKeySegment> for KeySegment {
325    fn from(reg: RegularKeySegment) -> Self {
326        Self::Regular(reg)
327    }
328}
329
330impl Deref for KeySegment {
331    type Target = str;
332
333    fn deref(&self) -> &Self::Target {
334        match self {
335            KeySegment::Regular(reg) => reg,
336            KeySegment::Wildcard => "?",
337            KeySegment::MultiWildcard => "#",
338        }
339    }
340}
341
342impl fmt::Display for KeySegment {
343    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
344        match self {
345            KeySegment::Regular(segment) => segment.fmt(f),
346            KeySegment::Wildcard => write!(f, "?"),
347            KeySegment::MultiWildcard => write!(f, "#"),
348            // PathSegment::RegexWildcard(regex) => write!(f, "?{regex}?"),
349        }
350    }
351}
352
353impl From<&str> for KeySegment {
354    fn from(str: &str) -> Self {
355        match str {
356            "?" => KeySegment::Wildcard,
357            "#" => KeySegment::MultiWildcard,
358            other => KeySegment::Regular(other.to_owned()),
359        }
360    }
361}
362
363impl KeySegment {
364    pub fn parse(pattern: impl AsRef<str>) -> Vec<KeySegment> {
365        let segments = pattern.as_ref().split('/');
366        segments.map(KeySegment::from).collect()
367    }
368}
369
370pub fn quote(str: impl AsRef<str>) -> String {
371    let str_ref = str.as_ref();
372    if str_ref.starts_with('\"') && str_ref.ends_with('\"') {
373        str_ref.to_owned()
374    } else {
375        format!("\"{str_ref}\"")
376    }
377}
378
379pub fn digest_token(auth_token: &Option<String>, client_id: String) -> Option<String> {
380    auth_token.as_deref().map(|token| {
381        let salted = client_id + token;
382        let mut hasher = Sha256::new();
383        hasher.update(salted.as_bytes());
384        format!("{:x}", hasher.finalize())
385    })
386}
387
388#[cfg(test)]
389mod test {
390    use crate::{ErrorCode, ProtocolVersion};
391    use serde_json::json;
392
393    #[test]
394    fn protocol_versions_are_sorted_correctly() {
395        assert!(ProtocolVersion::new(1, 2) < ProtocolVersion::new(3, 2));
396        assert!(ProtocolVersion::new(1, 2) == ProtocolVersion::new(1, 2));
397        assert!(ProtocolVersion::new(2, 1) > ProtocolVersion::new(1, 9));
398
399        let mut versions = vec![
400            ProtocolVersion::new(1, 2),
401            ProtocolVersion::new(0, 456),
402            ProtocolVersion::new(9, 0),
403            ProtocolVersion::new(3, 15),
404        ];
405        versions.sort();
406        assert_eq!(
407            vec![
408                ProtocolVersion::new(0, 456),
409                ProtocolVersion::new(1, 2),
410                ProtocolVersion::new(3, 15),
411                ProtocolVersion::new(9, 0)
412            ],
413            versions
414        );
415    }
416
417    #[test]
418    fn topic_macro_generates_topic_correctly() {
419        assert_eq!(
420            "hello/world/foo/bar",
421            topic!("hello", "world", "foo", "bar")
422        );
423    }
424
425    #[test]
426    fn error_codes_are_serialized_as_numbers() {
427        assert_eq!(
428            "1",
429            serde_json::to_string(&ErrorCode::IllegalMultiWildcard).unwrap()
430        )
431    }
432
433    #[test]
434    fn error_codes_are_deserialized_from_numbers() {
435        assert_eq!(
436            ErrorCode::ProtocolNegotiationFailed,
437            serde_json::from_str("7").unwrap()
438        )
439    }
440
441    #[test]
442    fn protocol_version_get_serialized_correctly() {
443        assert_eq!(&json!(ProtocolVersion::new(2, 1)).to_string(), "[2,1]")
444    }
445
446    #[test]
447    fn protocol_version_get_formatted_correctly() {
448        assert_eq!(&ProtocolVersion::new(2, 1).to_string(), "2.1")
449    }
450
451    #[test]
452    fn compatible_version_is_selected_correctly() {
453        let client_version = ProtocolVersion::new(1, 2);
454        let server_versions = [
455            ProtocolVersion::new(0, 11),
456            ProtocolVersion::new(1, 6),
457            ProtocolVersion::new(2, 0),
458        ];
459        let compatible_version = server_versions
460            .iter()
461            .find(|v| client_version.is_compatible_with_server(v));
462        assert_eq!(compatible_version, Some(&server_versions[1]))
463    }
464}