seshat/database/
recovery.rs

1use std::{
2    fs,
3    path::{Path, PathBuf},
4    sync::{
5        atomic::{AtomicU64, Ordering},
6        Arc,
7    },
8};
9
10use std::{
11    convert::TryInto,
12    io::{Error as IoError, ErrorKind},
13};
14
15use r2d2::PooledConnection;
16use r2d2_sqlite::SqliteConnectionManager;
17use rusqlite::ToSql;
18
19use crate::{
20    config::Config,
21    database::{DATABASE_VERSION, EVENTS_DB_NAME},
22    error::{Error, Result},
23    events::{Event, SerializedEvent},
24    index::{Index, Writer},
25    Connection, Database,
26};
27
28use crate::EventType;
29use serde_json::Value;
30
31/// Database that can be used to reindex the events.
32///
33/// Reindexing the database may be needed if the index schema changes. This may
34/// happen occasionally on upgrades or if language settings for the database
35/// change.
36pub struct RecoveryDatabase {
37    path: PathBuf,
38    connection: PooledConnection<SqliteConnectionManager>,
39    pool: r2d2::Pool<SqliteConnectionManager>,
40    config: Config,
41    recovery_info: RecoveryInfo,
42    index_deleted: bool,
43    index: Option<Index>,
44    index_writer: Option<Writer>,
45}
46
47#[derive(Debug, Clone)]
48/// Info about the recovery process.
49///
50/// This can be used to track the progress of the reindex.
51///
52/// `RecoveryInfo` implements `Send` and `Sync` so it can be shared between
53/// threads if for example the UI is in a separate thread.
54pub struct RecoveryInfo {
55    total_event_count: u64,
56    reindexed_events: Arc<AtomicU64>,
57}
58
59impl RecoveryInfo {
60    /// The total number of events that the database holds.
61    pub fn total_events(&self) -> u64 {
62        self.total_event_count
63    }
64
65    /// The number of events that are processed and reindexed.
66    pub fn reindexed_events(&self) -> &AtomicU64 {
67        &self.reindexed_events
68    }
69}
70
71impl RecoveryDatabase {
72    /// Open a read-only Seshat database.
73    ///
74    /// # Arguments
75    ///
76    /// * `path` - The directory where the database will be stored in. This
77    ///   should be an empty directory if a new database should be created.
78    pub fn new<P: AsRef<Path>>(path: P) -> Result<Self>
79    where
80        PathBuf: std::convert::From<P>,
81    {
82        Self::new_with_config(path, &Config::new())
83    }
84
85    /// Open a recovery Seshat database with the provided config.
86    ///
87    /// # Arguments
88    ///
89    /// * `path` - The directory where the database will be stored in. This
90    ///   should be an empty directory if a new database should be created.
91    ///
92    /// * `config` - Configuration that changes the behaviour of the database.
93    pub fn new_with_config<P: AsRef<Path>>(path: P, config: &Config) -> Result<Self>
94    where
95        PathBuf: std::convert::From<P>,
96    {
97        let db_path = path.as_ref().join(EVENTS_DB_NAME);
98        let pool = Database::get_pool(&db_path, config)?;
99
100        let mut connection = pool.get()?;
101        Database::unlock(&connection, config)?;
102        connection.pragma_update(None, "foreign_keys", &1 as &dyn ToSql)?;
103
104        let (version, _) = match Database::get_version(&mut connection) {
105            Ok(ret) => ret,
106            Err(e) => return Err(Error::DatabaseOpenError(e.to_string())),
107        };
108
109        Database::create_tables(&connection)?;
110
111        if version != DATABASE_VERSION {
112            return Err(Error::DatabaseVersionError);
113        }
114
115        let event_count = Database::get_event_count(&connection)?;
116
117        let info = RecoveryInfo {
118            total_event_count: event_count as u64,
119            reindexed_events: Arc::new(AtomicU64::new(0)),
120        };
121
122        Ok(Self {
123            path: path.into(),
124            connection,
125            pool,
126            config: config.clone(),
127            recovery_info: info,
128            index_deleted: false,
129            index: None,
130            index_writer: None,
131        })
132    }
133
134    /// Delete the Seshat index, leaving only the events database.
135    ///
136    /// After this operation is done, the index can be rebuilt.
137    pub fn delete_the_index(&mut self) -> Result<()> {
138        let writer = self.index_writer.take();
139        let index = self.index.take();
140
141        drop(writer);
142        drop(index);
143
144        for entry in fs::read_dir(&self.path)? {
145            let entry = entry?;
146            let path = entry.path();
147
148            // Skip removing directories, we don't create subdirs in our
149            // database dir.
150            if path.is_dir() {
151                continue;
152            }
153
154            if let Some(file_name) = path.file_name() {
155                // Skip removing the events database, those will be needed for
156                // reindexing.
157                if file_name.to_string_lossy().starts_with(EVENTS_DB_NAME) {
158                    continue;
159                }
160
161                fs::remove_file(path)?
162            }
163        }
164        self.index_deleted = true;
165        Ok(())
166    }
167
168    pub(crate) fn event_from_json(event_source: &str) -> std::io::Result<Event> {
169        let object: Value = serde_json::from_str(event_source)?;
170        let content = &object["content"];
171        let event_type = &object["type"];
172
173        let event_type = match event_type.as_str().unwrap_or_default() {
174            "m.room.message" => EventType::Message,
175            "m.room.name" => EventType::Name,
176            "m.room.topic" => EventType::Topic,
177            _ => return Err(IoError::new(ErrorKind::Other, "Invalid event type.")),
178        };
179
180        let (content_value, msgtype) = match event_type {
181            EventType::Message => (
182                content["body"]
183                    .as_str()
184                    .ok_or_else(|| IoError::new(ErrorKind::Other, "No content value found"))?,
185                Some("m.text"),
186            ),
187            EventType::Topic => (
188                content["topic"]
189                    .as_str()
190                    .ok_or_else(|| IoError::new(ErrorKind::Other, "No content value found"))?,
191                None,
192            ),
193            EventType::Name => (
194                content["name"]
195                    .as_str()
196                    .ok_or_else(|| IoError::new(ErrorKind::Other, "No content value found"))?,
197                None,
198            ),
199        };
200
201        let event_id = object["event_id"]
202            .as_str()
203            .ok_or_else(|| IoError::new(ErrorKind::Other, "No event id found"))?;
204        let sender = object["sender"]
205            .as_str()
206            .ok_or_else(|| IoError::new(ErrorKind::Other, "No sender found"))?;
207        let server_ts = object["origin_server_ts"]
208            .as_u64()
209            .ok_or_else(|| IoError::new(ErrorKind::Other, "No server timestamp found"))?;
210        let room_id = object["room_id"]
211            .as_str()
212            .ok_or_else(|| IoError::new(ErrorKind::Other, "No room id found"))?;
213
214        Ok(Event::new(
215            event_type,
216            content_value,
217            msgtype,
218            event_id,
219            sender,
220            server_ts.try_into().map_err(|_e| {
221                IoError::new(ErrorKind::Other, "Server timestamp out of valid range")
222            })?,
223            room_id,
224            event_source,
225        ))
226    }
227
228    /// Load deserialized events from the database.
229    ///
230    /// * `limit` - The number of events to load.
231    /// * `from_event` - The event where to continue loading from.
232    ///
233    /// Events that fail to be deserialized will be filtered out.
234    pub fn load_events_deserialized(
235        &self,
236        limit: usize,
237        from_event: Option<&Event>,
238    ) -> Result<Vec<Event>> {
239        let serialized_events = self.load_events(limit, from_event)?;
240
241        let events = serialized_events
242            .iter()
243            .map(|e| RecoveryDatabase::event_from_json(e))
244            .filter_map(std::io::Result::ok)
245            .collect();
246
247        Ok(events)
248    }
249
250    /// Load serialized events from the database.
251    ///
252    /// * `limit` - The number of events to load.
253    /// * `from_event` - The event where to continue loading from.
254    pub fn load_events(
255        &self,
256        limit: usize,
257        from_event: Option<&Event>,
258    ) -> Result<Vec<SerializedEvent>> {
259        Ok(Database::load_all_events(
260            &self.connection,
261            limit,
262            from_event,
263        )?)
264    }
265
266    /// Create and open a new index.
267    ///
268    /// Returns `ReindexError` if the index wasn't deleted first.
269    pub fn open_index(&mut self) -> Result<()> {
270        if !self.index_deleted {
271            return Err(Error::ReindexError);
272        }
273
274        let index = Index::new(&self.path, &self.config)?;
275        let writer = index.get_writer()?;
276        self.index = Some(index);
277        self.index_writer = Some(writer);
278
279        Ok(())
280    }
281
282    /// Get the recovery info for the database.
283    pub fn info(&self) -> &RecoveryInfo {
284        &self.recovery_info
285    }
286
287    /// Get a database connection.
288    ///
289    /// Note that this connection should only be used for reading.
290    pub fn get_connection(&self) -> Result<Connection> {
291        let connection = self.pool.get()?;
292        Database::unlock(&connection, &self.config)?;
293        Database::set_pragmas(&connection)?;
294
295        Ok(Connection {
296            inner: connection,
297            path: self.path.clone(),
298        })
299    }
300
301    /// Re-index a batch of events.
302    ///
303    /// # Arguments
304    ///
305    /// * `events` - The events that should be reindexed.
306    ///
307    /// Returns `ReindexError` if the index wasn't previously deleted and
308    /// opened.
309    pub fn index_events(&mut self, events: &[Event]) -> Result<()> {
310        match self.index_writer.as_mut() {
311            Some(writer) => events.iter().map(|e| writer.add_event(e)).collect(),
312            None => panic!("Index wasn't deleted"),
313        }
314
315        self.recovery_info
316            .reindexed_events
317            .fetch_add(events.len() as u64, Ordering::SeqCst);
318
319        Ok(())
320    }
321
322    /// Commit to the index.
323    ///
324    /// Returns true if the commit was forwarded, false if not enough events are
325    /// queued up.
326    ///
327    /// Returns `ReindexError` if the index wasn't previously deleted and
328    /// opened.
329    pub fn commit(&mut self) -> Result<bool> {
330        match self.index_writer.as_mut() {
331            Some(writer) => {
332                let ret = writer.commit()?;
333                Ok(ret)
334            }
335            None => Err(Error::ReindexError),
336        }
337    }
338
339    /// Commit the remaining added events and mark the reindex as done.
340    ///
341    /// Returns `ReindexError` if the index wasn't previously deleted and
342    /// opened.
343    pub fn commit_and_close(mut self) -> Result<()> {
344        match self.index_writer.as_mut() {
345            Some(writer) => {
346                writer.force_commit()?;
347                self.connection
348                    .execute("UPDATE reindex_needed SET reindex_needed = ?1", [false])?;
349                Ok(())
350            }
351            None => Err(Error::ReindexError),
352        }
353    }
354
355    /// Shut the database down.
356    ///
357    /// This will terminate the writer thread making sure that no writes will
358    /// happen after this operation.
359    pub fn shutdown(mut self) -> Result<()> {
360        let index_writer = self.index_writer.take();
361        index_writer.map_or(Ok(()), |i| i.wait_merging_threads())?;
362
363        Ok(())
364    }
365}
366
367#[cfg(test)]
368pub(crate) mod test {
369    use crate::{
370        database::DATABASE_VERSION, Database, Error, Event, RecoveryDatabase, Result, SearchConfig,
371    };
372
373    use std::{path::PathBuf, sync::atomic::Ordering};
374
375    pub(crate) fn reindex_loop(
376        db: &mut RecoveryDatabase,
377        initial_events: Vec<Event>,
378    ) -> Result<()> {
379        let mut events = initial_events;
380
381        loop {
382            let serialized_events = db.load_events(10, events.last())?;
383            if serialized_events.is_empty() {
384                break;
385            }
386
387            events = serialized_events
388                .iter()
389                .map(|e| RecoveryDatabase::event_from_json(e))
390                .filter_map(std::io::Result::ok)
391                .collect();
392
393            db.index_events(&events)?;
394            db.commit()?;
395        }
396        Ok(())
397    }
398
399    #[test]
400    fn test_recovery() {
401        let mut path = PathBuf::from(file!());
402        path.pop();
403        path.pop();
404        path.pop();
405        path.push("data/database/v2");
406        let db = Database::new(&path);
407
408        match db {
409            Ok(_) => panic!("Database doesn't need a reindex."),
410            Err(e) => match e {
411                Error::ReindexError => (),
412                e => panic!("Database doesn't need a reindex: {}", e),
413            },
414        }
415
416        let mut recovery_db = RecoveryDatabase::new(&path).expect("Can't open recovery db");
417        assert_ne!(recovery_db.info().total_events(), 0);
418        recovery_db
419            .delete_the_index()
420            .expect("Can't delete the index");
421        recovery_db
422            .open_index()
423            .expect("Can't open the new the index");
424
425        let events = recovery_db
426            .load_events_deserialized(10, None)
427            .expect("Can't load events");
428
429        assert!(!events.is_empty());
430        assert_eq!(events.len(), 10);
431
432        recovery_db.index_events(&events).unwrap();
433        assert_eq!(
434            recovery_db
435                .info()
436                .reindexed_events()
437                .load(Ordering::Relaxed),
438            10
439        );
440
441        reindex_loop(&mut recovery_db, events).expect("Can't reindex the db");
442
443        assert_eq!(
444            recovery_db
445                .info()
446                .reindexed_events()
447                .load(Ordering::Relaxed),
448            999
449        );
450
451        recovery_db.commit_and_close().unwrap();
452
453        let db = Database::new(&path).unwrap();
454        let mut connection = db.get_connection().unwrap();
455
456        let (version, reindex_needed) = Database::get_version(&mut connection).unwrap();
457
458        assert_eq!(version, DATABASE_VERSION);
459        assert!(!reindex_needed);
460
461        let result = db.search("Hello", &SearchConfig::new()).unwrap().results;
462        assert!(!result.is_empty())
463    }
464}