1use std::fs;
2use std::io::Read;
3use std::path::{Path, PathBuf};
4
5use orion_conf::EnvTomlLoad;
6use serde::Deserialize;
7use wp_log::info_ctrl;
8
9use crate::mem::memdb::MemDB;
10use orion_error::{ContextRecord, ErrorOwe, OperationContext, ToStructError, UvsFrom};
11use orion_variate::EnvDict;
12use rusqlite::OpenFlags;
13use wp_error::{KnowledgeReason, KnowledgeResult};
14
15#[derive(Debug, Deserialize)]
18pub struct KnowDbConf {
19 pub version: u32,
20 #[serde(default = "default_dot")]
21 pub base_dir: String,
22 #[serde(default)]
23 pub default: OptLoadSpec,
24 #[serde(default)]
25 pub csv: CsvSpec,
26 #[serde(default)]
27 pub cache: CacheSpec,
28 #[serde(default)]
29 pub provider: Option<ProviderSpec>,
30 #[serde(default)]
31 pub tables: Vec<TableSpec>,
32}
33
34#[derive(Debug, Clone, Deserialize)]
35pub struct CacheSpec {
36 #[serde(default = "default_true")]
37 pub enabled: bool,
38 #[serde(default = "default_result_cache_capacity")]
39 pub capacity: usize,
40 #[serde(default = "default_result_cache_ttl_ms")]
41 pub ttl_ms: u64,
42}
43
44impl Default for CacheSpec {
45 fn default() -> Self {
46 Self {
47 enabled: default_true(),
48 capacity: default_result_cache_capacity(),
49 ttl_ms: default_result_cache_ttl_ms(),
50 }
51 }
52}
53
54#[derive(Debug, Clone, Deserialize)]
55#[serde(rename_all = "snake_case")]
56pub enum ProviderKind {
57 SqliteAuthority,
58 Postgres,
59 Mysql,
60}
61
62#[derive(Debug, Clone, Deserialize)]
63pub struct ProviderSpec {
64 pub kind: ProviderKind,
65 pub connection_uri: String,
66 #[serde(default)]
67 pub pool_size: Option<u32>,
68 #[serde(default)]
69 pub min_connections: Option<u32>,
70 #[serde(default)]
71 pub acquire_timeout_ms: Option<u64>,
72 #[serde(default)]
73 pub idle_timeout_ms: Option<u64>,
74 #[serde(default)]
75 pub max_lifetime_ms: Option<u64>,
76}
77
78#[derive(Debug, Clone, Deserialize)]
79pub struct OptLoadSpec {
80 #[serde(default = "default_true")]
81 pub transaction: bool,
82 #[serde(default = "default_batch")]
83 pub batch_size: usize,
84 #[serde(default = "default_on_error")]
85 pub on_error: OnError,
86}
87impl Default for OptLoadSpec {
88 fn default() -> Self {
89 Self {
90 transaction: true,
91 batch_size: default_batch(),
92 on_error: default_on_error(),
93 }
94 }
95}
96
97#[derive(Debug, Clone, Deserialize, Default)]
98#[serde(rename_all = "lowercase")]
99pub enum OnError {
100 #[default]
101 Fail,
102 Skip,
103}
104
105#[derive(Debug, Clone, Deserialize)]
106pub struct CsvSpec {
107 #[serde(default = "default_true")]
108 pub has_header: bool,
109 #[serde(default = "default_comma")]
110 pub delimiter: String,
111 #[serde(default = "default_utf8")]
112 pub encoding: String,
113 #[serde(default = "default_true")]
114 pub trim: bool,
115}
116impl Default for CsvSpec {
117 fn default() -> Self {
118 CsvSpec {
119 has_header: true,
120 delimiter: ",".into(),
121 encoding: "utf-8".into(),
122 trim: true,
123 }
124 }
125}
126
127#[derive(Debug, Clone, Deserialize)]
128pub struct TableSpec {
129 pub name: String,
130 #[serde(default)]
131 pub dir: Option<String>,
132 #[serde(default)]
133 pub data_file: Option<String>,
134 pub columns: ColumnsSpec,
135 #[serde(default)]
136 pub expected_rows: RowExpect,
137 #[serde(default = "default_true")]
138 pub enabled: bool,
139}
140
141#[derive(Debug, Clone, Deserialize)]
142pub struct ColumnsSpec {
143 #[serde(default)]
144 pub by_header: Vec<String>,
145 #[serde(default)]
146 pub by_index: Vec<usize>,
147}
148
149#[derive(Debug, Clone, Deserialize, Default)]
150pub struct RowExpect {
151 pub min: Option<usize>,
152 pub max: Option<usize>,
153}
154
155const fn default_true() -> bool {
156 true
157}
158const fn default_batch() -> usize {
159 2000
160}
161fn default_comma() -> String {
162 ",".to_string()
163}
164fn default_utf8() -> String {
165 "utf-8".to_string()
166}
167fn default_on_error() -> OnError {
168 OnError::Fail
169}
170fn default_dot() -> String {
171 ".".to_string()
172}
173const fn default_result_cache_capacity() -> usize {
174 1024
175}
176const fn default_result_cache_ttl_ms() -> u64 {
177 30_000
178}
179
180fn read_to_string(path: &Path) -> KnowledgeResult<String> {
182 let mut f = fs::File::open(path).owe_res()?;
183 let mut buf = String::new();
184 f.read_to_string(&mut buf).owe_res()?;
185 Ok(buf)
186}
187
188fn replace_table(sql: &str, table: &str) -> String {
189 sql.replace("{table}", table)
190}
191
192fn join_rel(base: &Path, rel: &str) -> PathBuf {
193 let p = Path::new(rel);
194 if p.is_absolute() {
195 p.to_path_buf()
196 } else {
197 base.join(p)
198 }
199}
200
201pub fn build_authority_from_knowdb(
202 root: &Path,
203 conf_path: &Path,
204 authority_uri: &str,
205 dict: &EnvDict,
206) -> KnowledgeResult<Vec<String>> {
207 let mut opx = OperationContext::want("build authority from knowdb").with_auto_log();
208 let (conf, conf_abs, base_dir) = parse_knowdb_conf(root, conf_path, dict)?;
210 opx.record("conf", &conf_abs);
211 opx.record("base_dir", &base_dir);
212 let db = open_authority(authority_uri)?;
214 let mut loaded_names = Vec::new();
216 for t in &conf.tables {
217 if !t.enabled {
218 continue;
219 }
220 load_one_table(&db, &base_dir, t, &conf.csv, &conf.default)?;
221 info_ctrl!("load table {} suc!", base_dir.display(),);
222 loaded_names.push(t.name.clone());
223 }
224 opx.mark_suc();
225 Ok(loaded_names)
226}
227
228pub fn parse_knowdb_conf(
229 root: &Path,
230 conf_path: &Path,
231 dict: &EnvDict,
232) -> KnowledgeResult<(KnowDbConf, PathBuf, PathBuf)> {
233 let conf_abs = if conf_path.is_absolute() {
234 conf_path.to_path_buf()
235 } else {
236 root.join(conf_path)
237 };
238 let conf_txt = read_to_string(&conf_abs)?;
239 let conf: KnowDbConf =
240 <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(&conf_txt, dict).owe_conf()?;
241 if conf.version != 2 {
242 return Err(KnowledgeReason::from_conf()
243 .to_err()
244 .with_detail("unsupported knowdb.version"));
245 }
246 let conf_dir = conf_abs.parent().unwrap_or_else(|| Path::new("."));
247 let base_dir = join_rel(conf_dir, &conf.base_dir);
248 Ok((conf, conf_abs, base_dir))
249}
250
251fn open_authority(authority_uri: &str) -> KnowledgeResult<MemDB> {
252 ensure_parent_dir_for_file_uri(authority_uri);
253 let flags = OpenFlags::SQLITE_OPEN_READ_WRITE
254 | OpenFlags::SQLITE_OPEN_CREATE
255 | OpenFlags::SQLITE_OPEN_URI;
256 let db = MemDB::new_file(authority_uri, 1, flags)?;
257 let _ = db.with_conn(|conn| {
259 let _ = crate::sqlite_ext::register_builtin(conn);
260 Ok::<(), anyhow::Error>(())
261 });
262 Ok(db)
263}
264
265fn ensure_parent_dir_for_file_uri(uri: &str) {
268 if let Some(rest) = uri.strip_prefix("file:") {
269 let path_part = rest.split('?').next().unwrap_or(rest);
270 let p = Path::new(path_part);
271 if let Some(parent) = p.parent() {
272 let _ = fs::create_dir_all(parent);
273 }
274 }
275}
276
277fn load_one_table(
278 db: &MemDB,
279 base_dir: &Path,
280 t: &TableSpec,
281 csvd: &CsvSpec,
282 load: &OptLoadSpec,
283) -> KnowledgeResult<()> {
284 let mut opx = OperationContext::want("load table to kdb")
286 .with_auto_log()
287 .with_mod_path("ctrl");
288 let dir_name: &str = t.dir.as_deref().unwrap_or(&t.name);
289 let table_dir = base_dir.join(dir_name);
290 opx.record("table_dir", &table_dir);
291 let create_sql = replace_table(&read_to_string(&table_dir.join("create.sql"))?, &t.name);
292 let insert_sql = replace_table(&read_to_string(&table_dir.join("insert.sql"))?, &t.name);
293 let clean_path = table_dir.join("clean.sql");
294 let clean_sql = if clean_path.exists() {
295 replace_table(&read_to_string(&clean_path)?, &t.name)
296 } else {
297 format!("DELETE FROM {}", &t.name)
298 };
299
300 db.with_conn(|conn| {
302 let _ = crate::sqlite_ext::register_builtin(conn);
304 conn.execute_batch(&create_sql)?;
305 conn.execute_batch(&clean_sql)?;
306 Ok::<(), anyhow::Error>(())
307 })
308 .owe_res()?;
309
310 let data_path = match &t.data_file {
312 Some(rel) => join_rel(&table_dir, rel),
313 None => table_dir.join("data.csv"),
314 };
315 if !data_path.exists() {
316 return Err(KnowledgeReason::from_conf()
317 .to_err()
318 .with_detail("data.csv not found"));
319 }
320 opx.record("data_path", &data_path);
321
322 let mut rdr = build_csv_reader(csvd, &data_path)?;
324
325 let col_indices: Vec<usize> = if !t.columns.by_header.is_empty() {
327 let headers = rdr.headers().owe_res()?;
328 select_indices_by_header(headers, &t.columns.by_header)?
329 } else if !t.columns.by_index.is_empty() {
330 t.columns.by_index.clone()
331 } else {
332 return Err(KnowledgeReason::from_conf()
333 .to_err()
334 .with_detail("columns mapping required"));
335 };
336
337 let mut inserted: usize = 0;
339 let mut bad: usize = 0;
340 let mut batch_left = load.batch_size.max(1);
341 db.with_conn(|conn| {
342 let _ = crate::sqlite_ext::register_builtin(conn);
344 let mut tx = if load.transaction {
345 Some(conn.unchecked_transaction()?)
346 } else {
347 None
348 };
349 let mut stmt = conn.prepare(&insert_sql)?;
350 for rec in rdr.into_records() {
351 match rec {
352 Ok(record) => {
353 let refs = extract_row_refs(&record, &col_indices, &mut bad, load)?;
354 if let Some(refs) = refs {
355 stmt.execute(rusqlite::params_from_iter(refs))?;
356 inserted += 1;
357 if load.transaction {
358 batch_left -= 1;
359 if batch_left == 0 {
360 tx.take().unwrap().commit()?;
361 tx = Some(conn.unchecked_transaction()?);
362 batch_left = load.batch_size.max(1);
363 }
364 }
365 }
366 }
367 Err(_e) => {
368 if matches!(load.on_error, OnError::Skip) {
369 bad += 1;
370 continue;
371 } else {
372 anyhow::bail!("csv record parse error");
373 }
374 }
375 }
376 }
377 if let Some(tx) = tx {
378 tx.commit()?;
379 }
380 Ok::<(), anyhow::Error>(())
381 })
382 .owe_res()?;
383
384 if let Some(min) = t.expected_rows.min
386 && inserted < min
387 {
388 return Err(KnowledgeReason::from_conf()
389 .to_err()
390 .with_detail("table data less"));
391 }
392 if let Some(max) = t.expected_rows.max
393 && inserted > max
394 {
395 wp_log::warn_kdb!(
396 "table {} loaded rows {} exceed max {}",
397 &t.name,
398 inserted,
399 max
400 );
401 }
402 if bad > 0 {
403 wp_log::warn_kdb!("table {} skipped {} bad rows (on_error=skip)", &t.name, bad);
404 }
405 opx.mark_suc();
406 Ok(())
407}
408
409fn build_csv_reader(
410 csvd: &CsvSpec,
411 data_path: &Path,
412) -> KnowledgeResult<csv::Reader<std::fs::File>> {
413 if csvd.encoding.to_lowercase() != "utf-8" {
414 return Err(KnowledgeReason::from_conf()
415 .to_err()
416 .with_detail("only utf-8 csv is supported"));
417 }
418 let mut rdr_b = csv::ReaderBuilder::new();
419 rdr_b.has_headers(csvd.has_header);
420 if csvd.delimiter.len() == 1 {
421 rdr_b.delimiter(csvd.delimiter.as_bytes()[0]);
422 }
423 if csvd.trim {
424 rdr_b.trim(csv::Trim::All);
425 }
426 rdr_b.from_path(data_path).owe_res()
427}
428
429fn select_indices_by_header(
430 headers: &csv::StringRecord,
431 wanted: &[String],
432) -> KnowledgeResult<Vec<usize>> {
433 let mut out = Vec::with_capacity(wanted.len());
434 for name in wanted {
435 let pos = headers.iter().position(|h| h == name).ok_or_else(|| {
436 KnowledgeReason::from_conf()
437 .to_err()
438 .with_detail("header not found")
439 })?;
440 out.push(pos);
441 }
442 Ok(out)
443}
444
445fn extract_row_refs<'a>(
446 record: &'a csv::StringRecord,
447 col_indices: &[usize],
448 bad: &mut usize,
449 load: &OptLoadSpec,
450) -> anyhow::Result<Option<Vec<&'a str>>> {
451 let mut vs: Vec<&str> = Vec::with_capacity(col_indices.len());
452 for &idx in col_indices {
453 if idx >= record.len() {
454 if matches!(load.on_error, OnError::Skip) {
455 *bad += 1;
456 return Ok(None);
457 } else {
458 anyhow::bail!("missing column at index {}", idx);
459 }
460 }
461 vs.push(record.get(idx).unwrap_or(""));
462 }
463 Ok(Some(vs))
464}
465
466#[cfg(test)]
467mod tests {
468 use super::*;
469
470 #[test]
471 fn parse_external_provider_spec() {
472 let dict = EnvDict::default();
473 let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(
474 r#"
475version = 2
476
477[provider]
478kind = "postgres"
479connection_uri = "postgres://demo:demo@127.0.0.1/demo"
480min_connections = 2
481acquire_timeout_ms = 1500
482idle_timeout_ms = 30000
483max_lifetime_ms = 60000
484"#,
485 &dict,
486 )
487 .expect("parse knowdb with provider");
488
489 assert!(conf.tables.is_empty());
490 let provider = conf.provider.expect("provider");
491 assert!(matches!(provider.kind, ProviderKind::Postgres));
492 assert_eq!(
493 provider.connection_uri,
494 "postgres://demo:demo@127.0.0.1/demo"
495 );
496 assert_eq!(provider.min_connections, Some(2));
497 assert_eq!(provider.acquire_timeout_ms, Some(1500));
498 assert_eq!(provider.idle_timeout_ms, Some(30000));
499 assert_eq!(provider.max_lifetime_ms, Some(60000));
500 }
501
502 #[test]
503 fn parse_mysql_provider_spec() {
504 let dict = EnvDict::default();
505 let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(
506 r#"
507version = 2
508
509[provider]
510kind = "mysql"
511connection_uri = "mysql://demo:demo@127.0.0.1:3306/demo"
512pool_size = 12
513min_connections = 3
514acquire_timeout_ms = 2500
515idle_timeout_ms = 45000
516max_lifetime_ms = 120000
517"#,
518 &dict,
519 )
520 .expect("parse knowdb with mysql provider");
521
522 let provider = conf.provider.expect("provider");
523 assert!(matches!(provider.kind, ProviderKind::Mysql));
524 assert_eq!(
525 provider.connection_uri,
526 "mysql://demo:demo@127.0.0.1:3306/demo"
527 );
528 assert_eq!(provider.pool_size, Some(12));
529 assert_eq!(provider.min_connections, Some(3));
530 assert_eq!(provider.acquire_timeout_ms, Some(2500));
531 assert_eq!(provider.idle_timeout_ms, Some(45000));
532 assert_eq!(provider.max_lifetime_ms, Some(120000));
533 }
534
535 #[test]
536 fn parse_cache_spec_with_defaults() {
537 let dict = EnvDict::default();
538 let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(
539 r#"
540version = 2
541"#,
542 &dict,
543 )
544 .expect("parse knowdb with default cache spec");
545
546 assert!(conf.cache.enabled);
547 assert_eq!(conf.cache.capacity, 1024);
548 assert_eq!(conf.cache.ttl_ms, 30_000);
549 }
550
551 #[test]
552 fn parse_cache_spec_from_toml() {
553 let dict = EnvDict::default();
554 let conf: KnowDbConf = <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(
555 r#"
556version = 2
557
558[cache]
559enabled = false
560capacity = 256
561ttl_ms = 1500
562"#,
563 &dict,
564 )
565 .expect("parse knowdb with cache spec");
566
567 assert!(!conf.cache.enabled);
568 assert_eq!(conf.cache.capacity, 256);
569 assert_eq!(conf.cache.ttl_ms, 1500);
570 }
571}