worterbuch_common/
lib.rs

1/*
2 *  Worterbuch common modules library
3 *
4 *  Copyright (C) 2024 Michael Bachmann
5 *
6 *  This program is free software: you can redistribute it and/or modify
7 *  it under the terms of the GNU Affero General Public License as published by
8 *  the Free Software Foundation, either version 3 of the License, or
9 *  (at your option) any later version.
10 *
11 *  This program is distributed in the hope that it will be useful,
12 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
13 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 *  GNU Affero General Public License for more details.
15 *
16 *  You should have received a copy of the GNU Affero General Public License
17 *  along with this program.  If not, see <https://www.gnu.org/licenses/>.
18 */
19
20#[cfg(feature = "benchmark")]
21pub mod benchmark;
22mod client;
23pub mod error;
24mod server;
25pub mod tcp;
26
27pub use client::*;
28use serde_json::json;
29pub use server::*;
30
31use error::WorterbuchResult;
32use serde::{Deserialize, Serialize, de::DeserializeOwned};
33use serde_repr::*;
34use sha2::{Digest, Sha256};
35use std::{fmt, ops::Deref};
36
37pub const SYSTEM_TOPIC_ROOT: &str = "$SYS";
38pub const SYSTEM_TOPIC_ROOT_PREFIX: &str = "$SYS/";
39pub const SYSTEM_TOPIC_CLIENTS: &str = "clients";
40pub const SYSTEM_TOPIC_VERSION: &str = "version";
41pub const SYSTEM_TOPIC_LICENSE: &str = "license";
42pub const SYSTEM_TOPIC_SOURCES: &str = "source-code";
43pub const SYSTEM_TOPIC_SUBSCRIPTIONS: &str = "subscriptions";
44pub const SYSTEM_TOPIC_CLIENTS_PROTOCOL: &str = "protocol";
45pub const SYSTEM_TOPIC_CLIENTS_ADDRESS: &str = "address";
46pub const SYSTEM_TOPIC_LAST_WILL: &str = "lastWill";
47pub const SYSTEM_TOPIC_GRAVE_GOODS: &str = "graveGoods";
48pub const SYSTEM_TOPIC_CLIENT_NAME: &str = "clientName";
49pub const SYSTEM_TOPIC_SUPPORTED_PROTOCOL_VERSION: &str = "protocolVersion";
50pub const SYSTEM_TOPIC_MODE: &str = "mode";
51
52pub type TransactionId = u64;
53pub type RequestPattern = String;
54pub type RequestPatterns = Vec<RequestPattern>;
55pub type Key = String;
56pub type Value = serde_json::Value;
57pub type KeyValuePairs = Vec<KeyValuePair>;
58pub type TypedKeyValuePairs<T> = Vec<TypedKeyValuePair<T>>;
59pub type MetaData = String;
60pub type Path = String;
61pub type ProtocolVersionSegment = u32;
62pub type ProtocolVersions = Vec<ProtocolVersion>;
63pub type LastWill = KeyValuePairs;
64pub type GraveGoods = RequestPatterns;
65pub type UniqueFlag = bool;
66pub type LiveOnlyFlag = bool;
67pub type AuthToken = String;
68pub type CasVersion = u64;
69
70#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
71#[serde(rename_all = "camelCase")]
72pub enum Privilege {
73    Read,
74    Write,
75    Delete,
76}
77
78impl fmt::Display for Privilege {
79    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80        match self {
81            Privilege::Read => "read".fmt(f),
82            Privilege::Write => "write".fmt(f),
83            Privilege::Delete => "delete".fmt(f),
84        }
85    }
86}
87
88#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize_repr, Deserialize_repr)]
89#[repr(u8)]
90pub enum ErrorCode {
91    IllegalWildcard = 0b00000000,
92    IllegalMultiWildcard = 0b00000001,
93    MultiWildcardAtIllegalPosition = 0b00000010,
94    IoError = 0b00000011,
95    SerdeError = 0b00000100,
96    NoSuchValue = 0b00000101,
97    NotSubscribed = 0b00000110,
98    ProtocolNegotiationFailed = 0b00000111,
99    InvalidServerResponse = 0b00001000,
100    ReadOnlyKey = 0b00001001,
101    AuthorizationFailed = 0b00001010,
102    AuthorizationRequired = 0b00001011,
103    AlreadyAuthorized = 0b00001100,
104    MissingValue = 0b00001101,
105    Unauthorized = 0b00001110,
106    NoPubStream = 0b00001111,
107    NotLeader = 0b00010000,
108    Cas = 0b00010001,
109    CasVersionMismatch = 0b00010010,
110    NotImplemented = 0b00010011,
111    KeyIsLocked = 0b00010100,
112    KeyIsNotLocked = 0b00010101,
113    Other = 0b11111111,
114}
115
116impl fmt::Display for ErrorCode {
117    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
118        (self.to_owned() as u8).fmt(f)
119    }
120}
121
122#[macro_export]
123macro_rules! topic {
124    ($( $x:expr ),+ ) => {
125        {
126            let mut segments = Vec::new();
127            $(
128                segments.push($x.to_string());
129            )+
130            segments.join("/")
131        }
132    };
133}
134
135pub type Version = String;
136
137#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
138pub struct ProtocolVersion(ProtocolVersionSegment, ProtocolVersionSegment);
139
140impl ProtocolVersion {
141    pub const fn new(major: ProtocolVersionSegment, minor: ProtocolVersionSegment) -> Self {
142        Self(major, minor)
143    }
144
145    pub const fn major(&self) -> ProtocolVersionSegment {
146        self.0
147    }
148
149    pub const fn minor(&self) -> ProtocolVersionSegment {
150        self.1
151    }
152
153    pub fn is_compatible_with_server(&self, server_version: &ProtocolVersion) -> bool {
154        self.major() == server_version.major() && self.minor() <= server_version.minor()
155    }
156
157    pub fn is_compatible_with_client_version(&self, client_version: &ProtocolVersion) -> bool {
158        self.major() == client_version.major() && self.minor() >= client_version.minor()
159    }
160}
161
162impl fmt::Display for ProtocolVersion {
163    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
164        write!(f, "{}.{}", self.0, self.1)
165    }
166}
167
168#[derive(Debug, Clone, PartialEq, Eq, Serialize, Hash, Deserialize)]
169pub enum Protocol {
170    TCP,
171    WS,
172    HTTP,
173    UNIX,
174}
175
176#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
177#[serde(rename_all = "camelCase")]
178pub struct KeyValuePair {
179    pub key: Key,
180    pub value: Value,
181}
182
183impl fmt::Display for KeyValuePair {
184    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
185        write!(f, "{}={}", self.key, self.value)
186    }
187}
188
189impl From<KeyValuePair> for Option<Value> {
190    fn from(kvp: KeyValuePair) -> Self {
191        Some(kvp.value)
192    }
193}
194
195impl From<KeyValuePair> for Value {
196    fn from(kvp: KeyValuePair) -> Self {
197        kvp.value
198    }
199}
200
201impl KeyValuePair {
202    pub fn new<S: Serialize>(key: String, value: S) -> Self {
203        (key, value).into()
204    }
205
206    pub fn of<S: Serialize>(key: String, value: S) -> Self {
207        KeyValuePair::new(key, value)
208    }
209}
210
211#[derive(Debug, Clone, PartialEq, Eq)]
212pub struct TypedKeyValuePair<T: DeserializeOwned> {
213    pub key: Key,
214    pub value: T,
215}
216
217impl<T: DeserializeOwned> TryFrom<KeyValuePair> for TypedKeyValuePair<T> {
218    type Error = serde_json::Error;
219
220    fn try_from(kvp: KeyValuePair) -> Result<Self, Self::Error> {
221        let deserialized = serde_json::from_value(kvp.value)?;
222        Ok(TypedKeyValuePair {
223            key: kvp.key,
224            value: deserialized,
225        })
226    }
227}
228
229impl<S: Serialize> From<(String, S)> for KeyValuePair {
230    fn from((key, value): (String, S)) -> Self {
231        let value = json!(value);
232        KeyValuePair { key, value }
233    }
234}
235
236impl<S: Serialize> From<(&str, S)> for KeyValuePair {
237    fn from((key, value): (&str, S)) -> Self {
238        let value = json!(value);
239        KeyValuePair {
240            key: key.to_owned(),
241            value,
242        }
243    }
244}
245
246// #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord, Tags)]
247pub type RegularKeySegment = String;
248
249pub fn parse_segments(pattern: &str) -> WorterbuchResult<Vec<RegularKeySegment>> {
250    let mut segments = Vec::new();
251    for segment in pattern.split('/') {
252        let ks: KeySegment = segment.into();
253        match ks {
254            KeySegment::Regular(reg) => segments.push(reg),
255            KeySegment::Wildcard => {
256                return Err(error::WorterbuchError::IllegalWildcard(pattern.to_owned()));
257            }
258            KeySegment::MultiWildcard => {
259                return Err(error::WorterbuchError::IllegalMultiWildcard(
260                    pattern.to_owned(),
261                ));
262            }
263        }
264    }
265    Ok(segments)
266}
267
268#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
269pub enum KeySegment {
270    Regular(RegularKeySegment),
271    Wildcard,
272    MultiWildcard,
273    // RegexWildcard(String),
274}
275
276pub fn format_path(path: &[KeySegment]) -> String {
277    path.iter()
278        .map(|seg| format!("{seg}"))
279        .collect::<Vec<String>>()
280        .join("/")
281}
282
283impl From<RegularKeySegment> for KeySegment {
284    fn from(reg: RegularKeySegment) -> Self {
285        Self::Regular(reg)
286    }
287}
288
289impl Deref for KeySegment {
290    type Target = str;
291
292    fn deref(&self) -> &Self::Target {
293        match self {
294            KeySegment::Regular(reg) => reg,
295            KeySegment::Wildcard => "?",
296            KeySegment::MultiWildcard => "#",
297        }
298    }
299}
300
301impl fmt::Display for KeySegment {
302    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
303        match self {
304            KeySegment::Regular(segment) => segment.fmt(f),
305            KeySegment::Wildcard => write!(f, "?"),
306            KeySegment::MultiWildcard => write!(f, "#"),
307            // PathSegment::RegexWildcard(regex) => write!(f, "?{regex}?"),
308        }
309    }
310}
311
312impl From<&str> for KeySegment {
313    fn from(str: &str) -> Self {
314        match str {
315            "?" => KeySegment::Wildcard,
316            "#" => KeySegment::MultiWildcard,
317            other => KeySegment::Regular(other.to_owned()),
318        }
319    }
320}
321
322impl KeySegment {
323    pub fn parse(pattern: impl AsRef<str>) -> Vec<KeySegment> {
324        let segments = pattern.as_ref().split('/');
325        segments.map(KeySegment::from).collect()
326    }
327}
328
329pub fn quote(str: impl AsRef<str>) -> String {
330    let str_ref = str.as_ref();
331    if str_ref.starts_with('\"') && str_ref.ends_with('\"') {
332        str_ref.to_owned()
333    } else {
334        format!("\"{str_ref}\"")
335    }
336}
337
338pub fn digest_token(auth_token: &Option<String>, client_id: String) -> Option<String> {
339    auth_token.as_deref().map(|token| {
340        let salted = client_id + token;
341        let mut hasher = Sha256::new();
342        hasher.update(salted.as_bytes());
343        format!("{:x}", hasher.finalize())
344    })
345}
346
347#[cfg(test)]
348mod test {
349    use crate::{ErrorCode, ProtocolVersion};
350    use serde_json::json;
351
352    #[test]
353    fn protocol_versions_are_sorted_correctly() {
354        assert!(ProtocolVersion::new(1, 2) < ProtocolVersion::new(3, 2));
355        assert!(ProtocolVersion::new(1, 2) == ProtocolVersion::new(1, 2));
356        assert!(ProtocolVersion::new(2, 1) > ProtocolVersion::new(1, 9));
357
358        let mut versions = vec![
359            ProtocolVersion::new(1, 2),
360            ProtocolVersion::new(0, 456),
361            ProtocolVersion::new(9, 0),
362            ProtocolVersion::new(3, 15),
363        ];
364        versions.sort();
365        assert_eq!(
366            vec![
367                ProtocolVersion::new(0, 456),
368                ProtocolVersion::new(1, 2),
369                ProtocolVersion::new(3, 15),
370                ProtocolVersion::new(9, 0)
371            ],
372            versions
373        );
374    }
375
376    #[test]
377    fn topic_macro_generates_topic_correctly() {
378        assert_eq!(
379            "hello/world/foo/bar",
380            topic!("hello", "world", "foo", "bar")
381        );
382    }
383
384    #[test]
385    fn error_codes_are_serialized_as_numbers() {
386        assert_eq!(
387            "1",
388            serde_json::to_string(&ErrorCode::IllegalMultiWildcard).unwrap()
389        )
390    }
391
392    #[test]
393    fn error_codes_are_deserialized_from_numbers() {
394        assert_eq!(
395            ErrorCode::ProtocolNegotiationFailed,
396            serde_json::from_str("7").unwrap()
397        )
398    }
399
400    #[test]
401    fn protocol_version_get_serialized_correctly() {
402        assert_eq!(&json!(ProtocolVersion::new(2, 1)).to_string(), "[2,1]")
403    }
404
405    #[test]
406    fn protocol_version_get_formatted_correctly() {
407        assert_eq!(&ProtocolVersion::new(2, 1).to_string(), "2.1")
408    }
409
410    #[test]
411    fn compatible_version_is_selected_correctly() {
412        let client_version = ProtocolVersion::new(1, 2);
413        let server_versions = [
414            ProtocolVersion::new(0, 11),
415            ProtocolVersion::new(1, 6),
416            ProtocolVersion::new(2, 0),
417        ];
418        let compatible_version = server_versions
419            .iter()
420            .find(|v| client_version.is_compatible_with_server(v));
421        assert_eq!(compatible_version, Some(&server_versions[1]))
422    }
423}