Skip to main content

wp_knowledge/mem/
thread_clone.rs

1use std::cell::RefCell;
2use std::time::Duration;
3
4use crate::DBQuery;
5use crate::mem::RowData;
6use orion_error::{ErrorOwe, ErrorWith};
7use rusqlite::ToSql;
8use rusqlite::backup::Backup;
9use rusqlite::{Connection, Params};
10use wp_error::KnowledgeResult;
11use wp_log::debug_kdb;
12use wp_model_core::model::DataField;
13
14use super::SqlNamedParam;
15
16thread_local! {
17    // clippy: use const init for thread_local value
18    static TLS_DB: RefCell<Option<ThreadLocalState>> = const { RefCell::new(None) };
19}
20
21struct ThreadLocalState {
22    authority_path: String,
23    generation: u64,
24    conn: Connection,
25}
26
27/// Thread-cloned read-only in-memory DB built from an authority file DB via SQLite backup API.
28/// Each thread lazily creates its own in-memory Connection (no cross-thread sharing).
29#[derive(Clone)]
30pub struct ThreadClonedMDB {
31    authority_path: String,
32    generation: u64,
33}
34
35impl ThreadClonedMDB {
36    pub fn from_authority(path: &str) -> Self {
37        Self {
38            authority_path: path.to_string(),
39            generation: 0,
40        }
41    }
42
43    pub fn from_authority_with_generation(path: &str, generation: u64) -> Self {
44        Self {
45            authority_path: path.to_string(),
46            generation,
47        }
48    }
49
50    pub fn with_tls_conn<T, F: FnOnce(&Connection) -> KnowledgeResult<T>>(
51        &self,
52        f: F,
53    ) -> KnowledgeResult<T> {
54        let path = self.authority_path.clone();
55        let generation = self.generation;
56        TLS_DB.with(|cell| {
57            // make sure a thread-local in-memory db exists
58            let should_rebuild = cell
59                .borrow()
60                .as_ref()
61                .map(|state| state.authority_path != path || state.generation != generation)
62                .unwrap_or(true);
63            if should_rebuild {
64                debug_kdb!(
65                    "[kdb] rebuild thread-local sqlite snapshot generation={} path={}",
66                    generation,
67                    path
68                );
69                // source: authority file; dest: in-memory
70                let src = Connection::open_with_flags(
71                    &path,
72                    rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY
73                        | rusqlite::OpenFlags::SQLITE_OPEN_URI,
74                )
75                .owe_res()
76                .want("connect db")?;
77                let mut dst = Connection::open_in_memory().owe_res().want("oepn conn")?;
78                {
79                    let bk = Backup::new(&src, &mut dst).owe_conf().want("backup")?;
80                    // Copy all pages with small sleep to yield
81                    bk.run_to_completion(50, Duration::from_millis(0), None)
82                        .owe_res()
83                        .want("backup run")?;
84                }
85                // 为查询连接注册内置 UDF(只读场景也可用在 SQL/OML 查询中)
86                let _ = crate::sqlite_ext::register_builtin(&dst);
87                *cell.borrow_mut() = Some(ThreadLocalState {
88                    authority_path: path.clone(),
89                    generation,
90                    conn: dst,
91                });
92            }
93            // safe to unwrap since ensured above
94            let conn = cell.borrow();
95            f(&conn.as_ref().unwrap().conn)
96        })
97    }
98
99    pub fn query_fields(&self, sql: &str, params: &[DataField]) -> KnowledgeResult<Vec<RowData>> {
100        self.with_tls_conn(|conn| {
101            let named_params = params
102                .iter()
103                .cloned()
104                .map(SqlNamedParam)
105                .collect::<Vec<_>>();
106            let refs: Vec<(&str, &dyn ToSql)> = named_params
107                .iter()
108                .map(|param| (param.0.get_name(), param as &dyn ToSql))
109                .collect();
110            super::query_util::query_cached(conn, sql, refs.as_slice())
111        })
112    }
113
114    pub fn query_named_fields(&self, sql: &str, params: &[DataField]) -> KnowledgeResult<RowData> {
115        self.query_fields(sql, params)
116            .map(|rows| rows.into_iter().next().unwrap_or_default())
117    }
118}
119
120impl DBQuery for ThreadClonedMDB {
121    fn query(&self, sql: &str) -> KnowledgeResult<Vec<RowData>> {
122        self.with_tls_conn(|conn| super::query_util::query(conn, sql, []))
123    }
124    fn query_row(&self, sql: &str) -> KnowledgeResult<RowData> {
125        self.with_tls_conn(|conn| super::query_util::query_first_row(conn, sql, []))
126    }
127
128    fn query_row_params<P: Params>(&self, sql: &str, params: P) -> KnowledgeResult<RowData> {
129        self.with_tls_conn(|conn| super::query_util::query_first_row(conn, sql, params))
130    }
131
132    fn query_row_tdos<P: Params>(
133        &self,
134        _sql: &str,
135        _params: &[DataField; 2],
136    ) -> KnowledgeResult<RowData> {
137        // not used in current benchmarks
138        Ok(vec![])
139    }
140}