1#![allow(dead_code, unused_variables)]
2use std::{env, fs, path::PathBuf};
3
4use chrono::{DateTime, Utc};
5use rusqlite::{named_params, Connection, ToSql};
6
7use crate::{
8 error::{err, Result},
9 APP,
10};
11
12#[derive(Debug)]
13pub struct Cache {
14 conn: Connection,
15}
16
17#[derive(Debug)]
18pub struct CacheEntry {
19 url: String,
20 pub created_at: DateTime<Utc>,
21 pub max_age: Option<u32>,
22 last_modified: DateTime<Utc>,
23 pub content: String,
24}
25
26impl Cache {
27 pub fn new() -> Result<Self> {
30 Self::with_base_dir(None)
31 }
32
33 pub fn with_base_dir(base_dir: Option<PathBuf>) -> Result<Self> {
34 let db_path = match base_dir {
35 Some(dir) => dir,
36 None => match env::var("XDG_DATA_HOME") {
37 Ok(dir) => PathBuf::from(dir).join(APP),
38 Err(_) => match env::var("HOME") {
39 Ok(dir) => PathBuf::from(dir).join(".local").join("share").join(APP),
40 Err(_) => PathBuf::from("/tmp").join(APP),
41 },
42 },
43 };
44
45 let db_path = match fs::create_dir_all(&db_path) {
47 Ok(_) => Some(db_path),
48 Err(err) if err.kind() == std::io::ErrorKind::AlreadyExists => Some(db_path),
49 Err(_) => None,
50 };
51
52 let conn = match db_path {
54 Some(mut db_path) => {
55 db_path = db_path.join(APP);
56 db_path.set_extension("db");
57 Connection::open(&db_path)?
58 }
59 None => Connection::open_in_memory()?,
60 };
61
62 if conn.prepare("select max(id) from version").is_err() {
64 init_db(&conn)?;
65 }
66
67 Ok(Self { conn })
68 }
69
70 pub fn get<T>(&self, url: T) -> Result<Option<CacheEntry>>
71 where
72 T: AsRef<str> + ToSql,
73 {
74 let mut stmt = self.conn.prepare(
75 "select url, created_at, max_age, last_modified, content
76 from cache where url = ?",
77 )?;
78 let mut rows = stmt.query_map([url], |row| {
79 Ok(CacheEntry {
80 url: row.get(0)?,
81 created_at: row.get(1)?,
82 max_age: row.get(2)?,
83 last_modified: row.get(3)?,
84 content: row.get(4)?,
85 })
86 })?;
87
88 if let Some(row) = rows.next() {
89 let row = row?;
90 Ok(Some(row))
91 } else {
92 Ok(None)
93 }
94 }
95
96 pub fn insert<T>(
97 &mut self,
98 url: T,
99 max_age: Option<u32>,
100 last_modified: DateTime<Utc>,
101 content: &str,
102 ) -> Result<()>
103 where
104 T: AsRef<str> + ToSql,
105 {
106 let sql = "\
107 insert into cache(url, created_at, max_age, last_modified, content)
108 values(:url, :created_at, :max_age, :last_modified, :content)
109 on conflict(url) do update set
110 created_at=:created_at, max_age=:max_age,
111 last_modified=:last_modified, content=:content";
112 self.conn.execute(
113 sql,
114 named_params! {
115 ":url": url,
116 ":created_at": Utc::now(),
117 ":max_age": max_age,
118 ":last_modified": last_modified,
119 ":content": content,
120 },
121 )?;
122
123 Ok(())
124 }
125
126 pub fn db_version(&self) -> Result<u32> {
127 let mut stmt = self.conn.prepare("select max(id) from version")?;
128 let mut rows = stmt.query([])?;
129 if let Some(row) = rows.next()? {
130 Ok(row.get(0)?)
131 } else {
132 err("Can't determine database version")
133 }
134 }
135}
136
137fn init_db(conn: &Connection) -> Result<()> {
138 conn.execute_batch(
139 "\
140 create table cache(
141 url text unique,
142 created_at datetime,
143 max_age int,
144 last_modified datetime,
145 content text);
146 create table version(id int);
147 insert into version values(1);
148 ",
149 )?;
150
151 Ok(())
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157
158 use tempfile::{tempdir, TempDir};
159
160 fn tempcache() -> (Cache, TempDir) {
162 let temp_dir = tempdir().unwrap();
163
164 let cache = Cache::with_base_dir(Some(temp_dir.path().to_path_buf())).unwrap();
166 assert_eq!(cache.db_version().unwrap(), 1);
167
168 (cache, temp_dir)
169 }
170
171 #[test]
172 fn cache_works() {
173 let (mut cache, _temp_dir) = tempcache();
174 cache
175 .insert("mock.url", Some(888), Utc::now(), "content")
176 .unwrap();
177
178 let cached_page = cache.get("mock.url").unwrap().unwrap();
179 assert_eq!(cached_page.url, "mock.url".to_string());
180 }
181
182 #[test]
183 fn db_initializes() {
184 let (_cache, temp_dir) = tempcache();
185 let mut db_path = PathBuf::from(temp_dir.path());
186
187 db_path = db_path.join(APP);
189 db_path.set_extension("db");
190 let conn = Connection::open(&db_path).unwrap();
191
192 let mut stmt = conn.prepare("select max(id) from version").unwrap();
194 let rows = stmt.query([]);
195 match rows {
196 Ok(mut rows) => {
197 assert_eq!(
198 rows.next().unwrap().unwrap().get::<usize, u32>(0).unwrap(),
199 1
200 );
201 }
202 Err(_) => panic!(),
203 }
204
205 conn.execute("update version set id = ?", [2]).unwrap();
207 let cache = Cache::with_base_dir(Some(temp_dir.path().to_path_buf())).unwrap();
208 assert_eq!(cache.db_version().unwrap(), 2);
209 }
210
211 #[test]
212 fn max_age_works() {
213 let (mut cache, _temp_dir) = tempcache();
214
215 cache
217 .insert("url_888", Some(888), Utc::now(), "content")
218 .unwrap();
219 let cached_page = cache.get("url_888").unwrap().unwrap();
220 assert_eq!(Some(888), cached_page.max_age);
221
222 cache
224 .insert("url_None", None, Utc::now(), "content")
225 .unwrap();
226 let cached_page = cache.get("url_None").unwrap().unwrap();
227 assert_eq!(None, cached_page.max_age);
228 }
229}