subtr_actor/stats/analysis_graph/
graph.rs1#![allow(dead_code)]
2
3use std::any::{type_name, Any, TypeId};
4use std::collections::{HashMap, HashSet};
5
6use crate::*;
7
8mod render;
9
10#[derive(Clone, Copy)]
11pub struct AnalysisDependency {
12 state_type_id: TypeId,
13 state_type_name: &'static str,
14 source: AnalysisDependencySource,
15}
16
17#[derive(Clone, Copy)]
18enum AnalysisDependencySource {
19 DefaultFactory(fn() -> Box<dyn AnalysisNodeDyn>),
20 External,
21}
22
23impl AnalysisDependency {
24 pub fn required<T: 'static>() -> Self {
25 Self {
26 state_type_id: TypeId::of::<T>(),
27 state_type_name: type_name::<T>(),
28 source: AnalysisDependencySource::External,
29 }
30 }
31
32 pub fn with_default<T: 'static>(default_factory: fn() -> Box<dyn AnalysisNodeDyn>) -> Self {
33 Self {
34 state_type_id: TypeId::of::<T>(),
35 state_type_name: type_name::<T>(),
36 source: AnalysisDependencySource::DefaultFactory(default_factory),
37 }
38 }
39
40 pub fn state_type_id(&self) -> TypeId {
41 self.state_type_id
42 }
43
44 pub fn state_type_name(&self) -> &'static str {
45 self.state_type_name
46 }
47
48 fn default_factory(&self) -> fn() -> Box<dyn AnalysisNodeDyn> {
49 match self.source {
50 AnalysisDependencySource::DefaultFactory(default_factory) => default_factory,
51 AnalysisDependencySource::External => panic!(
52 "analysis dependency for {} has no default factory",
53 self.state_type_name
54 ),
55 }
56 }
57
58 fn is_external(&self) -> bool {
59 matches!(self.source, AnalysisDependencySource::External)
60 }
61}
62
63pub struct AnalysisStateContext<'a> {
64 states: HashMap<TypeId, &'a dyn Any>,
65}
66
67pub struct AnalysisStateRef<'a> {
68 type_id: TypeId,
69 type_name: &'static str,
70 state: &'a dyn Any,
71}
72
73impl<'a> AnalysisStateRef<'a> {
74 pub fn of<T: 'static>(state: &'a T) -> Self {
75 Self {
76 type_id: TypeId::of::<T>(),
77 type_name: type_name::<T>(),
78 state,
79 }
80 }
81
82 fn type_id(&self) -> TypeId {
83 self.type_id
84 }
85
86 fn type_name(&self) -> &'static str {
87 self.type_name
88 }
89
90 fn state(&self) -> &'a dyn Any {
91 self.state
92 }
93}
94
95impl<'a> AnalysisStateContext<'a> {
96 fn from_parts(
97 root_states: &'a HashMap<TypeId, Box<dyn Any>>,
98 input_states: &'a [AnalysisStateRef<'a>],
99 before: &'a [Box<dyn AnalysisNodeDyn>],
100 ) -> Self {
101 let mut states =
102 HashMap::with_capacity(root_states.len() + input_states.len() + before.len());
103 for (type_id, state) in root_states {
104 states.insert(*type_id, state.as_ref());
105 }
106 for input_state in input_states {
107 states.insert(input_state.type_id(), input_state.state());
108 }
109 for node in before {
110 states.insert(node.provides_state_type_id(), node.state_any());
111 }
112 Self { states }
113 }
114
115 pub fn get<T: 'static>(&self) -> SubtrActorResult<&'a T> {
116 self.maybe_get::<T>().ok_or_else(|| {
117 analysis_node_graph_error(format!(
118 "Missing state {} in analysis context",
119 type_name::<T>()
120 ))
121 })
122 }
123
124 pub fn maybe_get<T: 'static>(&self) -> Option<&'a T> {
125 self.states
126 .get(&TypeId::of::<T>())
127 .and_then(|state| state.downcast_ref::<T>())
128 }
129}
130
131pub trait AnalysisNode: 'static {
132 type State: 'static;
133
134 fn name(&self) -> &'static str;
135
136 fn on_replay_meta(&mut self, _meta: &ReplayMeta) -> SubtrActorResult<()> {
137 Ok(())
138 }
139
140 fn dependencies(&self) -> Vec<AnalysisDependency> {
141 Vec::new()
142 }
143
144 fn evaluate(&mut self, ctx: &AnalysisStateContext<'_>) -> SubtrActorResult<()>;
145
146 fn finish(&mut self) -> SubtrActorResult<()> {
147 Ok(())
148 }
149
150 fn state(&self) -> &Self::State;
151}
152
153pub trait AnalysisNodeDyn: 'static {
154 fn name(&self) -> &'static str;
155
156 fn provides_state_type_id(&self) -> TypeId;
157
158 fn provides_state_type_name(&self) -> &'static str;
159
160 fn on_replay_meta(&mut self, meta: &ReplayMeta) -> SubtrActorResult<()>;
161
162 fn dependencies(&self) -> Vec<AnalysisDependency>;
163
164 fn evaluate(&mut self, ctx: &AnalysisStateContext<'_>) -> SubtrActorResult<()>;
165
166 fn finish(&mut self) -> SubtrActorResult<()>;
167
168 fn state_any(&self) -> &dyn Any;
169}
170
171impl<N> AnalysisNodeDyn for N
172where
173 N: AnalysisNode,
174{
175 fn name(&self) -> &'static str {
176 AnalysisNode::name(self)
177 }
178
179 fn provides_state_type_id(&self) -> TypeId {
180 TypeId::of::<N::State>()
181 }
182
183 fn provides_state_type_name(&self) -> &'static str {
184 type_name::<N::State>()
185 }
186
187 fn on_replay_meta(&mut self, meta: &ReplayMeta) -> SubtrActorResult<()> {
188 AnalysisNode::on_replay_meta(self, meta)
189 }
190
191 fn dependencies(&self) -> Vec<AnalysisDependency> {
192 AnalysisNode::dependencies(self)
193 }
194
195 fn evaluate(&mut self, ctx: &AnalysisStateContext<'_>) -> SubtrActorResult<()> {
196 AnalysisNode::evaluate(self, ctx)
197 }
198
199 fn finish(&mut self) -> SubtrActorResult<()> {
200 AnalysisNode::finish(self)
201 }
202
203 fn state_any(&self) -> &dyn Any {
204 self.state()
205 }
206}
207
208#[derive(Default)]
209pub struct AnalysisGraph {
210 nodes: Vec<Box<dyn AnalysisNodeDyn>>,
211 evaluation_order: Vec<usize>,
212 declared_root_states: HashMap<TypeId, &'static str>,
213 declared_input_states: HashMap<TypeId, &'static str>,
214 root_states: HashMap<TypeId, Box<dyn Any>>,
215 resolved: bool,
216}
217
218impl AnalysisGraph {
219 pub fn new() -> Self {
220 Self::default()
221 }
222
223 pub fn with_root_state_type<T: 'static>(mut self) -> Self {
224 self.register_root_state::<T>();
225 self
226 }
227
228 pub fn register_root_state<T: 'static>(&mut self) {
229 self.declared_root_states
230 .insert(TypeId::of::<T>(), type_name::<T>());
231 }
232
233 pub fn with_input_state_type<T: 'static>(mut self) -> Self {
234 self.register_input_state::<T>();
235 self
236 }
237
238 pub fn register_input_state<T: 'static>(&mut self) {
239 self.declared_input_states
240 .insert(TypeId::of::<T>(), type_name::<T>());
241 }
242
243 pub fn set_root_state<T: 'static>(&mut self, value: T) {
244 self.register_root_state::<T>();
245 self.root_states.insert(TypeId::of::<T>(), Box::new(value));
246 }
247
248 pub fn with_node<N>(mut self, node: N) -> Self
249 where
250 N: AnalysisNode,
251 {
252 self.push_node(node);
253 self
254 }
255
256 pub fn with_boxed_node(mut self, node: Box<dyn AnalysisNodeDyn>) -> Self {
257 self.push_boxed_node(node);
258 self
259 }
260
261 pub fn push_node<N>(&mut self, node: N)
262 where
263 N: AnalysisNode,
264 {
265 self.push_boxed_node(Box::new(node));
266 }
267
268 pub fn push_boxed_node(&mut self, node: Box<dyn AnalysisNodeDyn>) {
269 self.nodes.push(node);
270 self.resolved = false;
271 }
272
273 pub fn resolve(&mut self) -> SubtrActorResult<()> {
274 if self.resolved {
275 return Ok(());
276 }
277
278 loop {
279 let providers = self.provider_index_by_type()?;
280 let mut additions = Vec::new();
281 let mut queued_types = HashSet::new();
282
283 for node in &self.nodes {
284 for dependency in node.dependencies() {
285 if providers.contains_key(&dependency.state_type_id())
286 || self
287 .declared_root_states
288 .contains_key(&dependency.state_type_id())
289 || self
290 .declared_input_states
291 .contains_key(&dependency.state_type_id())
292 {
293 continue;
294 }
295 if dependency.is_external() {
296 return Err(analysis_node_graph_error(format!(
297 "Node '{}' requires state {} with no provider",
298 node.name(),
299 dependency.state_type_name(),
300 )));
301 }
302 let default_factory = dependency.default_factory();
303 if queued_types.insert(dependency.state_type_id()) {
304 additions.push(default_factory());
305 }
306 }
307 }
308
309 if additions.is_empty() {
310 break;
311 }
312
313 self.nodes.extend(additions);
314 }
315
316 let providers = self.provider_index_by_type()?;
317 let mut visiting = HashSet::new();
318 let mut visited = HashSet::new();
319 let mut order = Vec::with_capacity(self.nodes.len());
320
321 for index in 0..self.nodes.len() {
322 self.visit_node(
323 index,
324 &providers,
325 &mut visiting,
326 &mut visited,
327 &mut order,
328 &mut Vec::new(),
329 )?;
330 }
331
332 let mut ordered_nodes = Vec::with_capacity(self.nodes.len());
333 let mut original_nodes: Vec<Option<Box<dyn AnalysisNodeDyn>>> =
334 std::mem::take(&mut self.nodes)
335 .into_iter()
336 .map(Some)
337 .collect();
338 for index in order {
339 ordered_nodes.push(
340 original_nodes[index]
341 .take()
342 .expect("topological order should only reference each node once"),
343 );
344 }
345
346 self.nodes = ordered_nodes;
347 self.evaluation_order = (0..self.nodes.len()).collect();
348 self.resolved = true;
349 Ok(())
350 }
351
352 pub fn on_replay_meta(&mut self, meta: &ReplayMeta) -> SubtrActorResult<()> {
353 self.resolve()?;
354 for node in &mut self.nodes {
355 node.on_replay_meta(meta)?;
356 }
357 Ok(())
358 }
359
360 pub fn evaluate(&mut self) -> SubtrActorResult<()> {
361 self.evaluate_with_states(&[])
362 }
363
364 pub fn evaluate_with_state<T: 'static>(&mut self, value: &T) -> SubtrActorResult<()> {
365 self.evaluate_with_states(&[AnalysisStateRef::of(value)])
366 }
367
368 pub fn evaluate_with_states<'a>(
369 &mut self,
370 input_states: &'a [AnalysisStateRef<'a>],
371 ) -> SubtrActorResult<()> {
372 self.resolve()?;
373
374 for (type_id, type_name) in &self.declared_root_states {
375 if !self.root_states.contains_key(type_id) {
376 return Err(analysis_node_graph_error(format!(
377 "Missing root state {type_name} for evaluation"
378 )));
379 }
380 }
381
382 let mut provided_input_types = HashMap::with_capacity(input_states.len());
383 for input_state in input_states {
384 if let Some(existing) =
385 provided_input_types.insert(input_state.type_id(), input_state.type_name())
386 {
387 return Err(analysis_node_graph_error(format!(
388 "Duplicate input states for {}: {} and {}",
389 input_state.type_name(),
390 existing,
391 input_state.type_name(),
392 )));
393 }
394 }
395 for (type_id, type_name) in self.required_input_states() {
396 if !provided_input_types.contains_key(&type_id) {
397 return Err(analysis_node_graph_error(format!(
398 "Missing input state {type_name} for evaluation"
399 )));
400 }
401 }
402
403 for node_index in self.evaluation_order.clone() {
404 let (before, current_and_after) = self.nodes.split_at_mut(node_index);
405 let (current, _) = current_and_after
406 .split_first_mut()
407 .expect("evaluation order should contain valid indexes");
408 let ctx = AnalysisStateContext::from_parts(&self.root_states, input_states, before);
409 current.evaluate(&ctx)?;
410 }
411
412 Ok(())
413 }
414
415 pub fn finish(&mut self) -> SubtrActorResult<()> {
416 for node in &mut self.nodes {
417 node.finish()?;
418 }
419 Ok(())
420 }
421
422 pub fn state<T: 'static>(&self) -> Option<&T> {
423 let target = TypeId::of::<T>();
424 self.root_states
425 .get(&target)
426 .and_then(|state| state.downcast_ref::<T>())
427 .or_else(|| {
428 self.nodes
429 .iter()
430 .find(|node| node.provides_state_type_id() == target)
431 .and_then(|node| node.state_any().downcast_ref::<T>())
432 })
433 }
434
435 pub fn node_names(&self) -> impl Iterator<Item = &'static str> + '_ {
436 self.nodes.iter().map(|node| node.name())
437 }
438
439 fn provider_index_by_type(&self) -> SubtrActorResult<HashMap<TypeId, usize>> {
440 let mut providers = HashMap::new();
441 for (index, node) in self.nodes.iter().enumerate() {
442 if self
443 .declared_root_states
444 .contains_key(&node.provides_state_type_id())
445 {
446 return SubtrActorError::new_result(
447 SubtrActorErrorVariant::CallbackError(format!(
448 "analysis node graph error: Duplicate providers for root state {}: root and '{}'",
449 node.provides_state_type_name(),
450 node.name(),
451 )),
452 );
453 }
454 if self
455 .declared_input_states
456 .contains_key(&node.provides_state_type_id())
457 {
458 return SubtrActorError::new_result(
459 SubtrActorErrorVariant::CallbackError(format!(
460 "analysis node graph error: Duplicate providers for input state {}: input and '{}'",
461 node.provides_state_type_name(),
462 node.name(),
463 )),
464 );
465 }
466 if let Some(existing) = providers.insert(node.provides_state_type_id(), index) {
467 return SubtrActorError::new_result(
468 SubtrActorErrorVariant::CallbackError(format!(
469 "analysis node graph error: Duplicate providers for state {}: '{}' and '{}'",
470 node.provides_state_type_name(),
471 self.nodes[existing].name(),
472 node.name(),
473 )),
474 );
475 }
476 }
477 Ok(providers)
478 }
479
480 fn required_input_states(&self) -> HashMap<TypeId, &'static str> {
481 let mut required = HashMap::new();
482 for node in &self.nodes {
483 for dependency in node.dependencies() {
484 let type_id = dependency.state_type_id();
485 if self.declared_input_states.contains_key(&type_id)
486 && !self.root_states.contains_key(&type_id)
487 {
488 required.insert(type_id, dependency.state_type_name());
489 }
490 }
491 }
492 required
493 }
494
495 fn visit_node(
496 &self,
497 index: usize,
498 providers: &HashMap<TypeId, usize>,
499 visiting: &mut HashSet<usize>,
500 visited: &mut HashSet<usize>,
501 order: &mut Vec<usize>,
502 stack: &mut Vec<&'static str>,
503 ) -> SubtrActorResult<()> {
504 if visited.contains(&index) {
505 return Ok(());
506 }
507 if !visiting.insert(index) {
508 stack.push(self.nodes[index].name());
509 let cycle = stack.join(" -> ");
510 stack.pop();
511 return Err(analysis_node_graph_error(format!(
512 "Cycle detected in analysis node graph: {cycle}"
513 )));
514 }
515
516 stack.push(self.nodes[index].name());
517 for dependency in self.nodes[index].dependencies() {
518 if self
519 .declared_root_states
520 .contains_key(&dependency.state_type_id())
521 || self
522 .declared_input_states
523 .contains_key(&dependency.state_type_id())
524 {
525 continue;
526 }
527
528 let Some(dependency_index) = providers.get(&dependency.state_type_id()).copied() else {
529 stack.pop();
530 return Err(analysis_node_graph_error(format!(
531 "Node '{}' depends on missing state {}",
532 self.nodes[index].name(),
533 dependency.state_type_name(),
534 )));
535 };
536 self.visit_node(dependency_index, providers, visiting, visited, order, stack)?;
537 }
538 stack.pop();
539
540 visiting.remove(&index);
541 visited.insert(index);
542 order.push(index);
543 Ok(())
544 }
545}
546
547fn analysis_node_graph_error(message: String) -> SubtrActorError {
548 SubtrActorError::new(SubtrActorErrorVariant::CallbackError(format!(
549 "analysis node graph error: {message}"
550 )))
551}
552
553#[cfg(test)]
554#[path = "graph_tests.rs"]
555mod tests;