1#[cfg(feature = "self-substrate-adapters")]
8use vyre_self_substrate::do_calculus_change_impact::{
9 predict_impact_via_into, DoCalculusImpactScratch,
10};
11#[cfg(feature = "self-substrate-adapters")]
12use vyre_self_substrate::optimizer::dispatcher::{
13 DispatchError as SelfSubstrateDispatchError, OptimizerDispatcher,
14};
15#[cfg(feature = "self-substrate-adapters")]
16use vyre_self_substrate::scallop_provenance::provenance_closure_via_into;
17
18#[derive(Debug, Clone, PartialEq, Eq)]
20pub struct CacheInvalidationError {
21 message: String,
22}
23
24impl CacheInvalidationError {
25 fn new(message: impl Into<String>) -> Self {
26 Self {
27 message: message.into(),
28 }
29 }
30}
31
32impl std::fmt::Display for CacheInvalidationError {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 f.write_str(&self.message)
35 }
36}
37
38impl std::error::Error for CacheInvalidationError {}
39
40#[cfg(feature = "self-substrate-adapters")]
41impl From<SelfSubstrateDispatchError> for CacheInvalidationError {
42 fn from(error: SelfSubstrateDispatchError) -> Self {
43 Self::new(error.to_string())
44 }
45}
46
47#[derive(Debug, Default)]
49pub struct CacheInvalidationScratch {
50 #[cfg(feature = "self-substrate-adapters")]
51 impact: DoCalculusImpactScratch,
52 #[cfg(feature = "self-substrate-adapters")]
53 closure: Vec<u32>,
54}
55
56pub fn impacted_entries_into(
62 #[cfg(feature = "self-substrate-adapters")] dispatcher: &dyn OptimizerDispatcher,
63 intervention_mask: &[u32],
64 rule_adj: &[u32],
65 state: &[u32],
66 join_rules: &[u32],
67 n: u32,
68 max_iterations: u32,
69 lineage_cells: &[u32],
70 out: &mut Vec<u32>,
71 _scratch: &mut CacheInvalidationScratch,
72) -> Result<(), CacheInvalidationError> {
73 out.clear();
74 reserve_impact_mask(out, lineage_cells.len())?;
75 out.resize(lineage_cells.len(), 0);
76
77 #[cfg(not(feature = "self-substrate-adapters"))]
78 {
79 let _ = (
80 intervention_mask,
81 rule_adj,
82 state,
83 join_rules,
84 n,
85 max_iterations,
86 lineage_cells,
87 _scratch,
88 );
89 return Err(CacheInvalidationError::new(
90 "vyre-driver cache invalidation requires the `self-substrate-adapters` feature. Fix: enable the feature; production builds must not run the reference cache-invalidation oracle.",
91 ));
92 }
93
94 #[cfg(feature = "self-substrate-adapters")]
95 {
96 let n_us = n as usize;
97 let Some(matrix_len) = n_us.checked_mul(n_us) else {
98 return Err(CacheInvalidationError::new(format!(
99 "Fix: cache invalidation n*n overflows usize for n={n}."
100 )));
101 };
102 if intervention_mask.len() != n_us {
103 return Err(CacheInvalidationError::new(format!(
104 "Fix: cache invalidation requires intervention_mask.len() == n ({n_us}), got {}.",
105 intervention_mask.len()
106 )));
107 }
108 if rule_adj.len() != matrix_len {
109 return Err(CacheInvalidationError::new(format!(
110 "Fix: cache invalidation requires rule_adj.len() == n*n ({matrix_len}), got {}.",
111 rule_adj.len()
112 )));
113 }
114 if state.len() != matrix_len {
115 return Err(CacheInvalidationError::new(format!(
116 "Fix: cache invalidation requires state.len() == n*n ({matrix_len}), got {}.",
117 state.len()
118 )));
119 }
120 if join_rules.len() != matrix_len {
121 return Err(CacheInvalidationError::new(format!(
122 "Fix: cache invalidation requires join_rules.len() == n*n ({matrix_len}), got {}.",
123 join_rules.len()
124 )));
125 }
126
127 predict_impact_via_into(
128 dispatcher,
129 rule_adj,
130 intervention_mask,
131 n,
132 &mut _scratch.impact,
133 )
134 .map_err(CacheInvalidationError::from)?;
135 provenance_closure_via_into(
136 dispatcher,
137 state,
138 join_rules,
139 n,
140 max_iterations,
141 &mut _scratch.closure,
142 )
143 .map_err(CacheInvalidationError::from)?;
144
145 let impacted_rules = _scratch.impact.impact_mask();
146 let closure = &_scratch.closure;
147 if impacted_rules.len() < n_us || closure.len() < matrix_len {
148 return Err(CacheInvalidationError::new(format!(
149 "Fix: cache invalidation GPU outputs were undersized: impact_mask={}, closure={}, required n={n_us}, matrix={matrix_len}.",
150 impacted_rules.len(),
151 closure.len()
152 )));
153 }
154
155 for (entry_idx, &cell) in lineage_cells.iter().enumerate() {
156 let cell = cell as usize;
157 if cell >= n_us {
158 continue;
159 }
160 let row_start = cell * n_us;
161 let row = &closure[row_start..row_start + n_us];
162 let directly_impacted = impacted_rules.get(cell).is_some_and(|&v| v != 0);
169 if directly_impacted
170 || row
171 .iter()
172 .zip(impacted_rules.iter())
173 .any(|(&bitset, &impacted)| bitset != 0 && impacted != 0)
174 {
175 out[entry_idx] = 1;
176 }
177 }
178 Ok(())
179 }
180}
181
182#[must_use]
184pub fn impacted_entries(
185 #[cfg(feature = "self-substrate-adapters")] dispatcher: &dyn OptimizerDispatcher,
186 intervention_mask: &[u32],
187 rule_adj: &[u32],
188 state: &[u32],
189 join_rules: &[u32],
190 n: u32,
191 max_iterations: u32,
192 lineage_cells: &[u32],
193) -> Result<Vec<u32>, CacheInvalidationError> {
194 let mut out = reserved_impact_mask(lineage_cells.len())?;
195 let mut scratch = CacheInvalidationScratch::default();
196 impacted_entries_into(
197 #[cfg(feature = "self-substrate-adapters")]
198 dispatcher,
199 intervention_mask,
200 rule_adj,
201 state,
202 join_rules,
203 n,
204 max_iterations,
205 lineage_cells,
206 &mut out,
207 &mut scratch,
208 )?;
209 Ok(out)
210}
211
212fn reserve_impact_mask(out: &mut Vec<u32>, len: usize) -> Result<(), CacheInvalidationError> {
213 crate::allocation::try_reserve_vec_to_capacity(out, len).map_err(|error| {
214 CacheInvalidationError::new(format!(
215 "pipeline cache invalidation could not reserve {len} impact-mask slot(s): {error}. Fix: split lineage cells across smaller cache-invalidation shards."
216 ))
217 })
218}
219
220fn reserved_impact_mask(len: usize) -> Result<Vec<u32>, CacheInvalidationError> {
221 let mut out = Vec::new();
222 reserve_impact_mask(&mut out, len)?;
223 Ok(out)
224}
225
226#[cfg(all(test, feature = "self-substrate-adapters"))]
227mod tests {
228 use super::*;
229 use vyre_foundation::ir::Program;
230
231 struct EchoStateDispatcher;
232
233 impl OptimizerDispatcher for EchoStateDispatcher {
234 fn dispatch(
235 &self,
236 _program: &Program,
237 inputs: &[Vec<u8>],
238 _grid_override: Option<[u32; 3]>,
239 ) -> Result<Vec<Vec<u8>>, SelfSubstrateDispatchError> {
240 Ok(vec![inputs.first().cloned().unwrap_or_default()])
241 }
242 }
243
244 #[test]
245 fn impact_mask_marks_lineage_intersection() {
246 let dispatcher = EchoStateDispatcher;
247 let n = 3;
248 let mut rule_adj = vec![0u32; 9];
249 rule_adj[0 * 3 + 1] = 1;
250 let intervention_mask = vec![1, 0, 0];
251
252 let mut state = vec![0u32; 9];
253 state[1 * 3] = 1;
254 let join_rules = vec![0u32; 9];
255 let mask = impacted_entries(
256 &dispatcher,
257 &intervention_mask,
258 &rule_adj,
259 &state,
260 &join_rules,
261 n,
262 16,
263 &[1, 2],
264 )
265 .expect("Fix: test dispatcher must return one state output");
266 assert_eq!(mask, vec![1, 0]);
267 }
268
269 #[test]
270 fn malformed_dimensions_do_not_panic() {
271 let dispatcher = EchoStateDispatcher;
272 let err = impacted_entries(&dispatcher, &[1], &[], &[], &[], 32, 16, &[0, 1])
273 .expect_err("malformed dimensions must fail loudly");
274 assert!(
275 err.to_string().contains("Fix:"),
276 "cache invalidation dimension errors must be actionable"
277 );
278 }
279}