Skip to main content

wp_knowledge/
loader.rs

1use std::fs;
2use std::io::Read;
3use std::path::{Path, PathBuf};
4
5use orion_conf::EnvTomlLoad;
6use serde::Deserialize;
7use wp_log::info_ctrl;
8
9use crate::mem::memdb::MemDB;
10use orion_error::{ContextRecord, ErrorOwe, OperationContext, ToStructError, UvsFrom};
11use orion_variate::EnvDict;
12use rusqlite::OpenFlags;
13use wp_error::{KnowledgeReason, KnowledgeResult};
14
15/// V2 KnowDB 配置:目录式 + 外置 SQL。仅支持单一数据文件:`<table_dir>/data.csv`,
16/// 或通过 `tables[n].data_file` 相对 `<table_dir>` 指定。
17#[derive(Debug, Deserialize)]
18pub struct KnowDbConf {
19    pub version: u32,
20    #[serde(default = "default_dot")]
21    pub base_dir: String,
22    #[serde(default)]
23    pub default: OptLoadSpec,
24    #[serde(default)]
25    pub csv: CsvSpec,
26    pub tables: Vec<TableSpec>,
27}
28
29#[derive(Debug, Clone, Deserialize)]
30pub struct OptLoadSpec {
31    #[serde(default = "default_true")]
32    pub transaction: bool,
33    #[serde(default = "default_batch")]
34    pub batch_size: usize,
35    #[serde(default = "default_on_error")]
36    pub on_error: OnError,
37}
38impl Default for OptLoadSpec {
39    fn default() -> Self {
40        Self {
41            transaction: true,
42            batch_size: default_batch(),
43            on_error: default_on_error(),
44        }
45    }
46}
47
48#[derive(Debug, Clone, Deserialize, Default)]
49#[serde(rename_all = "lowercase")]
50pub enum OnError {
51    #[default]
52    Fail,
53    Skip,
54}
55
56#[derive(Debug, Clone, Deserialize)]
57pub struct CsvSpec {
58    #[serde(default = "default_true")]
59    pub has_header: bool,
60    #[serde(default = "default_comma")]
61    pub delimiter: String,
62    #[serde(default = "default_utf8")]
63    pub encoding: String,
64    #[serde(default = "default_true")]
65    pub trim: bool,
66}
67impl Default for CsvSpec {
68    fn default() -> Self {
69        CsvSpec {
70            has_header: true,
71            delimiter: ",".into(),
72            encoding: "utf-8".into(),
73            trim: true,
74        }
75    }
76}
77
78#[derive(Debug, Clone, Deserialize)]
79pub struct TableSpec {
80    pub name: String,
81    #[serde(default)]
82    pub dir: Option<String>,
83    #[serde(default)]
84    pub data_file: Option<String>,
85    pub columns: ColumnsSpec,
86    #[serde(default)]
87    pub expected_rows: RowExpect,
88    #[serde(default = "default_true")]
89    pub enabled: bool,
90}
91
92#[derive(Debug, Clone, Deserialize)]
93pub struct ColumnsSpec {
94    #[serde(default)]
95    pub by_header: Vec<String>,
96    #[serde(default)]
97    pub by_index: Vec<usize>,
98}
99
100#[derive(Debug, Clone, Deserialize, Default)]
101pub struct RowExpect {
102    pub min: Option<usize>,
103    pub max: Option<usize>,
104}
105
106const fn default_true() -> bool {
107    true
108}
109const fn default_batch() -> usize {
110    2000
111}
112fn default_comma() -> String {
113    ",".to_string()
114}
115fn default_utf8() -> String {
116    "utf-8".to_string()
117}
118fn default_on_error() -> OnError {
119    OnError::Fail
120}
121fn default_dot() -> String {
122    ".".to_string()
123}
124
125/// 读取文本文件,返回字符串
126fn read_to_string(path: &Path) -> KnowledgeResult<String> {
127    let mut f = fs::File::open(path).owe_res()?;
128    let mut buf = String::new();
129    f.read_to_string(&mut buf).owe_res()?;
130    Ok(buf)
131}
132
133fn replace_table(sql: &str, table: &str) -> String {
134    sql.replace("{table}", table)
135}
136
137fn join_rel(base: &Path, rel: &str) -> PathBuf {
138    let p = Path::new(rel);
139    if p.is_absolute() {
140        p.to_path_buf()
141    } else {
142        base.join(p)
143    }
144}
145
146pub fn build_authority_from_knowdb(
147    root: &Path,
148    conf_path: &Path,
149    authority_uri: &str,
150    dict: &EnvDict,
151) -> KnowledgeResult<Vec<String>> {
152    let mut opx = OperationContext::want("build authority from knowdb").with_auto_log();
153    // 1) 解析配置与 base_dir
154    let (conf, conf_abs, base_dir) = parse_knowdb_conf(root, conf_path, dict)?;
155    opx.record("conf", &conf_abs);
156    opx.record("base_dir", &base_dir);
157    // 2) 打开权威库
158    let db = open_authority(authority_uri)?;
159    // 3) 逐表加载(按配置顺序);不再处理显式依赖
160    let mut loaded_names = Vec::new();
161    for t in &conf.tables {
162        if !t.enabled {
163            continue;
164        }
165        load_one_table(&db, &base_dir, t, &conf.csv, &conf.default)?;
166        info_ctrl!("load table {} suc!", base_dir.display(),);
167        loaded_names.push(t.name.clone());
168    }
169    opx.mark_suc();
170    Ok(loaded_names)
171}
172
173fn parse_knowdb_conf(
174    root: &Path,
175    conf_path: &Path,
176    dict: &EnvDict,
177) -> KnowledgeResult<(KnowDbConf, PathBuf, PathBuf)> {
178    let conf_abs = if conf_path.is_absolute() {
179        conf_path.to_path_buf()
180    } else {
181        root.join(conf_path)
182    };
183    let conf_txt = read_to_string(&conf_abs)?;
184    let conf: KnowDbConf =
185        <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(&conf_txt, dict).owe_conf()?;
186    if conf.version != 2 {
187        return Err(KnowledgeReason::from_conf()
188            .to_err()
189            .with_detail("unsupported knowdb.version"));
190    }
191    let conf_dir = conf_abs.parent().unwrap_or_else(|| Path::new("."));
192    let base_dir = join_rel(conf_dir, &conf.base_dir);
193    Ok((conf, conf_abs, base_dir))
194}
195
196fn open_authority(authority_uri: &str) -> KnowledgeResult<MemDB> {
197    ensure_parent_dir_for_file_uri(authority_uri);
198    let flags = OpenFlags::SQLITE_OPEN_READ_WRITE
199        | OpenFlags::SQLITE_OPEN_CREATE
200        | OpenFlags::SQLITE_OPEN_URI;
201    let db = MemDB::new_file(authority_uri, 1, flags)?;
202    // 预注册内置 UDF 至权威库连接(注意:连接池可能返回不同连接,导入时也会再次注册)
203    let _ = db.with_conn(|conn| {
204        let _ = crate::sqlite_ext::register_builtin(conn);
205        Ok::<(), anyhow::Error>(())
206    });
207    Ok(db)
208}
209
210/// Kahn 拓扑排序:返回按依赖顺序的表索引列表。
211/// no topo_sort_tables: V2 简化版按配置顺序加载
212fn ensure_parent_dir_for_file_uri(uri: &str) {
213    if let Some(rest) = uri.strip_prefix("file:") {
214        let path_part = rest.split('?').next().unwrap_or(rest);
215        let p = Path::new(path_part);
216        if let Some(parent) = p.parent() {
217            let _ = fs::create_dir_all(parent);
218        }
219    }
220}
221
222fn load_one_table(
223    db: &MemDB,
224    base_dir: &Path,
225    t: &TableSpec,
226    csvd: &CsvSpec,
227    load: &OptLoadSpec,
228) -> KnowledgeResult<()> {
229    // 目录与必须文件
230    let mut opx = OperationContext::want("load table to kdb")
231        .with_auto_log()
232        .with_mod_path("ctrl");
233    let dir_name: &str = t.dir.as_deref().unwrap_or(&t.name);
234    let table_dir = base_dir.join(dir_name);
235    opx.record("table_dir", &table_dir);
236    let create_sql = replace_table(&read_to_string(&table_dir.join("create.sql"))?, &t.name);
237    let insert_sql = replace_table(&read_to_string(&table_dir.join("insert.sql"))?, &t.name);
238    let clean_path = table_dir.join("clean.sql");
239    let clean_sql = if clean_path.exists() {
240        replace_table(&read_to_string(&clean_path)?, &t.name)
241    } else {
242        format!("DELETE FROM {}", &t.name)
243    };
244
245    // 建表与清理
246    db.with_conn(|conn| {
247        // 注册内置 UDF(导入连接)
248        let _ = crate::sqlite_ext::register_builtin(conn);
249        conn.execute_batch(&create_sql)?;
250        conn.execute_batch(&clean_sql)?;
251        Ok::<(), anyhow::Error>(())
252    })
253    .owe_res()?;
254
255    // 数据源
256    let data_path = match &t.data_file {
257        Some(rel) => join_rel(&table_dir, rel),
258        None => table_dir.join("data.csv"),
259    };
260    if !data_path.exists() {
261        return Err(KnowledgeReason::from_conf()
262            .to_err()
263            .with_detail("data.csv not found"));
264    }
265    opx.record("data_path", &data_path);
266
267    // CSV 解析器
268    let mut rdr = build_csv_reader(csvd, &data_path)?;
269
270    // 列映射
271    let col_indices: Vec<usize> = if !t.columns.by_header.is_empty() {
272        let headers = rdr.headers().owe_res()?;
273        select_indices_by_header(headers, &t.columns.by_header)?
274    } else if !t.columns.by_index.is_empty() {
275        t.columns.by_index.clone()
276    } else {
277        return Err(KnowledgeReason::from_conf()
278            .to_err()
279            .with_detail("columns mapping required"));
280    };
281
282    // 导入(分批事务)
283    let mut inserted: usize = 0;
284    let mut bad: usize = 0;
285    let mut batch_left = load.batch_size.max(1);
286    db.with_conn(|conn| {
287        // 注册内置 UDF(用于 INSERT 绑定表达式)
288        let _ = crate::sqlite_ext::register_builtin(conn);
289        let mut tx = if load.transaction {
290            Some(conn.unchecked_transaction()?)
291        } else {
292            None
293        };
294        let mut stmt = conn.prepare(&insert_sql)?;
295        for rec in rdr.into_records() {
296            match rec {
297                Ok(record) => {
298                    let refs = extract_row_refs(&record, &col_indices, &mut bad, load)?;
299                    if let Some(refs) = refs {
300                        stmt.execute(rusqlite::params_from_iter(refs))?;
301                        inserted += 1;
302                        if load.transaction {
303                            batch_left -= 1;
304                            if batch_left == 0 {
305                                tx.take().unwrap().commit()?;
306                                tx = Some(conn.unchecked_transaction()?);
307                                batch_left = load.batch_size.max(1);
308                            }
309                        }
310                    }
311                }
312                Err(_e) => {
313                    if matches!(load.on_error, OnError::Skip) {
314                        bad += 1;
315                        continue;
316                    } else {
317                        anyhow::bail!("csv record parse error");
318                    }
319                }
320            }
321        }
322        if let Some(tx) = tx {
323            tx.commit()?;
324        }
325        Ok::<(), anyhow::Error>(())
326    })
327    .owe_res()?;
328
329    // 行数校验
330    if let Some(min) = t.expected_rows.min
331        && inserted < min
332    {
333        return Err(KnowledgeReason::from_conf()
334            .to_err()
335            .with_detail("table data less"));
336    }
337    if let Some(max) = t.expected_rows.max
338        && inserted > max
339    {
340        wp_log::warn_kdb!(
341            "table {} loaded rows {} exceed max {}",
342            &t.name,
343            inserted,
344            max
345        );
346    }
347    if bad > 0 {
348        wp_log::warn_kdb!("table {} skipped {} bad rows (on_error=skip)", &t.name, bad);
349    }
350    opx.mark_suc();
351    Ok(())
352}
353
354fn build_csv_reader(
355    csvd: &CsvSpec,
356    data_path: &Path,
357) -> KnowledgeResult<csv::Reader<std::fs::File>> {
358    if csvd.encoding.to_lowercase() != "utf-8" {
359        return Err(KnowledgeReason::from_conf()
360            .to_err()
361            .with_detail("only utf-8 csv is supported"));
362    }
363    let mut rdr_b = csv::ReaderBuilder::new();
364    rdr_b.has_headers(csvd.has_header);
365    if csvd.delimiter.len() == 1 {
366        rdr_b.delimiter(csvd.delimiter.as_bytes()[0]);
367    }
368    if csvd.trim {
369        rdr_b.trim(csv::Trim::All);
370    }
371    rdr_b.from_path(data_path).owe_res()
372}
373
374fn select_indices_by_header(
375    headers: &csv::StringRecord,
376    wanted: &[String],
377) -> KnowledgeResult<Vec<usize>> {
378    let mut out = Vec::with_capacity(wanted.len());
379    for name in wanted {
380        let pos = headers.iter().position(|h| h == name).ok_or_else(|| {
381            KnowledgeReason::from_conf()
382                .to_err()
383                .with_detail("header not found")
384        })?;
385        out.push(pos);
386    }
387    Ok(out)
388}
389
390fn extract_row_refs<'a>(
391    record: &'a csv::StringRecord,
392    col_indices: &[usize],
393    bad: &mut usize,
394    load: &OptLoadSpec,
395) -> anyhow::Result<Option<Vec<&'a str>>> {
396    let mut vs: Vec<&str> = Vec::with_capacity(col_indices.len());
397    for &idx in col_indices {
398        if idx >= record.len() {
399            if matches!(load.on_error, OnError::Skip) {
400                *bad += 1;
401                return Ok(None);
402            } else {
403                // 将错误在调用方 bail(构建 anyhow)
404                anyhow::bail!("missing column at index {}", idx);
405            }
406        }
407        vs.push(record.get(idx).unwrap_or(""));
408    }
409    Ok(Some(vs))
410}