Skip to main content

subtr_actor/stats/analysis_graph/
graph.rs

1use std::any::{type_name, Any, TypeId};
2use std::collections::{HashMap, HashSet};
3
4use crate::stats::calculators::{event_producers, EmittedEvent};
5use crate::*;
6
7#[derive(Clone, Copy)]
8pub struct AnalysisDependency {
9    state_type_id: TypeId,
10    state_type_name: &'static str,
11    source: AnalysisDependencySource,
12}
13
14#[derive(Clone, Copy)]
15enum AnalysisDependencySource {
16    DefaultFactory(fn() -> Box<dyn AnalysisNodeDyn>),
17    External,
18}
19
20impl AnalysisDependency {
21    pub fn required<T: 'static>() -> Self {
22        Self {
23            state_type_id: TypeId::of::<T>(),
24            state_type_name: type_name::<T>(),
25            source: AnalysisDependencySource::External,
26        }
27    }
28
29    pub fn with_default<T: 'static>(default_factory: fn() -> Box<dyn AnalysisNodeDyn>) -> Self {
30        Self {
31            state_type_id: TypeId::of::<T>(),
32            state_type_name: type_name::<T>(),
33            source: AnalysisDependencySource::DefaultFactory(default_factory),
34        }
35    }
36
37    pub fn state_type_id(&self) -> TypeId {
38        self.state_type_id
39    }
40
41    pub fn state_type_name(&self) -> &'static str {
42        self.state_type_name
43    }
44
45    fn default_factory(&self) -> fn() -> Box<dyn AnalysisNodeDyn> {
46        match self.source {
47            AnalysisDependencySource::DefaultFactory(default_factory) => default_factory,
48            AnalysisDependencySource::External => panic!(
49                "analysis dependency for {} has no default factory",
50                self.state_type_name
51            ),
52        }
53    }
54
55    fn is_external(&self) -> bool {
56        matches!(self.source, AnalysisDependencySource::External)
57    }
58}
59
60pub struct AnalysisStateContext<'a> {
61    states: HashMap<TypeId, &'a dyn Any>,
62}
63
64pub struct AnalysisStateRef<'a> {
65    type_id: TypeId,
66    type_name: &'static str,
67    state: &'a dyn Any,
68}
69
70impl<'a> AnalysisStateRef<'a> {
71    pub fn of<T: 'static>(state: &'a T) -> Self {
72        Self {
73            type_id: TypeId::of::<T>(),
74            type_name: type_name::<T>(),
75            state,
76        }
77    }
78
79    fn type_id(&self) -> TypeId {
80        self.type_id
81    }
82
83    fn type_name(&self) -> &'static str {
84        self.type_name
85    }
86
87    fn state(&self) -> &'a dyn Any {
88        self.state
89    }
90}
91
92impl<'a> AnalysisStateContext<'a> {
93    fn from_parts(
94        root_states: &'a HashMap<TypeId, Box<dyn Any>>,
95        input_states: &'a [AnalysisStateRef<'a>],
96        before: &'a [Box<dyn AnalysisNodeDyn>],
97    ) -> Self {
98        let mut states =
99            HashMap::with_capacity(root_states.len() + input_states.len() + before.len());
100        for (type_id, state) in root_states {
101            states.insert(*type_id, state.as_ref());
102        }
103        for input_state in input_states {
104            states.insert(input_state.type_id(), input_state.state());
105        }
106        for node in before {
107            states.insert(node.provides_state_type_id(), node.state_any());
108        }
109        Self { states }
110    }
111
112    pub fn get<T: 'static>(&self) -> SubtrActorResult<&'a T> {
113        self.maybe_get::<T>().ok_or_else(|| {
114            analysis_node_graph_error(format!(
115                "Missing state {} in analysis context",
116                type_name::<T>()
117            ))
118        })
119    }
120
121    pub fn maybe_get<T: 'static>(&self) -> Option<&'a T> {
122        self.states
123            .get(&TypeId::of::<T>())
124            .and_then(|state| state.downcast_ref::<T>())
125    }
126}
127
128pub trait AnalysisNode: 'static {
129    type State: 'static;
130
131    fn name(&self) -> &'static str;
132
133    fn on_replay_meta(&mut self, _meta: &ReplayMeta) -> SubtrActorResult<()> {
134        Ok(())
135    }
136
137    fn dependencies(&self) -> Vec<AnalysisDependency> {
138        Vec::new()
139    }
140
141    fn evaluate(&mut self, ctx: &AnalysisStateContext<'_>) -> SubtrActorResult<()>;
142
143    fn finish(&mut self, _ctx: &AnalysisStateContext<'_>) -> SubtrActorResult<()> {
144        Ok(())
145    }
146
147    fn state(&self) -> &Self::State;
148}
149
150pub trait AnalysisNodeDyn: 'static {
151    fn name(&self) -> &'static str;
152
153    fn provides_state_type_id(&self) -> TypeId;
154
155    fn provides_state_type_name(&self) -> &'static str;
156
157    fn on_replay_meta(&mut self, meta: &ReplayMeta) -> SubtrActorResult<()>;
158
159    fn dependencies(&self) -> Vec<AnalysisDependency>;
160
161    fn evaluate(&mut self, ctx: &AnalysisStateContext<'_>) -> SubtrActorResult<()>;
162
163    fn finish(&mut self, ctx: &AnalysisStateContext<'_>) -> SubtrActorResult<()>;
164
165    fn state_any(&self) -> &dyn Any;
166}
167
168impl<N> AnalysisNodeDyn for N
169where
170    N: AnalysisNode,
171{
172    fn name(&self) -> &'static str {
173        AnalysisNode::name(self)
174    }
175
176    fn provides_state_type_id(&self) -> TypeId {
177        TypeId::of::<N::State>()
178    }
179
180    fn provides_state_type_name(&self) -> &'static str {
181        type_name::<N::State>()
182    }
183
184    fn on_replay_meta(&mut self, meta: &ReplayMeta) -> SubtrActorResult<()> {
185        AnalysisNode::on_replay_meta(self, meta)
186    }
187
188    fn dependencies(&self) -> Vec<AnalysisDependency> {
189        AnalysisNode::dependencies(self)
190    }
191
192    fn evaluate(&mut self, ctx: &AnalysisStateContext<'_>) -> SubtrActorResult<()> {
193        AnalysisNode::evaluate(self, ctx)
194    }
195
196    fn finish(&mut self, ctx: &AnalysisStateContext<'_>) -> SubtrActorResult<()> {
197        AnalysisNode::finish(self, ctx)
198    }
199
200    fn state_any(&self) -> &dyn Any {
201        self.state()
202    }
203}
204
205#[derive(Default)]
206pub struct AnalysisGraph {
207    nodes: Vec<Box<dyn AnalysisNodeDyn>>,
208    evaluation_order: Vec<usize>,
209    declared_root_states: HashMap<TypeId, &'static str>,
210    declared_input_states: HashMap<TypeId, &'static str>,
211    root_states: HashMap<TypeId, Box<dyn Any>>,
212    resolved: bool,
213}
214
215impl AnalysisGraph {
216    pub fn new() -> Self {
217        Self::default()
218    }
219
220    pub fn with_root_state_type<T: 'static>(mut self) -> Self {
221        self.register_root_state::<T>();
222        self
223    }
224
225    pub fn register_root_state<T: 'static>(&mut self) {
226        self.declared_root_states
227            .insert(TypeId::of::<T>(), type_name::<T>());
228    }
229
230    pub fn with_input_state_type<T: 'static>(mut self) -> Self {
231        self.register_input_state::<T>();
232        self
233    }
234
235    pub fn register_input_state<T: 'static>(&mut self) {
236        self.declared_input_states
237            .insert(TypeId::of::<T>(), type_name::<T>());
238    }
239
240    pub fn set_root_state<T: 'static>(&mut self, value: T) {
241        self.register_root_state::<T>();
242        self.root_states.insert(TypeId::of::<T>(), Box::new(value));
243    }
244
245    pub fn with_node<N>(mut self, node: N) -> Self
246    where
247        N: AnalysisNode,
248    {
249        self.push_node(node);
250        self
251    }
252
253    pub fn with_boxed_node(mut self, node: Box<dyn AnalysisNodeDyn>) -> Self {
254        self.push_boxed_node(node);
255        self
256    }
257
258    pub fn push_node<N>(&mut self, node: N)
259    where
260        N: AnalysisNode,
261    {
262        self.push_boxed_node(Box::new(node));
263    }
264
265    pub fn push_boxed_node(&mut self, node: Box<dyn AnalysisNodeDyn>) {
266        self.nodes.push(node);
267        self.resolved = false;
268    }
269
270    pub fn ensure_dependency(&mut self, dependency: AnalysisDependency) -> SubtrActorResult<()> {
271        let providers = self.provider_index_by_type()?;
272        if providers.contains_key(&dependency.state_type_id())
273            || self
274                .declared_root_states
275                .contains_key(&dependency.state_type_id())
276            || self
277                .declared_input_states
278                .contains_key(&dependency.state_type_id())
279        {
280            return Ok(());
281        }
282        if dependency.is_external() {
283            return Err(analysis_node_graph_error(format!(
284                "Required state {} has no provider",
285                dependency.state_type_name(),
286            )));
287        }
288
289        self.push_boxed_node((dependency.default_factory())());
290        Ok(())
291    }
292
293    pub fn ensure_dependencies<I>(&mut self, dependencies: I) -> SubtrActorResult<()>
294    where
295        I: IntoIterator<Item = AnalysisDependency>,
296    {
297        for dependency in dependencies {
298            self.ensure_dependency(dependency)?;
299        }
300        Ok(())
301    }
302
303    pub fn resolve(&mut self) -> SubtrActorResult<()> {
304        if self.resolved {
305            return Ok(());
306        }
307
308        loop {
309            let providers = self.provider_index_by_type()?;
310            let mut additions = Vec::new();
311            let mut queued_types = HashSet::new();
312
313            for node in &self.nodes {
314                for dependency in node.dependencies() {
315                    if providers.contains_key(&dependency.state_type_id())
316                        || self
317                            .declared_root_states
318                            .contains_key(&dependency.state_type_id())
319                        || self
320                            .declared_input_states
321                            .contains_key(&dependency.state_type_id())
322                    {
323                        continue;
324                    }
325                    if dependency.is_external() {
326                        return Err(analysis_node_graph_error(format!(
327                            "Node '{}' requires state {} with no provider",
328                            node.name(),
329                            dependency.state_type_name(),
330                        )));
331                    }
332                    let default_factory = dependency.default_factory();
333                    if queued_types.insert(dependency.state_type_id()) {
334                        additions.push(default_factory());
335                    }
336                }
337            }
338
339            if additions.is_empty() {
340                break;
341            }
342
343            self.nodes.extend(additions);
344        }
345
346        let providers = self.provider_index_by_type()?;
347        let mut visiting = HashSet::new();
348        let mut visited = HashSet::new();
349        let mut order = Vec::with_capacity(self.nodes.len());
350
351        for index in 0..self.nodes.len() {
352            self.visit_node(
353                index,
354                &providers,
355                &mut visiting,
356                &mut visited,
357                &mut order,
358                &mut Vec::new(),
359            )?;
360        }
361
362        let mut ordered_nodes = Vec::with_capacity(self.nodes.len());
363        let mut original_nodes: Vec<Option<Box<dyn AnalysisNodeDyn>>> =
364            std::mem::take(&mut self.nodes)
365                .into_iter()
366                .map(Some)
367                .collect();
368        for index in order {
369            ordered_nodes.push(
370                original_nodes[index]
371                    .take()
372                    .expect("topological order should only reference each node once"),
373            );
374        }
375
376        self.nodes = ordered_nodes;
377        self.evaluation_order = (0..self.nodes.len()).collect();
378        self.resolved = true;
379        Ok(())
380    }
381
382    pub fn on_replay_meta(&mut self, meta: &ReplayMeta) -> SubtrActorResult<()> {
383        self.resolve()?;
384        for node in &mut self.nodes {
385            node.on_replay_meta(meta)?;
386        }
387        Ok(())
388    }
389
390    pub fn evaluate(&mut self) -> SubtrActorResult<()> {
391        self.evaluate_with_states(&[])
392    }
393
394    pub fn evaluate_with_state<T: 'static>(&mut self, value: &T) -> SubtrActorResult<()> {
395        self.evaluate_with_states(&[AnalysisStateRef::of(value)])
396    }
397
398    pub fn evaluate_with_states<'a>(
399        &mut self,
400        input_states: &'a [AnalysisStateRef<'a>],
401    ) -> SubtrActorResult<()> {
402        self.resolve()?;
403
404        for (type_id, type_name) in &self.declared_root_states {
405            if !self.root_states.contains_key(type_id) {
406                return Err(analysis_node_graph_error(format!(
407                    "Missing root state {type_name} for evaluation"
408                )));
409            }
410        }
411
412        let mut provided_input_types = HashMap::with_capacity(input_states.len());
413        for input_state in input_states {
414            if let Some(existing) =
415                provided_input_types.insert(input_state.type_id(), input_state.type_name())
416            {
417                return Err(analysis_node_graph_error(format!(
418                    "Duplicate input states for {}: {} and {}",
419                    input_state.type_name(),
420                    existing,
421                    input_state.type_name(),
422                )));
423            }
424        }
425        for (type_id, type_name) in self.required_input_states() {
426            if !provided_input_types.contains_key(&type_id) {
427                return Err(analysis_node_graph_error(format!(
428                    "Missing input state {type_name} for evaluation"
429                )));
430            }
431        }
432
433        for node_index in self.evaluation_order.clone() {
434            let (before, current_and_after) = self.nodes.split_at_mut(node_index);
435            let (current, _) = current_and_after
436                .split_first_mut()
437                .expect("evaluation order should contain valid indexes");
438            let ctx = AnalysisStateContext::from_parts(&self.root_states, input_states, before);
439            current.evaluate(&ctx)?;
440        }
441
442        Ok(())
443    }
444
445    pub fn finish(&mut self) -> SubtrActorResult<()> {
446        self.resolve()?;
447        for node_index in self.evaluation_order.clone() {
448            let (before, current_and_after) = self.nodes.split_at_mut(node_index);
449            let (current, _) = current_and_after
450                .split_first_mut()
451                .expect("evaluation order should contain valid indexes");
452            let ctx = AnalysisStateContext::from_parts(&self.root_states, &[], before);
453            current.finish(&ctx)?;
454        }
455        Ok(())
456    }
457
458    pub fn state<T: 'static>(&self) -> Option<&T> {
459        let target = TypeId::of::<T>();
460        self.root_states
461            .get(&target)
462            .and_then(|state| state.downcast_ref::<T>())
463            .or_else(|| {
464                self.nodes
465                    .iter()
466                    .find(|node| node.provides_state_type_id() == target)
467                    .and_then(|node| node.state_any().downcast_ref::<T>())
468            })
469    }
470
471    pub fn node_names(&self) -> impl Iterator<Item = &'static str> + '_ {
472        self.nodes.iter().map(|node| node.name())
473    }
474
475    pub fn emitted_events(&mut self) -> SubtrActorResult<Vec<EmittedEvent>> {
476        self.resolve()?;
477        let node_names = self
478            .nodes
479            .iter()
480            .map(|node| node.name())
481            .collect::<HashSet<_>>();
482        Ok(self
483            .event_producers()
484            .iter()
485            .filter(|producer| node_names.contains(producer.node_name))
486            .flat_map(|producer| producer.emitted_events.iter().copied())
487            .collect())
488    }
489
490    fn event_producers(&self) -> &'static [crate::stats::calculators::EventProducerDefinition] {
491        event_producers()
492    }
493
494    fn provider_index_by_type(&self) -> SubtrActorResult<HashMap<TypeId, usize>> {
495        let mut providers = HashMap::new();
496        for (index, node) in self.nodes.iter().enumerate() {
497            if self
498                .declared_root_states
499                .contains_key(&node.provides_state_type_id())
500            {
501                return SubtrActorError::new_result(
502                    SubtrActorErrorVariant::CallbackError(format!(
503                        "analysis node graph error: Duplicate providers for root state {}: root and '{}'",
504                        node.provides_state_type_name(),
505                        node.name(),
506                    )),
507                );
508            }
509            if self
510                .declared_input_states
511                .contains_key(&node.provides_state_type_id())
512            {
513                return SubtrActorError::new_result(
514                    SubtrActorErrorVariant::CallbackError(format!(
515                        "analysis node graph error: Duplicate providers for input state {}: input and '{}'",
516                        node.provides_state_type_name(),
517                        node.name(),
518                    )),
519                );
520            }
521            if let Some(existing) = providers.insert(node.provides_state_type_id(), index) {
522                return SubtrActorError::new_result(
523                    SubtrActorErrorVariant::CallbackError(format!(
524                        "analysis node graph error: Duplicate providers for state {}: '{}' and '{}'",
525                        node.provides_state_type_name(),
526                        self.nodes[existing].name(),
527                        node.name(),
528                    )),
529                );
530            }
531        }
532        Ok(providers)
533    }
534
535    fn required_input_states(&self) -> HashMap<TypeId, &'static str> {
536        let mut required = HashMap::new();
537        for node in &self.nodes {
538            for dependency in node.dependencies() {
539                let type_id = dependency.state_type_id();
540                if self.declared_input_states.contains_key(&type_id)
541                    && !self.root_states.contains_key(&type_id)
542                {
543                    required.insert(type_id, dependency.state_type_name());
544                }
545            }
546        }
547        required
548    }
549
550    fn visit_node(
551        &self,
552        index: usize,
553        providers: &HashMap<TypeId, usize>,
554        visiting: &mut HashSet<usize>,
555        visited: &mut HashSet<usize>,
556        order: &mut Vec<usize>,
557        stack: &mut Vec<&'static str>,
558    ) -> SubtrActorResult<()> {
559        if visited.contains(&index) {
560            return Ok(());
561        }
562        if !visiting.insert(index) {
563            stack.push(self.nodes[index].name());
564            let cycle = stack.join(" -> ");
565            stack.pop();
566            return Err(analysis_node_graph_error(format!(
567                "Cycle detected in analysis node graph: {cycle}"
568            )));
569        }
570
571        stack.push(self.nodes[index].name());
572        for dependency in self.nodes[index].dependencies() {
573            if self
574                .declared_root_states
575                .contains_key(&dependency.state_type_id())
576                || self
577                    .declared_input_states
578                    .contains_key(&dependency.state_type_id())
579            {
580                continue;
581            }
582
583            let Some(dependency_index) = providers.get(&dependency.state_type_id()).copied() else {
584                stack.pop();
585                return Err(analysis_node_graph_error(format!(
586                    "Node '{}' depends on missing state {}",
587                    self.nodes[index].name(),
588                    dependency.state_type_name(),
589                )));
590            };
591            self.visit_node(dependency_index, providers, visiting, visited, order, stack)?;
592        }
593        stack.pop();
594
595        visiting.remove(&index);
596        visited.insert(index);
597        order.push(index);
598        Ok(())
599    }
600}
601
602impl AnalysisGraph {
603    pub fn render_ascii_dag(&mut self) -> SubtrActorResult<String> {
604        self.resolve()?;
605
606        let providers = self.provider_index_by_type()?;
607        let mut external_labels = Vec::new();
608        let mut external_node_ids = HashMap::new();
609
610        for node in &self.nodes {
611            for dependency in node.dependencies() {
612                let dependency_type_id = dependency.state_type_id();
613                if providers.contains_key(&dependency_type_id) {
614                    continue;
615                }
616
617                let label = if self.declared_root_states.contains_key(&dependency_type_id) {
618                    format!("root:{}", short_type_name(dependency.state_type_name()))
619                } else if self.declared_input_states.contains_key(&dependency_type_id) {
620                    format!("input:{}", short_type_name(dependency.state_type_name()))
621                } else {
622                    return Err(analysis_node_graph_error(format!(
623                        "Node '{}' depends on missing state {}",
624                        node.name(),
625                        dependency.state_type_name(),
626                    )));
627                };
628                ensure_external_render_node(
629                    &mut external_labels,
630                    &mut external_node_ids,
631                    dependency_type_id,
632                    label,
633                );
634            }
635        }
636
637        if self.nodes.is_empty() && external_labels.is_empty() {
638            return Ok("AnalysisGraph\n\\- (empty)".to_owned());
639        }
640
641        let external_count = external_labels.len();
642        let mut lines = Vec::with_capacity(1 + external_count + self.nodes.len());
643        lines.push("AnalysisGraph".to_owned());
644
645        for (display_id, (_, label)) in external_labels.iter().enumerate() {
646            lines.push(format!("[{display_id}] {label}"));
647        }
648
649        for (index, node) in self.nodes.iter().enumerate() {
650            let display_id = external_count + index;
651            let mut dependency_refs = Vec::new();
652            for dependency in node.dependencies() {
653                let dependency_type_id = dependency.state_type_id();
654                let source_id = if let Some(provider_index) = providers.get(&dependency_type_id) {
655                    external_count + *provider_index
656                } else if self.declared_root_states.contains_key(&dependency_type_id) {
657                    *external_node_ids
658                        .get(&dependency_type_id)
659                        .expect("root node should have been prepared")
660                } else if self.declared_input_states.contains_key(&dependency_type_id) {
661                    *external_node_ids
662                        .get(&dependency_type_id)
663                        .expect("input node should have been prepared")
664                } else {
665                    return Err(analysis_node_graph_error(format!(
666                        "Node '{}' depends on missing state {}",
667                        node.name(),
668                        dependency.state_type_name(),
669                    )));
670                };
671                dependency_refs.push(format!("[{source_id}]"));
672            }
673
674            if dependency_refs.is_empty() {
675                lines.push(format!("[{display_id}] {}", node.name()));
676            } else {
677                lines.push(format!(
678                    "[{display_id}] {} <- {}",
679                    node.name(),
680                    dependency_refs.join(", "),
681                ));
682            }
683        }
684
685        Ok(lines.join("\n"))
686    }
687}
688
689fn ensure_external_render_node(
690    labels: &mut Vec<(TypeId, Box<str>)>,
691    external_node_ids: &mut HashMap<TypeId, usize>,
692    dependency_type_id: TypeId,
693    label: String,
694) -> usize {
695    if let Some(node_id) = external_node_ids.get(&dependency_type_id) {
696        return *node_id;
697    }
698
699    let node_id = labels.len();
700    labels.push((dependency_type_id, label.into_boxed_str()));
701    external_node_ids.insert(dependency_type_id, node_id);
702    node_id
703}
704
705fn short_type_name(type_name: &str) -> String {
706    let mut shortened = String::with_capacity(type_name.len());
707    let mut token = String::new();
708
709    for character in type_name.chars() {
710        if character.is_alphanumeric() || matches!(character, '_' | ':') {
711            token.push(character);
712            continue;
713        }
714
715        if !token.is_empty() {
716            shortened.push_str(token.rsplit("::").next().unwrap_or(&token));
717            token.clear();
718        }
719        shortened.push(character);
720    }
721
722    if !token.is_empty() {
723        shortened.push_str(token.rsplit("::").next().unwrap_or(&token));
724    }
725
726    shortened
727}
728
729fn analysis_node_graph_error(message: String) -> SubtrActorError {
730    SubtrActorError::new(SubtrActorErrorVariant::CallbackError(format!(
731        "analysis node graph error: {message}"
732    )))
733}
734
735#[cfg(test)]
736#[path = "graph_tests.rs"]
737mod tests;