Skip to main content

subtr_actor/stats/analysis_graph/
graph.rs

1#![allow(dead_code)]
2
3use std::any::{type_name, Any, TypeId};
4use std::collections::{HashMap, HashSet};
5
6use crate::*;
7
8#[derive(Clone, Copy)]
9pub struct AnalysisDependency {
10    state_type_id: TypeId,
11    state_type_name: &'static str,
12    source: AnalysisDependencySource,
13}
14
15#[derive(Clone, Copy)]
16enum AnalysisDependencySource {
17    DefaultFactory(fn() -> Box<dyn AnalysisNodeDyn>),
18    External,
19}
20
21impl AnalysisDependency {
22    pub fn required<T: 'static>() -> Self {
23        Self {
24            state_type_id: TypeId::of::<T>(),
25            state_type_name: type_name::<T>(),
26            source: AnalysisDependencySource::External,
27        }
28    }
29
30    pub fn with_default<T: 'static>(default_factory: fn() -> Box<dyn AnalysisNodeDyn>) -> Self {
31        Self {
32            state_type_id: TypeId::of::<T>(),
33            state_type_name: type_name::<T>(),
34            source: AnalysisDependencySource::DefaultFactory(default_factory),
35        }
36    }
37
38    pub fn state_type_id(&self) -> TypeId {
39        self.state_type_id
40    }
41
42    pub fn state_type_name(&self) -> &'static str {
43        self.state_type_name
44    }
45
46    fn default_factory(&self) -> fn() -> Box<dyn AnalysisNodeDyn> {
47        match self.source {
48            AnalysisDependencySource::DefaultFactory(default_factory) => default_factory,
49            AnalysisDependencySource::External => panic!(
50                "analysis dependency for {} has no default factory",
51                self.state_type_name
52            ),
53        }
54    }
55
56    fn is_external(&self) -> bool {
57        matches!(self.source, AnalysisDependencySource::External)
58    }
59}
60
61pub struct AnalysisStateContext<'a> {
62    states: HashMap<TypeId, &'a dyn Any>,
63}
64
65pub struct AnalysisStateRef<'a> {
66    type_id: TypeId,
67    type_name: &'static str,
68    state: &'a dyn Any,
69}
70
71impl<'a> AnalysisStateRef<'a> {
72    pub fn of<T: 'static>(state: &'a T) -> Self {
73        Self {
74            type_id: TypeId::of::<T>(),
75            type_name: type_name::<T>(),
76            state,
77        }
78    }
79
80    fn type_id(&self) -> TypeId {
81        self.type_id
82    }
83
84    fn type_name(&self) -> &'static str {
85        self.type_name
86    }
87
88    fn state(&self) -> &'a dyn Any {
89        self.state
90    }
91}
92
93impl<'a> AnalysisStateContext<'a> {
94    fn from_parts(
95        root_states: &'a HashMap<TypeId, Box<dyn Any>>,
96        input_states: &'a [AnalysisStateRef<'a>],
97        before: &'a [Box<dyn AnalysisNodeDyn>],
98    ) -> Self {
99        let mut states =
100            HashMap::with_capacity(root_states.len() + input_states.len() + before.len());
101        for (type_id, state) in root_states {
102            states.insert(*type_id, state.as_ref());
103        }
104        for input_state in input_states {
105            states.insert(input_state.type_id(), input_state.state());
106        }
107        for node in before {
108            states.insert(node.provides_state_type_id(), node.state_any());
109        }
110        Self { states }
111    }
112
113    pub fn get<T: 'static>(&self) -> SubtrActorResult<&'a T> {
114        self.maybe_get::<T>().ok_or_else(|| {
115            analysis_node_graph_error(format!(
116                "Missing state {} in analysis context",
117                type_name::<T>()
118            ))
119        })
120    }
121
122    pub fn maybe_get<T: 'static>(&self) -> Option<&'a T> {
123        self.states
124            .get(&TypeId::of::<T>())
125            .and_then(|state| state.downcast_ref::<T>())
126    }
127}
128
129pub trait AnalysisNode: 'static {
130    type State: 'static;
131
132    fn name(&self) -> &'static str;
133
134    fn on_replay_meta(&mut self, _meta: &ReplayMeta) -> SubtrActorResult<()> {
135        Ok(())
136    }
137
138    fn dependencies(&self) -> Vec<AnalysisDependency> {
139        Vec::new()
140    }
141
142    fn evaluate(&mut self, ctx: &AnalysisStateContext<'_>) -> SubtrActorResult<()>;
143
144    fn finish(&mut self, _ctx: &AnalysisStateContext<'_>) -> SubtrActorResult<()> {
145        Ok(())
146    }
147
148    fn state(&self) -> &Self::State;
149}
150
151pub trait AnalysisNodeDyn: 'static {
152    fn name(&self) -> &'static str;
153
154    fn provides_state_type_id(&self) -> TypeId;
155
156    fn provides_state_type_name(&self) -> &'static str;
157
158    fn on_replay_meta(&mut self, meta: &ReplayMeta) -> SubtrActorResult<()>;
159
160    fn dependencies(&self) -> Vec<AnalysisDependency>;
161
162    fn evaluate(&mut self, ctx: &AnalysisStateContext<'_>) -> SubtrActorResult<()>;
163
164    fn finish(&mut self, ctx: &AnalysisStateContext<'_>) -> SubtrActorResult<()>;
165
166    fn state_any(&self) -> &dyn Any;
167}
168
169impl<N> AnalysisNodeDyn for N
170where
171    N: AnalysisNode,
172{
173    fn name(&self) -> &'static str {
174        AnalysisNode::name(self)
175    }
176
177    fn provides_state_type_id(&self) -> TypeId {
178        TypeId::of::<N::State>()
179    }
180
181    fn provides_state_type_name(&self) -> &'static str {
182        type_name::<N::State>()
183    }
184
185    fn on_replay_meta(&mut self, meta: &ReplayMeta) -> SubtrActorResult<()> {
186        AnalysisNode::on_replay_meta(self, meta)
187    }
188
189    fn dependencies(&self) -> Vec<AnalysisDependency> {
190        AnalysisNode::dependencies(self)
191    }
192
193    fn evaluate(&mut self, ctx: &AnalysisStateContext<'_>) -> SubtrActorResult<()> {
194        AnalysisNode::evaluate(self, ctx)
195    }
196
197    fn finish(&mut self, ctx: &AnalysisStateContext<'_>) -> SubtrActorResult<()> {
198        AnalysisNode::finish(self, ctx)
199    }
200
201    fn state_any(&self) -> &dyn Any {
202        self.state()
203    }
204}
205
206#[derive(Default)]
207pub struct AnalysisGraph {
208    nodes: Vec<Box<dyn AnalysisNodeDyn>>,
209    evaluation_order: Vec<usize>,
210    declared_root_states: HashMap<TypeId, &'static str>,
211    declared_input_states: HashMap<TypeId, &'static str>,
212    root_states: HashMap<TypeId, Box<dyn Any>>,
213    resolved: bool,
214}
215
216impl AnalysisGraph {
217    pub fn new() -> Self {
218        Self::default()
219    }
220
221    pub fn with_root_state_type<T: 'static>(mut self) -> Self {
222        self.register_root_state::<T>();
223        self
224    }
225
226    pub fn register_root_state<T: 'static>(&mut self) {
227        self.declared_root_states
228            .insert(TypeId::of::<T>(), type_name::<T>());
229    }
230
231    pub fn with_input_state_type<T: 'static>(mut self) -> Self {
232        self.register_input_state::<T>();
233        self
234    }
235
236    pub fn register_input_state<T: 'static>(&mut self) {
237        self.declared_input_states
238            .insert(TypeId::of::<T>(), type_name::<T>());
239    }
240
241    pub fn set_root_state<T: 'static>(&mut self, value: T) {
242        self.register_root_state::<T>();
243        self.root_states.insert(TypeId::of::<T>(), Box::new(value));
244    }
245
246    pub fn with_node<N>(mut self, node: N) -> Self
247    where
248        N: AnalysisNode,
249    {
250        self.push_node(node);
251        self
252    }
253
254    pub fn with_boxed_node(mut self, node: Box<dyn AnalysisNodeDyn>) -> Self {
255        self.push_boxed_node(node);
256        self
257    }
258
259    pub fn push_node<N>(&mut self, node: N)
260    where
261        N: AnalysisNode,
262    {
263        self.push_boxed_node(Box::new(node));
264    }
265
266    pub fn push_boxed_node(&mut self, node: Box<dyn AnalysisNodeDyn>) {
267        self.nodes.push(node);
268        self.resolved = false;
269    }
270
271    pub fn resolve(&mut self) -> SubtrActorResult<()> {
272        if self.resolved {
273            return Ok(());
274        }
275
276        loop {
277            let providers = self.provider_index_by_type()?;
278            let mut additions = Vec::new();
279            let mut queued_types = HashSet::new();
280
281            for node in &self.nodes {
282                for dependency in node.dependencies() {
283                    if providers.contains_key(&dependency.state_type_id())
284                        || self
285                            .declared_root_states
286                            .contains_key(&dependency.state_type_id())
287                        || self
288                            .declared_input_states
289                            .contains_key(&dependency.state_type_id())
290                    {
291                        continue;
292                    }
293                    if dependency.is_external() {
294                        return Err(analysis_node_graph_error(format!(
295                            "Node '{}' requires state {} with no provider",
296                            node.name(),
297                            dependency.state_type_name(),
298                        )));
299                    }
300                    let default_factory = dependency.default_factory();
301                    if queued_types.insert(dependency.state_type_id()) {
302                        additions.push(default_factory());
303                    }
304                }
305            }
306
307            if additions.is_empty() {
308                break;
309            }
310
311            self.nodes.extend(additions);
312        }
313
314        let providers = self.provider_index_by_type()?;
315        let mut visiting = HashSet::new();
316        let mut visited = HashSet::new();
317        let mut order = Vec::with_capacity(self.nodes.len());
318
319        for index in 0..self.nodes.len() {
320            self.visit_node(
321                index,
322                &providers,
323                &mut visiting,
324                &mut visited,
325                &mut order,
326                &mut Vec::new(),
327            )?;
328        }
329
330        let mut ordered_nodes = Vec::with_capacity(self.nodes.len());
331        let mut original_nodes: Vec<Option<Box<dyn AnalysisNodeDyn>>> =
332            std::mem::take(&mut self.nodes)
333                .into_iter()
334                .map(Some)
335                .collect();
336        for index in order {
337            ordered_nodes.push(
338                original_nodes[index]
339                    .take()
340                    .expect("topological order should only reference each node once"),
341            );
342        }
343
344        self.nodes = ordered_nodes;
345        self.evaluation_order = (0..self.nodes.len()).collect();
346        self.resolved = true;
347        Ok(())
348    }
349
350    pub fn on_replay_meta(&mut self, meta: &ReplayMeta) -> SubtrActorResult<()> {
351        self.resolve()?;
352        for node in &mut self.nodes {
353            node.on_replay_meta(meta)?;
354        }
355        Ok(())
356    }
357
358    pub fn evaluate(&mut self) -> SubtrActorResult<()> {
359        self.evaluate_with_states(&[])
360    }
361
362    pub fn evaluate_with_state<T: 'static>(&mut self, value: &T) -> SubtrActorResult<()> {
363        self.evaluate_with_states(&[AnalysisStateRef::of(value)])
364    }
365
366    pub fn evaluate_with_states<'a>(
367        &mut self,
368        input_states: &'a [AnalysisStateRef<'a>],
369    ) -> SubtrActorResult<()> {
370        self.resolve()?;
371
372        for (type_id, type_name) in &self.declared_root_states {
373            if !self.root_states.contains_key(type_id) {
374                return Err(analysis_node_graph_error(format!(
375                    "Missing root state {type_name} for evaluation"
376                )));
377            }
378        }
379
380        let mut provided_input_types = HashMap::with_capacity(input_states.len());
381        for input_state in input_states {
382            if let Some(existing) =
383                provided_input_types.insert(input_state.type_id(), input_state.type_name())
384            {
385                return Err(analysis_node_graph_error(format!(
386                    "Duplicate input states for {}: {} and {}",
387                    input_state.type_name(),
388                    existing,
389                    input_state.type_name(),
390                )));
391            }
392        }
393        for (type_id, type_name) in self.required_input_states() {
394            if !provided_input_types.contains_key(&type_id) {
395                return Err(analysis_node_graph_error(format!(
396                    "Missing input state {type_name} for evaluation"
397                )));
398            }
399        }
400
401        for node_index in self.evaluation_order.clone() {
402            let (before, current_and_after) = self.nodes.split_at_mut(node_index);
403            let (current, _) = current_and_after
404                .split_first_mut()
405                .expect("evaluation order should contain valid indexes");
406            let ctx = AnalysisStateContext::from_parts(&self.root_states, input_states, before);
407            current.evaluate(&ctx)?;
408        }
409
410        Ok(())
411    }
412
413    pub fn finish(&mut self) -> SubtrActorResult<()> {
414        self.resolve()?;
415        for node_index in self.evaluation_order.clone() {
416            let (before, current_and_after) = self.nodes.split_at_mut(node_index);
417            let (current, _) = current_and_after
418                .split_first_mut()
419                .expect("evaluation order should contain valid indexes");
420            let ctx = AnalysisStateContext::from_parts(&self.root_states, &[], before);
421            current.finish(&ctx)?;
422        }
423        Ok(())
424    }
425
426    pub fn state<T: 'static>(&self) -> Option<&T> {
427        let target = TypeId::of::<T>();
428        self.root_states
429            .get(&target)
430            .and_then(|state| state.downcast_ref::<T>())
431            .or_else(|| {
432                self.nodes
433                    .iter()
434                    .find(|node| node.provides_state_type_id() == target)
435                    .and_then(|node| node.state_any().downcast_ref::<T>())
436            })
437    }
438
439    pub fn node_names(&self) -> impl Iterator<Item = &'static str> + '_ {
440        self.nodes.iter().map(|node| node.name())
441    }
442
443    fn provider_index_by_type(&self) -> SubtrActorResult<HashMap<TypeId, usize>> {
444        let mut providers = HashMap::new();
445        for (index, node) in self.nodes.iter().enumerate() {
446            if self
447                .declared_root_states
448                .contains_key(&node.provides_state_type_id())
449            {
450                return SubtrActorError::new_result(
451                    SubtrActorErrorVariant::CallbackError(format!(
452                        "analysis node graph error: Duplicate providers for root state {}: root and '{}'",
453                        node.provides_state_type_name(),
454                        node.name(),
455                    )),
456                );
457            }
458            if self
459                .declared_input_states
460                .contains_key(&node.provides_state_type_id())
461            {
462                return SubtrActorError::new_result(
463                    SubtrActorErrorVariant::CallbackError(format!(
464                        "analysis node graph error: Duplicate providers for input state {}: input and '{}'",
465                        node.provides_state_type_name(),
466                        node.name(),
467                    )),
468                );
469            }
470            if let Some(existing) = providers.insert(node.provides_state_type_id(), index) {
471                return SubtrActorError::new_result(
472                    SubtrActorErrorVariant::CallbackError(format!(
473                        "analysis node graph error: Duplicate providers for state {}: '{}' and '{}'",
474                        node.provides_state_type_name(),
475                        self.nodes[existing].name(),
476                        node.name(),
477                    )),
478                );
479            }
480        }
481        Ok(providers)
482    }
483
484    fn required_input_states(&self) -> HashMap<TypeId, &'static str> {
485        let mut required = HashMap::new();
486        for node in &self.nodes {
487            for dependency in node.dependencies() {
488                let type_id = dependency.state_type_id();
489                if self.declared_input_states.contains_key(&type_id)
490                    && !self.root_states.contains_key(&type_id)
491                {
492                    required.insert(type_id, dependency.state_type_name());
493                }
494            }
495        }
496        required
497    }
498
499    fn visit_node(
500        &self,
501        index: usize,
502        providers: &HashMap<TypeId, usize>,
503        visiting: &mut HashSet<usize>,
504        visited: &mut HashSet<usize>,
505        order: &mut Vec<usize>,
506        stack: &mut Vec<&'static str>,
507    ) -> SubtrActorResult<()> {
508        if visited.contains(&index) {
509            return Ok(());
510        }
511        if !visiting.insert(index) {
512            stack.push(self.nodes[index].name());
513            let cycle = stack.join(" -> ");
514            stack.pop();
515            return Err(analysis_node_graph_error(format!(
516                "Cycle detected in analysis node graph: {cycle}"
517            )));
518        }
519
520        stack.push(self.nodes[index].name());
521        for dependency in self.nodes[index].dependencies() {
522            if self
523                .declared_root_states
524                .contains_key(&dependency.state_type_id())
525                || self
526                    .declared_input_states
527                    .contains_key(&dependency.state_type_id())
528            {
529                continue;
530            }
531
532            let Some(dependency_index) = providers.get(&dependency.state_type_id()).copied() else {
533                stack.pop();
534                return Err(analysis_node_graph_error(format!(
535                    "Node '{}' depends on missing state {}",
536                    self.nodes[index].name(),
537                    dependency.state_type_name(),
538                )));
539            };
540            self.visit_node(dependency_index, providers, visiting, visited, order, stack)?;
541        }
542        stack.pop();
543
544        visiting.remove(&index);
545        visited.insert(index);
546        order.push(index);
547        Ok(())
548    }
549}
550
551impl AnalysisGraph {
552    pub fn render_ascii_dag(&mut self) -> SubtrActorResult<String> {
553        self.resolve()?;
554
555        let providers = self.provider_index_by_type()?;
556        let mut external_labels = Vec::new();
557        let mut external_node_ids = HashMap::new();
558
559        for node in &self.nodes {
560            for dependency in node.dependencies() {
561                let dependency_type_id = dependency.state_type_id();
562                if providers.contains_key(&dependency_type_id) {
563                    continue;
564                }
565
566                let label = if self.declared_root_states.contains_key(&dependency_type_id) {
567                    format!("root:{}", short_type_name(dependency.state_type_name()))
568                } else if self.declared_input_states.contains_key(&dependency_type_id) {
569                    format!("input:{}", short_type_name(dependency.state_type_name()))
570                } else {
571                    return Err(analysis_node_graph_error(format!(
572                        "Node '{}' depends on missing state {}",
573                        node.name(),
574                        dependency.state_type_name(),
575                    )));
576                };
577                ensure_external_render_node(
578                    &mut external_labels,
579                    &mut external_node_ids,
580                    dependency_type_id,
581                    label,
582                );
583            }
584        }
585
586        if self.nodes.is_empty() && external_labels.is_empty() {
587            return Ok("AnalysisGraph\n\\- (empty)".to_owned());
588        }
589
590        let external_count = external_labels.len();
591        let mut lines = Vec::with_capacity(1 + external_count + self.nodes.len());
592        lines.push("AnalysisGraph".to_owned());
593
594        for (display_id, (_, label)) in external_labels.iter().enumerate() {
595            lines.push(format!("[{display_id}] {label}"));
596        }
597
598        for (index, node) in self.nodes.iter().enumerate() {
599            let display_id = external_count + index;
600            let mut dependency_refs = Vec::new();
601            for dependency in node.dependencies() {
602                let dependency_type_id = dependency.state_type_id();
603                let source_id = if let Some(provider_index) = providers.get(&dependency_type_id) {
604                    external_count + *provider_index
605                } else if self.declared_root_states.contains_key(&dependency_type_id) {
606                    *external_node_ids
607                        .get(&dependency_type_id)
608                        .expect("root node should have been prepared")
609                } else if self.declared_input_states.contains_key(&dependency_type_id) {
610                    *external_node_ids
611                        .get(&dependency_type_id)
612                        .expect("input node should have been prepared")
613                } else {
614                    return Err(analysis_node_graph_error(format!(
615                        "Node '{}' depends on missing state {}",
616                        node.name(),
617                        dependency.state_type_name(),
618                    )));
619                };
620                dependency_refs.push(format!("[{source_id}]"));
621            }
622
623            if dependency_refs.is_empty() {
624                lines.push(format!("[{display_id}] {}", node.name()));
625            } else {
626                lines.push(format!(
627                    "[{display_id}] {} <- {}",
628                    node.name(),
629                    dependency_refs.join(", "),
630                ));
631            }
632        }
633
634        Ok(lines.join("\n"))
635    }
636}
637
638fn ensure_external_render_node(
639    labels: &mut Vec<(TypeId, Box<str>)>,
640    external_node_ids: &mut HashMap<TypeId, usize>,
641    dependency_type_id: TypeId,
642    label: String,
643) -> usize {
644    if let Some(node_id) = external_node_ids.get(&dependency_type_id) {
645        return *node_id;
646    }
647
648    let node_id = labels.len();
649    labels.push((dependency_type_id, label.into_boxed_str()));
650    external_node_ids.insert(dependency_type_id, node_id);
651    node_id
652}
653
654fn short_type_name(type_name: &str) -> String {
655    let mut shortened = String::with_capacity(type_name.len());
656    let mut token = String::new();
657
658    for character in type_name.chars() {
659        if character.is_alphanumeric() || matches!(character, '_' | ':') {
660            token.push(character);
661            continue;
662        }
663
664        if !token.is_empty() {
665            shortened.push_str(token.rsplit("::").next().unwrap_or(&token));
666            token.clear();
667        }
668        shortened.push(character);
669    }
670
671    if !token.is_empty() {
672        shortened.push_str(token.rsplit("::").next().unwrap_or(&token));
673    }
674
675    shortened
676}
677
678fn analysis_node_graph_error(message: String) -> SubtrActorError {
679    SubtrActorError::new(SubtrActorErrorVariant::CallbackError(format!(
680        "analysis node graph error: {message}"
681    )))
682}
683
684#[cfg(test)]
685#[path = "graph_tests.rs"]
686mod tests;