sqlitegraph_cli/
reasoning.rs

1use std::{
2    fs,
3    io::{BufRead, BufReader, Read},
4    path::Path,
5};
6
7use serde_json::{json, Map, Value};
8
9use crate::dsl::{parse_dsl, DslResult};
10use sqlitegraph::{
11    pipeline::{ReasoningPipeline, ReasoningStep},
12    safety::{run_deep_safety_checks, run_integrity_sweep, run_safety_checks, SafetyReport},
13    subgraph::{structural_signature, SubgraphRequest},
14    BackendClient, SqliteGraphError,
15};
16
17const ERR_PREFIX: &str = "cli";
18
19pub fn handle_command(
20    client: &BackendClient,
21    command: &str,
22    args: &[String],
23) -> Result<Option<String>, SqliteGraphError> {
24    match command {
25        "subgraph" => run_subgraph(client, args).map(Some),
26        "pipeline" => run_pipeline(client, args).map(Some),
27        "explain-pipeline" => run_explain_pipeline(client, args).map(Some),
28        "dsl-parse" => run_dsl_parse(args).map(Some),
29        "safety-check" => run_safety_check(client, args).map(Some),
30        "metrics" => run_metrics(client, args).map(Some),
31        _ => Ok(None),
32    }
33}
34
35fn run_subgraph(client: &BackendClient, args: &[String]) -> Result<String, SqliteGraphError> {
36    let root = parse_required_i64(args, "--root")?;
37    let depth = parse_optional_u32(args, "--depth").unwrap_or(1);
38    let (edge_types, node_types) = parse_type_filters(args)?;
39    let mut edge_filters = edge_types.clone();
40    edge_filters.sort();
41    edge_filters.dedup();
42    let mut node_filters = node_types.clone();
43    node_filters.sort();
44    node_filters.dedup();
45    let request = SubgraphRequest {
46        root,
47        depth,
48        allowed_edge_types: edge_types,
49        allowed_node_types: node_types,
50    };
51    let subgraph = client.subgraph(request)?;
52    let edges = subgraph
53        .edges
54        .iter()
55        .map(|(from, to, ty)| json!({"from": from, "to": to, "type": ty}))
56        .collect::<Vec<_>>();
57    let signature = structural_signature(&subgraph);
58    let mut object = Map::new();
59    object.insert("command".into(), Value::String("subgraph".into()));
60    object.insert("root".into(), json!(root));
61    object.insert("depth".into(), json!(depth));
62    object.insert("nodes".into(), json!(subgraph.nodes));
63    object.insert("edges".into(), Value::Array(edges));
64    object.insert("signature".into(), Value::String(signature));
65    object.insert("edge_filters".into(), json!(edge_filters));
66    object.insert("node_filters".into(), json!(node_filters));
67    encode(object)
68}
69
70fn run_pipeline(client: &BackendClient, args: &[String]) -> Result<String, SqliteGraphError> {
71    let expr = pipeline_expression(args)?;
72    let pipeline = pipeline_from_expression(&expr)?;
73    let result = client.run_pipeline(pipeline)?;
74    let scores = result
75        .scores
76        .iter()
77        .map(|(node, score)| json!({"node": node, "score": score}))
78        .collect::<Vec<_>>();
79    let mut object = Map::new();
80    object.insert("command".into(), Value::String("pipeline".into()));
81    object.insert("dsl".into(), Value::String(expr));
82    object.insert("nodes".into(), json!(result.nodes));
83    object.insert("scores".into(), Value::Array(scores));
84    encode(object)
85}
86
87fn run_metrics(client: &BackendClient, args: &[String]) -> Result<String, SqliteGraphError> {
88    let graph = client.backend().graph();
89    if has_flag(args, "--reset-metrics") {
90        graph.reset_metrics();
91    }
92    let snapshot = graph.metrics_snapshot();
93    let mut object = Map::new();
94    object.insert("command".into(), Value::String("metrics".into()));
95    object.insert("prepare_count".into(), json!(snapshot.prepare_count));
96    object.insert("execute_count".into(), json!(snapshot.execute_count));
97    object.insert("tx_begin_count".into(), json!(snapshot.tx_begin_count));
98    object.insert("tx_commit_count".into(), json!(snapshot.tx_commit_count));
99    object.insert(
100        "tx_rollback_count".into(),
101        json!(snapshot.tx_rollback_count),
102    );
103    object.insert(
104        "prepare_cache_hits".into(),
105        json!(snapshot.prepare_cache_hits),
106    );
107    object.insert(
108        "prepare_cache_misses".into(),
109        json!(snapshot.prepare_cache_misses),
110    );
111    encode(object)
112}
113
114fn run_explain_pipeline(
115    client: &BackendClient,
116    args: &[String],
117) -> Result<String, SqliteGraphError> {
118    let expr = pipeline_expression(args)?;
119    let pipeline = pipeline_from_expression(&expr)?;
120    let explanation = client.explain_pipeline(pipeline)?;
121    let mut object = Map::new();
122    object.insert("command".into(), Value::String("explain-pipeline".into()));
123    object.insert("dsl".into(), Value::String(expr));
124    object.insert("steps_summary".into(), json!(explanation.steps_summary));
125    object.insert(
126        "node_counts".into(),
127        json!(explanation.node_counts_per_step),
128    );
129    object.insert("filters".into(), json!(explanation.filters_applied));
130    object.insert("scoring".into(), json!(explanation.scoring_notes));
131    encode(object)
132}
133
134fn run_safety_check(client: &BackendClient, args: &[String]) -> Result<String, SqliteGraphError> {
135    let strict = args.iter().any(|arg| arg == "--strict");
136    let deep = args.iter().any(|arg| arg == "--deep");
137    let sweep = args.iter().any(|arg| arg == "--sweep");
138    let report = if deep {
139        run_deep_safety_checks(client.backend().graph())?
140    } else {
141        run_safety_checks(client.backend().graph())?
142    };
143    let sweep_issues = if sweep {
144        run_integrity_sweep(client.backend().graph())?
145    } else {
146        Vec::new()
147    };
148    if strict && (report.has_issues() || !sweep_issues.is_empty()) {
149        return Err(invalid(format!(
150            "safety violations detected: orphan_edges={} duplicate_edges={} invalid_labels={} invalid_properties={} sweep_issues={}",
151            report.orphan_edges,
152            report.duplicate_edges,
153            report.invalid_labels,
154            report.invalid_properties,
155            sweep_issues.len(),
156        )));
157    }
158    let mut object = Map::new();
159    object.insert("command".into(), Value::String("safety-check".into()));
160    object.insert("report".into(), report_to_value(&report, &sweep_issues));
161    encode(object)
162}
163
164fn report_to_value(report: &SafetyReport, sweep_issues: &[String]) -> Value {
165    let mut inner = Map::new();
166    inner.insert("total_nodes".into(), json!(report.total_nodes));
167    inner.insert("total_edges".into(), json!(report.total_edges));
168    inner.insert("orphan_edges".into(), json!(report.orphan_edges));
169    inner.insert("duplicate_edges".into(), json!(report.duplicate_edges));
170    inner.insert("invalid_labels".into(), json!(report.invalid_labels));
171    inner.insert(
172        "invalid_properties".into(),
173        json!(report.invalid_properties),
174    );
175    inner.insert("integrity_errors".into(), json!(report.integrity_errors));
176    inner.insert(
177        "integrity_messages".into(),
178        json!(report.integrity_messages),
179    );
180    inner.insert("sweep_issues".into(), json!(sweep_issues));
181    Value::Object(inner)
182}
183
184fn run_dsl_parse(args: &[String]) -> Result<String, SqliteGraphError> {
185    let input = required_value(args, "--input")?;
186    let result = parse_dsl(&input);
187    let summary = summarize_dsl(result)?;
188    let mut object = Map::new();
189    object.insert("command".into(), Value::String("dsl-parse".into()));
190    object.insert("result".into(), summary);
191    encode(object)
192}
193
194fn parse_type_filters(args: &[String]) -> Result<(Vec<String>, Vec<String>), SqliteGraphError> {
195    let mut edges = Vec::new();
196    let mut nodes = Vec::new();
197    let mut iter = args.iter();
198    while let Some(arg) = iter.next() {
199        if arg == "--types" {
200            let value = iter
201                .next()
202                .ok_or_else(|| invalid("--types requires key=value"))?
203                .clone();
204            if let Some((key, val)) = value.split_once('=') {
205                match key {
206                    "edge" => edges.push(val.trim().to_string()),
207                    "node" => nodes.push(val.trim().to_string()),
208                    _ => return Err(invalid("--types key must be edge or node")),
209                }
210            } else {
211                return Err(invalid("--types expects key=value"));
212            }
213        }
214    }
215    Ok((edges, nodes))
216}
217
218fn pipeline_expression(args: &[String]) -> Result<String, SqliteGraphError> {
219    let dsl = value(args, "--dsl");
220    let file = value(args, "--file");
221    match (dsl, file) {
222        (Some(expr), None) => Ok(expr),
223        (None, Some(path)) => read_pipeline_file(&path),
224        (Some(_), Some(_)) => Err(invalid("provide only one of --dsl or --file")),
225        _ => Err(invalid("pipeline requires --dsl or --file")),
226    }
227}
228
229fn pipeline_from_expression(expr: &str) -> Result<ReasoningPipeline, SqliteGraphError> {
230    match parse_dsl(expr) {
231        DslResult::Pipeline(pipeline) => Ok(pipeline),
232        DslResult::Pattern(pattern) => Ok(ReasoningPipeline {
233            steps: vec![ReasoningStep::Pattern(pattern)],
234        }),
235        DslResult::Error(msg) => Err(invalid(msg)),
236        DslResult::Subgraph(_) => Err(invalid("DSL describes a subgraph, not a pipeline")),
237    }
238}
239
240fn read_pipeline_file(path: &str) -> Result<String, SqliteGraphError> {
241    let file = fs::File::open(Path::new(path))
242        .map_err(|e| invalid(format!("unable to read pipeline file: {e}")))?;
243    read_pipeline_reader(file)
244}
245
246fn read_pipeline_reader<R: Read>(reader: R) -> Result<String, SqliteGraphError> {
247    let mut buf = BufReader::new(reader);
248    match peek_non_whitespace(&mut buf)? {
249        None => Err(invalid("pipeline file is empty")),
250        Some(b'{') => read_pipeline_json(buf),
251        _ => read_pipeline_plain(buf),
252    }
253}
254
255fn read_pipeline_json<R: Read>(reader: R) -> Result<String, SqliteGraphError> {
256    let mut stream = serde_json::Deserializer::from_reader(reader).into_iter::<Value>();
257    let first = stream
258        .next()
259        .ok_or_else(|| invalid("pipeline json must contain a 'dsl' string"))
260        .and_then(|result| result.map_err(|e| invalid(format!("invalid pipeline json: {e}"))))?;
261    if let Some(expr) = first.get("dsl").and_then(|v| v.as_str()) {
262        Ok(expr.to_string())
263    } else {
264        Err(invalid("pipeline json must contain a 'dsl' string"))
265    }
266}
267
268fn read_pipeline_plain<R: Read>(mut reader: R) -> Result<String, SqliteGraphError> {
269    let mut contents = String::new();
270    reader
271        .read_to_string(&mut contents)
272        .map_err(|e| invalid(format!("unable to read pipeline file: {e}")))?;
273    let trimmed = contents.trim();
274    if trimmed.is_empty() {
275        Err(invalid("pipeline file is empty"))
276    } else {
277        Ok(trimmed.to_string())
278    }
279}
280
281fn peek_non_whitespace<R: BufRead>(reader: &mut R) -> Result<Option<u8>, SqliteGraphError> {
282    loop {
283        let buffer_len = {
284            let buffer = reader
285                .fill_buf()
286                .map_err(|e| invalid(format!("unable to read pipeline file: {e}")))?;
287            if buffer.is_empty() {
288                return Ok(None);
289            }
290            let mut idx = 0;
291            while idx < buffer.len() {
292                let byte = buffer[idx];
293                if byte.is_ascii_whitespace() {
294                    idx += 1;
295                    continue;
296                }
297                reader.consume(idx);
298                return Ok(Some(byte));
299            }
300            buffer.len()
301        };
302        reader.consume(buffer_len);
303    }
304}
305
306fn summarize_dsl(result: DslResult) -> Result<Value, SqliteGraphError> {
307    match result {
308        DslResult::Pattern(pattern) => Ok(json!({
309            "type": "pattern",
310            "legs": pattern.legs.len(),
311        })),
312        DslResult::Pipeline(pipeline) => Ok(json!({
313            "type": "pipeline",
314            "steps": pipeline.steps.len(),
315        })),
316        DslResult::Subgraph(request) => Ok(json!({
317            "type": "subgraph",
318            "depth": request.depth,
319            "edge_types": request.allowed_edge_types.len(),
320            "node_types": request.allowed_node_types.len(),
321        })),
322        DslResult::Error(msg) => Err(invalid(msg)),
323    }
324}
325
326fn parse_required_i64(args: &[String], flag: &str) -> Result<i64, SqliteGraphError> {
327    let value = required_value(args, flag)?;
328    value
329        .parse::<i64>()
330        .map_err(|_| invalid(format!("{flag} expects an integer")))
331}
332
333fn parse_optional_u32(args: &[String], flag: &str) -> Option<u32> {
334    value(args, flag)?.parse::<u32>().ok()
335}
336
337fn required_value(args: &[String], flag: &str) -> Result<String, SqliteGraphError> {
338    value(args, flag).ok_or_else(|| invalid(format!("missing {flag}")))
339}
340
341fn value(args: &[String], flag: &str) -> Option<String> {
342    let mut iter = args.iter();
343    while let Some(arg) = iter.next() {
344        if arg == flag {
345            return iter.next().cloned();
346        }
347    }
348    None
349}
350
351fn has_flag(args: &[String], flag: &str) -> bool {
352    args.iter().any(|arg| arg == flag)
353}
354
355fn encode(object: Map<String, Value>) -> Result<String, SqliteGraphError> {
356    serde_json::to_string(&Value::Object(object))
357        .map_err(|e| invalid(format!("{ERR_PREFIX} serialization failed: {e}")))
358}
359
360fn invalid<T: Into<String>>(message: T) -> SqliteGraphError {
361    SqliteGraphError::invalid_input(message.into())
362}