wthr/
cache.rs

1#![allow(dead_code, unused_variables)]
2use std::{env, fs, path::PathBuf};
3
4use chrono::{DateTime, Utc};
5use rusqlite::{named_params, Connection, ToSql};
6
7use crate::{
8    error::{err, Result},
9    APP,
10};
11
12#[derive(Debug)]
13pub struct Cache {
14    conn: Connection,
15}
16
17#[derive(Debug)]
18pub struct CacheEntry {
19    url: String,
20    pub created_at: DateTime<Utc>,
21    pub max_age: Option<u32>,
22    last_modified: DateTime<Utc>,
23    pub content: String,
24}
25
26impl Cache {
27    /// Connect to a database and return a handle to perform caching
28    /// operations.
29    pub fn new() -> Result<Self> {
30        Self::with_base_dir(None)
31    }
32
33    pub fn with_base_dir(base_dir: Option<PathBuf>) -> Result<Self> {
34        let db_path = match base_dir {
35            Some(dir) => dir,
36            None => match env::var("XDG_DATA_HOME") {
37                Ok(dir) => PathBuf::from(dir).join(APP),
38                Err(_) => match env::var("HOME") {
39                    Ok(dir) => PathBuf::from(dir).join(".local").join("share").join(APP),
40                    Err(_) => PathBuf::from("/tmp").join(APP),
41                },
42            },
43        };
44
45        // Make sure the database directory exists
46        let db_path = match fs::create_dir_all(&db_path) {
47            Ok(_) => Some(db_path),
48            Err(err) if err.kind() == std::io::ErrorKind::AlreadyExists => Some(db_path),
49            Err(_) => None,
50        };
51
52        // Open the database, which creates a new db file if needed
53        let conn = match db_path {
54            Some(mut db_path) => {
55                db_path = db_path.join(APP);
56                db_path.set_extension("db");
57                Connection::open(&db_path)?
58            }
59            None => Connection::open_in_memory()?,
60        };
61
62        // Create tables if this is a new database
63        if conn.prepare("select max(id) from version").is_err() {
64            init_db(&conn)?;
65        }
66
67        Ok(Self { conn })
68    }
69
70    pub fn get<T>(&self, url: T) -> Result<Option<CacheEntry>>
71    where
72        T: AsRef<str> + ToSql,
73    {
74        let mut stmt = self.conn.prepare(
75            "select url, created_at, max_age, last_modified, content
76             from cache where url = ?",
77        )?;
78        let mut rows = stmt.query_map([url], |row| {
79            Ok(CacheEntry {
80                url: row.get(0)?,
81                created_at: row.get(1)?,
82                max_age: row.get(2)?,
83                last_modified: row.get(3)?,
84                content: row.get(4)?,
85            })
86        })?;
87
88        if let Some(row) = rows.next() {
89            let row = row?;
90            Ok(Some(row))
91        } else {
92            Ok(None)
93        }
94    }
95
96    pub fn insert<T>(
97        &mut self,
98        url: T,
99        max_age: Option<u32>,
100        last_modified: DateTime<Utc>,
101        content: &str,
102    ) -> Result<()>
103    where
104        T: AsRef<str> + ToSql,
105    {
106        let sql = "\
107            insert into cache(url, created_at, max_age, last_modified, content)
108            values(:url, :created_at, :max_age, :last_modified, :content)
109            on conflict(url) do update set
110                created_at=:created_at, max_age=:max_age,
111                last_modified=:last_modified, content=:content";
112        self.conn.execute(
113            sql,
114            named_params! {
115                ":url": url,
116                ":created_at": Utc::now(),
117                ":max_age": max_age,
118                ":last_modified": last_modified,
119                ":content": content,
120            },
121        )?;
122
123        Ok(())
124    }
125
126    pub fn db_version(&self) -> Result<u32> {
127        let mut stmt = self.conn.prepare("select max(id) from version")?;
128        let mut rows = stmt.query([])?;
129        if let Some(row) = rows.next()? {
130            Ok(row.get(0)?)
131        } else {
132            err("Can't determine database version")
133        }
134    }
135}
136
137fn init_db(conn: &Connection) -> Result<()> {
138    conn.execute_batch(
139        "\
140        create table cache(
141            url text unique,
142            created_at datetime,
143            max_age int,
144            last_modified datetime,
145            content text);
146        create table version(id int);
147        insert into version values(1);
148        ",
149    )?;
150
151    Ok(())
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157
158    use tempfile::{tempdir, TempDir};
159
160    /// Create a cache database in a temporary directory for testing.
161    fn tempcache() -> (Cache, TempDir) {
162        let temp_dir = tempdir().unwrap();
163
164        // Cache should create new database with version = 1
165        let cache = Cache::with_base_dir(Some(temp_dir.path().to_path_buf())).unwrap();
166        assert_eq!(cache.db_version().unwrap(), 1);
167
168        (cache, temp_dir)
169    }
170
171    #[test]
172    fn cache_works() {
173        let (mut cache, _temp_dir) = tempcache();
174        cache
175            .insert("mock.url", Some(888), Utc::now(), "content")
176            .unwrap();
177
178        let cached_page = cache.get("mock.url").unwrap().unwrap();
179        assert_eq!(cached_page.url, "mock.url".to_string());
180    }
181
182    #[test]
183    fn db_initializes() {
184        let (_cache, temp_dir) = tempcache();
185        let mut db_path = PathBuf::from(temp_dir.path());
186
187        // Confirm full database path
188        db_path = db_path.join(APP);
189        db_path.set_extension("db");
190        let conn = Connection::open(&db_path).unwrap();
191
192        // Check version outside of `Cache::version`
193        let mut stmt = conn.prepare("select max(id) from version").unwrap();
194        let rows = stmt.query([]);
195        match rows {
196            Ok(mut rows) => {
197                assert_eq!(
198                    rows.next().unwrap().unwrap().get::<usize, u32>(0).unwrap(),
199                    1
200                );
201            }
202            Err(_) => panic!(),
203        }
204
205        // Create a new instance of `Cache` to make sure it can reuse the database
206        conn.execute("update version set id = ?", [2]).unwrap();
207        let cache = Cache::with_base_dir(Some(temp_dir.path().to_path_buf())).unwrap();
208        assert_eq!(cache.db_version().unwrap(), 2);
209    }
210
211    #[test]
212    fn max_age_works() {
213        let (mut cache, _temp_dir) = tempcache();
214
215        // With max-age
216        cache
217            .insert("url_888", Some(888), Utc::now(), "content")
218            .unwrap();
219        let cached_page = cache.get("url_888").unwrap().unwrap();
220        assert_eq!(Some(888), cached_page.max_age);
221
222        // Without max-age
223        cache
224            .insert("url_None", None, Utc::now(), "content")
225            .unwrap();
226        let cached_page = cache.get("url_None").unwrap().unwrap();
227        assert_eq!(None, cached_page.max_age);
228    }
229}