worterbuch_common/
lib.rs

1/*
2 *  Worterbuch common modules library
3 *
4 *  Copyright (C) 2024 Michael Bachmann
5 *
6 *  This program is free software: you can redistribute it and/or modify
7 *  it under the terms of the GNU Affero General Public License as published by
8 *  the Free Software Foundation, either version 3 of the License, or
9 *  (at your option) any later version.
10 *
11 *  This program is distributed in the hope that it will be useful,
12 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
13 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 *  GNU Affero General Public License for more details.
15 *
16 *  You should have received a copy of the GNU Affero General Public License
17 *  along with this program.  If not, see <https://www.gnu.org/licenses/>.
18 */
19
20#[cfg(feature = "benchmark")]
21pub mod benchmark;
22mod client;
23pub mod error;
24mod server;
25
26pub use client::*;
27use serde_json::json;
28pub use server::*;
29
30use error::WorterbuchResult;
31use serde::{Deserialize, Serialize, de::DeserializeOwned};
32use serde_repr::*;
33use sha2::{Digest, Sha256};
34use std::{fmt, net::SocketAddr, ops::Deref};
35use tokio::sync::{mpsc, oneshot};
36use tracing::Span;
37use uuid::Uuid;
38#[cfg(feature = "jemalloc")]
39mod jemalloc;
40#[cfg(feature = "jemalloc")]
41pub mod profiling;
42
43pub const INTERNAL_CLIENT_ID: Uuid = Uuid::nil();
44
45pub const SYSTEM_TOPIC_ROOT: &str = "$SYS";
46pub const SYSTEM_TOPIC_ROOT_PREFIX: &str = "$SYS/";
47pub const SYSTEM_TOPIC_CLIENTS: &str = "clients";
48pub const SYSTEM_TOPIC_VERSION: &str = "version";
49pub const SYSTEM_TOPIC_LICENSE: &str = "license";
50pub const SYSTEM_TOPIC_SOURCES: &str = "source-code";
51pub const SYSTEM_TOPIC_SUBSCRIPTIONS: &str = "subscriptions";
52pub const SYSTEM_TOPIC_LOCKS: &str = "locks";
53pub const SYSTEM_TOPIC_CLIENTS_PROTOCOL: &str = "protocol";
54pub const SYSTEM_TOPIC_CLIENTS_PROTOCOL_VERSION: &str = "protocolVersion";
55pub const SYSTEM_TOPIC_CLIENTS_ADDRESS: &str = "address";
56pub const SYSTEM_TOPIC_CLIENTS_TIMESTAMP: &str = "connectedSince";
57pub const SYSTEM_TOPIC_LAST_WILL: &str = "lastWill";
58pub const SYSTEM_TOPIC_GRAVE_GOODS: &str = "graveGoods";
59pub const SYSTEM_TOPIC_CLIENT_NAME: &str = "clientName";
60pub const SYSTEM_TOPIC_SUPPORTED_PROTOCOL_VERSION: &str = "protocolVersion";
61pub const SYSTEM_TOPIC_MODE: &str = "mode";
62
63pub type TransactionId = u64;
64pub type RequestPattern = String;
65pub type RequestPatterns = Vec<RequestPattern>;
66pub type Key = String;
67pub type Value = serde_json::Value;
68pub type KeyValuePairs = Vec<KeyValuePair>;
69pub type TypedKeyValuePairs<T> = Vec<TypedKeyValuePair<T>>;
70pub type MetaData = String;
71pub type Path = String;
72pub type ProtocolVersionSegment = u32;
73pub type ProtocolMajorVersion = ProtocolVersionSegment;
74pub type ProtocolVersions = Vec<ProtocolVersion>;
75pub type LastWill = KeyValuePairs;
76pub type GraveGoods = RequestPatterns;
77pub type UniqueFlag = bool;
78pub type LiveOnlyFlag = bool;
79pub type AuthToken = String;
80pub type AuthTokenKey = String;
81pub type CasVersion = u64;
82
83#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
84pub enum ValueEntry {
85    Cas(Value, u64),
86    #[serde(untagged)]
87    Plain(Value),
88}
89
90impl AsRef<Value> for ValueEntry {
91    fn as_ref(&self) -> &Value {
92        match self {
93            ValueEntry::Plain(value) => value,
94            ValueEntry::Cas(value, _) => value,
95        }
96    }
97}
98
99impl From<ValueEntry> for Value {
100    fn from(value: ValueEntry) -> Self {
101        match value {
102            ValueEntry::Plain(value) => value,
103            ValueEntry::Cas(value, _) => value,
104        }
105    }
106}
107
108impl From<Value> for ValueEntry {
109    fn from(value: Value) -> Self {
110        ValueEntry::Plain(value)
111    }
112}
113
114#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
115#[serde(rename_all = "camelCase")]
116pub enum Privilege {
117    Read,
118    Write,
119    Delete,
120    Profile,
121    WebLogin,
122}
123
124impl fmt::Display for Privilege {
125    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126        match self {
127            Privilege::Read => "read".fmt(f),
128            Privilege::Write => "write".fmt(f),
129            Privilege::Delete => "delete".fmt(f),
130            Privilege::Profile => "profile".fmt(f),
131            Privilege::WebLogin => "web-login".fmt(f),
132        }
133    }
134}
135
136#[derive(Debug, Serialize, Deserialize)]
137#[serde(rename_all = "camelCase")]
138pub enum AuthCheck<'a> {
139    Pattern(&'a str),
140    Flag,
141}
142
143#[derive(Debug, Clone, Serialize, Deserialize)]
144#[serde(rename_all = "camelCase")]
145pub enum AuthCheckOwned {
146    Pattern(String),
147    Flag,
148}
149
150impl<'a> From<AuthCheck<'a>> for AuthCheckOwned {
151    fn from(value: AuthCheck<'a>) -> Self {
152        match value {
153            AuthCheck::Pattern(p) => AuthCheckOwned::Pattern(p.to_owned()),
154            AuthCheck::Flag => AuthCheckOwned::Flag,
155        }
156    }
157}
158
159impl fmt::Display for AuthCheckOwned {
160    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
161        match self {
162            AuthCheckOwned::Pattern(p) => p.fmt(f),
163            AuthCheckOwned::Flag => true.fmt(f),
164        }
165    }
166}
167
168#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize_repr, Deserialize_repr)]
169#[repr(u8)]
170pub enum ErrorCode {
171    IllegalWildcard = 0,
172    IllegalMultiWildcard = 1,
173    MultiWildcardAtIllegalPosition = 2,
174    IoError = 3,
175    SerdeError = 4,
176    NoSuchValue = 5,
177    NotSubscribed = 6,
178    ProtocolNegotiationFailed = 7,
179    InvalidServerResponse = 8,
180    ReadOnlyKey = 9,
181    AuthorizationFailed = 10,
182    AuthorizationRequired = 11,
183    AlreadyAuthorized = 12,
184    MissingValue = 13,
185    Unauthorized = 14,
186    NoPubStream = 15,
187    NotLeader = 16,
188    Cas = 17,
189    CasVersionMismatch = 18,
190    NotImplemented = 19,
191    KeyIsLocked = 20,
192    KeyIsNotLocked = 21,
193    LockAcquisitionCancelled = 22,
194    FeatureDisabled = 23,
195    ClientIDCollision = 24,
196    Other = u8::MAX,
197}
198
199impl fmt::Display for ErrorCode {
200    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201        (self.to_owned() as u8).fmt(f)
202    }
203}
204
205#[macro_export]
206macro_rules! topic {
207    ($( $x:expr ),+ ) => {
208        {
209            let mut segments = Vec::new();
210            $(
211                segments.push($x.to_string());
212            )+
213            segments.join("/")
214        }
215    };
216}
217
218pub type Version = String;
219
220#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
221pub struct ProtocolVersion(ProtocolVersionSegment, ProtocolVersionSegment);
222
223impl ProtocolVersion {
224    pub const fn new(major: ProtocolVersionSegment, minor: ProtocolVersionSegment) -> Self {
225        Self(major, minor)
226    }
227
228    pub const fn major(&self) -> ProtocolVersionSegment {
229        self.0
230    }
231
232    pub const fn minor(&self) -> ProtocolVersionSegment {
233        self.1
234    }
235
236    pub fn is_compatible_with_server(&self, server_version: &ProtocolVersion) -> bool {
237        self.major() == server_version.major() && self.minor() <= server_version.minor()
238    }
239
240    pub fn is_compatible_with_client_version(&self, client_version: &ProtocolVersion) -> bool {
241        self.major() == client_version.major() && self.minor() >= client_version.minor()
242    }
243}
244
245impl fmt::Display for ProtocolVersion {
246    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
247        write!(f, "{}.{}", self.0, self.1)
248    }
249}
250
251#[derive(Debug, Clone, PartialEq, Eq, Serialize, Hash, Deserialize)]
252pub enum Protocol {
253    TCP,
254    WS,
255    HTTP,
256    UNIX,
257}
258
259#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
260#[serde(rename_all = "camelCase")]
261pub struct KeyValuePair {
262    pub key: Key,
263    pub value: Value,
264}
265
266impl fmt::Display for KeyValuePair {
267    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
268        write!(f, "{}={}", self.key, self.value)
269    }
270}
271
272impl From<KeyValuePair> for Option<Value> {
273    fn from(kvp: KeyValuePair) -> Self {
274        Some(kvp.value)
275    }
276}
277
278impl From<KeyValuePair> for Value {
279    fn from(kvp: KeyValuePair) -> Self {
280        kvp.value
281    }
282}
283
284impl KeyValuePair {
285    pub fn new(key: String, value: Value) -> Self {
286        KeyValuePair { key, value }
287    }
288
289    pub fn of<S: Serialize>(key: impl Into<String>, value: S) -> Self {
290        KeyValuePair::new(key.into(), json!(value))
291    }
292}
293
294#[derive(Debug, Clone, PartialEq, Eq)]
295pub struct TypedKeyValuePair<T: DeserializeOwned> {
296    pub key: Key,
297    pub value: T,
298}
299
300impl<T: DeserializeOwned> TryFrom<KeyValuePair> for TypedKeyValuePair<T> {
301    type Error = serde_json::Error;
302
303    fn try_from(kvp: KeyValuePair) -> Result<Self, Self::Error> {
304        let deserialized = serde_json::from_value(kvp.value)?;
305        Ok(TypedKeyValuePair {
306            key: kvp.key,
307            value: deserialized,
308        })
309    }
310}
311
312impl<S: Serialize> From<(String, S)> for KeyValuePair {
313    fn from((key, value): (String, S)) -> Self {
314        let value = json!(value);
315        KeyValuePair { key, value }
316    }
317}
318
319impl<S: Serialize> From<(&str, S)> for KeyValuePair {
320    fn from((key, value): (&str, S)) -> Self {
321        let value = json!(value);
322        KeyValuePair {
323            key: key.to_owned(),
324            value,
325        }
326    }
327}
328
329// #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord, Tags)]
330pub type RegularKeySegment = String;
331
332pub fn parse_segments(pattern: &str) -> WorterbuchResult<Vec<RegularKeySegment>> {
333    let mut segments = Vec::new();
334    for segment in pattern.split('/') {
335        let ks: KeySegment = segment.into();
336        match ks {
337            KeySegment::Regular(reg) => segments.push(reg),
338            KeySegment::Wildcard => {
339                return Err(error::WorterbuchError::IllegalWildcard(pattern.to_owned()));
340            }
341            KeySegment::MultiWildcard => {
342                return Err(error::WorterbuchError::IllegalMultiWildcard(
343                    pattern.to_owned(),
344                ));
345            }
346        }
347    }
348    Ok(segments)
349}
350
351#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
352pub enum KeySegment {
353    Regular(RegularKeySegment),
354    Wildcard,
355    MultiWildcard,
356    // RegexWildcard(String),
357}
358
359impl AsRef<str> for KeySegment {
360    fn as_ref(&self) -> &str {
361        match self {
362            KeySegment::Regular(segment) => segment.as_str(),
363            KeySegment::Wildcard => "?",
364            KeySegment::MultiWildcard => "#",
365        }
366    }
367}
368
369pub fn format_path(path: &[impl AsRef<str>]) -> String {
370    let mut path = path.iter().fold(String::new(), |mut a, b| {
371        let b = b.as_ref();
372        a.reserve(b.len() + 1);
373        a.push_str(b);
374        a.push('/');
375        a
376    });
377    path.pop();
378    path
379}
380
381impl From<RegularKeySegment> for KeySegment {
382    fn from(reg: RegularKeySegment) -> Self {
383        Self::Regular(reg)
384    }
385}
386
387impl Deref for KeySegment {
388    type Target = str;
389
390    fn deref(&self) -> &Self::Target {
391        match self {
392            KeySegment::Regular(reg) => reg,
393            KeySegment::Wildcard => "?",
394            KeySegment::MultiWildcard => "#",
395        }
396    }
397}
398
399impl fmt::Display for KeySegment {
400    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
401        match self {
402            KeySegment::Regular(segment) => segment.fmt(f),
403            KeySegment::Wildcard => write!(f, "?"),
404            KeySegment::MultiWildcard => write!(f, "#"),
405            // PathSegment::RegexWildcard(regex) => write!(f, "?{regex}?"),
406        }
407    }
408}
409
410impl From<&str> for KeySegment {
411    fn from(str: &str) -> Self {
412        match str {
413            "?" => KeySegment::Wildcard,
414            "#" => KeySegment::MultiWildcard,
415            other => KeySegment::Regular(other.to_owned()),
416        }
417    }
418}
419
420impl KeySegment {
421    pub fn parse(pattern: impl AsRef<str>) -> Vec<KeySegment> {
422        let segments = pattern.as_ref().split('/');
423        segments.map(KeySegment::from).collect()
424    }
425}
426
427pub fn quote(str: impl AsRef<str>) -> String {
428    let str_ref = str.as_ref();
429    if str_ref.starts_with('\"') && str_ref.ends_with('\"') {
430        str_ref.to_owned()
431    } else {
432        format!("\"{str_ref}\"")
433    }
434}
435
436pub fn digest_token(auth_token: &Option<String>, client_id: String) -> Option<String> {
437    auth_token.as_deref().map(|token| {
438        let salted = client_id + token;
439        let mut hasher = Sha256::new();
440        hasher.update(salted.as_bytes());
441        format!("{:x}", hasher.finalize())
442    })
443}
444
445#[derive(Debug, Clone, Eq, PartialEq, Hash)]
446pub struct SubscriptionId {
447    pub client_id: Uuid,
448    pub transaction_id: TransactionId,
449}
450
451impl SubscriptionId {
452    pub fn new(client_id: Uuid, transaction_id: TransactionId) -> Self {
453        SubscriptionId {
454            client_id,
455            transaction_id,
456        }
457    }
458}
459
460pub trait WbApi {
461    fn supported_protocol_versions(&self) -> Vec<ProtocolVersion>;
462
463    fn version(&self) -> &str;
464
465    fn get(&self, key: Key) -> impl Future<Output = WorterbuchResult<Value>> + Send;
466
467    fn cget(&self, key: Key) -> impl Future<Output = WorterbuchResult<(Value, CasVersion)>> + Send;
468
469    fn pget(
470        &self,
471        pattern: RequestPattern,
472    ) -> impl Future<Output = WorterbuchResult<KeyValuePairs>> + Send;
473
474    fn set(
475        &self,
476        key: Key,
477        value: Value,
478        client_id: Uuid,
479    ) -> impl Future<Output = WorterbuchResult<()>> + Send;
480
481    fn cset(
482        &self,
483        key: Key,
484        value: Value,
485        version: CasVersion,
486        client_id: Uuid,
487    ) -> impl Future<Output = WorterbuchResult<()>> + Send;
488
489    fn lock(&self, key: Key, client_id: Uuid) -> impl Future<Output = WorterbuchResult<()>> + Send;
490
491    fn acquire_lock(
492        &self,
493        key: Key,
494        client_id: Uuid,
495    ) -> impl Future<Output = WorterbuchResult<oneshot::Receiver<()>>> + Send;
496
497    fn release_lock(
498        &self,
499        key: Key,
500        client_id: Uuid,
501    ) -> impl Future<Output = WorterbuchResult<()>> + Send;
502
503    fn spub_init(
504        &self,
505        transaction_id: TransactionId,
506        key: Key,
507        client_id: Uuid,
508    ) -> impl Future<Output = WorterbuchResult<()>> + Send;
509
510    fn spub(
511        &self,
512        transaction_id: TransactionId,
513        value: Value,
514        client_id: Uuid,
515    ) -> impl Future<Output = WorterbuchResult<()>> + Send;
516
517    fn publish(&self, key: Key, value: Value) -> impl Future<Output = WorterbuchResult<()>> + Send;
518
519    fn ls(
520        &self,
521        parent: Option<Key>,
522    ) -> impl Future<Output = WorterbuchResult<Vec<RegularKeySegment>>> + Send;
523
524    fn pls(
525        &self,
526        parent: Option<RequestPattern>,
527    ) -> impl Future<Output = WorterbuchResult<Vec<RegularKeySegment>>> + Send;
528
529    fn subscribe(
530        &self,
531        client_id: Uuid,
532        transaction_id: TransactionId,
533        key: Key,
534        unique: bool,
535        live_only: bool,
536    ) -> impl Future<Output = WorterbuchResult<(mpsc::Receiver<StateEvent>, SubscriptionId)>> + Send;
537
538    fn psubscribe(
539        &self,
540        client_id: Uuid,
541        transaction_id: TransactionId,
542        pattern: RequestPattern,
543        unique: bool,
544        live_only: bool,
545    ) -> impl Future<Output = WorterbuchResult<(mpsc::Receiver<PStateEvent>, SubscriptionId)>> + Send;
546
547    fn subscribe_ls(
548        &self,
549        client_id: Uuid,
550        transaction_id: TransactionId,
551        parent: Option<Key>,
552    ) -> impl Future<
553        Output = WorterbuchResult<(mpsc::Receiver<Vec<RegularKeySegment>>, SubscriptionId)>,
554    > + Send;
555
556    fn unsubscribe(
557        &self,
558        client_id: Uuid,
559        transaction_id: TransactionId,
560    ) -> impl Future<Output = WorterbuchResult<()>> + Send;
561
562    fn unsubscribe_ls(
563        &self,
564        client_id: Uuid,
565        transaction_id: TransactionId,
566    ) -> impl Future<Output = WorterbuchResult<()>> + Send;
567
568    fn delete(
569        &self,
570        key: Key,
571        client_id: Uuid,
572    ) -> impl Future<Output = WorterbuchResult<Value>> + Send;
573
574    fn pdelete(
575        &self,
576        pattern: RequestPattern,
577        client_id: Uuid,
578    ) -> impl Future<Output = WorterbuchResult<KeyValuePairs>> + Send;
579
580    fn connected(
581        &self,
582        client_id: Uuid,
583        remote_addr: Option<SocketAddr>,
584        protocol: Protocol,
585    ) -> impl Future<Output = WorterbuchResult<()>> + Send;
586
587    fn protocol_switched(
588        &self,
589        client_id: Uuid,
590        protocol: ProtocolMajorVersion,
591    ) -> impl Future<Output = WorterbuchResult<()>> + Send;
592
593    fn disconnected(
594        &self,
595        client_id: Uuid,
596        remote_addr: Option<SocketAddr>,
597    ) -> impl Future<Output = WorterbuchResult<()>> + Send;
598
599    fn export(
600        &self,
601        span: Span,
602    ) -> impl Future<Output = WorterbuchResult<(Value, GraveGoods, LastWill)>> + Send;
603
604    fn import(
605        &self,
606        json: String,
607    ) -> impl Future<Output = WorterbuchResult<Vec<(String, (ValueEntry, bool))>>> + Send;
608
609    fn entries(&self) -> impl Future<Output = WorterbuchResult<usize>> + Send;
610}
611
612#[cfg(test)]
613mod test {
614    use crate::{ErrorCode, ProtocolVersion};
615    use serde_json::json;
616
617    #[test]
618    fn protocol_versions_are_sorted_correctly() {
619        assert!(ProtocolVersion::new(1, 2) < ProtocolVersion::new(3, 2));
620        assert!(ProtocolVersion::new(1, 2) == ProtocolVersion::new(1, 2));
621        assert!(ProtocolVersion::new(2, 1) > ProtocolVersion::new(1, 9));
622
623        let mut versions = vec![
624            ProtocolVersion::new(1, 2),
625            ProtocolVersion::new(0, 456),
626            ProtocolVersion::new(9, 0),
627            ProtocolVersion::new(3, 15),
628        ];
629        versions.sort();
630        assert_eq!(
631            vec![
632                ProtocolVersion::new(0, 456),
633                ProtocolVersion::new(1, 2),
634                ProtocolVersion::new(3, 15),
635                ProtocolVersion::new(9, 0)
636            ],
637            versions
638        );
639    }
640
641    #[test]
642    fn topic_macro_generates_topic_correctly() {
643        assert_eq!(
644            "hello/world/foo/bar",
645            topic!("hello", "world", "foo", "bar")
646        );
647    }
648
649    #[test]
650    fn error_codes_are_serialized_as_numbers() {
651        assert_eq!(
652            "1",
653            serde_json::to_string(&ErrorCode::IllegalMultiWildcard).unwrap()
654        )
655    }
656
657    #[test]
658    fn error_codes_are_deserialized_from_numbers() {
659        assert_eq!(
660            ErrorCode::ProtocolNegotiationFailed,
661            serde_json::from_str("7").unwrap()
662        )
663    }
664
665    #[test]
666    fn protocol_version_get_serialized_correctly() {
667        assert_eq!(&json!(ProtocolVersion::new(2, 1)).to_string(), "[2,1]")
668    }
669
670    #[test]
671    fn protocol_version_get_formatted_correctly() {
672        assert_eq!(&ProtocolVersion::new(2, 1).to_string(), "2.1")
673    }
674
675    #[test]
676    fn compatible_version_is_selected_correctly() {
677        let client_version = ProtocolVersion::new(1, 2);
678        let server_versions = [
679            ProtocolVersion::new(0, 11),
680            ProtocolVersion::new(1, 6),
681            ProtocolVersion::new(2, 0),
682        ];
683        let compatible_version = server_versions
684            .iter()
685            .find(|v| client_version.is_compatible_with_server(v));
686        assert_eq!(compatible_version, Some(&server_versions[1]))
687    }
688}