Skip to main content

vela_protocol/
counterfactual.rs

1//! v0.45: Pearl level 3 — counterfactual queries over the claim graph.
2//!
3//! Level 1 (v0.40) answered identifiability per finding by 3×4 lookup.
4//! Level 2 (v0.44) answered "is the effect of source on target
5//! identifiable from the link graph?" via back-door / front-door
6//! adjustment.
7//! Level 3 (v0.45) answers "given that we observed Y under X=x, what
8//! would Y have been under X=x'?" via twin-network construction.
9//!
10//! ### Method (Pearl 2009, §7)
11//!
12//! A twin network is two copies of the SCM running in parallel: the
13//! *factual* world (what we actually observed) and the *counterfactual*
14//! world (what we would have observed under the intervention). The two
15//! worlds share the same exogenous noise terms but differ at the
16//! intervened node. Propagating perturbations through both worlds and
17//! comparing the target node yields the counterfactual delta.
18//!
19//! In Vela's claim graph the "values" being propagated are belief
20//! confidences in [0,1]. A `Mechanism` (v0.45) on each edge specifies
21//! how a parent's confidence determines the child's. Edges without
22//! mechanisms are treated as opaque — they block counterfactual
23//! propagation through that edge and surface as
24//! `MechanismUnspecified`.
25//!
26//! Doctrine:
27//! - We only answer counterfactuals along paths whose every edge has a
28//!   mechanism. Partial answers are honest about which edges blocked
29//!   propagation.
30//! - We perturb on the [0,1] confidence axis, not on the underlying
31//!   scientific quantity. Vela does not (and should not) infer real-
32//!   world units from prose; it tracks the kernel's first-class
33//!   quantity, belief.
34//! - Twin-network is overkill for claim graphs that are tree-shaped on
35//!   the observed→intervened path; we still implement it because
36//!   real-world claim graphs converge (multiple supports of one
37//!   finding) and a simple forward propagation would double-count.
38
39use std::collections::{HashMap, HashSet, VecDeque};
40
41use serde::{Deserialize, Serialize};
42
43use crate::bundle::Mechanism;
44use crate::causal_graph::CausalGraph;
45use crate::project::Project;
46
47/// A request: "intervene to set finding `vf_id`'s confidence to
48/// `value`, then ask: what is the counterfactual confidence of
49/// `target`?"
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct CounterfactualQuery {
52    pub intervene_on: String,
53    pub set_to: f64,
54    pub target: String,
55}
56
57/// The verdict for a counterfactual query.
58#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
59#[serde(tag = "kind", rename_all = "snake_case")]
60pub enum CounterfactualVerdict {
61    /// Twin-network propagation succeeded end-to-end.
62    Resolved {
63        /// Factual confidence at target (the observed value).
64        factual: f64,
65        /// Counterfactual confidence at target under the intervention.
66        counterfactual: f64,
67        /// `counterfactual − factual`.
68        delta: f64,
69        /// The directed paths from the intervened node to the target
70        /// that propagated the perturbation, in the order discovered.
71        paths_used: Vec<Vec<String>>,
72    },
73    /// The intervened node and the target are connected, but at least
74    /// one edge on every connecting path lacks a mechanism. We refuse
75    /// to guess.
76    MechanismUnspecified {
77        /// Edges (parent → child) along source→target paths with no
78        /// mechanism declared.
79        unspecified_edges: Vec<(String, String)>,
80    },
81    /// No directed path from the intervened node reaches the target.
82    /// Counterfactual is the same as factual by structural assumption.
83    NoCausalPath { factual: f64 },
84    /// One of the cited findings isn't in the graph.
85    UnknownNode { which: String },
86    /// `set_to` is outside [0, 1] (the confidence axis).
87    InvalidIntervention { reason: String },
88}
89
90/// Run a counterfactual query end-to-end.
91///
92/// 1. Validate inputs (nodes exist; intervention in [0,1]).
93/// 2. Build the directed graph and find directed paths from
94///    `intervene_on` to `target`.
95/// 3. For each path, check that every edge has a mechanism. If any
96///    path does, propagate the perturbation along it via mechanism
97///    composition.
98/// 4. Aggregate path contributions (we use *max-magnitude* to avoid
99///    additive double-counting on diamond graphs — this is the
100///    weakest defensible aggregation and keeps us honest about the
101///    structural causal model's limits).
102/// 5. Bound the result to [0, 1].
103#[must_use]
104pub fn answer_counterfactual(
105    project: &Project,
106    query: &CounterfactualQuery,
107) -> CounterfactualVerdict {
108    if !(0.0..=1.0).contains(&query.set_to) {
109        return CounterfactualVerdict::InvalidIntervention {
110            reason: format!(
111                "intervention must be on the confidence axis [0,1], got {}",
112                query.set_to
113            ),
114        };
115    }
116
117    let confidence_index = build_confidence_index(project);
118    let factual_target = match confidence_index.get(&query.target) {
119        Some(&c) => c,
120        None => {
121            return CounterfactualVerdict::UnknownNode {
122                which: query.target.clone(),
123            };
124        }
125    };
126    let factual_source = match confidence_index.get(&query.intervene_on) {
127        Some(&c) => c,
128        None => {
129            return CounterfactualVerdict::UnknownNode {
130                which: query.intervene_on.clone(),
131            };
132        }
133    };
134
135    let graph = CausalGraph::from_project(project);
136    if !graph.contains(&query.intervene_on) {
137        return CounterfactualVerdict::UnknownNode {
138            which: query.intervene_on.clone(),
139        };
140    }
141    if !graph.contains(&query.target) {
142        return CounterfactualVerdict::UnknownNode {
143            which: query.target.clone(),
144        };
145    }
146
147    // Directed paths from intervene_on (cause) to target (effect),
148    // using the v0.44 graph's child-direction edges.
149    let paths = directed_paths_from_to(&graph, &query.intervene_on, &query.target, 8);
150    if paths.is_empty() {
151        return CounterfactualVerdict::NoCausalPath {
152            factual: factual_target,
153        };
154    }
155
156    // Build a mechanism lookup: (parent, child) -> Option<Mechanism>.
157    let mech_index = build_mechanism_index(project);
158
159    let mut unspecified_edges: HashSet<(String, String)> = HashSet::new();
160    let mut path_deltas: Vec<f64> = Vec::new();
161    let mut paths_used: Vec<Vec<String>> = Vec::new();
162
163    let delta_x = query.set_to - factual_source;
164
165    for path in &paths {
166        // Path is [source, ..., target]; `depends`/`supports` edges in
167        // the v0.44 graph point from the *dependent* to the *parent*.
168        // In our convention, "child depends on parent" means a directed
169        // causal edge from parent → child for level-3 propagation. The
170        // CausalGraph lookup is parents_of(child) -> {parent}; so a
171        // forward path from cause→effect in the graph traverses
172        // children_of edges.
173        let mut delta = delta_x;
174        let mut path_ok = true;
175        for window in path.windows(2) {
176            let parent = &window[0];
177            let child = &window[1];
178            match mech_index.get(&(parent.clone(), child.clone())) {
179                Some(m) => match m.apply(delta) {
180                    Some(next_delta) => delta = next_delta,
181                    None => {
182                        unspecified_edges.insert((parent.clone(), child.clone()));
183                        path_ok = false;
184                        break;
185                    }
186                },
187                None => {
188                    unspecified_edges.insert((parent.clone(), child.clone()));
189                    path_ok = false;
190                    break;
191                }
192            }
193        }
194        if path_ok {
195            path_deltas.push(delta);
196            paths_used.push(path.clone());
197        }
198    }
199
200    if path_deltas.is_empty() {
201        let mut edges: Vec<(String, String)> = unspecified_edges.into_iter().collect();
202        edges.sort();
203        return CounterfactualVerdict::MechanismUnspecified {
204            unspecified_edges: edges,
205        };
206    }
207
208    // Aggregate: pick the path delta with maximum absolute magnitude.
209    // This is intentionally conservative — additive aggregation would
210    // double-count on diamond graphs; max-magnitude reports the
211    // strongest single causal route without inventing structural
212    // assumptions we don't have.
213    let aggregate_delta = path_deltas
214        .iter()
215        .copied()
216        .fold(0.0_f64, |acc, d| if d.abs() > acc.abs() { d } else { acc });
217
218    let counterfactual = (factual_target + aggregate_delta).clamp(0.0, 1.0);
219    CounterfactualVerdict::Resolved {
220        factual: factual_target,
221        counterfactual,
222        delta: counterfactual - factual_target,
223        paths_used,
224    }
225}
226
227/// BFS-enumerate directed paths cause→effect using `children_of` (the
228/// "downstream" direction). Bounded by `max_depth` and `max_paths`.
229fn directed_paths_from_to(
230    graph: &CausalGraph,
231    cause: &str,
232    effect: &str,
233    max_depth: usize,
234) -> Vec<Vec<String>> {
235    const MAX_PATHS: usize = 32;
236    let mut out: Vec<Vec<String>> = Vec::new();
237    let mut queue: VecDeque<Vec<String>> = VecDeque::new();
238    queue.push_back(vec![cause.to_string()]);
239
240    while let Some(path) = queue.pop_front() {
241        if out.len() >= MAX_PATHS {
242            break;
243        }
244        if path.len() > max_depth {
245            continue;
246        }
247        let last = path.last().expect("path non-empty");
248        if last == effect && path.len() > 1 {
249            out.push(path);
250            continue;
251        }
252        for child in graph.children_of(last) {
253            let child_owned = child.to_string();
254            if path.contains(&child_owned) {
255                continue; // no cycles
256            }
257            let mut next = path.clone();
258            next.push(child_owned);
259            queue.push_back(next);
260        }
261    }
262    out
263}
264
265fn build_confidence_index(project: &Project) -> HashMap<String, f64> {
266    let mut idx = HashMap::new();
267    for finding in &project.findings {
268        idx.insert(finding.id.clone(), finding.confidence.score);
269    }
270    idx
271}
272
273/// Build a (parent, child) → Mechanism index from the project's link
274/// graph. In v0.44 graph convention, a `depends`/`supports` link from
275/// finding A to finding B encodes "A depends on B" — i.e. B is the
276/// parent and A is the child. The mechanism on the link describes how
277/// B drives A.
278fn build_mechanism_index(project: &Project) -> HashMap<(String, String), Mechanism> {
279    let mut idx = HashMap::new();
280    for finding in &project.findings {
281        for link in &finding.links {
282            if !matches!(link.link_type.as_str(), "depends" | "supports") {
283                continue;
284            }
285            // a (the dependent / child) → link.target (the parent).
286            // Index by (parent, child).
287            let target = match link.target.split_once(':') {
288                Some((_, rest)) => rest.to_string(),
289                None => link.target.clone(),
290            };
291            if let Some(m) = link.mechanism {
292                idx.insert((target, finding.id.clone()), m);
293            }
294        }
295    }
296    idx
297}
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302    use crate::bundle::{
303        Assertion, Conditions, Confidence, Evidence, Extraction, FindingBundle, Flags, Link,
304        Mechanism, MechanismSign, Provenance,
305    };
306    use crate::project;
307
308    fn conditions() -> Conditions {
309        Conditions {
310            text: String::new(),
311            species_verified: vec![],
312            species_unverified: vec![],
313            in_vitro: false,
314            in_vivo: false,
315            human_data: false,
316            clinical_trial: false,
317            concentration_range: None,
318            duration: None,
319            age_group: None,
320            cell_type: None,
321        }
322    }
323
324    fn provenance() -> Provenance {
325        Provenance {
326            source_type: "published_paper".into(),
327            doi: None,
328            pmid: None,
329            pmc: None,
330            openalex_id: None,
331            url: None,
332            title: "Test".into(),
333            authors: vec![],
334            year: Some(2025),
335            journal: None,
336            license: None,
337            publisher: None,
338            funders: vec![],
339            extraction: Extraction::default(),
340            review: None,
341            citation_count: None,
342        }
343    }
344
345    fn finding(id: &str, conf: f64, links: Vec<Link>) -> FindingBundle {
346        let mut b = FindingBundle::new(
347            Assertion {
348                text: format!("claim {id}"),
349                assertion_type: "mechanism".into(),
350                entities: vec![],
351                relation: None,
352                direction: None,
353                causal_claim: None,
354                causal_evidence_grade: None,
355            },
356            Evidence {
357                evidence_type: "experimental".into(),
358                model_system: String::new(),
359                species: None,
360                method: String::new(),
361                sample_size: None,
362                effect_size: None,
363                p_value: None,
364                replicated: false,
365                replication_count: None,
366                evidence_spans: vec![],
367            },
368            conditions(),
369            Confidence::raw(conf, "test", 0.85),
370            provenance(),
371            Flags::default(),
372        );
373        b.id = id.to_string();
374        b.links = links;
375        b
376    }
377
378    fn link_with_mechanism(target: &str, mech: Option<Mechanism>) -> Link {
379        Link {
380            target: target.into(),
381            link_type: "depends".into(),
382            note: String::new(),
383            inferred_by: "test".into(),
384            created_at: String::new(),
385            mechanism: mech,
386        }
387    }
388
389    /// Three findings A → B → C and confidences 0.9, 0.8, 0.7.
390    /// `B depends on A` and `C depends on B`. Mechanisms vary per test.
391    fn fixture_chain(ab: Option<Mechanism>, bc: Option<Mechanism>) -> Project {
392        let a = finding("vf_aaa", 0.9, vec![]);
393        let b = finding("vf_bbb", 0.8, vec![link_with_mechanism("vf_aaa", ab)]);
394        let c = finding("vf_ccc", 0.7, vec![link_with_mechanism("vf_bbb", bc)]);
395        project::assemble("test", vec![a, b, c], 1, 0, "test")
396    }
397
398    #[test]
399    fn linear_chain_resolves() {
400        let project = fixture_chain(
401            Some(Mechanism::Linear {
402                sign: MechanismSign::Positive,
403                slope: 0.5,
404            }),
405            Some(Mechanism::Linear {
406                sign: MechanismSign::Positive,
407                slope: 0.4,
408            }),
409        );
410        let q = CounterfactualQuery {
411            intervene_on: "vf_aaa".into(),
412            set_to: 0.5,
413            target: "vf_ccc".into(),
414        };
415        let v = answer_counterfactual(&project, &q);
416        match v {
417            CounterfactualVerdict::Resolved {
418                factual,
419                counterfactual,
420                delta,
421                ..
422            } => {
423                assert!((factual - 0.7).abs() < 1e-9);
424                // delta_x = 0.5 - 0.9 = -0.4; bc(ab(-0.4)) = 0.4*0.5*-0.4 = -0.08
425                assert!((delta - (-0.08)).abs() < 1e-6, "delta = {delta}");
426                assert!(counterfactual > 0.0 && counterfactual < 1.0);
427            }
428            _ => panic!("expected Resolved, got {v:?}"),
429        }
430    }
431
432    #[test]
433    fn missing_mechanism_blocks_propagation() {
434        let project = fixture_chain(
435            Some(Mechanism::Linear {
436                sign: MechanismSign::Positive,
437                slope: 0.5,
438            }),
439            None,
440        );
441        let q = CounterfactualQuery {
442            intervene_on: "vf_aaa".into(),
443            set_to: 0.5,
444            target: "vf_ccc".into(),
445        };
446        let v = answer_counterfactual(&project, &q);
447        assert!(matches!(
448            v,
449            CounterfactualVerdict::MechanismUnspecified { .. }
450        ));
451    }
452
453    #[test]
454    fn unknown_mechanism_blocks_propagation() {
455        let project = fixture_chain(
456            Some(Mechanism::Linear {
457                sign: MechanismSign::Positive,
458                slope: 0.5,
459            }),
460            Some(Mechanism::Unknown),
461        );
462        let q = CounterfactualQuery {
463            intervene_on: "vf_aaa".into(),
464            set_to: 0.5,
465            target: "vf_ccc".into(),
466        };
467        let v = answer_counterfactual(&project, &q);
468        assert!(matches!(
469            v,
470            CounterfactualVerdict::MechanismUnspecified { .. }
471        ));
472    }
473
474    #[test]
475    fn out_of_range_intervention_rejected() {
476        let project = fixture_chain(None, None);
477        let q = CounterfactualQuery {
478            intervene_on: "vf_aaa".into(),
479            set_to: 1.5,
480            target: "vf_ccc".into(),
481        };
482        assert!(matches!(
483            answer_counterfactual(&project, &q),
484            CounterfactualVerdict::InvalidIntervention { .. }
485        ));
486    }
487
488    #[test]
489    fn no_path_yields_factual() {
490        let project = fixture_chain(None, None);
491        let q = CounterfactualQuery {
492            intervene_on: "vf_ccc".into(), // C has no descendants
493            set_to: 0.5,
494            target: "vf_aaa".into(),
495        };
496        match answer_counterfactual(&project, &q) {
497            CounterfactualVerdict::NoCausalPath { factual } => {
498                assert!((factual - 0.9).abs() < 1e-9);
499            }
500            v => panic!("expected NoCausalPath, got {v:?}"),
501        }
502    }
503
504    #[test]
505    fn negative_sign_flips_direction() {
506        let project = fixture_chain(
507            Some(Mechanism::Linear {
508                sign: MechanismSign::Negative,
509                slope: 0.5,
510            }),
511            Some(Mechanism::Linear {
512                sign: MechanismSign::Positive,
513                slope: 1.0,
514            }),
515        );
516        // intervene to bump A from 0.9 -> 1.0 (delta_x = +0.1)
517        // ab(+0.1) = -0.5*0.1 = -0.05
518        // bc(-0.05) = +1.0*-0.05 = -0.05
519        // counterfactual C = 0.7 + (-0.05) = 0.65
520        let q = CounterfactualQuery {
521            intervene_on: "vf_aaa".into(),
522            set_to: 1.0,
523            target: "vf_ccc".into(),
524        };
525        match answer_counterfactual(&project, &q) {
526            CounterfactualVerdict::Resolved {
527                counterfactual,
528                delta,
529                ..
530            } => {
531                assert!((delta - (-0.05)).abs() < 1e-6, "delta = {delta}");
532                assert!((counterfactual - 0.65).abs() < 1e-6);
533            }
534            v => panic!("expected Resolved, got {v:?}"),
535        }
536    }
537}