trust_dns_server/store/sqlite/
persistence.rs1use std::iter::Iterator;
11use std::path::Path;
12use std::sync::{Mutex, MutexGuard};
13
14use rusqlite::types::ToSql;
15use rusqlite::{self, Connection};
16use time;
17use tracing::error;
18
19use crate::error::{PersistenceErrorKind, PersistenceResult};
20use crate::proto::rr::Record;
21use crate::proto::serialize::binary::{BinDecodable, BinDecoder, BinEncodable, BinEncoder};
22
23pub const CURRENT_VERSION: i64 = 1;
25
26pub struct Journal {
28 conn: Mutex<Connection>,
29 version: i64,
30}
31
32impl Journal {
33 pub fn new(conn: Connection) -> PersistenceResult<Self> {
35 let version = Self::select_schema_version(&conn)?;
36 Ok(Self {
37 conn: Mutex::new(conn),
38 version,
39 })
40 }
41
42 pub fn from_file(journal_file: &Path) -> PersistenceResult<Self> {
44 let result = Self::new(Connection::open(journal_file)?);
45 let mut journal = result?;
46 journal.schema_up()?;
47 Ok(journal)
48 }
49
50 pub fn conn(&self) -> MutexGuard<'_, Connection> {
52 self.conn.lock().expect("conn poisoned")
53 }
54
55 pub fn schema_version(&self) -> i64 {
57 self.version
58 }
59
60 pub fn iter(&self) -> JournalIter<'_> {
62 JournalIter::new(self)
63 }
64
65 pub fn insert_record(&self, soa_serial: u32, record: &Record) -> PersistenceResult<()> {
75 assert!(
76 self.version == CURRENT_VERSION,
77 "schema version mismatch, schema_up() resolves this"
78 );
79
80 let mut serial_record: Vec<u8> = Vec::with_capacity(512);
81 {
82 let mut encoder = BinEncoder::new(&mut serial_record);
83 record.emit(&mut encoder)?;
84 }
85
86 let timestamp = time::OffsetDateTime::now_utc();
87 let client_id: i64 = 0; let soa_serial: i64 = i64::from(soa_serial);
89
90 let count = self.conn.lock().expect("conn poisoned").execute(
91 "INSERT
92 \
93 INTO records (client_id, soa_serial, timestamp, \
94 record)
95 \
96 VALUES ($1, $2, $3, $4)",
97 [
98 &client_id as &dyn ToSql,
99 &soa_serial,
100 ×tamp,
101 &serial_record,
102 ],
103 )?;
104 if count != 1 {
106 return Err(PersistenceErrorKind::WrongInsertCount {
107 got: count,
108 expect: 1,
109 }
110 .into());
111 };
112
113 Ok(())
114 }
115
116 pub fn insert_records(&self, soa_serial: u32, records: &[Record]) -> PersistenceResult<()> {
118 for record in records {
120 self.insert_record(soa_serial, record)?;
121 }
122
123 Ok(())
124 }
125
126 pub fn select_record(&self, row_id: i64) -> PersistenceResult<Option<(i64, Record)>> {
136 assert!(
137 self.version == CURRENT_VERSION,
138 "schema version mismatch, schema_up() resolves this"
139 );
140
141 let conn = self.conn.lock().expect("conn poisoned");
142 let mut stmt = conn.prepare(
143 "SELECT _rowid_, record
144 \
145 FROM records
146 \
147 WHERE _rowid_ >= $1
148 \
149 LIMIT 1",
150 )?;
151
152 let record_opt: Option<Result<(i64, Record), rusqlite::Error>> = stmt
153 .query_and_then([&row_id], |row| -> Result<(i64, Record), rusqlite::Error> {
154 let row_id: i64 = row.get(0)?;
155 let record_bytes: Vec<u8> = row.get(1)?;
156 let mut decoder = BinDecoder::new(&record_bytes);
157
158 match Record::read(&mut decoder) {
160 Ok(record) => Ok((row_id, record)),
161 Err(decode_error) => Err(rusqlite::Error::InvalidParameterName(format!(
162 "could not decode: {decode_error}"
163 ))),
164 }
165 })?
166 .next();
167
168 match record_opt {
170 Some(Ok((row_id, record))) => Ok(Some((row_id, record))),
171 Some(Err(err)) => Err(err.into()),
172 None => Ok(None),
173 }
174 }
175
176 pub fn select_schema_version(conn: &Connection) -> PersistenceResult<i64> {
183 let mut stmt = conn.prepare(
185 "SELECT name
186 \
187 FROM sqlite_master
188 \
189 WHERE type='table'
190 \
191 AND name='tdns_schema'",
192 )?;
193
194 let tdns_schema_opt: Option<Result<String, _>> =
195 stmt.query_map([], |row| row.get(0))?.next();
196
197 let tdns_schema = match tdns_schema_opt {
198 Some(Ok(string)) => string,
199 Some(Err(err)) => return Err(err.into()),
200 None => return Ok(-1),
201 };
202
203 assert_eq!(&tdns_schema, "tdns_schema");
204
205 let version: i64 = conn.query_row(
206 "SELECT version
207 \
208 FROM tdns_schema",
209 [],
210 |row| row.get(0),
211 )?;
212
213 Ok(version)
214 }
215
216 fn update_schema_version(&self, new_version: i64) -> PersistenceResult<()> {
218 assert!(new_version <= CURRENT_VERSION);
220
221 let count = self
222 .conn
223 .lock()
224 .expect("conn poisoned")
225 .execute("UPDATE tdns_schema SET version = $1", [&new_version])?;
226
227 assert_eq!(count, 1);
229 Ok(())
230 }
231
232 pub fn schema_up(&mut self) -> PersistenceResult<i64> {
234 while self.version < CURRENT_VERSION {
235 match self.version + 1 {
236 0 => self.version = self.init_up()?,
237 1 => self.version = self.records_up()?,
238 _ => panic!("incorrect version somewhere"), }
240
241 self.update_schema_version(self.version)?;
242 }
243
244 Ok(self.version)
245 }
246
247 fn init_up(&self) -> PersistenceResult<i64> {
249 let count = self.conn.lock().expect("conn poisoned").execute(
250 "CREATE TABLE tdns_schema (
251 \
252 version INTEGER NOT NULL
253 \
254 )",
255 [],
256 )?;
257 assert_eq!(count, 0);
259
260 let count = self
261 .conn
262 .lock()
263 .expect("conn poisoned")
264 .execute("INSERT INTO tdns_schema (version) VALUES (0)", [])?;
265 assert_eq!(count, 1);
267
268 Ok(0)
269 }
270
271 fn records_up(&self) -> PersistenceResult<i64> {
274 let count = self.conn.lock().expect("conn poisoned").execute(
276 "CREATE TABLE records (
277 \
278 client_id INTEGER NOT NULL,
279 \
280 soa_serial INTEGER NOT NULL,
281 \
282 timestamp TEXT NOT NULL,
283 \
284 record BLOB NOT NULL
285 \
286 )",
287 [],
288 )?;
289 assert_eq!(count, 1);
291
292 Ok(1)
293 }
294}
295
296pub struct JournalIter<'j> {
300 current_row_id: i64,
301 journal: &'j Journal,
302}
303
304impl<'j> JournalIter<'j> {
305 fn new(journal: &'j Journal) -> Self {
306 JournalIter {
307 current_row_id: 0,
308 journal,
309 }
310 }
311}
312
313impl<'j> Iterator for JournalIter<'j> {
314 type Item = Record;
315
316 fn next(&mut self) -> Option<Self::Item> {
317 let next: PersistenceResult<Option<(i64, Record)>> =
318 self.journal.select_record(self.current_row_id + 1);
319
320 match next {
321 Ok(Some((row_id, record))) => {
322 self.current_row_id = row_id;
323 Some(record)
324 }
325 Ok(None) => None,
326 Err(err) => {
327 error!("persistence error while iterating over journal: {}", err);
328 None
329 }
330 }
331 }
332}