1use std::{
2 fs,
3 path::{Path, PathBuf},
4 sync::{
5 atomic::{AtomicU64, Ordering},
6 Arc,
7 },
8};
9
10use std::{
11 convert::TryInto,
12 io::{Error as IoError, ErrorKind},
13};
14
15use r2d2::PooledConnection;
16use r2d2_sqlite::SqliteConnectionManager;
17use rusqlite::ToSql;
18
19use crate::{
20 config::Config,
21 database::{DATABASE_VERSION, EVENTS_DB_NAME},
22 error::{Error, Result},
23 events::{Event, SerializedEvent},
24 index::{Index, Writer},
25 Connection, Database,
26};
27
28use crate::EventType;
29use serde_json::Value;
30
31pub struct RecoveryDatabase {
37 path: PathBuf,
38 connection: PooledConnection<SqliteConnectionManager>,
39 pool: r2d2::Pool<SqliteConnectionManager>,
40 config: Config,
41 recovery_info: RecoveryInfo,
42 index_deleted: bool,
43 index: Option<Index>,
44 index_writer: Option<Writer>,
45}
46
47#[derive(Debug, Clone)]
48pub struct RecoveryInfo {
55 total_event_count: u64,
56 reindexed_events: Arc<AtomicU64>,
57}
58
59impl RecoveryInfo {
60 pub fn total_events(&self) -> u64 {
62 self.total_event_count
63 }
64
65 pub fn reindexed_events(&self) -> &AtomicU64 {
67 &self.reindexed_events
68 }
69}
70
71impl RecoveryDatabase {
72 pub fn new<P: AsRef<Path>>(path: P) -> Result<Self>
79 where
80 PathBuf: std::convert::From<P>,
81 {
82 Self::new_with_config(path, &Config::new())
83 }
84
85 pub fn new_with_config<P: AsRef<Path>>(path: P, config: &Config) -> Result<Self>
94 where
95 PathBuf: std::convert::From<P>,
96 {
97 let db_path = path.as_ref().join(EVENTS_DB_NAME);
98 let pool = Database::get_pool(&db_path, config)?;
99
100 let mut connection = pool.get()?;
101 Database::unlock(&connection, config)?;
102 connection.pragma_update(None, "foreign_keys", &1 as &dyn ToSql)?;
103
104 let (version, _) = match Database::get_version(&mut connection) {
105 Ok(ret) => ret,
106 Err(e) => return Err(Error::DatabaseOpenError(e.to_string())),
107 };
108
109 Database::create_tables(&connection)?;
110
111 if version != DATABASE_VERSION {
112 return Err(Error::DatabaseVersionError);
113 }
114
115 let event_count = Database::get_event_count(&connection)?;
116
117 let info = RecoveryInfo {
118 total_event_count: event_count as u64,
119 reindexed_events: Arc::new(AtomicU64::new(0)),
120 };
121
122 Ok(Self {
123 path: path.into(),
124 connection,
125 pool,
126 config: config.clone(),
127 recovery_info: info,
128 index_deleted: false,
129 index: None,
130 index_writer: None,
131 })
132 }
133
134 pub fn delete_the_index(&mut self) -> Result<()> {
138 let writer = self.index_writer.take();
139 let index = self.index.take();
140
141 drop(writer);
142 drop(index);
143
144 for entry in fs::read_dir(&self.path)? {
145 let entry = entry?;
146 let path = entry.path();
147
148 if path.is_dir() {
151 continue;
152 }
153
154 if let Some(file_name) = path.file_name() {
155 if file_name.to_string_lossy().starts_with(EVENTS_DB_NAME) {
158 continue;
159 }
160
161 fs::remove_file(path)?
162 }
163 }
164 self.index_deleted = true;
165 Ok(())
166 }
167
168 pub(crate) fn event_from_json(event_source: &str) -> std::io::Result<Event> {
169 let object: Value = serde_json::from_str(event_source)?;
170 let content = &object["content"];
171 let event_type = &object["type"];
172
173 let event_type = match event_type.as_str().unwrap_or_default() {
174 "m.room.message" => EventType::Message,
175 "m.room.name" => EventType::Name,
176 "m.room.topic" => EventType::Topic,
177 _ => return Err(IoError::new(ErrorKind::Other, "Invalid event type.")),
178 };
179
180 let (content_value, msgtype) = match event_type {
181 EventType::Message => (
182 content["body"]
183 .as_str()
184 .ok_or_else(|| IoError::new(ErrorKind::Other, "No content value found"))?,
185 Some("m.text"),
186 ),
187 EventType::Topic => (
188 content["topic"]
189 .as_str()
190 .ok_or_else(|| IoError::new(ErrorKind::Other, "No content value found"))?,
191 None,
192 ),
193 EventType::Name => (
194 content["name"]
195 .as_str()
196 .ok_or_else(|| IoError::new(ErrorKind::Other, "No content value found"))?,
197 None,
198 ),
199 };
200
201 let event_id = object["event_id"]
202 .as_str()
203 .ok_or_else(|| IoError::new(ErrorKind::Other, "No event id found"))?;
204 let sender = object["sender"]
205 .as_str()
206 .ok_or_else(|| IoError::new(ErrorKind::Other, "No sender found"))?;
207 let server_ts = object["origin_server_ts"]
208 .as_u64()
209 .ok_or_else(|| IoError::new(ErrorKind::Other, "No server timestamp found"))?;
210 let room_id = object["room_id"]
211 .as_str()
212 .ok_or_else(|| IoError::new(ErrorKind::Other, "No room id found"))?;
213
214 Ok(Event::new(
215 event_type,
216 content_value,
217 msgtype,
218 event_id,
219 sender,
220 server_ts.try_into().map_err(|_e| {
221 IoError::new(ErrorKind::Other, "Server timestamp out of valid range")
222 })?,
223 room_id,
224 event_source,
225 ))
226 }
227
228 pub fn load_events_deserialized(
235 &self,
236 limit: usize,
237 from_event: Option<&Event>,
238 ) -> Result<Vec<Event>> {
239 let serialized_events = self.load_events(limit, from_event)?;
240
241 let events = serialized_events
242 .iter()
243 .map(|e| RecoveryDatabase::event_from_json(e))
244 .filter_map(std::io::Result::ok)
245 .collect();
246
247 Ok(events)
248 }
249
250 pub fn load_events(
255 &self,
256 limit: usize,
257 from_event: Option<&Event>,
258 ) -> Result<Vec<SerializedEvent>> {
259 Ok(Database::load_all_events(
260 &self.connection,
261 limit,
262 from_event,
263 )?)
264 }
265
266 pub fn open_index(&mut self) -> Result<()> {
270 if !self.index_deleted {
271 return Err(Error::ReindexError);
272 }
273
274 let index = Index::new(&self.path, &self.config)?;
275 let writer = index.get_writer()?;
276 self.index = Some(index);
277 self.index_writer = Some(writer);
278
279 Ok(())
280 }
281
282 pub fn info(&self) -> &RecoveryInfo {
284 &self.recovery_info
285 }
286
287 pub fn get_connection(&self) -> Result<Connection> {
291 let connection = self.pool.get()?;
292 Database::unlock(&connection, &self.config)?;
293 Database::set_pragmas(&connection)?;
294
295 Ok(Connection {
296 inner: connection,
297 path: self.path.clone(),
298 })
299 }
300
301 pub fn index_events(&mut self, events: &[Event]) -> Result<()> {
310 match self.index_writer.as_mut() {
311 Some(writer) => events.iter().map(|e| writer.add_event(e)).collect(),
312 None => panic!("Index wasn't deleted"),
313 }
314
315 self.recovery_info
316 .reindexed_events
317 .fetch_add(events.len() as u64, Ordering::SeqCst);
318
319 Ok(())
320 }
321
322 pub fn commit(&mut self) -> Result<bool> {
330 match self.index_writer.as_mut() {
331 Some(writer) => {
332 let ret = writer.commit()?;
333 Ok(ret)
334 }
335 None => Err(Error::ReindexError),
336 }
337 }
338
339 pub fn commit_and_close(mut self) -> Result<()> {
344 match self.index_writer.as_mut() {
345 Some(writer) => {
346 writer.force_commit()?;
347 self.connection
348 .execute("UPDATE reindex_needed SET reindex_needed = ?1", [false])?;
349 Ok(())
350 }
351 None => Err(Error::ReindexError),
352 }
353 }
354
355 pub fn shutdown(mut self) -> Result<()> {
360 let index_writer = self.index_writer.take();
361 index_writer.map_or(Ok(()), |i| i.wait_merging_threads())?;
362
363 Ok(())
364 }
365}
366
367#[cfg(test)]
368pub(crate) mod test {
369 use crate::{
370 database::DATABASE_VERSION, Database, Error, Event, RecoveryDatabase, Result, SearchConfig,
371 };
372
373 use std::{path::PathBuf, sync::atomic::Ordering};
374
375 pub(crate) fn reindex_loop(
376 db: &mut RecoveryDatabase,
377 initial_events: Vec<Event>,
378 ) -> Result<()> {
379 let mut events = initial_events;
380
381 loop {
382 let serialized_events = db.load_events(10, events.last())?;
383 if serialized_events.is_empty() {
384 break;
385 }
386
387 events = serialized_events
388 .iter()
389 .map(|e| RecoveryDatabase::event_from_json(e))
390 .filter_map(std::io::Result::ok)
391 .collect();
392
393 db.index_events(&events)?;
394 db.commit()?;
395 }
396 Ok(())
397 }
398
399 #[test]
400 fn test_recovery() {
401 let mut path = PathBuf::from(file!());
402 path.pop();
403 path.pop();
404 path.pop();
405 path.push("data/database/v2");
406 let db = Database::new(&path);
407
408 match db {
409 Ok(_) => panic!("Database doesn't need a reindex."),
410 Err(e) => match e {
411 Error::ReindexError => (),
412 e => panic!("Database doesn't need a reindex: {}", e),
413 },
414 }
415
416 let mut recovery_db = RecoveryDatabase::new(&path).expect("Can't open recovery db");
417 assert_ne!(recovery_db.info().total_events(), 0);
418 recovery_db
419 .delete_the_index()
420 .expect("Can't delete the index");
421 recovery_db
422 .open_index()
423 .expect("Can't open the new the index");
424
425 let events = recovery_db
426 .load_events_deserialized(10, None)
427 .expect("Can't load events");
428
429 assert!(!events.is_empty());
430 assert_eq!(events.len(), 10);
431
432 recovery_db.index_events(&events).unwrap();
433 assert_eq!(
434 recovery_db
435 .info()
436 .reindexed_events()
437 .load(Ordering::Relaxed),
438 10
439 );
440
441 reindex_loop(&mut recovery_db, events).expect("Can't reindex the db");
442
443 assert_eq!(
444 recovery_db
445 .info()
446 .reindexed_events()
447 .load(Ordering::Relaxed),
448 999
449 );
450
451 recovery_db.commit_and_close().unwrap();
452
453 let db = Database::new(&path).unwrap();
454 let mut connection = db.get_connection().unwrap();
455
456 let (version, reindex_needed) = Database::get_version(&mut connection).unwrap();
457
458 assert_eq!(version, DATABASE_VERSION);
459 assert!(!reindex_needed);
460
461 let result = db.search("Hello", &SearchConfig::new()).unwrap().results;
462 assert!(!result.is_empty())
463 }
464}