Skip to main content

wp_knowledge/
loader.rs

1use std::fs;
2use std::io::Read;
3use std::path::{Path, PathBuf};
4
5use orion_conf::EnvTomlLoad;
6use serde::Deserialize;
7use wp_log::info_ctrl;
8
9use crate::mem::memdb::MemDB;
10use orion_error::{ContextRecord, ErrorOwe, OperationContext, ToStructError, UvsFrom};
11use orion_variate::EnvDict;
12use rusqlite::OpenFlags;
13use wp_error::{KnowledgeReason, KnowledgeResult};
14
15/// V2 KnowDB 配置:目录式 + 外置 SQL。仅支持单一数据文件:`<table_dir>/data.csv`,
16/// 或通过 `tables[n].data_file` 相对 `<table_dir>` 指定。
17#[derive(Debug, Deserialize)]
18pub struct KnowDbConf {
19    pub version: u32,
20    #[serde(default = "default_dot")]
21    pub base_dir: String,
22    #[serde(default)]
23    pub default: OptLoadSpec,
24    #[serde(default)]
25    pub csv: CsvSpec,
26    #[serde(default)]
27    pub cache: CacheSpec,
28    #[serde(default)]
29    pub provider: Option<ProviderSpec>,
30    #[serde(default)]
31    pub tables: Vec<TableSpec>,
32}
33
34#[derive(Debug, Clone, Deserialize)]
35pub struct CacheSpec {
36    #[serde(default = "default_true")]
37    pub enabled: bool,
38    #[serde(default = "default_result_cache_capacity")]
39    pub capacity: usize,
40    #[serde(default = "default_result_cache_ttl_ms")]
41    pub ttl_ms: u64,
42}
43
44impl Default for CacheSpec {
45    fn default() -> Self {
46        Self {
47            enabled: default_true(),
48            capacity: default_result_cache_capacity(),
49            ttl_ms: default_result_cache_ttl_ms(),
50        }
51    }
52}
53
54#[derive(Debug, Clone, Deserialize)]
55#[serde(rename_all = "snake_case")]
56pub enum ProviderKind {
57    SqliteAuthority,
58    Postgres,
59    Mysql,
60}
61
62#[derive(Debug, Clone, Deserialize)]
63pub struct ProviderSpec {
64    pub kind: ProviderKind,
65    pub connection_uri: String,
66    #[serde(default)]
67    pub pool_size: Option<u32>,
68}
69
70#[derive(Debug, Clone, Deserialize)]
71pub struct OptLoadSpec {
72    #[serde(default = "default_true")]
73    pub transaction: bool,
74    #[serde(default = "default_batch")]
75    pub batch_size: usize,
76    #[serde(default = "default_on_error")]
77    pub on_error: OnError,
78}
79impl Default for OptLoadSpec {
80    fn default() -> Self {
81        Self {
82            transaction: true,
83            batch_size: default_batch(),
84            on_error: default_on_error(),
85        }
86    }
87}
88
89#[derive(Debug, Clone, Deserialize, Default)]
90#[serde(rename_all = "lowercase")]
91pub enum OnError {
92    #[default]
93    Fail,
94    Skip,
95}
96
97#[derive(Debug, Clone, Deserialize)]
98pub struct CsvSpec {
99    #[serde(default = "default_true")]
100    pub has_header: bool,
101    #[serde(default = "default_comma")]
102    pub delimiter: String,
103    #[serde(default = "default_utf8")]
104    pub encoding: String,
105    #[serde(default = "default_true")]
106    pub trim: bool,
107}
108impl Default for CsvSpec {
109    fn default() -> Self {
110        CsvSpec {
111            has_header: true,
112            delimiter: ",".into(),
113            encoding: "utf-8".into(),
114            trim: true,
115        }
116    }
117}
118
119#[derive(Debug, Clone, Deserialize)]
120pub struct TableSpec {
121    pub name: String,
122    #[serde(default)]
123    pub dir: Option<String>,
124    #[serde(default)]
125    pub data_file: Option<String>,
126    pub columns: ColumnsSpec,
127    #[serde(default)]
128    pub expected_rows: RowExpect,
129    #[serde(default = "default_true")]
130    pub enabled: bool,
131}
132
133#[derive(Debug, Clone, Deserialize)]
134pub struct ColumnsSpec {
135    #[serde(default)]
136    pub by_header: Vec<String>,
137    #[serde(default)]
138    pub by_index: Vec<usize>,
139}
140
141#[derive(Debug, Clone, Deserialize, Default)]
142pub struct RowExpect {
143    pub min: Option<usize>,
144    pub max: Option<usize>,
145}
146
147const fn default_true() -> bool {
148    true
149}
150const fn default_batch() -> usize {
151    2000
152}
153fn default_comma() -> String {
154    ",".to_string()
155}
156fn default_utf8() -> String {
157    "utf-8".to_string()
158}
159fn default_on_error() -> OnError {
160    OnError::Fail
161}
162fn default_dot() -> String {
163    ".".to_string()
164}
165const fn default_result_cache_capacity() -> usize {
166    1024
167}
168const fn default_result_cache_ttl_ms() -> u64 {
169    30_000
170}
171
172/// 读取文本文件,返回字符串
173fn read_to_string(path: &Path) -> KnowledgeResult<String> {
174    let mut f = fs::File::open(path).owe_res()?;
175    let mut buf = String::new();
176    f.read_to_string(&mut buf).owe_res()?;
177    Ok(buf)
178}
179
180fn replace_table(sql: &str, table: &str) -> String {
181    sql.replace("{table}", table)
182}
183
184fn join_rel(base: &Path, rel: &str) -> PathBuf {
185    let p = Path::new(rel);
186    if p.is_absolute() {
187        p.to_path_buf()
188    } else {
189        base.join(p)
190    }
191}
192
193pub fn build_authority_from_knowdb(
194    root: &Path,
195    conf_path: &Path,
196    authority_uri: &str,
197    dict: &EnvDict,
198) -> KnowledgeResult<Vec<String>> {
199    let mut opx = OperationContext::want("build authority from knowdb").with_auto_log();
200    // 1) 解析配置与 base_dir
201    let (conf, conf_abs, base_dir) = parse_knowdb_conf(root, conf_path, dict)?;
202    opx.record("conf", &conf_abs);
203    opx.record("base_dir", &base_dir);
204    // 2) 打开权威库
205    let db = open_authority(authority_uri)?;
206    // 3) 逐表加载(按配置顺序);不再处理显式依赖
207    let mut loaded_names = Vec::new();
208    for t in &conf.tables {
209        if !t.enabled {
210            continue;
211        }
212        load_one_table(&db, &base_dir, t, &conf.csv, &conf.default)?;
213        info_ctrl!("load table {} suc!", base_dir.display(),);
214        loaded_names.push(t.name.clone());
215    }
216    opx.mark_suc();
217    Ok(loaded_names)
218}
219
220pub fn parse_knowdb_conf(
221    root: &Path,
222    conf_path: &Path,
223    dict: &EnvDict,
224) -> KnowledgeResult<(KnowDbConf, PathBuf, PathBuf)> {
225    let conf_abs = if conf_path.is_absolute() {
226        conf_path.to_path_buf()
227    } else {
228        root.join(conf_path)
229    };
230    let conf_txt = read_to_string(&conf_abs)?;
231    let conf: KnowDbConf =
232        <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(&conf_txt, dict).owe_conf()?;
233    if conf.version != 2 {
234        return Err(KnowledgeReason::from_conf()
235            .to_err()
236            .with_detail("unsupported knowdb.version"));
237    }
238    let conf_dir = conf_abs.parent().unwrap_or_else(|| Path::new("."));
239    let base_dir = join_rel(conf_dir, &conf.base_dir);
240    Ok((conf, conf_abs, base_dir))
241}
242
243fn open_authority(authority_uri: &str) -> KnowledgeResult<MemDB> {
244    ensure_parent_dir_for_file_uri(authority_uri);
245    let flags = OpenFlags::SQLITE_OPEN_READ_WRITE
246        | OpenFlags::SQLITE_OPEN_CREATE
247        | OpenFlags::SQLITE_OPEN_URI;
248    let db = MemDB::new_file(authority_uri, 1, flags)?;
249    // 预注册内置 UDF 至权威库连接(注意:连接池可能返回不同连接,导入时也会再次注册)
250    let _ = db.with_conn(|conn| {
251        let _ = crate::sqlite_ext::register_builtin(conn);
252        Ok::<(), anyhow::Error>(())
253    });
254    Ok(db)
255}
256
257/// Kahn 拓扑排序:返回按依赖顺序的表索引列表。
258/// no topo_sort_tables: V2 简化版按配置顺序加载
259fn ensure_parent_dir_for_file_uri(uri: &str) {
260    if let Some(rest) = uri.strip_prefix("file:") {
261        let path_part = rest.split('?').next().unwrap_or(rest);
262        let p = Path::new(path_part);
263        if let Some(parent) = p.parent() {
264            let _ = fs::create_dir_all(parent);
265        }
266    }
267}
268
269fn load_one_table(
270    db: &MemDB,
271    base_dir: &Path,
272    t: &TableSpec,
273    csvd: &CsvSpec,
274    load: &OptLoadSpec,
275) -> KnowledgeResult<()> {
276    // 目录与必须文件
277    let mut opx = OperationContext::want("load table to kdb")
278        .with_auto_log()
279        .with_mod_path("ctrl");
280    let dir_name: &str = t.dir.as_deref().unwrap_or(&t.name);
281    let table_dir = base_dir.join(dir_name);
282    opx.record("table_dir", &table_dir);
283    let create_sql = replace_table(&read_to_string(&table_dir.join("create.sql"))?, &t.name);
284    let insert_sql = replace_table(&read_to_string(&table_dir.join("insert.sql"))?, &t.name);
285    let clean_path = table_dir.join("clean.sql");
286    let clean_sql = if clean_path.exists() {
287        replace_table(&read_to_string(&clean_path)?, &t.name)
288    } else {
289        format!("DELETE FROM {}", &t.name)
290    };
291
292    // 建表与清理
293    db.with_conn(|conn| {
294        // 注册内置 UDF(导入连接)
295        let _ = crate::sqlite_ext::register_builtin(conn);
296        conn.execute_batch(&create_sql)?;
297        conn.execute_batch(&clean_sql)?;
298        Ok::<(), anyhow::Error>(())
299    })
300    .owe_res()?;
301
302    // 数据源
303    let data_path = match &t.data_file {
304        Some(rel) => join_rel(&table_dir, rel),
305        None => table_dir.join("data.csv"),
306    };
307    if !data_path.exists() {
308        return Err(KnowledgeReason::from_conf()
309            .to_err()
310            .with_detail("data.csv not found"));
311    }
312    opx.record("data_path", &data_path);
313
314    // CSV 解析器
315    let mut rdr = build_csv_reader(csvd, &data_path)?;
316
317    // 列映射
318    let col_indices: Vec<usize> = if !t.columns.by_header.is_empty() {
319        let headers = rdr.headers().owe_res()?;
320        select_indices_by_header(headers, &t.columns.by_header)?
321    } else if !t.columns.by_index.is_empty() {
322        t.columns.by_index.clone()
323    } else {
324        return Err(KnowledgeReason::from_conf()
325            .to_err()
326            .with_detail("columns mapping required"));
327    };
328
329    // 导入(分批事务)
330    let mut inserted: usize = 0;
331    let mut bad: usize = 0;
332    let mut batch_left = load.batch_size.max(1);
333    db.with_conn(|conn| {
334        // 注册内置 UDF(用于 INSERT 绑定表达式)
335        let _ = crate::sqlite_ext::register_builtin(conn);
336        let mut tx = if load.transaction {
337            Some(conn.unchecked_transaction()?)
338        } else {
339            None
340        };
341        let mut stmt = conn.prepare(&insert_sql)?;
342        for rec in rdr.into_records() {
343            match rec {
344                Ok(record) => {
345                    let refs = extract_row_refs(&record, &col_indices, &mut bad, load)?;
346                    if let Some(refs) = refs {
347                        stmt.execute(rusqlite::params_from_iter(refs))?;
348                        inserted += 1;
349                        if load.transaction {
350                            batch_left -= 1;
351                            if batch_left == 0 {
352                                tx.take().unwrap().commit()?;
353                                tx = Some(conn.unchecked_transaction()?);
354                                batch_left = load.batch_size.max(1);
355                            }
356                        }
357                    }
358                }
359                Err(_e) => {
360                    if matches!(load.on_error, OnError::Skip) {
361                        bad += 1;
362                        continue;
363                    } else {
364                        anyhow::bail!("csv record parse error");
365                    }
366                }
367            }
368        }
369        if let Some(tx) = tx {
370            tx.commit()?;
371        }
372        Ok::<(), anyhow::Error>(())
373    })
374    .owe_res()?;
375
376    // 行数校验
377    if let Some(min) = t.expected_rows.min
378        && inserted < min
379    {
380        return Err(KnowledgeReason::from_conf()
381            .to_err()
382            .with_detail("table data less"));
383    }
384    if let Some(max) = t.expected_rows.max
385        && inserted > max
386    {
387        wp_log::warn_kdb!(
388            "table {} loaded rows {} exceed max {}",
389            &t.name,
390            inserted,
391            max
392        );
393    }
394    if bad > 0 {
395        wp_log::warn_kdb!("table {} skipped {} bad rows (on_error=skip)", &t.name, bad);
396    }
397    opx.mark_suc();
398    Ok(())
399}
400
401fn build_csv_reader(
402    csvd: &CsvSpec,
403    data_path: &Path,
404) -> KnowledgeResult<csv::Reader<std::fs::File>> {
405    if csvd.encoding.to_lowercase() != "utf-8" {
406        return Err(KnowledgeReason::from_conf()
407            .to_err()
408            .with_detail("only utf-8 csv is supported"));
409    }
410    let mut rdr_b = csv::ReaderBuilder::new();
411    rdr_b.has_headers(csvd.has_header);
412    if csvd.delimiter.len() == 1 {
413        rdr_b.delimiter(csvd.delimiter.as_bytes()[0]);
414    }
415    if csvd.trim {
416        rdr_b.trim(csv::Trim::All);
417    }
418    rdr_b.from_path(data_path).owe_res()
419}
420
421fn select_indices_by_header(
422    headers: &csv::StringRecord,
423    wanted: &[String],
424) -> KnowledgeResult<Vec<usize>> {
425    let mut out = Vec::with_capacity(wanted.len());
426    for name in wanted {
427        let pos = headers.iter().position(|h| h == name).ok_or_else(|| {
428            KnowledgeReason::from_conf()
429                .to_err()
430                .with_detail("header not found")
431        })?;
432        out.push(pos);
433    }
434    Ok(out)
435}
436
437fn extract_row_refs<'a>(
438    record: &'a csv::StringRecord,
439    col_indices: &[usize],
440    bad: &mut usize,
441    load: &OptLoadSpec,
442) -> anyhow::Result<Option<Vec<&'a str>>> {
443    let mut vs: Vec<&str> = Vec::with_capacity(col_indices.len());
444    for &idx in col_indices {
445        if idx >= record.len() {
446            if matches!(load.on_error, OnError::Skip) {
447                *bad += 1;
448                return Ok(None);
449            } else {
450                anyhow::bail!("missing column at index {}", idx);
451            }
452        }
453        vs.push(record.get(idx).unwrap_or(""));
454    }
455    Ok(Some(vs))
456}
457
458#[cfg(test)]
459mod tests {
460    use super::*;
461
462    #[test]
463    fn parse_external_provider_spec() {
464        let dict = EnvDict::default();
465        let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(
466            r#"
467version = 2
468
469[provider]
470kind = "postgres"
471connection_uri = "postgres://demo:demo@127.0.0.1/demo"
472"#,
473            &dict,
474        )
475        .expect("parse knowdb with provider");
476
477        assert!(conf.tables.is_empty());
478        let provider = conf.provider.expect("provider");
479        assert!(matches!(provider.kind, ProviderKind::Postgres));
480        assert_eq!(
481            provider.connection_uri,
482            "postgres://demo:demo@127.0.0.1/demo"
483        );
484    }
485
486    #[test]
487    fn parse_mysql_provider_spec() {
488        let dict = EnvDict::default();
489        let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(
490            r#"
491version = 2
492
493[provider]
494kind = "mysql"
495connection_uri = "mysql://demo:demo@127.0.0.1:3306/demo"
496pool_size = 12
497"#,
498            &dict,
499        )
500        .expect("parse knowdb with mysql provider");
501
502        let provider = conf.provider.expect("provider");
503        assert!(matches!(provider.kind, ProviderKind::Mysql));
504        assert_eq!(
505            provider.connection_uri,
506            "mysql://demo:demo@127.0.0.1:3306/demo"
507        );
508        assert_eq!(provider.pool_size, Some(12));
509    }
510
511    #[test]
512    fn parse_cache_spec_with_defaults() {
513        let dict = EnvDict::default();
514        let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(
515            r#"
516version = 2
517"#,
518            &dict,
519        )
520        .expect("parse knowdb with default cache spec");
521
522        assert!(conf.cache.enabled);
523        assert_eq!(conf.cache.capacity, 1024);
524        assert_eq!(conf.cache.ttl_ms, 30_000);
525    }
526
527    #[test]
528    fn parse_cache_spec_from_toml() {
529        let dict = EnvDict::default();
530        let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(
531            r#"
532version = 2
533
534[cache]
535enabled = false
536capacity = 256
537ttl_ms = 1500
538"#,
539            &dict,
540        )
541        .expect("parse knowdb with cache spec");
542
543        assert!(!conf.cache.enabled);
544        assert_eq!(conf.cache.capacity, 256);
545        assert_eq!(conf.cache.ttl_ms, 1500);
546    }
547}