Skip to main content

wp_knowledge/mem/
thread_clone.rs

1use std::cell::RefCell;
2use std::time::Duration;
3
4use crate::DBQuery;
5use crate::mem::RowData;
6use orion_error::{ErrorOwe, ErrorWith};
7use rusqlite::backup::Backup;
8use rusqlite::{Connection, Params};
9use wp_error::KnowledgeResult;
10use wp_model_core::model::DataField;
11
12thread_local! {
13    // clippy: use const init for thread_local value
14    static TLS_DB: RefCell<Option<Connection>> = const { RefCell::new(None) };
15}
16
17/// Thread-cloned read-only in-memory DB built from an authority file DB via SQLite backup API.
18/// Each thread lazily creates its own in-memory Connection (no cross-thread sharing).
19#[derive(Clone)]
20pub struct ThreadClonedMDB {
21    authority_path: String,
22}
23
24impl ThreadClonedMDB {
25    pub fn from_authority(path: &str) -> Self {
26        Self {
27            authority_path: path.to_string(),
28        }
29    }
30
31    pub fn with_tls_conn<T, F: FnOnce(&Connection) -> KnowledgeResult<T>>(
32        &self,
33        f: F,
34    ) -> KnowledgeResult<T> {
35        let path = self.authority_path.clone();
36        TLS_DB.with(|cell| {
37            // make sure a thread-local in-memory db exists
38            if cell.borrow().is_none() {
39                // source: authority file; dest: in-memory
40                let src = Connection::open_with_flags(
41                    &path,
42                    rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY
43                        | rusqlite::OpenFlags::SQLITE_OPEN_URI,
44                )
45                .owe_res()
46                .want("connect db")?;
47                let mut dst = Connection::open_in_memory().owe_res().want("oepn conn")?;
48                {
49                    let bk = Backup::new(&src, &mut dst).owe_conf().want("backup")?;
50                    // Copy all pages with small sleep to yield
51                    bk.run_to_completion(50, Duration::from_millis(0), None)
52                        .owe_res()
53                        .want("backup run")?;
54                }
55                // 为查询连接注册内置 UDF(只读场景也可用在 SQL/OML 查询中)
56                let _ = crate::sqlite_ext::register_builtin(&dst);
57                *cell.borrow_mut() = Some(dst);
58            }
59            // safe to unwrap since ensured above
60            let conn = cell.borrow();
61            f(conn.as_ref().unwrap())
62        })
63    }
64}
65
66impl DBQuery for ThreadClonedMDB {
67    fn query(&self, sql: &str) -> KnowledgeResult<Vec<RowData>> {
68        self.with_tls_conn(|conn| super::query_util::query(conn, sql, []))
69    }
70    fn query_row(&self, sql: &str) -> KnowledgeResult<RowData> {
71        self.with_tls_conn(|conn| super::query_util::query_first_row(conn, sql, []))
72    }
73
74    fn query_row_params<P: Params>(&self, sql: &str, params: P) -> KnowledgeResult<RowData> {
75        self.with_tls_conn(|conn| super::query_util::query_first_row(conn, sql, params))
76    }
77
78    fn query_row_tdos<P: Params>(
79        &self,
80        _sql: &str,
81        _params: &[DataField; 2],
82    ) -> KnowledgeResult<RowData> {
83        // not used in current benchmarks
84        Ok(vec![])
85    }
86
87    fn query_cipher(&self, table: &str) -> KnowledgeResult<Vec<String>> {
88        self.with_tls_conn(|conn| {
89            let sql = format!("select value from {}", table);
90            let mut stmt = conn.prepare(&sql).owe_rule()?;
91            let mut rows = stmt.query([]).owe_rule()?;
92            let mut result = Vec::new();
93            while let Some(row) = rows.next().owe_rule()? {
94                let x = row.get_ref(0).owe_rule()?;
95                if let rusqlite::types::ValueRef::Text(val) = x {
96                    result.push(String::from_utf8(val.to_vec()).owe_rule()?);
97                }
98            }
99            Ok(result)
100        })
101    }
102}