worterbuch_common/
lib.rs

1/*
2 *  Worterbuch common modules library
3 *
4 *  Copyright (C) 2024 Michael Bachmann
5 *
6 *  This program is free software: you can redistribute it and/or modify
7 *  it under the terms of the GNU Affero General Public License as published by
8 *  the Free Software Foundation, either version 3 of the License, or
9 *  (at your option) any later version.
10 *
11 *  This program is distributed in the hope that it will be useful,
12 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
13 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 *  GNU Affero General Public License for more details.
15 *
16 *  You should have received a copy of the GNU Affero General Public License
17 *  along with this program.  If not, see <https://www.gnu.org/licenses/>.
18 */
19
20#[cfg(feature = "benchmark")]
21pub mod benchmark;
22mod client;
23pub mod error;
24mod server;
25
26pub use client::*;
27use serde_json::json;
28pub use server::*;
29
30use error::WorterbuchResult;
31use serde::{Deserialize, Serialize, de::DeserializeOwned};
32use serde_repr::*;
33use sha2::{Digest, Sha256};
34use std::{fmt, ops::Deref};
35#[cfg(feature = "jemalloc")]
36mod jemalloc;
37#[cfg(feature = "jemalloc")]
38pub mod profiling;
39
40pub const SYSTEM_TOPIC_ROOT: &str = "$SYS";
41pub const SYSTEM_TOPIC_ROOT_PREFIX: &str = "$SYS/";
42pub const SYSTEM_TOPIC_CLIENTS: &str = "clients";
43pub const SYSTEM_TOPIC_VERSION: &str = "version";
44pub const SYSTEM_TOPIC_LICENSE: &str = "license";
45pub const SYSTEM_TOPIC_SOURCES: &str = "source-code";
46pub const SYSTEM_TOPIC_SUBSCRIPTIONS: &str = "subscriptions";
47pub const SYSTEM_TOPIC_CLIENTS_PROTOCOL: &str = "protocol";
48pub const SYSTEM_TOPIC_CLIENTS_PROTOCOL_VERSION: &str = "protocolVersion";
49pub const SYSTEM_TOPIC_CLIENTS_ADDRESS: &str = "address";
50pub const SYSTEM_TOPIC_CLIENTS_TIMESTAMP: &str = "connectedSince";
51pub const SYSTEM_TOPIC_LAST_WILL: &str = "lastWill";
52pub const SYSTEM_TOPIC_GRAVE_GOODS: &str = "graveGoods";
53pub const SYSTEM_TOPIC_CLIENT_NAME: &str = "clientName";
54pub const SYSTEM_TOPIC_SUPPORTED_PROTOCOL_VERSION: &str = "protocolVersion";
55pub const SYSTEM_TOPIC_MODE: &str = "mode";
56
57pub type TransactionId = u64;
58pub type RequestPattern = String;
59pub type RequestPatterns = Vec<RequestPattern>;
60pub type Key = String;
61pub type Value = serde_json::Value;
62pub type KeyValuePairs = Vec<KeyValuePair>;
63pub type TypedKeyValuePairs<T> = Vec<TypedKeyValuePair<T>>;
64pub type MetaData = String;
65pub type Path = String;
66pub type ProtocolVersionSegment = u32;
67pub type ProtocolMajorVersion = ProtocolVersionSegment;
68pub type ProtocolVersions = Vec<ProtocolVersion>;
69pub type LastWill = KeyValuePairs;
70pub type GraveGoods = RequestPatterns;
71pub type UniqueFlag = bool;
72pub type LiveOnlyFlag = bool;
73pub type AuthToken = String;
74pub type CasVersion = u64;
75
76#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
77#[serde(rename_all = "camelCase")]
78pub enum Privilege {
79    Read,
80    Write,
81    Delete,
82    Profile,
83    WebLogin,
84}
85
86impl fmt::Display for Privilege {
87    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88        match self {
89            Privilege::Read => "read".fmt(f),
90            Privilege::Write => "write".fmt(f),
91            Privilege::Delete => "delete".fmt(f),
92            Privilege::Profile => "profile".fmt(f),
93            Privilege::WebLogin => "web-login".fmt(f),
94        }
95    }
96}
97
98#[derive(Debug, Serialize, Deserialize)]
99#[serde(rename_all = "camelCase")]
100pub enum AuthCheck<'a> {
101    Pattern(&'a str),
102    Flag,
103}
104
105#[derive(Debug, Clone, Serialize, Deserialize)]
106#[serde(rename_all = "camelCase")]
107pub enum AuthCheckOwned {
108    Pattern(String),
109    Flag,
110}
111
112impl<'a> From<AuthCheck<'a>> for AuthCheckOwned {
113    fn from(value: AuthCheck<'a>) -> Self {
114        match value {
115            AuthCheck::Pattern(p) => AuthCheckOwned::Pattern(p.to_owned()),
116            AuthCheck::Flag => AuthCheckOwned::Flag,
117        }
118    }
119}
120
121impl fmt::Display for AuthCheckOwned {
122    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
123        match self {
124            AuthCheckOwned::Pattern(p) => p.fmt(f),
125            AuthCheckOwned::Flag => true.fmt(f),
126        }
127    }
128}
129
130#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize_repr, Deserialize_repr)]
131#[repr(u8)]
132pub enum ErrorCode {
133    IllegalWildcard = 0b00000000,
134    IllegalMultiWildcard = 0b00000001,
135    MultiWildcardAtIllegalPosition = 0b00000010,
136    IoError = 0b00000011,
137    SerdeError = 0b00000100,
138    NoSuchValue = 0b00000101,
139    NotSubscribed = 0b00000110,
140    ProtocolNegotiationFailed = 0b00000111,
141    InvalidServerResponse = 0b00001000,
142    ReadOnlyKey = 0b00001001,
143    AuthorizationFailed = 0b00001010,
144    AuthorizationRequired = 0b00001011,
145    AlreadyAuthorized = 0b00001100,
146    MissingValue = 0b00001101,
147    Unauthorized = 0b00001110,
148    NoPubStream = 0b00001111,
149    NotLeader = 0b00010000,
150    Cas = 0b00010001,
151    CasVersionMismatch = 0b00010010,
152    NotImplemented = 0b00010011,
153    KeyIsLocked = 0b00010100,
154    KeyIsNotLocked = 0b00010101,
155    LockAcquisitionCancelled = 0b00010110,
156    FeatureDisabled = 0b00010111,
157    Other = 0b11111111,
158}
159
160impl fmt::Display for ErrorCode {
161    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
162        (self.to_owned() as u8).fmt(f)
163    }
164}
165
166#[macro_export]
167macro_rules! topic {
168    ($( $x:expr ),+ ) => {
169        {
170            let mut segments = Vec::new();
171            $(
172                segments.push($x.to_string());
173            )+
174            segments.join("/")
175        }
176    };
177}
178
179pub type Version = String;
180
181#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
182pub struct ProtocolVersion(ProtocolVersionSegment, ProtocolVersionSegment);
183
184impl ProtocolVersion {
185    pub const fn new(major: ProtocolVersionSegment, minor: ProtocolVersionSegment) -> Self {
186        Self(major, minor)
187    }
188
189    pub const fn major(&self) -> ProtocolVersionSegment {
190        self.0
191    }
192
193    pub const fn minor(&self) -> ProtocolVersionSegment {
194        self.1
195    }
196
197    pub fn is_compatible_with_server(&self, server_version: &ProtocolVersion) -> bool {
198        self.major() == server_version.major() && self.minor() <= server_version.minor()
199    }
200
201    pub fn is_compatible_with_client_version(&self, client_version: &ProtocolVersion) -> bool {
202        self.major() == client_version.major() && self.minor() >= client_version.minor()
203    }
204}
205
206impl fmt::Display for ProtocolVersion {
207    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
208        write!(f, "{}.{}", self.0, self.1)
209    }
210}
211
212#[derive(Debug, Clone, PartialEq, Eq, Serialize, Hash, Deserialize)]
213pub enum Protocol {
214    TCP,
215    WS,
216    HTTP,
217    UNIX,
218}
219
220#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
221#[serde(rename_all = "camelCase")]
222pub struct KeyValuePair {
223    pub key: Key,
224    pub value: Value,
225}
226
227impl fmt::Display for KeyValuePair {
228    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
229        write!(f, "{}={}", self.key, self.value)
230    }
231}
232
233impl From<KeyValuePair> for Option<Value> {
234    fn from(kvp: KeyValuePair) -> Self {
235        Some(kvp.value)
236    }
237}
238
239impl From<KeyValuePair> for Value {
240    fn from(kvp: KeyValuePair) -> Self {
241        kvp.value
242    }
243}
244
245impl KeyValuePair {
246    pub fn new(key: String, value: Value) -> Self {
247        KeyValuePair { key, value }
248    }
249
250    pub fn of<S: Serialize>(key: impl Into<String>, value: S) -> Self {
251        KeyValuePair::new(key.into(), json!(value))
252    }
253}
254
255#[derive(Debug, Clone, PartialEq, Eq)]
256pub struct TypedKeyValuePair<T: DeserializeOwned> {
257    pub key: Key,
258    pub value: T,
259}
260
261impl<T: DeserializeOwned> TryFrom<KeyValuePair> for TypedKeyValuePair<T> {
262    type Error = serde_json::Error;
263
264    fn try_from(kvp: KeyValuePair) -> Result<Self, Self::Error> {
265        let deserialized = serde_json::from_value(kvp.value)?;
266        Ok(TypedKeyValuePair {
267            key: kvp.key,
268            value: deserialized,
269        })
270    }
271}
272
273impl<S: Serialize> From<(String, S)> for KeyValuePair {
274    fn from((key, value): (String, S)) -> Self {
275        let value = json!(value);
276        KeyValuePair { key, value }
277    }
278}
279
280impl<S: Serialize> From<(&str, S)> for KeyValuePair {
281    fn from((key, value): (&str, S)) -> Self {
282        let value = json!(value);
283        KeyValuePair {
284            key: key.to_owned(),
285            value,
286        }
287    }
288}
289
290// #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord, Tags)]
291pub type RegularKeySegment = String;
292
293pub fn parse_segments(pattern: &str) -> WorterbuchResult<Vec<RegularKeySegment>> {
294    let mut segments = Vec::new();
295    for segment in pattern.split('/') {
296        let ks: KeySegment = segment.into();
297        match ks {
298            KeySegment::Regular(reg) => segments.push(reg),
299            KeySegment::Wildcard => {
300                return Err(error::WorterbuchError::IllegalWildcard(pattern.to_owned()));
301            }
302            KeySegment::MultiWildcard => {
303                return Err(error::WorterbuchError::IllegalMultiWildcard(
304                    pattern.to_owned(),
305                ));
306            }
307        }
308    }
309    Ok(segments)
310}
311
312#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
313pub enum KeySegment {
314    Regular(RegularKeySegment),
315    Wildcard,
316    MultiWildcard,
317    // RegexWildcard(String),
318}
319
320pub fn format_path(path: &[KeySegment]) -> String {
321    path.iter()
322        .map(|seg| format!("{seg}"))
323        .collect::<Vec<String>>()
324        .join("/")
325}
326
327impl From<RegularKeySegment> for KeySegment {
328    fn from(reg: RegularKeySegment) -> Self {
329        Self::Regular(reg)
330    }
331}
332
333impl Deref for KeySegment {
334    type Target = str;
335
336    fn deref(&self) -> &Self::Target {
337        match self {
338            KeySegment::Regular(reg) => reg,
339            KeySegment::Wildcard => "?",
340            KeySegment::MultiWildcard => "#",
341        }
342    }
343}
344
345impl fmt::Display for KeySegment {
346    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
347        match self {
348            KeySegment::Regular(segment) => segment.fmt(f),
349            KeySegment::Wildcard => write!(f, "?"),
350            KeySegment::MultiWildcard => write!(f, "#"),
351            // PathSegment::RegexWildcard(regex) => write!(f, "?{regex}?"),
352        }
353    }
354}
355
356impl From<&str> for KeySegment {
357    fn from(str: &str) -> Self {
358        match str {
359            "?" => KeySegment::Wildcard,
360            "#" => KeySegment::MultiWildcard,
361            other => KeySegment::Regular(other.to_owned()),
362        }
363    }
364}
365
366impl KeySegment {
367    pub fn parse(pattern: impl AsRef<str>) -> Vec<KeySegment> {
368        let segments = pattern.as_ref().split('/');
369        segments.map(KeySegment::from).collect()
370    }
371}
372
373pub fn quote(str: impl AsRef<str>) -> String {
374    let str_ref = str.as_ref();
375    if str_ref.starts_with('\"') && str_ref.ends_with('\"') {
376        str_ref.to_owned()
377    } else {
378        format!("\"{str_ref}\"")
379    }
380}
381
382pub fn digest_token(auth_token: &Option<String>, client_id: String) -> Option<String> {
383    auth_token.as_deref().map(|token| {
384        let salted = client_id + token;
385        let mut hasher = Sha256::new();
386        hasher.update(salted.as_bytes());
387        format!("{:x}", hasher.finalize())
388    })
389}
390
391#[cfg(test)]
392mod test {
393    use crate::{ErrorCode, ProtocolVersion};
394    use serde_json::json;
395
396    #[test]
397    fn protocol_versions_are_sorted_correctly() {
398        assert!(ProtocolVersion::new(1, 2) < ProtocolVersion::new(3, 2));
399        assert!(ProtocolVersion::new(1, 2) == ProtocolVersion::new(1, 2));
400        assert!(ProtocolVersion::new(2, 1) > ProtocolVersion::new(1, 9));
401
402        let mut versions = vec![
403            ProtocolVersion::new(1, 2),
404            ProtocolVersion::new(0, 456),
405            ProtocolVersion::new(9, 0),
406            ProtocolVersion::new(3, 15),
407        ];
408        versions.sort();
409        assert_eq!(
410            vec![
411                ProtocolVersion::new(0, 456),
412                ProtocolVersion::new(1, 2),
413                ProtocolVersion::new(3, 15),
414                ProtocolVersion::new(9, 0)
415            ],
416            versions
417        );
418    }
419
420    #[test]
421    fn topic_macro_generates_topic_correctly() {
422        assert_eq!(
423            "hello/world/foo/bar",
424            topic!("hello", "world", "foo", "bar")
425        );
426    }
427
428    #[test]
429    fn error_codes_are_serialized_as_numbers() {
430        assert_eq!(
431            "1",
432            serde_json::to_string(&ErrorCode::IllegalMultiWildcard).unwrap()
433        )
434    }
435
436    #[test]
437    fn error_codes_are_deserialized_from_numbers() {
438        assert_eq!(
439            ErrorCode::ProtocolNegotiationFailed,
440            serde_json::from_str("7").unwrap()
441        )
442    }
443
444    #[test]
445    fn protocol_version_get_serialized_correctly() {
446        assert_eq!(&json!(ProtocolVersion::new(2, 1)).to_string(), "[2,1]")
447    }
448
449    #[test]
450    fn protocol_version_get_formatted_correctly() {
451        assert_eq!(&ProtocolVersion::new(2, 1).to_string(), "2.1")
452    }
453
454    #[test]
455    fn compatible_version_is_selected_correctly() {
456        let client_version = ProtocolVersion::new(1, 2);
457        let server_versions = [
458            ProtocolVersion::new(0, 11),
459            ProtocolVersion::new(1, 6),
460            ProtocolVersion::new(2, 0),
461        ];
462        let compatible_version = server_versions
463            .iter()
464            .find(|v| client_version.is_compatible_with_server(v));
465        assert_eq!(compatible_version, Some(&server_versions[1]))
466    }
467}