1use super::*;
2use rand::{thread_rng, Rng};
3use sqlite_vfs::{OpenKind, OpenOptions, Vfs};
4use std::{
5 io::{Error, ErrorKind},
6 time::Duration,
7};
8
9pub struct HttpVfs {
10 pub(crate) client: Option<Client>,
11 pub(crate) block_size: usize,
12 pub(crate) download_threshold: usize,
13}
14
15impl Vfs for HttpVfs {
16 type Handle = Connection;
17
18 fn open(&self, db: &str, opts: OpenOptions) -> Result<Self::Handle, Error> {
19 if opts.kind != OpenKind::MainDb {
20 return Err(Error::new(
21 ErrorKind::ReadOnlyFilesystem,
22 "only main database supported",
23 ));
24 }
25
26 Ok(Connection::new(
27 db,
28 self.client.clone(),
29 self.block_size,
30 self.download_threshold,
31 )?)
32 }
33
34 fn delete(&self, _db: &str) -> Result<(), Error> {
35 Err(Error::new(
36 ErrorKind::ReadOnlyFilesystem,
37 "delete operation is not supported",
38 ))
39 }
40
41 fn exists(&self, _db: &str) -> Result<bool, Error> {
42 Ok(false)
43 }
44
45 fn temporary_name(&self) -> String {
46 String::from("main.db")
47 }
48
49 fn random(&self, buffer: &mut [i8]) {
50 Rng::fill(&mut thread_rng(), buffer);
51 }
52
53 fn sleep(&self, duration: Duration) -> Duration {
54 std::thread::sleep(duration);
55 duration
56 }
57}
58
59#[cfg(test)]
60mod tests {
61 use std::future::Future;
62
63 use super::*;
64 use rusqlite::{Connection, OpenFlags};
65 use tokio::time::sleep;
66
67 const QUERY_SQLITE_MASTER: &str = "SELECT count(1) FROM sqlite_master WHERE type = 'table'";
68 const QUERY_TEST: &str = "SELECT name FROM test";
69
70 mod server {
71 use rocket::{custom, figment::Figment, get, routes, Config, Shutdown, State};
72 use rocket_seek_stream::SeekStream;
73 use rusqlite::Connection;
74 use std::{collections::HashMap, fs::read, io::Cursor, thread::JoinHandle};
75 use tempfile::tempdir;
76 use tokio::runtime::Runtime;
77
78 fn init_database() -> HashMap<i64, Vec<u8>> {
79 let schemas = [
80 vec![
81 "PRAGMA journal_mode = MEMORY;",
82 "CREATE TABLE test1 (id INTEGER PRIMARY KEY, name TEXT);",
83 "CREATE TABLE test2 (id INTEGER PRIMARY KEY, name TEXT);",
84 ],
85 vec![
86 "PRAGMA journal_mode = MEMORY;",
87 "CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT);",
88 "INSERT INTO test (name) VALUES ('Alice');",
89 "INSERT INTO test (name) VALUES ('Bob');",
90 ],
91 ];
92 let mut database = HashMap::new();
93
94 let temp = tempdir().unwrap();
95
96 for (i, schema) in schemas.into_iter().enumerate() {
97 let path = temp.path().join(format!("{i}.db"));
98 let conn = Connection::open(&path).unwrap();
99 conn.execute_batch(&schema.join("\n")).unwrap();
100 conn.close().unwrap();
101 database.insert(i as i64, read(&path).unwrap());
102 }
103
104 database
105 }
106
107 #[get("/<id>")]
108 pub async fn database(
109 db: &State<HashMap<i64, Vec<u8>>>,
110 id: i64,
111 ) -> Option<SeekStream<'static>> {
112 if let Some(buffer) = db.get(&id) {
113 let cursor = Cursor::new(buffer.clone());
114 Some(SeekStream::with_opts(cursor, buffer.len() as u64, None))
115 } else {
116 None
117 }
118 }
119
120 #[get("/shutdown")]
121 pub async fn shutdown(shutdown: Shutdown) -> &'static str {
122 shutdown.notify();
123 "Shutting down..."
124 }
125
126 pub fn launch() -> JoinHandle<Result<(), rocket::Error>> {
127 std::thread::spawn(|| {
128 let rt = Runtime::new().unwrap();
129 rt.block_on(async {
130 custom(Figment::from(Config::default()).merge(("port", 4096)))
131 .manage(init_database())
132 .mount("/", routes![database, shutdown])
133 .launch()
134 .await?;
135
136 Ok(())
137 })
138 })
139 }
140 }
141
142 async fn init_server<C, F>(future: C) -> anyhow::Result<()>
143 where
144 C: FnOnce(String) -> F,
145 F: Future<Output = anyhow::Result<()>>,
146 {
147 let base = "http://127.0.0.1:4096";
148 let server = server::launch();
149
150 loop {
152 let resp = reqwest::get(base).await;
153 if let Ok(resp) = resp {
154 if resp.status() == 404 {
155 break;
156 }
157 }
158 sleep(Duration::from_millis(100)).await;
159 }
160
161 future(base.into()).await?;
162
163 reqwest::get(format!("{base}/shutdown").as_str()).await?;
164 server.join().unwrap()?;
165
166 Ok(())
167 }
168
169 #[tokio::test]
170 async fn test_http_vfs() {
171 init_server(|base| async move {
172 vfs::register_http_vfs();
173
174 {
175 let conn = Connection::open_with_flags_and_vfs(
176 format!("{base}/0"),
177 OpenFlags::SQLITE_OPEN_READ_WRITE
178 | OpenFlags::SQLITE_OPEN_CREATE
179 | OpenFlags::SQLITE_OPEN_NO_MUTEX,
180 HTTP_VFS,
181 )?;
182 assert_eq!(
183 conn.query_row::<usize, _, _>(QUERY_SQLITE_MASTER, [], |row| row.get(0))?,
184 2
185 );
186 }
187
188 {
189 let conn = Connection::open_with_flags_and_vfs(
190 format!("{base}/1"),
191 OpenFlags::SQLITE_OPEN_READ_WRITE
192 | OpenFlags::SQLITE_OPEN_CREATE
193 | OpenFlags::SQLITE_OPEN_NO_MUTEX,
194 HTTP_VFS,
195 )?;
196 let mut stmt = conn.prepare(QUERY_TEST)?;
197 assert_eq!(
198 stmt.query_map([], |row| row.get::<_, String>(0))?
199 .collect::<Result<Vec<_>, _>>()?,
200 vec!["Alice".to_string(), "Bob".to_string()]
201 );
202 }
203
204 Ok(())
205 })
206 .await
207 .unwrap();
208 }
209}