1use std::fs;
2use std::io::Read;
3use std::path::{Path, PathBuf};
4
5use orion_conf::EnvTomlLoad;
6use serde::Deserialize;
7use wp_log::info_ctrl;
8
9use crate::mem::memdb::MemDB;
10use orion_error::{ContextRecord, ErrorOwe, OperationContext, ToStructError, UvsFrom};
11use orion_variate::EnvDict;
12use rusqlite::OpenFlags;
13use wp_error::{KnowledgeReason, KnowledgeResult};
14
15#[derive(Debug, Deserialize)]
18pub struct KnowDbConf {
19 pub version: u32,
20 #[serde(default = "default_dot")]
21 pub base_dir: String,
22 #[serde(default)]
23 pub default: OptLoadSpec,
24 #[serde(default)]
25 pub csv: CsvSpec,
26 pub tables: Vec<TableSpec>,
27}
28
29#[derive(Debug, Clone, Deserialize)]
30pub struct OptLoadSpec {
31 #[serde(default = "default_true")]
32 pub transaction: bool,
33 #[serde(default = "default_batch")]
34 pub batch_size: usize,
35 #[serde(default = "default_on_error")]
36 pub on_error: OnError,
37}
38impl Default for OptLoadSpec {
39 fn default() -> Self {
40 Self {
41 transaction: true,
42 batch_size: default_batch(),
43 on_error: default_on_error(),
44 }
45 }
46}
47
48#[derive(Debug, Clone, Deserialize, Default)]
49#[serde(rename_all = "lowercase")]
50pub enum OnError {
51 #[default]
52 Fail,
53 Skip,
54}
55
56#[derive(Debug, Clone, Deserialize)]
57pub struct CsvSpec {
58 #[serde(default = "default_true")]
59 pub has_header: bool,
60 #[serde(default = "default_comma")]
61 pub delimiter: String,
62 #[serde(default = "default_utf8")]
63 pub encoding: String,
64 #[serde(default = "default_true")]
65 pub trim: bool,
66}
67impl Default for CsvSpec {
68 fn default() -> Self {
69 CsvSpec {
70 has_header: true,
71 delimiter: ",".into(),
72 encoding: "utf-8".into(),
73 trim: true,
74 }
75 }
76}
77
78#[derive(Debug, Clone, Deserialize)]
79pub struct TableSpec {
80 pub name: String,
81 #[serde(default)]
82 pub dir: Option<String>,
83 #[serde(default)]
84 pub data_file: Option<String>,
85 pub columns: ColumnsSpec,
86 #[serde(default)]
87 pub expected_rows: RowExpect,
88 #[serde(default = "default_true")]
89 pub enabled: bool,
90}
91
92#[derive(Debug, Clone, Deserialize)]
93pub struct ColumnsSpec {
94 #[serde(default)]
95 pub by_header: Vec<String>,
96 #[serde(default)]
97 pub by_index: Vec<usize>,
98}
99
100#[derive(Debug, Clone, Deserialize, Default)]
101pub struct RowExpect {
102 pub min: Option<usize>,
103 pub max: Option<usize>,
104}
105
106const fn default_true() -> bool {
107 true
108}
109const fn default_batch() -> usize {
110 2000
111}
112fn default_comma() -> String {
113 ",".to_string()
114}
115fn default_utf8() -> String {
116 "utf-8".to_string()
117}
118fn default_on_error() -> OnError {
119 OnError::Fail
120}
121fn default_dot() -> String {
122 ".".to_string()
123}
124
125fn read_to_string(path: &Path) -> KnowledgeResult<String> {
127 let mut f = fs::File::open(path).owe_res()?;
128 let mut buf = String::new();
129 f.read_to_string(&mut buf).owe_res()?;
130 Ok(buf)
131}
132
133fn replace_table(sql: &str, table: &str) -> String {
134 sql.replace("{table}", table)
135}
136
137fn join_rel(base: &Path, rel: &str) -> PathBuf {
138 let p = Path::new(rel);
139 if p.is_absolute() {
140 p.to_path_buf()
141 } else {
142 base.join(p)
143 }
144}
145
146pub fn build_authority_from_knowdb(
147 root: &Path,
148 conf_path: &Path,
149 authority_uri: &str,
150 dict: &EnvDict,
151) -> KnowledgeResult<Vec<String>> {
152 let mut opx = OperationContext::want("build authority from knowdb").with_auto_log();
153 let (conf, conf_abs, base_dir) = parse_knowdb_conf(root, conf_path, dict)?;
155 opx.record("conf", &conf_abs);
156 opx.record("base_dir", &base_dir);
157 let db = open_authority(authority_uri)?;
159 let mut loaded_names = Vec::new();
161 for t in &conf.tables {
162 if !t.enabled {
163 continue;
164 }
165 load_one_table(&db, &base_dir, t, &conf.csv, &conf.default)?;
166 info_ctrl!("load table {} suc!", base_dir.display(),);
167 loaded_names.push(t.name.clone());
168 }
169 opx.mark_suc();
170 Ok(loaded_names)
171}
172
173fn parse_knowdb_conf(
174 root: &Path,
175 conf_path: &Path,
176 dict: &EnvDict,
177) -> KnowledgeResult<(KnowDbConf, PathBuf, PathBuf)> {
178 let conf_abs = if conf_path.is_absolute() {
179 conf_path.to_path_buf()
180 } else {
181 root.join(conf_path)
182 };
183 let conf_txt = read_to_string(&conf_abs)?;
184 let conf: KnowDbConf =
185 <KnowDbConf as EnvTomlLoad<KnowDbConf>>::env_parse_toml(&conf_txt, dict).owe_conf()?;
186 if conf.version != 2 {
187 return Err(KnowledgeReason::from_conf()
188 .to_err()
189 .with_detail("unsupported knowdb.version"));
190 }
191 let conf_dir = conf_abs.parent().unwrap_or_else(|| Path::new("."));
192 let base_dir = join_rel(conf_dir, &conf.base_dir);
193 Ok((conf, conf_abs, base_dir))
194}
195
196fn open_authority(authority_uri: &str) -> KnowledgeResult<MemDB> {
197 ensure_parent_dir_for_file_uri(authority_uri);
198 let flags = OpenFlags::SQLITE_OPEN_READ_WRITE
199 | OpenFlags::SQLITE_OPEN_CREATE
200 | OpenFlags::SQLITE_OPEN_URI;
201 let db = MemDB::new_file(authority_uri, 1, flags)?;
202 let _ = db.with_conn(|conn| {
204 let _ = crate::sqlite_ext::register_builtin(conn);
205 Ok::<(), anyhow::Error>(())
206 });
207 Ok(db)
208}
209
210fn ensure_parent_dir_for_file_uri(uri: &str) {
213 if let Some(rest) = uri.strip_prefix("file:") {
214 let path_part = rest.split('?').next().unwrap_or(rest);
215 let p = Path::new(path_part);
216 if let Some(parent) = p.parent() {
217 let _ = fs::create_dir_all(parent);
218 }
219 }
220}
221
222fn load_one_table(
223 db: &MemDB,
224 base_dir: &Path,
225 t: &TableSpec,
226 csvd: &CsvSpec,
227 load: &OptLoadSpec,
228) -> KnowledgeResult<()> {
229 let mut opx = OperationContext::want("load table to kdb")
231 .with_auto_log()
232 .with_mod_path("ctrl");
233 let dir_name: &str = t.dir.as_deref().unwrap_or(&t.name);
234 let table_dir = base_dir.join(dir_name);
235 opx.record("table_dir", &table_dir);
236 let create_sql = replace_table(&read_to_string(&table_dir.join("create.sql"))?, &t.name);
237 let insert_sql = replace_table(&read_to_string(&table_dir.join("insert.sql"))?, &t.name);
238 let clean_path = table_dir.join("clean.sql");
239 let clean_sql = if clean_path.exists() {
240 replace_table(&read_to_string(&clean_path)?, &t.name)
241 } else {
242 format!("DELETE FROM {}", &t.name)
243 };
244
245 db.with_conn(|conn| {
247 let _ = crate::sqlite_ext::register_builtin(conn);
249 conn.execute_batch(&create_sql)?;
250 conn.execute_batch(&clean_sql)?;
251 Ok::<(), anyhow::Error>(())
252 })
253 .owe_res()?;
254
255 let data_path = match &t.data_file {
257 Some(rel) => join_rel(&table_dir, rel),
258 None => table_dir.join("data.csv"),
259 };
260 if !data_path.exists() {
261 return Err(KnowledgeReason::from_conf()
262 .to_err()
263 .with_detail("data.csv not found"));
264 }
265 opx.record("data_path", &data_path);
266
267 let mut rdr = build_csv_reader(csvd, &data_path)?;
269
270 let col_indices: Vec<usize> = if !t.columns.by_header.is_empty() {
272 let headers = rdr.headers().owe_res()?;
273 select_indices_by_header(headers, &t.columns.by_header)?
274 } else if !t.columns.by_index.is_empty() {
275 t.columns.by_index.clone()
276 } else {
277 return Err(KnowledgeReason::from_conf()
278 .to_err()
279 .with_detail("columns mapping required"));
280 };
281
282 let mut inserted: usize = 0;
284 let mut bad: usize = 0;
285 let mut batch_left = load.batch_size.max(1);
286 db.with_conn(|conn| {
287 let _ = crate::sqlite_ext::register_builtin(conn);
289 let mut tx = if load.transaction {
290 Some(conn.unchecked_transaction()?)
291 } else {
292 None
293 };
294 let mut stmt = conn.prepare(&insert_sql)?;
295 for rec in rdr.into_records() {
296 match rec {
297 Ok(record) => {
298 let refs = extract_row_refs(&record, &col_indices, &mut bad, load)?;
299 if let Some(refs) = refs {
300 stmt.execute(rusqlite::params_from_iter(refs))?;
301 inserted += 1;
302 if load.transaction {
303 batch_left -= 1;
304 if batch_left == 0 {
305 tx.take().unwrap().commit()?;
306 tx = Some(conn.unchecked_transaction()?);
307 batch_left = load.batch_size.max(1);
308 }
309 }
310 }
311 }
312 Err(_e) => {
313 if matches!(load.on_error, OnError::Skip) {
314 bad += 1;
315 continue;
316 } else {
317 anyhow::bail!("csv record parse error");
318 }
319 }
320 }
321 }
322 if let Some(tx) = tx {
323 tx.commit()?;
324 }
325 Ok::<(), anyhow::Error>(())
326 })
327 .owe_res()?;
328
329 if let Some(min) = t.expected_rows.min
331 && inserted < min
332 {
333 return Err(KnowledgeReason::from_conf()
334 .to_err()
335 .with_detail("table data less"));
336 }
337 if let Some(max) = t.expected_rows.max
338 && inserted > max
339 {
340 wp_log::warn_kdb!(
341 "table {} loaded rows {} exceed max {}",
342 &t.name,
343 inserted,
344 max
345 );
346 }
347 if bad > 0 {
348 wp_log::warn_kdb!("table {} skipped {} bad rows (on_error=skip)", &t.name, bad);
349 }
350 opx.mark_suc();
351 Ok(())
352}
353
354fn build_csv_reader(
355 csvd: &CsvSpec,
356 data_path: &Path,
357) -> KnowledgeResult<csv::Reader<std::fs::File>> {
358 if csvd.encoding.to_lowercase() != "utf-8" {
359 return Err(KnowledgeReason::from_conf()
360 .to_err()
361 .with_detail("only utf-8 csv is supported"));
362 }
363 let mut rdr_b = csv::ReaderBuilder::new();
364 rdr_b.has_headers(csvd.has_header);
365 if csvd.delimiter.len() == 1 {
366 rdr_b.delimiter(csvd.delimiter.as_bytes()[0]);
367 }
368 if csvd.trim {
369 rdr_b.trim(csv::Trim::All);
370 }
371 rdr_b.from_path(data_path).owe_res()
372}
373
374fn select_indices_by_header(
375 headers: &csv::StringRecord,
376 wanted: &[String],
377) -> KnowledgeResult<Vec<usize>> {
378 let mut out = Vec::with_capacity(wanted.len());
379 for name in wanted {
380 let pos = headers.iter().position(|h| h == name).ok_or_else(|| {
381 KnowledgeReason::from_conf()
382 .to_err()
383 .with_detail("header not found")
384 })?;
385 out.push(pos);
386 }
387 Ok(out)
388}
389
390fn extract_row_refs<'a>(
391 record: &'a csv::StringRecord,
392 col_indices: &[usize],
393 bad: &mut usize,
394 load: &OptLoadSpec,
395) -> anyhow::Result<Option<Vec<&'a str>>> {
396 let mut vs: Vec<&str> = Vec::with_capacity(col_indices.len());
397 for &idx in col_indices {
398 if idx >= record.len() {
399 if matches!(load.on_error, OnError::Skip) {
400 *bad += 1;
401 return Ok(None);
402 } else {
403 anyhow::bail!("missing column at index {}", idx);
405 }
406 }
407 vs.push(record.get(idx).unwrap_or(""));
408 }
409 Ok(Some(vs))
410}