Skip to main content

wp_knowledge/
loader.rs

1use std::collections::HashMap;
2use std::fs;
3use std::io::Read;
4use std::path::{Path, PathBuf};
5
6use orion_conf::EnvTomlLoad;
7use serde::Deserialize;
8use wp_log::info_ctrl;
9
10use crate::error::{KnowReason, KnowledgeResult};
11use crate::mem::memdb::MemDB;
12use orion_error::OperationContext;
13use orion_error::conversion::{SourceErr, SourceRawErr, ToStructError};
14use orion_variate::EnvDict;
15use rusqlite::OpenFlags;
16
17/// V2 KnowDB 配置:目录式 + 外置 SQL。仅支持单一数据文件:`<table_dir>/data.csv`,
18/// 或通过 `tables[n].data_file` 相对 `<table_dir>` 指定。
19#[derive(Debug, Deserialize)]
20pub struct KnowDbConf {
21    pub version: u32,
22    #[serde(default = "default_dot")]
23    pub base_dir: String,
24    #[serde(default)]
25    pub default: OptLoadSpec,
26    #[serde(default)]
27    pub csv: CsvSpec,
28    #[serde(default)]
29    pub cache: CacheSpec,
30    #[serde(default)]
31    pub tables: Vec<TableSpec>,
32
33    /// `[fun.<name>]` — external named-query definitions.
34    #[serde(default)]
35    pub fun: HashMap<String, FunSpec>,
36
37    /// Raw provider config — `[provider.sqldb]` / `[provider.redis]`.
38    #[serde(default, rename = "provider")]
39    provider_raw: Option<ProviderConfig>,
40}
41
42impl KnowDbConf {
43    pub fn provider(&self) -> Option<ProviderConfig> {
44        self.provider_raw.clone()
45    }
46}
47
48// ---------------------------------------------------------------------------
49// Fun (external named query) config
50// ---------------------------------------------------------------------------
51
52#[derive(Debug, Clone, Deserialize)]
53pub struct FunSpec {
54    pub call: FunCall,
55    #[serde(default)]
56    pub key: Option<String>,
57    #[serde(default = "default_true")]
58    pub cache: bool,
59    #[serde(default)]
60    pub ttl_ms: Option<u64>,
61}
62
63impl FunSpec {
64    /// Derive return type from the call (bf_exists/sismember → bool, hget/get → value).
65    pub fn returns_bool(&self) -> bool {
66        matches!(self.call, FunCall::BfExists | FunCall::Sismember)
67    }
68}
69
70#[derive(Debug, Clone, Deserialize, PartialEq)]
71#[serde(rename_all = "snake_case")]
72pub enum FunCall {
73    BfExists,
74    Sismember,
75    Hget,
76    Get,
77}
78
79// ---------------------------------------------------------------------------
80// Cache config
81// ---------------------------------------------------------------------------
82
83#[derive(Debug, Clone, Deserialize)]
84pub struct CacheSpec {
85    #[serde(default = "default_true")]
86    pub enabled: bool,
87    #[serde(default = "default_result_cache_capacity")]
88    pub capacity: usize,
89    #[serde(default = "default_result_cache_ttl_ms")]
90    pub ttl_ms: u64,
91}
92
93impl Default for CacheSpec {
94    fn default() -> Self {
95        Self {
96            enabled: default_true(),
97            capacity: default_result_cache_capacity(),
98            ttl_ms: default_result_cache_ttl_ms(),
99        }
100    }
101}
102
103// ---------------------------------------------------------------------------
104// Provider configuration (new format: [provider.sqldb] / [provider.redis])
105// ---------------------------------------------------------------------------
106
107#[derive(Debug, Clone, Default, Deserialize)]
108pub struct ProviderConfig {
109    #[serde(default)]
110    pub sqldb: Option<SqlProviderSpec>,
111    #[serde(default)]
112    pub redis: Option<RedisProviderSpec>,
113}
114
115#[derive(Debug, Clone, Deserialize)]
116pub struct SqlProviderSpec {
117    #[serde(rename = "kind")]
118    pub kind: SqlProviderKind,
119    pub connection_uri: String,
120    #[serde(default)]
121    pub pool_size: Option<u32>,
122    #[serde(default)]
123    pub min_connections: Option<u32>,
124    #[serde(default)]
125    pub acquire_timeout_ms: Option<u64>,
126    #[serde(default)]
127    pub idle_timeout_ms: Option<u64>,
128    #[serde(default)]
129    pub max_lifetime_ms: Option<u64>,
130}
131
132/// Config-level SQL provider kind — Postgres and Mysql only.
133/// The runtime-level [`ProviderKind`] additionally includes `SqliteAuthority`
134/// and `Redis` for the internal provider registry.
135#[derive(Debug, Clone, Deserialize)]
136#[serde(rename_all = "snake_case")]
137pub enum SqlProviderKind {
138    Postgres,
139    Mysql,
140}
141
142#[derive(Debug, Clone, Deserialize)]
143pub struct RedisProviderSpec {
144    pub connection_uri: String,
145    #[serde(default)]
146    pub pool_size: Option<usize>,
147    #[serde(default = "default_connect_timeout_ms")]
148    pub connect_timeout_ms: u64,
149    #[serde(default = "default_command_timeout_ms")]
150    pub command_timeout_ms: u64,
151}
152
153fn default_connect_timeout_ms() -> u64 {
154    3_000
155}
156
157fn default_command_timeout_ms() -> u64 {
158    100
159}
160
161/// Runtime-level provider kind — used by the internal registry to identify
162/// the active provider. Includes built-in types (SqliteAuthority) and all
163/// external types (Postgres, Mysql, Redis).
164///
165/// For config-level SQL providers, see [`SqlProviderKind`].
166#[derive(Debug, Clone, Deserialize)]
167#[serde(rename_all = "snake_case")]
168pub enum ProviderKind {
169    SqliteAuthority,
170    Postgres,
171    Mysql,
172    Redis,
173}
174
175#[derive(Debug, Clone, Deserialize)]
176pub struct OptLoadSpec {
177    #[serde(default = "default_true")]
178    pub transaction: bool,
179    #[serde(default = "default_batch")]
180    pub batch_size: usize,
181    #[serde(default = "default_on_error")]
182    pub on_error: OnError,
183}
184impl Default for OptLoadSpec {
185    fn default() -> Self {
186        Self {
187            transaction: true,
188            batch_size: default_batch(),
189            on_error: default_on_error(),
190        }
191    }
192}
193
194#[derive(Debug, Clone, Deserialize, Default)]
195#[serde(rename_all = "lowercase")]
196pub enum OnError {
197    #[default]
198    Fail,
199    Skip,
200}
201
202#[derive(Debug, Clone, Deserialize)]
203pub struct CsvSpec {
204    #[serde(default = "default_true")]
205    pub has_header: bool,
206    #[serde(default = "default_comma")]
207    pub delimiter: String,
208    #[serde(default = "default_utf8")]
209    pub encoding: String,
210    #[serde(default = "default_true")]
211    pub trim: bool,
212}
213impl Default for CsvSpec {
214    fn default() -> Self {
215        CsvSpec {
216            has_header: true,
217            delimiter: ",".into(),
218            encoding: "utf-8".into(),
219            trim: true,
220        }
221    }
222}
223
224#[derive(Debug, Clone, Deserialize)]
225pub struct TableSpec {
226    pub name: String,
227    #[serde(default)]
228    pub dir: Option<String>,
229    #[serde(default)]
230    pub data_file: Option<String>,
231    pub columns: ColumnsSpec,
232    #[serde(default)]
233    pub expected_rows: RowExpect,
234    #[serde(default = "default_true")]
235    pub enabled: bool,
236}
237
238#[derive(Debug, Clone, Deserialize)]
239pub struct ColumnsSpec {
240    #[serde(default)]
241    pub by_header: Vec<String>,
242    #[serde(default)]
243    pub by_index: Vec<usize>,
244}
245
246#[derive(Debug, Clone, Deserialize, Default)]
247pub struct RowExpect {
248    pub min: Option<usize>,
249    pub max: Option<usize>,
250}
251
252const fn default_true() -> bool {
253    true
254}
255const fn default_batch() -> usize {
256    2000
257}
258fn default_comma() -> String {
259    ",".to_string()
260}
261fn default_utf8() -> String {
262    "utf-8".to_string()
263}
264fn default_on_error() -> OnError {
265    OnError::Fail
266}
267fn default_dot() -> String {
268    ".".to_string()
269}
270const fn default_result_cache_capacity() -> usize {
271    1024
272}
273const fn default_result_cache_ttl_ms() -> u64 {
274    30_000
275}
276
277/// 读取文本文件,返回字符串
278fn read_to_string(path: &Path) -> KnowledgeResult<String> {
279    let mut f = fs::File::open(path).source_raw_err(KnowReason::from_res(), "source error")?;
280    let mut buf = String::new();
281    f.read_to_string(&mut buf)
282        .source_raw_err(KnowReason::from_res(), "source error")?;
283    Ok(buf)
284}
285
286fn replace_table(sql: &str, table: &str) -> String {
287    sql.replace("{table}", table)
288}
289
290fn join_rel(base: &Path, rel: &str) -> PathBuf {
291    let p = Path::new(rel);
292    if p.is_absolute() {
293        p.to_path_buf()
294    } else {
295        base.join(p)
296    }
297}
298
299pub fn build_authority_from_knowdb(
300    root: &Path,
301    conf_path: &Path,
302    authority_uri: &str,
303    dict: &EnvDict,
304) -> KnowledgeResult<Vec<String>> {
305    let mut opx = OperationContext::doing("build authority from knowdb").with_auto_log();
306    // 1) 解析配置与 base_dir
307    let (conf, conf_abs, base_dir) = parse_knowdb_conf(root, conf_path, dict)?;
308    opx.record("conf", conf_abs.display());
309    opx.record("base_dir", base_dir.display());
310    // 2) 打开权威库
311    let db = open_authority(authority_uri)?;
312    // 3) 逐表加载(按配置顺序);不再处理显式依赖
313    let mut loaded_names = Vec::new();
314    for t in &conf.tables {
315        if !t.enabled {
316            continue;
317        }
318        load_one_table(&db, &base_dir, t, &conf.csv, &conf.default)?;
319        info_ctrl!("load table {} suc!", base_dir.display(),);
320        loaded_names.push(t.name.clone());
321    }
322    opx.mark_suc();
323    Ok(loaded_names)
324}
325
326pub fn parse_knowdb_conf(
327    root: &Path,
328    conf_path: &Path,
329    dict: &EnvDict,
330) -> KnowledgeResult<(KnowDbConf, PathBuf, PathBuf)> {
331    let conf_abs = if conf_path.is_absolute() {
332        conf_path.to_path_buf()
333    } else {
334        root.join(conf_path)
335    };
336    let conf_txt = read_to_string(&conf_abs)?;
337    let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(&conf_txt, dict)
338        .source_err(KnowReason::from_conf(), "parse knowdb config")?;
339    if conf.version != 2 {
340        return Err(KnowReason::from_conf()
341            .to_err()
342            .with_detail("unsupported knowdb.version"));
343    }
344    let conf_dir = conf_abs.parent().unwrap_or_else(|| Path::new("."));
345    let base_dir = join_rel(conf_dir, &conf.base_dir);
346    Ok((conf, conf_abs, base_dir))
347}
348
349fn open_authority(authority_uri: &str) -> KnowledgeResult<MemDB> {
350    ensure_parent_dir_for_file_uri(authority_uri);
351    let flags = OpenFlags::SQLITE_OPEN_READ_WRITE
352        | OpenFlags::SQLITE_OPEN_CREATE
353        | OpenFlags::SQLITE_OPEN_URI;
354    let db = MemDB::new_file(authority_uri, 1, flags)?;
355    // 预注册内置 UDF 至权威库连接(注意:连接池可能返回不同连接,导入时也会再次注册)
356    let _ = db.with_conn(|conn| {
357        let _ = crate::sqlite_ext::register_builtin(conn);
358        Ok::<(), anyhow::Error>(())
359    });
360    Ok(db)
361}
362
363/// Kahn 拓扑排序:返回按依赖顺序的表索引列表。
364/// no topo_sort_tables: V2 简化版按配置顺序加载
365fn ensure_parent_dir_for_file_uri(uri: &str) {
366    if let Some(rest) = uri.strip_prefix("file:") {
367        let path_part = rest.split('?').next().unwrap_or(rest);
368        let p = Path::new(path_part);
369        if let Some(parent) = p.parent() {
370            let _ = fs::create_dir_all(parent);
371        }
372    }
373}
374
375fn load_one_table(
376    db: &MemDB,
377    base_dir: &Path,
378    t: &TableSpec,
379    csvd: &CsvSpec,
380    load: &OptLoadSpec,
381) -> KnowledgeResult<()> {
382    // 目录与必须文件
383    let mut opx = OperationContext::doing("load table to kdb")
384        .with_auto_log()
385        .with_mod_path("ctrl");
386    let dir_name: &str = t.dir.as_deref().unwrap_or(&t.name);
387    let table_dir = base_dir.join(dir_name);
388    opx.record("table_dir", table_dir.display());
389    let create_sql = replace_table(&read_to_string(&table_dir.join("create.sql"))?, &t.name);
390    let insert_sql = replace_table(&read_to_string(&table_dir.join("insert.sql"))?, &t.name);
391    let clean_path = table_dir.join("clean.sql");
392    let clean_sql = if clean_path.exists() {
393        replace_table(&read_to_string(&clean_path)?, &t.name)
394    } else {
395        format!("DELETE FROM {}", &t.name)
396    };
397
398    // 建表与清理
399    db.with_conn(|conn| {
400        // 注册内置 UDF(导入连接)
401        let _ = crate::sqlite_ext::register_builtin(conn);
402        conn.execute_batch(&create_sql)?;
403        conn.execute_batch(&clean_sql)?;
404        Ok::<(), anyhow::Error>(())
405    })
406    .source_err(KnowReason::from_res(), "prepare authority table")?;
407
408    // 数据源
409    let data_path = match &t.data_file {
410        Some(rel) => join_rel(&table_dir, rel),
411        None => table_dir.join("data.csv"),
412    };
413    if !data_path.exists() {
414        return Err(KnowReason::from_conf()
415            .to_err()
416            .with_detail("data.csv not found"));
417    }
418    opx.record("data_path", data_path.display());
419
420    // CSV 解析器
421    let mut rdr = build_csv_reader(csvd, &data_path)?;
422
423    // 列映射
424    let col_indices: Vec<usize> = if !t.columns.by_header.is_empty() {
425        let headers = rdr
426            .headers()
427            .source_raw_err(KnowReason::from_res(), "source error")?;
428        select_indices_by_header(headers, &t.columns.by_header)?
429    } else if !t.columns.by_index.is_empty() {
430        t.columns.by_index.clone()
431    } else {
432        return Err(KnowReason::from_conf()
433            .to_err()
434            .with_detail("columns mapping required"));
435    };
436
437    // 导入(分批事务)
438    let mut inserted: usize = 0;
439    let mut bad: usize = 0;
440    let mut batch_left = load.batch_size.max(1);
441    db.with_conn(|conn| {
442        // 注册内置 UDF(用于 INSERT 绑定表达式)
443        let _ = crate::sqlite_ext::register_builtin(conn);
444        let mut tx = if load.transaction {
445            Some(conn.unchecked_transaction()?)
446        } else {
447            None
448        };
449        let mut stmt = conn.prepare(&insert_sql)?;
450        for rec in rdr.into_records() {
451            match rec {
452                Ok(record) => {
453                    let refs = extract_row_refs(&record, &col_indices, &mut bad, load)?;
454                    if let Some(refs) = refs {
455                        stmt.execute(rusqlite::params_from_iter(refs))?;
456                        inserted += 1;
457                        if load.transaction {
458                            batch_left -= 1;
459                            if batch_left == 0 {
460                                tx.take().unwrap().commit()?;
461                                tx = Some(conn.unchecked_transaction()?);
462                                batch_left = load.batch_size.max(1);
463                            }
464                        }
465                    }
466                }
467                Err(_e) => {
468                    if matches!(load.on_error, OnError::Skip) {
469                        bad += 1;
470                        continue;
471                    } else {
472                        anyhow::bail!("csv record parse error");
473                    }
474                }
475            }
476        }
477        if let Some(tx) = tx {
478            tx.commit()?;
479        }
480        Ok::<(), anyhow::Error>(())
481    })
482    .source_err(KnowReason::from_res(), "load authority table data")?;
483
484    // 行数校验
485    if let Some(min) = t.expected_rows.min
486        && inserted < min
487    {
488        return Err(KnowReason::from_conf()
489            .to_err()
490            .with_detail("table data less"));
491    }
492    if let Some(max) = t.expected_rows.max
493        && inserted > max
494    {
495        wp_log::warn_kdb!(
496            "table {} loaded rows {} exceed max {}",
497            &t.name,
498            inserted,
499            max
500        );
501    }
502    if bad > 0 {
503        wp_log::warn_kdb!("table {} skipped {} bad rows (on_error=skip)", &t.name, bad);
504    }
505    opx.mark_suc();
506    Ok(())
507}
508
509fn build_csv_reader(
510    csvd: &CsvSpec,
511    data_path: &Path,
512) -> KnowledgeResult<csv::Reader<std::fs::File>> {
513    if csvd.encoding.to_lowercase() != "utf-8" {
514        return Err(KnowReason::from_conf()
515            .to_err()
516            .with_detail("only utf-8 csv is supported"));
517    }
518    let mut rdr_b = csv::ReaderBuilder::new();
519    rdr_b.has_headers(csvd.has_header);
520    if csvd.delimiter.len() == 1 {
521        rdr_b.delimiter(csvd.delimiter.as_bytes()[0]);
522    }
523    if csvd.trim {
524        rdr_b.trim(csv::Trim::All);
525    }
526    rdr_b
527        .from_path(data_path)
528        .source_raw_err(KnowReason::from_res(), "source error")
529}
530
531fn select_indices_by_header(
532    headers: &csv::StringRecord,
533    wanted: &[String],
534) -> KnowledgeResult<Vec<usize>> {
535    let mut out = Vec::with_capacity(wanted.len());
536    for name in wanted {
537        let pos = headers.iter().position(|h| h == name).ok_or_else(|| {
538            KnowReason::from_conf()
539                .to_err()
540                .with_detail("header not found")
541        })?;
542        out.push(pos);
543    }
544    Ok(out)
545}
546
547fn extract_row_refs<'a>(
548    record: &'a csv::StringRecord,
549    col_indices: &[usize],
550    bad: &mut usize,
551    load: &OptLoadSpec,
552) -> anyhow::Result<Option<Vec<&'a str>>> {
553    let mut vs: Vec<&str> = Vec::with_capacity(col_indices.len());
554    for &idx in col_indices {
555        if idx >= record.len() {
556            if matches!(load.on_error, OnError::Skip) {
557                *bad += 1;
558                return Ok(None);
559            } else {
560                anyhow::bail!("missing column at index {}", idx);
561            }
562        }
563        vs.push(record.get(idx).unwrap_or(""));
564    }
565    Ok(Some(vs))
566}
567
568#[cfg(test)]
569mod tests {
570    use super::*;
571
572    #[test]
573    fn parse_new_style_sqldb_provider() {
574        let dict = EnvDict::default();
575        let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(
576            r#"
577version = 2
578
579[provider.sqldb]
580kind = "postgres"
581connection_uri = "postgres://demo:demo@127.0.0.1/demo"
582pool_size = 12
583"#,
584            &dict,
585        )
586        .expect("parse knowdb with sqldb provider");
587
588        let sqldb = conf
589            .provider()
590            .expect("provider")
591            .sqldb
592            .expect("sqldb provider");
593        assert!(matches!(sqldb.kind, SqlProviderKind::Postgres));
594        assert_eq!(sqldb.pool_size, Some(12));
595    }
596
597    #[test]
598    fn parse_new_style_redis_provider() {
599        let dict = EnvDict::default();
600        let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(
601            r#"
602version = 2
603
604[provider.redis]
605connection_uri = "redis://127.0.0.1:6379"
606pool_size = 16
607connect_timeout_ms = 5000
608command_timeout_ms = 200
609"#,
610            &dict,
611        )
612        .expect("parse knowdb with redis provider");
613
614        let redis_cfg = conf
615            .provider()
616            .expect("provider")
617            .redis
618            .expect("redis provider");
619        assert_eq!(redis_cfg.connection_uri, "redis://127.0.0.1:6379");
620        assert_eq!(redis_cfg.pool_size, Some(16));
621        assert_eq!(redis_cfg.connect_timeout_ms, 5000);
622        assert_eq!(redis_cfg.command_timeout_ms, 200);
623    }
624
625    #[test]
626    fn parse_redis_provider_with_default_timeouts() {
627        let dict = EnvDict::default();
628        let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(
629            r#"
630version = 2
631
632[provider.redis]
633connection_uri = "redis://127.0.0.1:6379"
634"#,
635            &dict,
636        )
637        .expect("parse knowdb with redis provider (no timeout fields)");
638
639        let redis_cfg = conf.provider().expect("provider").redis.expect("redis");
640        assert_eq!(redis_cfg.connect_timeout_ms, 3000);
641        assert_eq!(redis_cfg.command_timeout_ms, 100);
642    }
643
644    #[test]
645    fn parse_both_sqldb_and_redis_providers() {
646        let dict = EnvDict::default();
647        let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(
648            r#"
649version = 2
650
651[provider.sqldb]
652kind = "postgres"
653connection_uri = "postgres://demo:demo@127.0.0.1/demo"
654
655[provider.redis]
656connection_uri = "redis://10.0.0.1:6379"
657pool_size = 4
658"#,
659            &dict,
660        )
661        .expect("parse knowdb with both sqldb and redis");
662
663        let provider_cfg = conf.provider().expect("provider");
664        let sqldb = provider_cfg.sqldb.expect("sqldb");
665        let redis_cfg = provider_cfg.redis.expect("redis");
666        assert!(matches!(sqldb.kind, SqlProviderKind::Postgres));
667        assert_eq!(redis_cfg.connection_uri, "redis://10.0.0.1:6379");
668        assert_eq!(redis_cfg.pool_size, Some(4));
669    }
670
671    #[test]
672    fn parse_redis_only_without_sqldb() {
673        let dict = EnvDict::default();
674        let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(
675            r#"
676version = 2
677
678[provider.redis]
679connection_uri = "redis://127.0.0.1:6379"
680"#,
681            &dict,
682        )
683        .expect("parse knowdb with redis only");
684
685        let provider_cfg = conf.provider().expect("provider");
686        assert!(provider_cfg.sqldb.is_none());
687        assert!(provider_cfg.redis.is_some());
688    }
689
690    #[test]
691    fn parse_no_provider_section() {
692        let dict = EnvDict::default();
693        let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(
694            r#"
695version = 2
696"#,
697            &dict,
698        )
699        .expect("parse knowdb without provider");
700
701        assert!(conf.provider().is_none());
702    }
703
704    #[test]
705    fn new_style_sqldb_mysql_variant() {
706        let dict = EnvDict::default();
707        let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(
708            r#"
709version = 2
710
711[provider.sqldb]
712kind = "mysql"
713connection_uri = "mysql://user:pass@127.0.0.1:3306/db"
714pool_size = 8
715"#,
716            &dict,
717        )
718        .expect("parse new-style mysql sqldb");
719
720        let sqldb = conf.provider().expect("provider").sqldb.expect("sqldb");
721        assert!(matches!(sqldb.kind, SqlProviderKind::Mysql));
722        assert_eq!(sqldb.pool_size, Some(8));
723    }
724
725    #[test]
726    fn parse_cache_spec_with_defaults() {
727        let dict = EnvDict::default();
728        let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(
729            r#"
730version = 2
731"#,
732            &dict,
733        )
734        .expect("parse knowdb with default cache spec");
735
736        assert!(conf.cache.enabled);
737        assert_eq!(conf.cache.capacity, 1024);
738        assert_eq!(conf.cache.ttl_ms, 30_000);
739    }
740
741    #[test]
742    fn parse_cache_spec_from_toml() {
743        let dict = EnvDict::default();
744        let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(
745            r#"
746version = 2
747
748[cache]
749enabled = false
750capacity = 256
751ttl_ms = 1500
752"#,
753            &dict,
754        )
755        .expect("parse knowdb with cache spec");
756
757        assert!(!conf.cache.enabled);
758        assert_eq!(conf.cache.capacity, 256);
759        assert_eq!(conf.cache.ttl_ms, 1500);
760    }
761
762    #[test]
763    fn parse_redis_cache_spec() {
764        let dict = EnvDict::default();
765        let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(
766            r#"
767version = 2
768
769[cache]
770enabled = true
771capacity = 512
772"#,
773            &dict,
774        )
775        .expect("parse knowdb with cache");
776
777        assert!(conf.cache.enabled);
778        assert_eq!(conf.cache.capacity, 512);
779    }
780
781    #[test]
782    fn parse_redis_cache_defaults() {
783        let dict = EnvDict::default();
784        let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(
785            r#"
786version = 2
787"#,
788            &dict,
789        )
790        .expect("parse knowdb without redis.cache");
791
792        // No [cache] → defaults enabled=true, capacity=1024
793        assert!(conf.cache.enabled);
794        assert_eq!(conf.cache.capacity, 1024);
795    }
796
797    // -----------------------------------------------------------------------
798    // Fun (external named query) config tests
799    // -----------------------------------------------------------------------
800
801    #[test]
802    fn parse_fun_bool_services() {
803        let dict = EnvDict::default();
804        let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(
805            r#"
806version = 2
807
808[fun.password_check]
809call = "bf_exists"
810key = "weak_passwords"
811
812[fun.ip_whitelist]
813call = "sismember"
814key = "allowed_ips"
815"#,
816            &dict,
817        )
818        .expect("parse fun bool services");
819
820        let pw = conf.fun.get("password_check").expect("password_check");
821        assert_eq!(pw.call, FunCall::BfExists);
822        assert_eq!(pw.key.as_deref(), Some("weak_passwords"));
823        assert!(pw.returns_bool());
824
825        let ip = conf.fun.get("ip_whitelist").expect("ip_whitelist");
826        assert_eq!(ip.call, FunCall::Sismember);
827        assert_eq!(ip.key.as_deref(), Some("allowed_ips"));
828        assert!(ip.returns_bool());
829    }
830
831    #[test]
832    fn parse_fun_value_services() {
833        let dict = EnvDict::default();
834        let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(
835            r#"
836version = 2
837
838[fun.threat_actor]
839call = "hget"
840key = "threat_actors"
841cache = true
842ttl_ms = 60000
843
844[fun.user_tag]
845call = "get"
846"#,
847            &dict,
848        )
849        .expect("parse fun value services");
850
851        let ta = conf.fun.get("threat_actor").expect("threat_actor");
852        assert_eq!(ta.call, FunCall::Hget);
853        assert_eq!(ta.key.as_deref(), Some("threat_actors"));
854        assert!(ta.cache);
855        assert_eq!(ta.ttl_ms, Some(60000));
856        assert!(!ta.returns_bool());
857
858        let ut = conf.fun.get("user_tag").expect("user_tag");
859        assert_eq!(ut.call, FunCall::Get);
860        assert!(ut.key.is_none());
861        assert!(ut.cache); // default true
862        assert!(!ut.returns_bool());
863    }
864
865    #[test]
866    fn parse_fun_default_cache() {
867        let dict = EnvDict::default();
868        let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(
869            r#"
870version = 2
871
872[fun.app_config]
873call = "get"
874key = "app_config"
875"#,
876            &dict,
877        )
878        .expect("parse fun default cache");
879
880        let spec = conf.fun.get("app_config").expect("app_config");
881        assert!(spec.cache);
882        assert!(spec.ttl_ms.is_none());
883    }
884}