Skip to main content

wp_knowledge/
loader.rs

1use std::collections::HashMap;
2use std::fs;
3use std::io::Read;
4use std::path::{Path, PathBuf};
5
6use orion_conf::EnvTomlLoad;
7use serde::Deserialize;
8use wp_log::info_ctrl;
9
10use crate::error::{KnowReason, KnowledgeResult};
11use crate::mem::memdb::MemDB;
12use orion_error::OperationContext;
13use orion_error::conversion::{SourceErr, SourceRawErr, ToStructError};
14use orion_variate::EnvDict;
15use rusqlite::OpenFlags;
16
17/// V2 KnowDB 配置:目录式 + 外置 SQL。仅支持单一数据文件:`<table_dir>/data.csv`,
18/// 或通过 `tables[n].data_file` 相对 `<table_dir>` 指定。
19#[derive(Debug, Deserialize)]
20pub struct KnowDbConf {
21    pub version: u32,
22    #[serde(default = "default_dot")]
23    pub base_dir: String,
24    #[serde(default)]
25    pub default: OptLoadSpec,
26    #[serde(default)]
27    pub csv: CsvSpec,
28    #[serde(default)]
29    pub cache: CacheSpec,
30    #[serde(default)]
31    pub tables: Vec<TableSpec>,
32
33    /// Raw provider config — `[provider.sqldb]` / `[provider.redis]`.
34    #[serde(default, rename = "provider")]
35    provider_raw: Option<ProviderConfig>,
36}
37
38impl KnowDbConf {
39    pub fn provider(&self) -> Option<ProviderConfig> {
40        self.provider_raw.clone()
41    }
42
43    /// Resolve Redis cache config for a given key.
44    /// Per-key overrides (`[[cache.redis_key]]`) take precedence over `[cache].enabled`.
45    pub fn redis_cache_enabled(&self, key: &str) -> bool {
46        if let Some(entry) = self.cache.redis_key_entry(key)
47            && let Some(enabled) = entry.enabled
48        {
49            return enabled;
50        }
51        self.cache.enabled
52    }
53}
54
55// ---------------------------------------------------------------------------
56// Cache config
57// ---------------------------------------------------------------------------
58
59#[derive(Debug, Clone, Deserialize)]
60pub struct CacheSpec {
61    #[serde(default = "default_true")]
62    pub enabled: bool,
63    #[serde(default = "default_result_cache_capacity")]
64    pub capacity: usize,
65    #[serde(default = "default_result_cache_ttl_ms")]
66    pub ttl_ms: u64,
67    /// `[[cache.redis_key]]` — per-key Redis cache overrides.
68    #[serde(default, rename = "redis_key")]
69    pub redis_key: Vec<RedisKeyCacheSpec>,
70}
71
72impl CacheSpec {
73    fn redis_key_entry(&self, key: &str) -> Option<&RedisKeyCacheSpec> {
74        self.redis_key.iter().find(|e| e.key == key)
75    }
76
77    pub fn redis_key_map(&self) -> HashMap<String, bool> {
78        self.redis_key
79            .iter()
80            .filter_map(|e| e.enabled.map(|v| (e.key.clone(), v)))
81            .collect()
82    }
83}
84
85impl Default for CacheSpec {
86    fn default() -> Self {
87        Self {
88            enabled: default_true(),
89            capacity: default_result_cache_capacity(),
90            ttl_ms: default_result_cache_ttl_ms(),
91            redis_key: Vec::new(),
92        }
93    }
94}
95
96#[derive(Debug, Clone, Deserialize)]
97pub struct RedisKeyCacheSpec {
98    pub key: String,
99    #[serde(default)]
100    pub enabled: Option<bool>,
101    #[serde(default)]
102    pub capacity: Option<usize>,
103}
104
105// ---------------------------------------------------------------------------
106// Provider configuration (new format: [provider.sqldb] / [provider.redis])
107// ---------------------------------------------------------------------------
108
109#[derive(Debug, Clone, Default, Deserialize)]
110pub struct ProviderConfig {
111    #[serde(default)]
112    pub sqldb: Option<SqlProviderSpec>,
113    #[serde(default)]
114    pub redis: Option<RedisProviderSpec>,
115}
116
117#[derive(Debug, Clone, Deserialize)]
118pub struct SqlProviderSpec {
119    #[serde(rename = "kind")]
120    pub kind: SqlProviderKind,
121    pub connection_uri: String,
122    #[serde(default)]
123    pub pool_size: Option<u32>,
124    #[serde(default)]
125    pub min_connections: Option<u32>,
126    #[serde(default)]
127    pub acquire_timeout_ms: Option<u64>,
128    #[serde(default)]
129    pub idle_timeout_ms: Option<u64>,
130    #[serde(default)]
131    pub max_lifetime_ms: Option<u64>,
132}
133
134/// Config-level SQL provider kind — Postgres and Mysql only.
135/// The runtime-level [`ProviderKind`] additionally includes `SqliteAuthority`
136/// and `Redis` for the internal provider registry.
137#[derive(Debug, Clone, Deserialize)]
138#[serde(rename_all = "snake_case")]
139pub enum SqlProviderKind {
140    Postgres,
141    Mysql,
142}
143
144#[derive(Debug, Clone, Deserialize)]
145pub struct RedisProviderSpec {
146    pub connection_uri: String,
147    #[serde(default)]
148    pub pool_size: Option<usize>,
149    #[serde(default = "default_connect_timeout_ms")]
150    pub connect_timeout_ms: u64,
151    #[serde(default = "default_command_timeout_ms")]
152    pub command_timeout_ms: u64,
153}
154
155fn default_connect_timeout_ms() -> u64 {
156    3_000
157}
158
159fn default_command_timeout_ms() -> u64 {
160    100
161}
162
163/// Runtime-level provider kind — used by the internal registry to identify
164/// the active provider. Includes built-in types (SqliteAuthority) and all
165/// external types (Postgres, Mysql, Redis).
166///
167/// For config-level SQL providers, see [`SqlProviderKind`].
168#[derive(Debug, Clone, Deserialize)]
169#[serde(rename_all = "snake_case")]
170pub enum ProviderKind {
171    SqliteAuthority,
172    Postgres,
173    Mysql,
174    Redis,
175}
176
177#[derive(Debug, Clone, Deserialize)]
178pub struct OptLoadSpec {
179    #[serde(default = "default_true")]
180    pub transaction: bool,
181    #[serde(default = "default_batch")]
182    pub batch_size: usize,
183    #[serde(default = "default_on_error")]
184    pub on_error: OnError,
185}
186impl Default for OptLoadSpec {
187    fn default() -> Self {
188        Self {
189            transaction: true,
190            batch_size: default_batch(),
191            on_error: default_on_error(),
192        }
193    }
194}
195
196#[derive(Debug, Clone, Deserialize, Default)]
197#[serde(rename_all = "lowercase")]
198pub enum OnError {
199    #[default]
200    Fail,
201    Skip,
202}
203
204#[derive(Debug, Clone, Deserialize)]
205pub struct CsvSpec {
206    #[serde(default = "default_true")]
207    pub has_header: bool,
208    #[serde(default = "default_comma")]
209    pub delimiter: String,
210    #[serde(default = "default_utf8")]
211    pub encoding: String,
212    #[serde(default = "default_true")]
213    pub trim: bool,
214}
215impl Default for CsvSpec {
216    fn default() -> Self {
217        CsvSpec {
218            has_header: true,
219            delimiter: ",".into(),
220            encoding: "utf-8".into(),
221            trim: true,
222        }
223    }
224}
225
226#[derive(Debug, Clone, Deserialize)]
227pub struct TableSpec {
228    pub name: String,
229    #[serde(default)]
230    pub dir: Option<String>,
231    #[serde(default)]
232    pub data_file: Option<String>,
233    pub columns: ColumnsSpec,
234    #[serde(default)]
235    pub expected_rows: RowExpect,
236    #[serde(default = "default_true")]
237    pub enabled: bool,
238}
239
240#[derive(Debug, Clone, Deserialize)]
241pub struct ColumnsSpec {
242    #[serde(default)]
243    pub by_header: Vec<String>,
244    #[serde(default)]
245    pub by_index: Vec<usize>,
246}
247
248#[derive(Debug, Clone, Deserialize, Default)]
249pub struct RowExpect {
250    pub min: Option<usize>,
251    pub max: Option<usize>,
252}
253
254const fn default_true() -> bool {
255    true
256}
257const fn default_batch() -> usize {
258    2000
259}
260fn default_comma() -> String {
261    ",".to_string()
262}
263fn default_utf8() -> String {
264    "utf-8".to_string()
265}
266fn default_on_error() -> OnError {
267    OnError::Fail
268}
269fn default_dot() -> String {
270    ".".to_string()
271}
272const fn default_result_cache_capacity() -> usize {
273    1024
274}
275const fn default_result_cache_ttl_ms() -> u64 {
276    30_000
277}
278
279/// 读取文本文件,返回字符串
280fn read_to_string(path: &Path) -> KnowledgeResult<String> {
281    let mut f = fs::File::open(path).source_raw_err(KnowReason::from_res(), "source error")?;
282    let mut buf = String::new();
283    f.read_to_string(&mut buf)
284        .source_raw_err(KnowReason::from_res(), "source error")?;
285    Ok(buf)
286}
287
288fn replace_table(sql: &str, table: &str) -> String {
289    sql.replace("{table}", table)
290}
291
292fn join_rel(base: &Path, rel: &str) -> PathBuf {
293    let p = Path::new(rel);
294    if p.is_absolute() {
295        p.to_path_buf()
296    } else {
297        base.join(p)
298    }
299}
300
301pub fn build_authority_from_knowdb(
302    root: &Path,
303    conf_path: &Path,
304    authority_uri: &str,
305    dict: &EnvDict,
306) -> KnowledgeResult<Vec<String>> {
307    let mut opx = OperationContext::doing("build authority from knowdb").with_auto_log();
308    // 1) 解析配置与 base_dir
309    let (conf, conf_abs, base_dir) = parse_knowdb_conf(root, conf_path, dict)?;
310    opx.record("conf", conf_abs.display());
311    opx.record("base_dir", base_dir.display());
312    // 2) 打开权威库
313    let db = open_authority(authority_uri)?;
314    // 3) 逐表加载(按配置顺序);不再处理显式依赖
315    let mut loaded_names = Vec::new();
316    for t in &conf.tables {
317        if !t.enabled {
318            continue;
319        }
320        load_one_table(&db, &base_dir, t, &conf.csv, &conf.default)?;
321        info_ctrl!("load table {} suc!", base_dir.display(),);
322        loaded_names.push(t.name.clone());
323    }
324    opx.mark_suc();
325    Ok(loaded_names)
326}
327
328pub fn parse_knowdb_conf(
329    root: &Path,
330    conf_path: &Path,
331    dict: &EnvDict,
332) -> KnowledgeResult<(KnowDbConf, PathBuf, PathBuf)> {
333    let conf_abs = if conf_path.is_absolute() {
334        conf_path.to_path_buf()
335    } else {
336        root.join(conf_path)
337    };
338    let conf_txt = read_to_string(&conf_abs)?;
339    let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(&conf_txt, dict)
340        .source_err(KnowReason::from_conf(), "parse knowdb config")?;
341    if conf.version != 2 {
342        return Err(KnowReason::from_conf()
343            .to_err()
344            .with_detail("unsupported knowdb.version"));
345    }
346    let conf_dir = conf_abs.parent().unwrap_or_else(|| Path::new("."));
347    let base_dir = join_rel(conf_dir, &conf.base_dir);
348    Ok((conf, conf_abs, base_dir))
349}
350
351fn open_authority(authority_uri: &str) -> KnowledgeResult<MemDB> {
352    ensure_parent_dir_for_file_uri(authority_uri);
353    let flags = OpenFlags::SQLITE_OPEN_READ_WRITE
354        | OpenFlags::SQLITE_OPEN_CREATE
355        | OpenFlags::SQLITE_OPEN_URI;
356    let db = MemDB::new_file(authority_uri, 1, flags)?;
357    // 预注册内置 UDF 至权威库连接(注意:连接池可能返回不同连接,导入时也会再次注册)
358    let _ = db.with_conn(|conn| {
359        let _ = crate::sqlite_ext::register_builtin(conn);
360        Ok::<(), anyhow::Error>(())
361    });
362    Ok(db)
363}
364
365/// Kahn 拓扑排序:返回按依赖顺序的表索引列表。
366/// no topo_sort_tables: V2 简化版按配置顺序加载
367fn ensure_parent_dir_for_file_uri(uri: &str) {
368    if let Some(rest) = uri.strip_prefix("file:") {
369        let path_part = rest.split('?').next().unwrap_or(rest);
370        let p = Path::new(path_part);
371        if let Some(parent) = p.parent() {
372            let _ = fs::create_dir_all(parent);
373        }
374    }
375}
376
377fn load_one_table(
378    db: &MemDB,
379    base_dir: &Path,
380    t: &TableSpec,
381    csvd: &CsvSpec,
382    load: &OptLoadSpec,
383) -> KnowledgeResult<()> {
384    // 目录与必须文件
385    let mut opx = OperationContext::doing("load table to kdb")
386        .with_auto_log()
387        .with_mod_path("ctrl");
388    let dir_name: &str = t.dir.as_deref().unwrap_or(&t.name);
389    let table_dir = base_dir.join(dir_name);
390    opx.record("table_dir", table_dir.display());
391    let create_sql = replace_table(&read_to_string(&table_dir.join("create.sql"))?, &t.name);
392    let insert_sql = replace_table(&read_to_string(&table_dir.join("insert.sql"))?, &t.name);
393    let clean_path = table_dir.join("clean.sql");
394    let clean_sql = if clean_path.exists() {
395        replace_table(&read_to_string(&clean_path)?, &t.name)
396    } else {
397        format!("DELETE FROM {}", &t.name)
398    };
399
400    // 建表与清理
401    db.with_conn(|conn| {
402        // 注册内置 UDF(导入连接)
403        let _ = crate::sqlite_ext::register_builtin(conn);
404        conn.execute_batch(&create_sql)?;
405        conn.execute_batch(&clean_sql)?;
406        Ok::<(), anyhow::Error>(())
407    })
408    .source_err(KnowReason::from_res(), "prepare authority table")?;
409
410    // 数据源
411    let data_path = match &t.data_file {
412        Some(rel) => join_rel(&table_dir, rel),
413        None => table_dir.join("data.csv"),
414    };
415    if !data_path.exists() {
416        return Err(KnowReason::from_conf()
417            .to_err()
418            .with_detail("data.csv not found"));
419    }
420    opx.record("data_path", data_path.display());
421
422    // CSV 解析器
423    let mut rdr = build_csv_reader(csvd, &data_path)?;
424
425    // 列映射
426    let col_indices: Vec<usize> = if !t.columns.by_header.is_empty() {
427        let headers = rdr
428            .headers()
429            .source_raw_err(KnowReason::from_res(), "source error")?;
430        select_indices_by_header(headers, &t.columns.by_header)?
431    } else if !t.columns.by_index.is_empty() {
432        t.columns.by_index.clone()
433    } else {
434        return Err(KnowReason::from_conf()
435            .to_err()
436            .with_detail("columns mapping required"));
437    };
438
439    // 导入(分批事务)
440    let mut inserted: usize = 0;
441    let mut bad: usize = 0;
442    let mut batch_left = load.batch_size.max(1);
443    db.with_conn(|conn| {
444        // 注册内置 UDF(用于 INSERT 绑定表达式)
445        let _ = crate::sqlite_ext::register_builtin(conn);
446        let mut tx = if load.transaction {
447            Some(conn.unchecked_transaction()?)
448        } else {
449            None
450        };
451        let mut stmt = conn.prepare(&insert_sql)?;
452        for rec in rdr.into_records() {
453            match rec {
454                Ok(record) => {
455                    let refs = extract_row_refs(&record, &col_indices, &mut bad, load)?;
456                    if let Some(refs) = refs {
457                        stmt.execute(rusqlite::params_from_iter(refs))?;
458                        inserted += 1;
459                        if load.transaction {
460                            batch_left -= 1;
461                            if batch_left == 0 {
462                                tx.take().unwrap().commit()?;
463                                tx = Some(conn.unchecked_transaction()?);
464                                batch_left = load.batch_size.max(1);
465                            }
466                        }
467                    }
468                }
469                Err(_e) => {
470                    if matches!(load.on_error, OnError::Skip) {
471                        bad += 1;
472                        continue;
473                    } else {
474                        anyhow::bail!("csv record parse error");
475                    }
476                }
477            }
478        }
479        if let Some(tx) = tx {
480            tx.commit()?;
481        }
482        Ok::<(), anyhow::Error>(())
483    })
484    .source_err(KnowReason::from_res(), "load authority table data")?;
485
486    // 行数校验
487    if let Some(min) = t.expected_rows.min
488        && inserted < min
489    {
490        return Err(KnowReason::from_conf()
491            .to_err()
492            .with_detail("table data less"));
493    }
494    if let Some(max) = t.expected_rows.max
495        && inserted > max
496    {
497        wp_log::warn_kdb!(
498            "table {} loaded rows {} exceed max {}",
499            &t.name,
500            inserted,
501            max
502        );
503    }
504    if bad > 0 {
505        wp_log::warn_kdb!("table {} skipped {} bad rows (on_error=skip)", &t.name, bad);
506    }
507    opx.mark_suc();
508    Ok(())
509}
510
511fn build_csv_reader(
512    csvd: &CsvSpec,
513    data_path: &Path,
514) -> KnowledgeResult<csv::Reader<std::fs::File>> {
515    if csvd.encoding.to_lowercase() != "utf-8" {
516        return Err(KnowReason::from_conf()
517            .to_err()
518            .with_detail("only utf-8 csv is supported"));
519    }
520    let mut rdr_b = csv::ReaderBuilder::new();
521    rdr_b.has_headers(csvd.has_header);
522    if csvd.delimiter.len() == 1 {
523        rdr_b.delimiter(csvd.delimiter.as_bytes()[0]);
524    }
525    if csvd.trim {
526        rdr_b.trim(csv::Trim::All);
527    }
528    rdr_b
529        .from_path(data_path)
530        .source_raw_err(KnowReason::from_res(), "source error")
531}
532
533fn select_indices_by_header(
534    headers: &csv::StringRecord,
535    wanted: &[String],
536) -> KnowledgeResult<Vec<usize>> {
537    let mut out = Vec::with_capacity(wanted.len());
538    for name in wanted {
539        let pos = headers.iter().position(|h| h == name).ok_or_else(|| {
540            KnowReason::from_conf()
541                .to_err()
542                .with_detail("header not found")
543        })?;
544        out.push(pos);
545    }
546    Ok(out)
547}
548
549fn extract_row_refs<'a>(
550    record: &'a csv::StringRecord,
551    col_indices: &[usize],
552    bad: &mut usize,
553    load: &OptLoadSpec,
554) -> anyhow::Result<Option<Vec<&'a str>>> {
555    let mut vs: Vec<&str> = Vec::with_capacity(col_indices.len());
556    for &idx in col_indices {
557        if idx >= record.len() {
558            if matches!(load.on_error, OnError::Skip) {
559                *bad += 1;
560                return Ok(None);
561            } else {
562                anyhow::bail!("missing column at index {}", idx);
563            }
564        }
565        vs.push(record.get(idx).unwrap_or(""));
566    }
567    Ok(Some(vs))
568}
569
570#[cfg(test)]
571mod tests {
572    use super::*;
573
574    #[test]
575    fn parse_new_style_sqldb_provider() {
576        let dict = EnvDict::default();
577        let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(
578            r#"
579version = 2
580
581[provider.sqldb]
582kind = "postgres"
583connection_uri = "postgres://demo:demo@127.0.0.1/demo"
584pool_size = 12
585"#,
586            &dict,
587        )
588        .expect("parse knowdb with sqldb provider");
589
590        let sqldb = conf
591            .provider()
592            .expect("provider")
593            .sqldb
594            .expect("sqldb provider");
595        assert!(matches!(sqldb.kind, SqlProviderKind::Postgres));
596        assert_eq!(sqldb.pool_size, Some(12));
597    }
598
599    #[test]
600    fn parse_new_style_redis_provider() {
601        let dict = EnvDict::default();
602        let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(
603            r#"
604version = 2
605
606[provider.redis]
607connection_uri = "redis://127.0.0.1:6379"
608pool_size = 16
609connect_timeout_ms = 5000
610command_timeout_ms = 200
611"#,
612            &dict,
613        )
614        .expect("parse knowdb with redis provider");
615
616        let redis_cfg = conf
617            .provider()
618            .expect("provider")
619            .redis
620            .expect("redis provider");
621        assert_eq!(redis_cfg.connection_uri, "redis://127.0.0.1:6379");
622        assert_eq!(redis_cfg.pool_size, Some(16));
623        assert_eq!(redis_cfg.connect_timeout_ms, 5000);
624        assert_eq!(redis_cfg.command_timeout_ms, 200);
625    }
626
627    #[test]
628    fn parse_redis_provider_with_default_timeouts() {
629        let dict = EnvDict::default();
630        let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(
631            r#"
632version = 2
633
634[provider.redis]
635connection_uri = "redis://127.0.0.1:6379"
636"#,
637            &dict,
638        )
639        .expect("parse knowdb with redis provider (no timeout fields)");
640
641        let redis_cfg = conf.provider().expect("provider").redis.expect("redis");
642        assert_eq!(redis_cfg.connect_timeout_ms, 3000);
643        assert_eq!(redis_cfg.command_timeout_ms, 100);
644    }
645
646    #[test]
647    fn parse_both_sqldb_and_redis_providers() {
648        let dict = EnvDict::default();
649        let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(
650            r#"
651version = 2
652
653[provider.sqldb]
654kind = "postgres"
655connection_uri = "postgres://demo:demo@127.0.0.1/demo"
656
657[provider.redis]
658connection_uri = "redis://10.0.0.1:6379"
659pool_size = 4
660"#,
661            &dict,
662        )
663        .expect("parse knowdb with both sqldb and redis");
664
665        let provider_cfg = conf.provider().expect("provider");
666        let sqldb = provider_cfg.sqldb.expect("sqldb");
667        let redis_cfg = provider_cfg.redis.expect("redis");
668        assert!(matches!(sqldb.kind, SqlProviderKind::Postgres));
669        assert_eq!(redis_cfg.connection_uri, "redis://10.0.0.1:6379");
670        assert_eq!(redis_cfg.pool_size, Some(4));
671    }
672
673    #[test]
674    fn parse_redis_only_without_sqldb() {
675        let dict = EnvDict::default();
676        let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(
677            r#"
678version = 2
679
680[provider.redis]
681connection_uri = "redis://127.0.0.1:6379"
682"#,
683            &dict,
684        )
685        .expect("parse knowdb with redis only");
686
687        let provider_cfg = conf.provider().expect("provider");
688        assert!(provider_cfg.sqldb.is_none());
689        assert!(provider_cfg.redis.is_some());
690    }
691
692    #[test]
693    fn parse_no_provider_section() {
694        let dict = EnvDict::default();
695        let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(
696            r#"
697version = 2
698"#,
699            &dict,
700        )
701        .expect("parse knowdb without provider");
702
703        assert!(conf.provider().is_none());
704    }
705
706    #[test]
707    fn new_style_sqldb_mysql_variant() {
708        let dict = EnvDict::default();
709        let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(
710            r#"
711version = 2
712
713[provider.sqldb]
714kind = "mysql"
715connection_uri = "mysql://user:pass@127.0.0.1:3306/db"
716pool_size = 8
717"#,
718            &dict,
719        )
720        .expect("parse new-style mysql sqldb");
721
722        let sqldb = conf.provider().expect("provider").sqldb.expect("sqldb");
723        assert!(matches!(sqldb.kind, SqlProviderKind::Mysql));
724        assert_eq!(sqldb.pool_size, Some(8));
725    }
726
727    #[test]
728    fn parse_cache_spec_with_defaults() {
729        let dict = EnvDict::default();
730        let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(
731            r#"
732version = 2
733"#,
734            &dict,
735        )
736        .expect("parse knowdb with default cache spec");
737
738        assert!(conf.cache.enabled);
739        assert_eq!(conf.cache.capacity, 1024);
740        assert_eq!(conf.cache.ttl_ms, 30_000);
741    }
742
743    #[test]
744    fn parse_cache_spec_from_toml() {
745        let dict = EnvDict::default();
746        let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(
747            r#"
748version = 2
749
750[cache]
751enabled = false
752capacity = 256
753ttl_ms = 1500
754"#,
755            &dict,
756        )
757        .expect("parse knowdb with cache spec");
758
759        assert!(!conf.cache.enabled);
760        assert_eq!(conf.cache.capacity, 256);
761        assert_eq!(conf.cache.ttl_ms, 1500);
762    }
763
764    #[test]
765    fn parse_redis_cache_spec() {
766        let dict = EnvDict::default();
767        let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(
768            r#"
769version = 2
770
771[cache]
772enabled = true
773capacity = 512
774"#,
775            &dict,
776        )
777        .expect("parse knowdb with cache");
778
779        assert!(conf.redis_cache_enabled("k"));
780        assert_eq!(conf.cache.capacity, 512);
781    }
782
783    #[test]
784    fn parse_redis_cache_defaults() {
785        let dict = EnvDict::default();
786        let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(
787            r#"
788version = 2
789"#,
790            &dict,
791        )
792        .expect("parse knowdb without redis.cache");
793
794        // No [cache] → defaults enabled=true, capacity=1024
795        assert!(conf.redis_cache_enabled("k"));
796        assert_eq!(conf.cache.capacity, 1024);
797    }
798
799    #[test]
800    fn parse_redis_key_cache_overrides() {
801        let dict = EnvDict::default();
802        let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(
803            r#"
804version = 2
805
806[cache]
807enabled = true
808capacity = 512
809
810[[cache.redis_key]]
811key = "weak_passwords"
812enabled = false
813
814[[cache.redis_key]]
815key = "black_domains"
816"#,
817            &dict,
818        )
819        .expect("parse knowdb with redis_key");
820
821        // Per-key disabled
822        assert!(!conf.redis_cache_enabled("weak_passwords"));
823        // Not listed → inherits from [cache].enabled
824        assert!(conf.redis_cache_enabled("black_domains"));
825        assert!(conf.redis_cache_enabled("other_key"));
826        assert_eq!(conf.cache.capacity, 512);
827    }
828}