1use std::{
2 fs,
3 io::{BufRead, BufReader, Read},
4 path::Path,
5};
6
7use serde_json::{json, Map, Value};
8
9use crate::dsl::{parse_dsl, DslResult};
10use sqlitegraph::{
11 pipeline::{ReasoningPipeline, ReasoningStep},
12 safety::{run_deep_safety_checks, run_integrity_sweep, run_safety_checks, SafetyReport},
13 subgraph::{structural_signature, SubgraphRequest},
14 BackendClient, SqliteGraphError,
15};
16
17const ERR_PREFIX: &str = "cli";
18
19pub fn handle_command(
20 client: &BackendClient,
21 command: &str,
22 args: &[String],
23) -> Result<Option<String>, SqliteGraphError> {
24 match command {
25 "subgraph" => run_subgraph(client, args).map(Some),
26 "pipeline" => run_pipeline(client, args).map(Some),
27 "explain-pipeline" => run_explain_pipeline(client, args).map(Some),
28 "dsl-parse" => run_dsl_parse(args).map(Some),
29 "safety-check" => run_safety_check(client, args).map(Some),
30 "metrics" => run_metrics(client, args).map(Some),
31 _ => Ok(None),
32 }
33}
34
35fn run_subgraph(client: &BackendClient, args: &[String]) -> Result<String, SqliteGraphError> {
36 let root = parse_required_i64(args, "--root")?;
37 let depth = parse_optional_u32(args, "--depth").unwrap_or(1);
38 let (edge_types, node_types) = parse_type_filters(args)?;
39 let mut edge_filters = edge_types.clone();
40 edge_filters.sort();
41 edge_filters.dedup();
42 let mut node_filters = node_types.clone();
43 node_filters.sort();
44 node_filters.dedup();
45 let request = SubgraphRequest {
46 root,
47 depth,
48 allowed_edge_types: edge_types,
49 allowed_node_types: node_types,
50 };
51 let subgraph = client.subgraph(request)?;
52 let edges = subgraph
53 .edges
54 .iter()
55 .map(|(from, to, ty)| json!({"from": from, "to": to, "type": ty}))
56 .collect::<Vec<_>>();
57 let signature = structural_signature(&subgraph);
58 let mut object = Map::new();
59 object.insert("command".into(), Value::String("subgraph".into()));
60 object.insert("root".into(), json!(root));
61 object.insert("depth".into(), json!(depth));
62 object.insert("nodes".into(), json!(subgraph.nodes));
63 object.insert("edges".into(), Value::Array(edges));
64 object.insert("signature".into(), Value::String(signature));
65 object.insert("edge_filters".into(), json!(edge_filters));
66 object.insert("node_filters".into(), json!(node_filters));
67 encode(object)
68}
69
70fn run_pipeline(client: &BackendClient, args: &[String]) -> Result<String, SqliteGraphError> {
71 let expr = pipeline_expression(args)?;
72 let pipeline = pipeline_from_expression(&expr)?;
73 let result = client.run_pipeline(pipeline)?;
74 let scores = result
75 .scores
76 .iter()
77 .map(|(node, score)| json!({"node": node, "score": score}))
78 .collect::<Vec<_>>();
79 let mut object = Map::new();
80 object.insert("command".into(), Value::String("pipeline".into()));
81 object.insert("dsl".into(), Value::String(expr));
82 object.insert("nodes".into(), json!(result.nodes));
83 object.insert("scores".into(), Value::Array(scores));
84 encode(object)
85}
86
87fn run_metrics(client: &BackendClient, args: &[String]) -> Result<String, SqliteGraphError> {
88 let graph = client.backend().graph();
89 if has_flag(args, "--reset-metrics") {
90 graph.reset_metrics();
91 }
92 let snapshot = graph.metrics_snapshot();
93 let mut object = Map::new();
94 object.insert("command".into(), Value::String("metrics".into()));
95 object.insert("prepare_count".into(), json!(snapshot.prepare_count));
96 object.insert("execute_count".into(), json!(snapshot.execute_count));
97 object.insert("tx_begin_count".into(), json!(snapshot.tx_begin_count));
98 object.insert("tx_commit_count".into(), json!(snapshot.tx_commit_count));
99 object.insert(
100 "tx_rollback_count".into(),
101 json!(snapshot.tx_rollback_count),
102 );
103 object.insert(
104 "prepare_cache_hits".into(),
105 json!(snapshot.prepare_cache_hits),
106 );
107 object.insert(
108 "prepare_cache_misses".into(),
109 json!(snapshot.prepare_cache_misses),
110 );
111 encode(object)
112}
113
114fn run_explain_pipeline(
115 client: &BackendClient,
116 args: &[String],
117) -> Result<String, SqliteGraphError> {
118 let expr = pipeline_expression(args)?;
119 let pipeline = pipeline_from_expression(&expr)?;
120 let explanation = client.explain_pipeline(pipeline)?;
121 let mut object = Map::new();
122 object.insert("command".into(), Value::String("explain-pipeline".into()));
123 object.insert("dsl".into(), Value::String(expr));
124 object.insert("steps_summary".into(), json!(explanation.steps_summary));
125 object.insert(
126 "node_counts".into(),
127 json!(explanation.node_counts_per_step),
128 );
129 object.insert("filters".into(), json!(explanation.filters_applied));
130 object.insert("scoring".into(), json!(explanation.scoring_notes));
131 encode(object)
132}
133
134fn run_safety_check(client: &BackendClient, args: &[String]) -> Result<String, SqliteGraphError> {
135 let strict = args.iter().any(|arg| arg == "--strict");
136 let deep = args.iter().any(|arg| arg == "--deep");
137 let sweep = args.iter().any(|arg| arg == "--sweep");
138 let report = if deep {
139 run_deep_safety_checks(client.backend().graph())?
140 } else {
141 run_safety_checks(client.backend().graph())?
142 };
143 let sweep_issues = if sweep {
144 run_integrity_sweep(client.backend().graph())?
145 } else {
146 Vec::new()
147 };
148 if strict && (report.has_issues() || !sweep_issues.is_empty()) {
149 return Err(invalid(format!(
150 "safety violations detected: orphan_edges={} duplicate_edges={} invalid_labels={} invalid_properties={} sweep_issues={}",
151 report.orphan_edges,
152 report.duplicate_edges,
153 report.invalid_labels,
154 report.invalid_properties,
155 sweep_issues.len(),
156 )));
157 }
158 let mut object = Map::new();
159 object.insert("command".into(), Value::String("safety-check".into()));
160 object.insert("report".into(), report_to_value(&report, &sweep_issues));
161 encode(object)
162}
163
164fn report_to_value(report: &SafetyReport, sweep_issues: &[String]) -> Value {
165 let mut inner = Map::new();
166 inner.insert("total_nodes".into(), json!(report.total_nodes));
167 inner.insert("total_edges".into(), json!(report.total_edges));
168 inner.insert("orphan_edges".into(), json!(report.orphan_edges));
169 inner.insert("duplicate_edges".into(), json!(report.duplicate_edges));
170 inner.insert("invalid_labels".into(), json!(report.invalid_labels));
171 inner.insert(
172 "invalid_properties".into(),
173 json!(report.invalid_properties),
174 );
175 inner.insert("integrity_errors".into(), json!(report.integrity_errors));
176 inner.insert(
177 "integrity_messages".into(),
178 json!(report.integrity_messages),
179 );
180 inner.insert("sweep_issues".into(), json!(sweep_issues));
181 Value::Object(inner)
182}
183
184fn run_dsl_parse(args: &[String]) -> Result<String, SqliteGraphError> {
185 let input = required_value(args, "--input")?;
186 let result = parse_dsl(&input);
187 let summary = summarize_dsl(result)?;
188 let mut object = Map::new();
189 object.insert("command".into(), Value::String("dsl-parse".into()));
190 object.insert("result".into(), summary);
191 encode(object)
192}
193
194fn parse_type_filters(args: &[String]) -> Result<(Vec<String>, Vec<String>), SqliteGraphError> {
195 let mut edges = Vec::new();
196 let mut nodes = Vec::new();
197 let mut iter = args.iter();
198 while let Some(arg) = iter.next() {
199 if arg == "--types" {
200 let value = iter
201 .next()
202 .ok_or_else(|| invalid("--types requires key=value"))?
203 .clone();
204 if let Some((key, val)) = value.split_once('=') {
205 match key {
206 "edge" => edges.push(val.trim().to_string()),
207 "node" => nodes.push(val.trim().to_string()),
208 _ => return Err(invalid("--types key must be edge or node")),
209 }
210 } else {
211 return Err(invalid("--types expects key=value"));
212 }
213 }
214 }
215 Ok((edges, nodes))
216}
217
218fn pipeline_expression(args: &[String]) -> Result<String, SqliteGraphError> {
219 let dsl = value(args, "--dsl");
220 let file = value(args, "--file");
221 match (dsl, file) {
222 (Some(expr), None) => Ok(expr),
223 (None, Some(path)) => read_pipeline_file(&path),
224 (Some(_), Some(_)) => Err(invalid("provide only one of --dsl or --file")),
225 _ => Err(invalid("pipeline requires --dsl or --file")),
226 }
227}
228
229fn pipeline_from_expression(expr: &str) -> Result<ReasoningPipeline, SqliteGraphError> {
230 match parse_dsl(expr) {
231 DslResult::Pipeline(pipeline) => Ok(pipeline),
232 DslResult::Pattern(pattern) => Ok(ReasoningPipeline {
233 steps: vec![ReasoningStep::Pattern(pattern)],
234 }),
235 DslResult::Error(msg) => Err(invalid(msg)),
236 DslResult::Subgraph(_) => Err(invalid("DSL describes a subgraph, not a pipeline")),
237 }
238}
239
240fn read_pipeline_file(path: &str) -> Result<String, SqliteGraphError> {
241 let file = fs::File::open(Path::new(path))
242 .map_err(|e| invalid(format!("unable to read pipeline file: {e}")))?;
243 read_pipeline_reader(file)
244}
245
246fn read_pipeline_reader<R: Read>(reader: R) -> Result<String, SqliteGraphError> {
247 let mut buf = BufReader::new(reader);
248 match peek_non_whitespace(&mut buf)? {
249 None => Err(invalid("pipeline file is empty")),
250 Some(b'{') => read_pipeline_json(buf),
251 _ => read_pipeline_plain(buf),
252 }
253}
254
255fn read_pipeline_json<R: Read>(reader: R) -> Result<String, SqliteGraphError> {
256 let mut stream = serde_json::Deserializer::from_reader(reader).into_iter::<Value>();
257 let first = stream
258 .next()
259 .ok_or_else(|| invalid("pipeline json must contain a 'dsl' string"))
260 .and_then(|result| result.map_err(|e| invalid(format!("invalid pipeline json: {e}"))))?;
261 if let Some(expr) = first.get("dsl").and_then(|v| v.as_str()) {
262 Ok(expr.to_string())
263 } else {
264 Err(invalid("pipeline json must contain a 'dsl' string"))
265 }
266}
267
268fn read_pipeline_plain<R: Read>(mut reader: R) -> Result<String, SqliteGraphError> {
269 let mut contents = String::new();
270 reader
271 .read_to_string(&mut contents)
272 .map_err(|e| invalid(format!("unable to read pipeline file: {e}")))?;
273 let trimmed = contents.trim();
274 if trimmed.is_empty() {
275 Err(invalid("pipeline file is empty"))
276 } else {
277 Ok(trimmed.to_string())
278 }
279}
280
281fn peek_non_whitespace<R: BufRead>(reader: &mut R) -> Result<Option<u8>, SqliteGraphError> {
282 loop {
283 let buffer_len = {
284 let buffer = reader
285 .fill_buf()
286 .map_err(|e| invalid(format!("unable to read pipeline file: {e}")))?;
287 if buffer.is_empty() {
288 return Ok(None);
289 }
290 let mut idx = 0;
291 while idx < buffer.len() {
292 let byte = buffer[idx];
293 if byte.is_ascii_whitespace() {
294 idx += 1;
295 continue;
296 }
297 reader.consume(idx);
298 return Ok(Some(byte));
299 }
300 buffer.len()
301 };
302 reader.consume(buffer_len);
303 }
304}
305
306fn summarize_dsl(result: DslResult) -> Result<Value, SqliteGraphError> {
307 match result {
308 DslResult::Pattern(pattern) => Ok(json!({
309 "type": "pattern",
310 "legs": pattern.legs.len(),
311 })),
312 DslResult::Pipeline(pipeline) => Ok(json!({
313 "type": "pipeline",
314 "steps": pipeline.steps.len(),
315 })),
316 DslResult::Subgraph(request) => Ok(json!({
317 "type": "subgraph",
318 "depth": request.depth,
319 "edge_types": request.allowed_edge_types.len(),
320 "node_types": request.allowed_node_types.len(),
321 })),
322 DslResult::Error(msg) => Err(invalid(msg)),
323 }
324}
325
326fn parse_required_i64(args: &[String], flag: &str) -> Result<i64, SqliteGraphError> {
327 let value = required_value(args, flag)?;
328 value
329 .parse::<i64>()
330 .map_err(|_| invalid(format!("{flag} expects an integer")))
331}
332
333fn parse_optional_u32(args: &[String], flag: &str) -> Option<u32> {
334 value(args, flag)?.parse::<u32>().ok()
335}
336
337fn required_value(args: &[String], flag: &str) -> Result<String, SqliteGraphError> {
338 value(args, flag).ok_or_else(|| invalid(format!("missing {flag}")))
339}
340
341fn value(args: &[String], flag: &str) -> Option<String> {
342 let mut iter = args.iter();
343 while let Some(arg) = iter.next() {
344 if arg == flag {
345 return iter.next().cloned();
346 }
347 }
348 None
349}
350
351fn has_flag(args: &[String], flag: &str) -> bool {
352 args.iter().any(|arg| arg == flag)
353}
354
355fn encode(object: Map<String, Value>) -> Result<String, SqliteGraphError> {
356 serde_json::to_string(&Value::Object(object))
357 .map_err(|e| invalid(format!("{ERR_PREFIX} serialization failed: {e}")))
358}
359
360fn invalid<T: Into<String>>(message: T) -> SqliteGraphError {
361 SqliteGraphError::invalid_input(message.into())
362}