Skip to main content

user_store/
user_store.rs

1use std::path::{Path, PathBuf};
2
3use trine_kv::{Bucket, BucketOptions, Db, Error, PrefixExtractor, Result, TransactionOptions};
4
5fn main() -> Result<()> {
6    let path = temp_path("trine-kv-user-store");
7    reset_dir(&path)?;
8
9    let store = UserStore::open(&path)?;
10    store.put_user(&User::new("001", "Ada", "ada@example.test"))?;
11    store.put_user(&User::new("002", "Lin", "lin@example.test"))?;
12
13    let users = store.list_users()?;
14    assert_eq!(
15        users
16            .iter()
17            .map(|user| user.display_name.as_str())
18            .collect::<Vec<_>>(),
19        ["Ada", "Lin"]
20    );
21
22    assert!(store.rename_if_email_matches("001", "ada@example.test", "Ada Lovelace")?);
23    assert!(!store.rename_if_email_matches("002", "other@example.test", "Someone Else")?);
24    store.flush()?;
25    drop(store);
26
27    let reopened = UserStore::open(&path)?;
28    assert_eq!(
29        reopened.get_user("001")?,
30        Some(User::new("001", "Ada Lovelace", "ada@example.test"))
31    );
32
33    drop(reopened);
34    std::fs::remove_dir_all(path)?;
35    Ok(())
36}
37
38struct UserStore {
39    db: Db,
40    users: Bucket,
41}
42
43impl UserStore {
44    fn open(path: &Path) -> Result<Self> {
45        let db = Db::open_sync(path)?;
46        let users = db.bucket_with_options_sync(
47            "users",
48            BucketOptions::default().with_prefix_extractor(PrefixExtractor::Separator(b':')),
49        )?;
50        Ok(Self { db, users })
51    }
52
53    fn put_user(&self, user: &User) -> Result<()> {
54        self.users.put_sync(user_key(&user.id), user.encode()?)
55    }
56
57    fn get_user(&self, id: &str) -> Result<Option<User>> {
58        self.users
59            .get_sync(&user_key(id))?
60            .map(|bytes| User::decode(&bytes))
61            .transpose()
62    }
63
64    fn list_users(&self) -> Result<Vec<User>> {
65        self.users
66            .prefix_sync(b"user:")?
67            .map(|item| item.and_then(|key_value| User::decode(&key_value.value)))
68            .collect()
69    }
70
71    fn rename_if_email_matches(
72        &self,
73        id: &str,
74        expected_email: &str,
75        new_name: &str,
76    ) -> Result<bool> {
77        let key = user_key(id);
78        let mut transaction = self.db.transaction(TransactionOptions::default());
79        let Some(bytes) = transaction.get_bucket_sync("users", &key)? else {
80            return Ok(false);
81        };
82        let mut user = User::decode(&bytes)?;
83        if user.email != expected_email {
84            return Ok(false);
85        }
86
87        new_name.clone_into(&mut user.display_name);
88        transaction.put_bucket("users", key, user.encode()?)?;
89        transaction.commit_sync()?;
90        Ok(true)
91    }
92
93    fn flush(&self) -> Result<()> {
94        self.db.flush_sync()
95    }
96}
97
98#[derive(Debug, Clone, PartialEq, Eq)]
99struct User {
100    id: String,
101    display_name: String,
102    email: String,
103}
104
105impl User {
106    fn new(id: &str, display_name: &str, email: &str) -> Self {
107        Self {
108            id: id.to_owned(),
109            display_name: display_name.to_owned(),
110            email: email.to_owned(),
111        }
112    }
113
114    fn encode(&self) -> Result<Vec<u8>> {
115        encode_fields(&[&self.id, &self.display_name, &self.email])
116    }
117
118    fn decode(bytes: &[u8]) -> Result<Self> {
119        let mut fields = FieldCursor::new(bytes);
120        let user = Self {
121            id: fields.read_string()?,
122            display_name: fields.read_string()?,
123            email: fields.read_string()?,
124        };
125        fields.finish()?;
126        Ok(user)
127    }
128}
129
130fn user_key(id: &str) -> Vec<u8> {
131    format!("user:{id}").into_bytes()
132}
133
134fn encode_fields(fields: &[&str]) -> Result<Vec<u8>> {
135    let mut bytes = Vec::new();
136    for field in fields {
137        let len = u32::try_from(field.len())
138            .map_err(|_| Error::invalid_options("user field exceeds u32::MAX"))?;
139        bytes.extend_from_slice(&len.to_le_bytes());
140        bytes.extend_from_slice(field.as_bytes());
141    }
142    Ok(bytes)
143}
144
145struct FieldCursor<'bytes> {
146    bytes: &'bytes [u8],
147    offset: usize,
148}
149
150impl<'bytes> FieldCursor<'bytes> {
151    const fn new(bytes: &'bytes [u8]) -> Self {
152        Self { bytes, offset: 0 }
153    }
154
155    fn read_string(&mut self) -> Result<String> {
156        let len_bytes = self
157            .bytes
158            .get(self.offset..self.offset.saturating_add(4))
159            .ok_or_else(|| invalid_user("short field length"))?;
160        self.offset += 4;
161
162        let len =
163            u32::from_le_bytes([len_bytes[0], len_bytes[1], len_bytes[2], len_bytes[3]]) as usize;
164        let end = self
165            .offset
166            .checked_add(len)
167            .ok_or_else(|| invalid_user("field length overflows usize"))?;
168        let value = self
169            .bytes
170            .get(self.offset..end)
171            .ok_or_else(|| invalid_user("short field bytes"))?;
172        self.offset = end;
173
174        std::str::from_utf8(value)
175            .map(str::to_owned)
176            .map_err(|_| invalid_user("field is not UTF-8"))
177    }
178
179    fn finish(&self) -> Result<()> {
180        if self.offset == self.bytes.len() {
181            return Ok(());
182        }
183        Err(invalid_user("trailing bytes"))
184    }
185}
186
187fn invalid_user(message: &'static str) -> Error {
188    Error::InvalidFormat {
189        message: format!("invalid user record: {message}"),
190    }
191}
192
193fn temp_path(name: &str) -> PathBuf {
194    std::env::temp_dir().join(format!("{name}-{}", std::process::id()))
195}
196
197fn reset_dir(path: &Path) -> Result<()> {
198    match std::fs::remove_dir_all(path) {
199        Ok(()) => {}
200        Err(error) if error.kind() == std::io::ErrorKind::NotFound => {}
201        Err(error) => return Err(Error::Io(error)),
202    }
203    Ok(())
204}