1use std::collections::{HashMap, HashSet, VecDeque};
40
41use serde::{Deserialize, Serialize};
42
43use crate::bundle::Mechanism;
44use crate::causal_graph::CausalGraph;
45use crate::project::Project;
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct CounterfactualQuery {
52 pub intervene_on: String,
53 pub set_to: f64,
54 pub target: String,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
59#[serde(tag = "kind", rename_all = "snake_case")]
60pub enum CounterfactualVerdict {
61 Resolved {
63 factual: f64,
65 counterfactual: f64,
67 delta: f64,
69 paths_used: Vec<Vec<String>>,
72 },
73 MechanismUnspecified {
77 unspecified_edges: Vec<(String, String)>,
80 },
81 NoCausalPath { factual: f64 },
84 UnknownNode { which: String },
86 InvalidIntervention { reason: String },
88}
89
90#[must_use]
104pub fn answer_counterfactual(
105 project: &Project,
106 query: &CounterfactualQuery,
107) -> CounterfactualVerdict {
108 if !(0.0..=1.0).contains(&query.set_to) {
109 return CounterfactualVerdict::InvalidIntervention {
110 reason: format!(
111 "intervention must be on the confidence axis [0,1], got {}",
112 query.set_to
113 ),
114 };
115 }
116
117 let confidence_index = build_confidence_index(project);
118 let factual_target = match confidence_index.get(&query.target) {
119 Some(&c) => c,
120 None => {
121 return CounterfactualVerdict::UnknownNode {
122 which: query.target.clone(),
123 };
124 }
125 };
126 let factual_source = match confidence_index.get(&query.intervene_on) {
127 Some(&c) => c,
128 None => {
129 return CounterfactualVerdict::UnknownNode {
130 which: query.intervene_on.clone(),
131 };
132 }
133 };
134
135 let graph = CausalGraph::from_project(project);
136 if !graph.contains(&query.intervene_on) {
137 return CounterfactualVerdict::UnknownNode {
138 which: query.intervene_on.clone(),
139 };
140 }
141 if !graph.contains(&query.target) {
142 return CounterfactualVerdict::UnknownNode {
143 which: query.target.clone(),
144 };
145 }
146
147 let paths = directed_paths_from_to(&graph, &query.intervene_on, &query.target, 8);
150 if paths.is_empty() {
151 return CounterfactualVerdict::NoCausalPath {
152 factual: factual_target,
153 };
154 }
155
156 let mech_index = build_mechanism_index(project);
158
159 let mut unspecified_edges: HashSet<(String, String)> = HashSet::new();
160 let mut path_deltas: Vec<f64> = Vec::new();
161 let mut paths_used: Vec<Vec<String>> = Vec::new();
162
163 let delta_x = query.set_to - factual_source;
164
165 for path in &paths {
166 let mut delta = delta_x;
174 let mut path_ok = true;
175 for window in path.windows(2) {
176 let parent = &window[0];
177 let child = &window[1];
178 match mech_index.get(&(parent.clone(), child.clone())) {
179 Some(m) => match m.apply(delta) {
180 Some(next_delta) => delta = next_delta,
181 None => {
182 unspecified_edges.insert((parent.clone(), child.clone()));
183 path_ok = false;
184 break;
185 }
186 },
187 None => {
188 unspecified_edges.insert((parent.clone(), child.clone()));
189 path_ok = false;
190 break;
191 }
192 }
193 }
194 if path_ok {
195 path_deltas.push(delta);
196 paths_used.push(path.clone());
197 }
198 }
199
200 if path_deltas.is_empty() {
201 let mut edges: Vec<(String, String)> = unspecified_edges.into_iter().collect();
202 edges.sort();
203 return CounterfactualVerdict::MechanismUnspecified {
204 unspecified_edges: edges,
205 };
206 }
207
208 let aggregate_delta = path_deltas
214 .iter()
215 .copied()
216 .fold(0.0_f64, |acc, d| if d.abs() > acc.abs() { d } else { acc });
217
218 let counterfactual = (factual_target + aggregate_delta).clamp(0.0, 1.0);
219 CounterfactualVerdict::Resolved {
220 factual: factual_target,
221 counterfactual,
222 delta: counterfactual - factual_target,
223 paths_used,
224 }
225}
226
227fn directed_paths_from_to(
230 graph: &CausalGraph,
231 cause: &str,
232 effect: &str,
233 max_depth: usize,
234) -> Vec<Vec<String>> {
235 const MAX_PATHS: usize = 32;
236 let mut out: Vec<Vec<String>> = Vec::new();
237 let mut queue: VecDeque<Vec<String>> = VecDeque::new();
238 queue.push_back(vec![cause.to_string()]);
239
240 while let Some(path) = queue.pop_front() {
241 if out.len() >= MAX_PATHS {
242 break;
243 }
244 if path.len() > max_depth {
245 continue;
246 }
247 let last = path.last().expect("path non-empty");
248 if last == effect && path.len() > 1 {
249 out.push(path);
250 continue;
251 }
252 for child in graph.children_of(last) {
253 let child_owned = child.to_string();
254 if path.contains(&child_owned) {
255 continue; }
257 let mut next = path.clone();
258 next.push(child_owned);
259 queue.push_back(next);
260 }
261 }
262 out
263}
264
265fn build_confidence_index(project: &Project) -> HashMap<String, f64> {
266 let mut idx = HashMap::new();
267 for finding in &project.findings {
268 idx.insert(finding.id.clone(), finding.confidence.score);
269 }
270 idx
271}
272
273fn build_mechanism_index(project: &Project) -> HashMap<(String, String), Mechanism> {
279 let mut idx = HashMap::new();
280 for finding in &project.findings {
281 for link in &finding.links {
282 if !matches!(link.link_type.as_str(), "depends" | "supports") {
283 continue;
284 }
285 let target = match link.target.split_once(':') {
288 Some((_, rest)) => rest.to_string(),
289 None => link.target.clone(),
290 };
291 if let Some(m) = link.mechanism {
292 idx.insert((target, finding.id.clone()), m);
293 }
294 }
295 }
296 idx
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302 use crate::bundle::{
303 Assertion, Conditions, Confidence, Evidence, Extraction, FindingBundle, Flags, Link,
304 Mechanism, MechanismSign, Provenance,
305 };
306 use crate::project;
307
308 fn conditions() -> Conditions {
309 Conditions {
310 text: String::new(),
311 species_verified: vec![],
312 species_unverified: vec![],
313 in_vitro: false,
314 in_vivo: false,
315 human_data: false,
316 clinical_trial: false,
317 concentration_range: None,
318 duration: None,
319 age_group: None,
320 cell_type: None,
321 }
322 }
323
324 fn provenance() -> Provenance {
325 Provenance {
326 source_type: "published_paper".into(),
327 doi: None,
328 pmid: None,
329 pmc: None,
330 openalex_id: None,
331 url: None,
332 title: "Test".into(),
333 authors: vec![],
334 year: Some(2025),
335 journal: None,
336 license: None,
337 publisher: None,
338 funders: vec![],
339 extraction: Extraction::default(),
340 review: None,
341 citation_count: None,
342 }
343 }
344
345 fn finding(id: &str, conf: f64, links: Vec<Link>) -> FindingBundle {
346 let mut b = FindingBundle::new(
347 Assertion {
348 text: format!("claim {id}"),
349 assertion_type: "mechanism".into(),
350 entities: vec![],
351 relation: None,
352 direction: None,
353 causal_claim: None,
354 causal_evidence_grade: None,
355 },
356 Evidence {
357 evidence_type: "experimental".into(),
358 model_system: String::new(),
359 species: None,
360 method: String::new(),
361 sample_size: None,
362 effect_size: None,
363 p_value: None,
364 replicated: false,
365 replication_count: None,
366 evidence_spans: vec![],
367 },
368 conditions(),
369 Confidence::raw(conf, "test", 0.85),
370 provenance(),
371 Flags::default(),
372 );
373 b.id = id.to_string();
374 b.links = links;
375 b
376 }
377
378 fn link_with_mechanism(target: &str, mech: Option<Mechanism>) -> Link {
379 Link {
380 target: target.into(),
381 link_type: "depends".into(),
382 note: String::new(),
383 inferred_by: "test".into(),
384 created_at: String::new(),
385 mechanism: mech,
386 }
387 }
388
389 fn fixture_chain(ab: Option<Mechanism>, bc: Option<Mechanism>) -> Project {
392 let a = finding("vf_aaa", 0.9, vec![]);
393 let b = finding("vf_bbb", 0.8, vec![link_with_mechanism("vf_aaa", ab)]);
394 let c = finding("vf_ccc", 0.7, vec![link_with_mechanism("vf_bbb", bc)]);
395 project::assemble("test", vec![a, b, c], 1, 0, "test")
396 }
397
398 #[test]
399 fn linear_chain_resolves() {
400 let project = fixture_chain(
401 Some(Mechanism::Linear {
402 sign: MechanismSign::Positive,
403 slope: 0.5,
404 }),
405 Some(Mechanism::Linear {
406 sign: MechanismSign::Positive,
407 slope: 0.4,
408 }),
409 );
410 let q = CounterfactualQuery {
411 intervene_on: "vf_aaa".into(),
412 set_to: 0.5,
413 target: "vf_ccc".into(),
414 };
415 let v = answer_counterfactual(&project, &q);
416 match v {
417 CounterfactualVerdict::Resolved {
418 factual,
419 counterfactual,
420 delta,
421 ..
422 } => {
423 assert!((factual - 0.7).abs() < 1e-9);
424 assert!((delta - (-0.08)).abs() < 1e-6, "delta = {delta}");
426 assert!(counterfactual > 0.0 && counterfactual < 1.0);
427 }
428 _ => panic!("expected Resolved, got {v:?}"),
429 }
430 }
431
432 #[test]
433 fn missing_mechanism_blocks_propagation() {
434 let project = fixture_chain(
435 Some(Mechanism::Linear {
436 sign: MechanismSign::Positive,
437 slope: 0.5,
438 }),
439 None,
440 );
441 let q = CounterfactualQuery {
442 intervene_on: "vf_aaa".into(),
443 set_to: 0.5,
444 target: "vf_ccc".into(),
445 };
446 let v = answer_counterfactual(&project, &q);
447 assert!(matches!(
448 v,
449 CounterfactualVerdict::MechanismUnspecified { .. }
450 ));
451 }
452
453 #[test]
454 fn unknown_mechanism_blocks_propagation() {
455 let project = fixture_chain(
456 Some(Mechanism::Linear {
457 sign: MechanismSign::Positive,
458 slope: 0.5,
459 }),
460 Some(Mechanism::Unknown),
461 );
462 let q = CounterfactualQuery {
463 intervene_on: "vf_aaa".into(),
464 set_to: 0.5,
465 target: "vf_ccc".into(),
466 };
467 let v = answer_counterfactual(&project, &q);
468 assert!(matches!(
469 v,
470 CounterfactualVerdict::MechanismUnspecified { .. }
471 ));
472 }
473
474 #[test]
475 fn out_of_range_intervention_rejected() {
476 let project = fixture_chain(None, None);
477 let q = CounterfactualQuery {
478 intervene_on: "vf_aaa".into(),
479 set_to: 1.5,
480 target: "vf_ccc".into(),
481 };
482 assert!(matches!(
483 answer_counterfactual(&project, &q),
484 CounterfactualVerdict::InvalidIntervention { .. }
485 ));
486 }
487
488 #[test]
489 fn no_path_yields_factual() {
490 let project = fixture_chain(None, None);
491 let q = CounterfactualQuery {
492 intervene_on: "vf_ccc".into(), set_to: 0.5,
494 target: "vf_aaa".into(),
495 };
496 match answer_counterfactual(&project, &q) {
497 CounterfactualVerdict::NoCausalPath { factual } => {
498 assert!((factual - 0.9).abs() < 1e-9);
499 }
500 v => panic!("expected NoCausalPath, got {v:?}"),
501 }
502 }
503
504 #[test]
505 fn negative_sign_flips_direction() {
506 let project = fixture_chain(
507 Some(Mechanism::Linear {
508 sign: MechanismSign::Negative,
509 slope: 0.5,
510 }),
511 Some(Mechanism::Linear {
512 sign: MechanismSign::Positive,
513 slope: 1.0,
514 }),
515 );
516 let q = CounterfactualQuery {
521 intervene_on: "vf_aaa".into(),
522 set_to: 1.0,
523 target: "vf_ccc".into(),
524 };
525 match answer_counterfactual(&project, &q) {
526 CounterfactualVerdict::Resolved {
527 counterfactual,
528 delta,
529 ..
530 } => {
531 assert!((delta - (-0.05)).abs() < 1e-6, "delta = {delta}");
532 assert!((counterfactual - 0.65).abs() < 1e-6);
533 }
534 v => panic!("expected Resolved, got {v:?}"),
535 }
536 }
537}