Skip to main content

subtr_actor/stats/analysis_graph/
graph.rs

1#![allow(dead_code)]
2
3use std::any::{type_name, Any, TypeId};
4use std::collections::{HashMap, HashSet};
5
6use crate::*;
7
8mod render;
9
10#[derive(Clone, Copy)]
11pub struct AnalysisDependency {
12    state_type_id: TypeId,
13    state_type_name: &'static str,
14    source: AnalysisDependencySource,
15}
16
17#[derive(Clone, Copy)]
18enum AnalysisDependencySource {
19    DefaultFactory(fn() -> Box<dyn AnalysisNodeDyn>),
20    External,
21}
22
23impl AnalysisDependency {
24    pub fn required<T: 'static>() -> Self {
25        Self {
26            state_type_id: TypeId::of::<T>(),
27            state_type_name: type_name::<T>(),
28            source: AnalysisDependencySource::External,
29        }
30    }
31
32    pub fn with_default<T: 'static>(default_factory: fn() -> Box<dyn AnalysisNodeDyn>) -> Self {
33        Self {
34            state_type_id: TypeId::of::<T>(),
35            state_type_name: type_name::<T>(),
36            source: AnalysisDependencySource::DefaultFactory(default_factory),
37        }
38    }
39
40    pub fn state_type_id(&self) -> TypeId {
41        self.state_type_id
42    }
43
44    pub fn state_type_name(&self) -> &'static str {
45        self.state_type_name
46    }
47
48    fn default_factory(&self) -> fn() -> Box<dyn AnalysisNodeDyn> {
49        match self.source {
50            AnalysisDependencySource::DefaultFactory(default_factory) => default_factory,
51            AnalysisDependencySource::External => panic!(
52                "analysis dependency for {} has no default factory",
53                self.state_type_name
54            ),
55        }
56    }
57
58    fn is_external(&self) -> bool {
59        matches!(self.source, AnalysisDependencySource::External)
60    }
61}
62
63pub struct AnalysisStateContext<'a> {
64    states: HashMap<TypeId, &'a dyn Any>,
65}
66
67pub struct AnalysisStateRef<'a> {
68    type_id: TypeId,
69    type_name: &'static str,
70    state: &'a dyn Any,
71}
72
73impl<'a> AnalysisStateRef<'a> {
74    pub fn of<T: 'static>(state: &'a T) -> Self {
75        Self {
76            type_id: TypeId::of::<T>(),
77            type_name: type_name::<T>(),
78            state,
79        }
80    }
81
82    fn type_id(&self) -> TypeId {
83        self.type_id
84    }
85
86    fn type_name(&self) -> &'static str {
87        self.type_name
88    }
89
90    fn state(&self) -> &'a dyn Any {
91        self.state
92    }
93}
94
95impl<'a> AnalysisStateContext<'a> {
96    fn from_parts(
97        root_states: &'a HashMap<TypeId, Box<dyn Any>>,
98        input_states: &'a [AnalysisStateRef<'a>],
99        before: &'a [Box<dyn AnalysisNodeDyn>],
100    ) -> Self {
101        let mut states =
102            HashMap::with_capacity(root_states.len() + input_states.len() + before.len());
103        for (type_id, state) in root_states {
104            states.insert(*type_id, state.as_ref());
105        }
106        for input_state in input_states {
107            states.insert(input_state.type_id(), input_state.state());
108        }
109        for node in before {
110            states.insert(node.provides_state_type_id(), node.state_any());
111        }
112        Self { states }
113    }
114
115    pub fn get<T: 'static>(&self) -> SubtrActorResult<&'a T> {
116        self.maybe_get::<T>().ok_or_else(|| {
117            analysis_node_graph_error(format!(
118                "Missing state {} in analysis context",
119                type_name::<T>()
120            ))
121        })
122    }
123
124    pub fn maybe_get<T: 'static>(&self) -> Option<&'a T> {
125        self.states
126            .get(&TypeId::of::<T>())
127            .and_then(|state| state.downcast_ref::<T>())
128    }
129}
130
131pub trait AnalysisNode: 'static {
132    type State: 'static;
133
134    fn name(&self) -> &'static str;
135
136    fn on_replay_meta(&mut self, _meta: &ReplayMeta) -> SubtrActorResult<()> {
137        Ok(())
138    }
139
140    fn dependencies(&self) -> Vec<AnalysisDependency> {
141        Vec::new()
142    }
143
144    fn evaluate(&mut self, ctx: &AnalysisStateContext<'_>) -> SubtrActorResult<()>;
145
146    fn finish(&mut self) -> SubtrActorResult<()> {
147        Ok(())
148    }
149
150    fn state(&self) -> &Self::State;
151}
152
153pub trait AnalysisNodeDyn: 'static {
154    fn name(&self) -> &'static str;
155
156    fn provides_state_type_id(&self) -> TypeId;
157
158    fn provides_state_type_name(&self) -> &'static str;
159
160    fn on_replay_meta(&mut self, meta: &ReplayMeta) -> SubtrActorResult<()>;
161
162    fn dependencies(&self) -> Vec<AnalysisDependency>;
163
164    fn evaluate(&mut self, ctx: &AnalysisStateContext<'_>) -> SubtrActorResult<()>;
165
166    fn finish(&mut self) -> SubtrActorResult<()>;
167
168    fn state_any(&self) -> &dyn Any;
169}
170
171impl<N> AnalysisNodeDyn for N
172where
173    N: AnalysisNode,
174{
175    fn name(&self) -> &'static str {
176        AnalysisNode::name(self)
177    }
178
179    fn provides_state_type_id(&self) -> TypeId {
180        TypeId::of::<N::State>()
181    }
182
183    fn provides_state_type_name(&self) -> &'static str {
184        type_name::<N::State>()
185    }
186
187    fn on_replay_meta(&mut self, meta: &ReplayMeta) -> SubtrActorResult<()> {
188        AnalysisNode::on_replay_meta(self, meta)
189    }
190
191    fn dependencies(&self) -> Vec<AnalysisDependency> {
192        AnalysisNode::dependencies(self)
193    }
194
195    fn evaluate(&mut self, ctx: &AnalysisStateContext<'_>) -> SubtrActorResult<()> {
196        AnalysisNode::evaluate(self, ctx)
197    }
198
199    fn finish(&mut self) -> SubtrActorResult<()> {
200        AnalysisNode::finish(self)
201    }
202
203    fn state_any(&self) -> &dyn Any {
204        self.state()
205    }
206}
207
208#[derive(Default)]
209pub struct AnalysisGraph {
210    nodes: Vec<Box<dyn AnalysisNodeDyn>>,
211    evaluation_order: Vec<usize>,
212    declared_root_states: HashMap<TypeId, &'static str>,
213    declared_input_states: HashMap<TypeId, &'static str>,
214    root_states: HashMap<TypeId, Box<dyn Any>>,
215    resolved: bool,
216}
217
218impl AnalysisGraph {
219    pub fn new() -> Self {
220        Self::default()
221    }
222
223    pub fn with_root_state_type<T: 'static>(mut self) -> Self {
224        self.register_root_state::<T>();
225        self
226    }
227
228    pub fn register_root_state<T: 'static>(&mut self) {
229        self.declared_root_states
230            .insert(TypeId::of::<T>(), type_name::<T>());
231    }
232
233    pub fn with_input_state_type<T: 'static>(mut self) -> Self {
234        self.register_input_state::<T>();
235        self
236    }
237
238    pub fn register_input_state<T: 'static>(&mut self) {
239        self.declared_input_states
240            .insert(TypeId::of::<T>(), type_name::<T>());
241    }
242
243    pub fn set_root_state<T: 'static>(&mut self, value: T) {
244        self.register_root_state::<T>();
245        self.root_states.insert(TypeId::of::<T>(), Box::new(value));
246    }
247
248    pub fn with_node<N>(mut self, node: N) -> Self
249    where
250        N: AnalysisNode,
251    {
252        self.push_node(node);
253        self
254    }
255
256    pub fn with_boxed_node(mut self, node: Box<dyn AnalysisNodeDyn>) -> Self {
257        self.push_boxed_node(node);
258        self
259    }
260
261    pub fn push_node<N>(&mut self, node: N)
262    where
263        N: AnalysisNode,
264    {
265        self.push_boxed_node(Box::new(node));
266    }
267
268    pub fn push_boxed_node(&mut self, node: Box<dyn AnalysisNodeDyn>) {
269        self.nodes.push(node);
270        self.resolved = false;
271    }
272
273    pub fn resolve(&mut self) -> SubtrActorResult<()> {
274        if self.resolved {
275            return Ok(());
276        }
277
278        loop {
279            let providers = self.provider_index_by_type()?;
280            let mut additions = Vec::new();
281            let mut queued_types = HashSet::new();
282
283            for node in &self.nodes {
284                for dependency in node.dependencies() {
285                    if providers.contains_key(&dependency.state_type_id())
286                        || self
287                            .declared_root_states
288                            .contains_key(&dependency.state_type_id())
289                        || self
290                            .declared_input_states
291                            .contains_key(&dependency.state_type_id())
292                    {
293                        continue;
294                    }
295                    if dependency.is_external() {
296                        return Err(analysis_node_graph_error(format!(
297                            "Node '{}' requires state {} with no provider",
298                            node.name(),
299                            dependency.state_type_name(),
300                        )));
301                    }
302                    let default_factory = dependency.default_factory();
303                    if queued_types.insert(dependency.state_type_id()) {
304                        additions.push(default_factory());
305                    }
306                }
307            }
308
309            if additions.is_empty() {
310                break;
311            }
312
313            self.nodes.extend(additions);
314        }
315
316        let providers = self.provider_index_by_type()?;
317        let mut visiting = HashSet::new();
318        let mut visited = HashSet::new();
319        let mut order = Vec::with_capacity(self.nodes.len());
320
321        for index in 0..self.nodes.len() {
322            self.visit_node(
323                index,
324                &providers,
325                &mut visiting,
326                &mut visited,
327                &mut order,
328                &mut Vec::new(),
329            )?;
330        }
331
332        let mut ordered_nodes = Vec::with_capacity(self.nodes.len());
333        let mut original_nodes: Vec<Option<Box<dyn AnalysisNodeDyn>>> =
334            std::mem::take(&mut self.nodes)
335                .into_iter()
336                .map(Some)
337                .collect();
338        for index in order {
339            ordered_nodes.push(
340                original_nodes[index]
341                    .take()
342                    .expect("topological order should only reference each node once"),
343            );
344        }
345
346        self.nodes = ordered_nodes;
347        self.evaluation_order = (0..self.nodes.len()).collect();
348        self.resolved = true;
349        Ok(())
350    }
351
352    pub fn on_replay_meta(&mut self, meta: &ReplayMeta) -> SubtrActorResult<()> {
353        self.resolve()?;
354        for node in &mut self.nodes {
355            node.on_replay_meta(meta)?;
356        }
357        Ok(())
358    }
359
360    pub fn evaluate(&mut self) -> SubtrActorResult<()> {
361        self.evaluate_with_states(&[])
362    }
363
364    pub fn evaluate_with_state<T: 'static>(&mut self, value: &T) -> SubtrActorResult<()> {
365        self.evaluate_with_states(&[AnalysisStateRef::of(value)])
366    }
367
368    pub fn evaluate_with_states<'a>(
369        &mut self,
370        input_states: &'a [AnalysisStateRef<'a>],
371    ) -> SubtrActorResult<()> {
372        self.resolve()?;
373
374        for (type_id, type_name) in &self.declared_root_states {
375            if !self.root_states.contains_key(type_id) {
376                return Err(analysis_node_graph_error(format!(
377                    "Missing root state {type_name} for evaluation"
378                )));
379            }
380        }
381
382        let mut provided_input_types = HashMap::with_capacity(input_states.len());
383        for input_state in input_states {
384            if let Some(existing) =
385                provided_input_types.insert(input_state.type_id(), input_state.type_name())
386            {
387                return Err(analysis_node_graph_error(format!(
388                    "Duplicate input states for {}: {} and {}",
389                    input_state.type_name(),
390                    existing,
391                    input_state.type_name(),
392                )));
393            }
394        }
395        for (type_id, type_name) in self.required_input_states() {
396            if !provided_input_types.contains_key(&type_id) {
397                return Err(analysis_node_graph_error(format!(
398                    "Missing input state {type_name} for evaluation"
399                )));
400            }
401        }
402
403        for node_index in self.evaluation_order.clone() {
404            let (before, current_and_after) = self.nodes.split_at_mut(node_index);
405            let (current, _) = current_and_after
406                .split_first_mut()
407                .expect("evaluation order should contain valid indexes");
408            let ctx = AnalysisStateContext::from_parts(&self.root_states, input_states, before);
409            current.evaluate(&ctx)?;
410        }
411
412        Ok(())
413    }
414
415    pub fn finish(&mut self) -> SubtrActorResult<()> {
416        for node in &mut self.nodes {
417            node.finish()?;
418        }
419        Ok(())
420    }
421
422    pub fn state<T: 'static>(&self) -> Option<&T> {
423        let target = TypeId::of::<T>();
424        self.root_states
425            .get(&target)
426            .and_then(|state| state.downcast_ref::<T>())
427            .or_else(|| {
428                self.nodes
429                    .iter()
430                    .find(|node| node.provides_state_type_id() == target)
431                    .and_then(|node| node.state_any().downcast_ref::<T>())
432            })
433    }
434
435    pub fn node_names(&self) -> impl Iterator<Item = &'static str> + '_ {
436        self.nodes.iter().map(|node| node.name())
437    }
438
439    fn provider_index_by_type(&self) -> SubtrActorResult<HashMap<TypeId, usize>> {
440        let mut providers = HashMap::new();
441        for (index, node) in self.nodes.iter().enumerate() {
442            if self
443                .declared_root_states
444                .contains_key(&node.provides_state_type_id())
445            {
446                return SubtrActorError::new_result(
447                    SubtrActorErrorVariant::CallbackError(format!(
448                        "analysis node graph error: Duplicate providers for root state {}: root and '{}'",
449                        node.provides_state_type_name(),
450                        node.name(),
451                    )),
452                );
453            }
454            if self
455                .declared_input_states
456                .contains_key(&node.provides_state_type_id())
457            {
458                return SubtrActorError::new_result(
459                    SubtrActorErrorVariant::CallbackError(format!(
460                        "analysis node graph error: Duplicate providers for input state {}: input and '{}'",
461                        node.provides_state_type_name(),
462                        node.name(),
463                    )),
464                );
465            }
466            if let Some(existing) = providers.insert(node.provides_state_type_id(), index) {
467                return SubtrActorError::new_result(
468                    SubtrActorErrorVariant::CallbackError(format!(
469                        "analysis node graph error: Duplicate providers for state {}: '{}' and '{}'",
470                        node.provides_state_type_name(),
471                        self.nodes[existing].name(),
472                        node.name(),
473                    )),
474                );
475            }
476        }
477        Ok(providers)
478    }
479
480    fn required_input_states(&self) -> HashMap<TypeId, &'static str> {
481        let mut required = HashMap::new();
482        for node in &self.nodes {
483            for dependency in node.dependencies() {
484                let type_id = dependency.state_type_id();
485                if self.declared_input_states.contains_key(&type_id)
486                    && !self.root_states.contains_key(&type_id)
487                {
488                    required.insert(type_id, dependency.state_type_name());
489                }
490            }
491        }
492        required
493    }
494
495    fn visit_node(
496        &self,
497        index: usize,
498        providers: &HashMap<TypeId, usize>,
499        visiting: &mut HashSet<usize>,
500        visited: &mut HashSet<usize>,
501        order: &mut Vec<usize>,
502        stack: &mut Vec<&'static str>,
503    ) -> SubtrActorResult<()> {
504        if visited.contains(&index) {
505            return Ok(());
506        }
507        if !visiting.insert(index) {
508            stack.push(self.nodes[index].name());
509            let cycle = stack.join(" -> ");
510            stack.pop();
511            return Err(analysis_node_graph_error(format!(
512                "Cycle detected in analysis node graph: {cycle}"
513            )));
514        }
515
516        stack.push(self.nodes[index].name());
517        for dependency in self.nodes[index].dependencies() {
518            if self
519                .declared_root_states
520                .contains_key(&dependency.state_type_id())
521                || self
522                    .declared_input_states
523                    .contains_key(&dependency.state_type_id())
524            {
525                continue;
526            }
527
528            let Some(dependency_index) = providers.get(&dependency.state_type_id()).copied() else {
529                stack.pop();
530                return Err(analysis_node_graph_error(format!(
531                    "Node '{}' depends on missing state {}",
532                    self.nodes[index].name(),
533                    dependency.state_type_name(),
534                )));
535            };
536            self.visit_node(dependency_index, providers, visiting, visited, order, stack)?;
537        }
538        stack.pop();
539
540        visiting.remove(&index);
541        visited.insert(index);
542        order.push(index);
543        Ok(())
544    }
545}
546
547fn analysis_node_graph_error(message: String) -> SubtrActorError {
548    SubtrActorError::new(SubtrActorErrorVariant::CallbackError(format!(
549        "analysis node graph error: {message}"
550    )))
551}
552
553#[cfg(test)]
554#[path = "graph_tests.rs"]
555mod tests;