Skip to main content

subtr_actor/stats/analysis_graph/
graph.rs

1#![allow(dead_code)]
2
3use std::any::{type_name, Any, TypeId};
4use std::collections::{HashMap, HashSet};
5
6use crate::*;
7
8mod render;
9
10#[derive(Clone, Copy)]
11pub struct AnalysisDependency {
12    state_type_id: TypeId,
13    state_type_name: &'static str,
14    source: AnalysisDependencySource,
15}
16
17#[derive(Clone, Copy)]
18enum AnalysisDependencySource {
19    DefaultFactory(fn() -> Box<dyn AnalysisNodeDyn>),
20    External,
21}
22
23impl AnalysisDependency {
24    pub fn required<T: 'static>() -> Self {
25        Self {
26            state_type_id: TypeId::of::<T>(),
27            state_type_name: type_name::<T>(),
28            source: AnalysisDependencySource::External,
29        }
30    }
31
32    pub fn with_default<T: 'static>(default_factory: fn() -> Box<dyn AnalysisNodeDyn>) -> Self {
33        Self {
34            state_type_id: TypeId::of::<T>(),
35            state_type_name: type_name::<T>(),
36            source: AnalysisDependencySource::DefaultFactory(default_factory),
37        }
38    }
39
40    pub fn state_type_id(&self) -> TypeId {
41        self.state_type_id
42    }
43
44    pub fn state_type_name(&self) -> &'static str {
45        self.state_type_name
46    }
47
48    fn default_factory(&self) -> fn() -> Box<dyn AnalysisNodeDyn> {
49        match self.source {
50            AnalysisDependencySource::DefaultFactory(default_factory) => default_factory,
51            AnalysisDependencySource::External => panic!(
52                "analysis dependency for {} has no default factory",
53                self.state_type_name
54            ),
55        }
56    }
57
58    fn is_external(&self) -> bool {
59        matches!(self.source, AnalysisDependencySource::External)
60    }
61}
62
63pub struct AnalysisStateContext<'a> {
64    states: HashMap<TypeId, &'a dyn Any>,
65}
66
67pub struct AnalysisStateRef<'a> {
68    type_id: TypeId,
69    type_name: &'static str,
70    state: &'a dyn Any,
71}
72
73impl<'a> AnalysisStateRef<'a> {
74    pub fn of<T: 'static>(state: &'a T) -> Self {
75        Self {
76            type_id: TypeId::of::<T>(),
77            type_name: type_name::<T>(),
78            state,
79        }
80    }
81
82    fn type_id(&self) -> TypeId {
83        self.type_id
84    }
85
86    fn type_name(&self) -> &'static str {
87        self.type_name
88    }
89
90    fn state(&self) -> &'a dyn Any {
91        self.state
92    }
93}
94
95impl<'a> AnalysisStateContext<'a> {
96    fn from_parts(
97        root_states: &'a HashMap<TypeId, Box<dyn Any>>,
98        input_states: &'a [AnalysisStateRef<'a>],
99        before: &'a [Box<dyn AnalysisNodeDyn>],
100    ) -> Self {
101        let mut states =
102            HashMap::with_capacity(root_states.len() + input_states.len() + before.len());
103        for (type_id, state) in root_states {
104            states.insert(*type_id, state.as_ref());
105        }
106        for input_state in input_states {
107            states.insert(input_state.type_id(), input_state.state());
108        }
109        for node in before {
110            states.insert(node.provides_state_type_id(), node.state_any());
111        }
112        Self { states }
113    }
114
115    pub fn get<T: 'static>(&self) -> SubtrActorResult<&'a T> {
116        self.maybe_get::<T>().ok_or_else(|| {
117            analysis_node_graph_error(format!(
118                "Missing state {} in analysis context",
119                type_name::<T>()
120            ))
121        })
122    }
123
124    pub fn maybe_get<T: 'static>(&self) -> Option<&'a T> {
125        self.states
126            .get(&TypeId::of::<T>())
127            .and_then(|state| state.downcast_ref::<T>())
128    }
129}
130
131pub trait AnalysisNode: 'static {
132    type State: 'static;
133
134    fn name(&self) -> &'static str;
135
136    fn on_replay_meta(&mut self, _meta: &ReplayMeta) -> SubtrActorResult<()> {
137        Ok(())
138    }
139
140    fn dependencies(&self) -> Vec<AnalysisDependency> {
141        Vec::new()
142    }
143
144    fn evaluate(&mut self, ctx: &AnalysisStateContext<'_>) -> SubtrActorResult<()>;
145
146    fn finish(&mut self, _ctx: &AnalysisStateContext<'_>) -> SubtrActorResult<()> {
147        Ok(())
148    }
149
150    fn state(&self) -> &Self::State;
151}
152
153pub trait AnalysisNodeDyn: 'static {
154    fn name(&self) -> &'static str;
155
156    fn provides_state_type_id(&self) -> TypeId;
157
158    fn provides_state_type_name(&self) -> &'static str;
159
160    fn on_replay_meta(&mut self, meta: &ReplayMeta) -> SubtrActorResult<()>;
161
162    fn dependencies(&self) -> Vec<AnalysisDependency>;
163
164    fn evaluate(&mut self, ctx: &AnalysisStateContext<'_>) -> SubtrActorResult<()>;
165
166    fn finish(&mut self, ctx: &AnalysisStateContext<'_>) -> SubtrActorResult<()>;
167
168    fn state_any(&self) -> &dyn Any;
169}
170
171impl<N> AnalysisNodeDyn for N
172where
173    N: AnalysisNode,
174{
175    fn name(&self) -> &'static str {
176        AnalysisNode::name(self)
177    }
178
179    fn provides_state_type_id(&self) -> TypeId {
180        TypeId::of::<N::State>()
181    }
182
183    fn provides_state_type_name(&self) -> &'static str {
184        type_name::<N::State>()
185    }
186
187    fn on_replay_meta(&mut self, meta: &ReplayMeta) -> SubtrActorResult<()> {
188        AnalysisNode::on_replay_meta(self, meta)
189    }
190
191    fn dependencies(&self) -> Vec<AnalysisDependency> {
192        AnalysisNode::dependencies(self)
193    }
194
195    fn evaluate(&mut self, ctx: &AnalysisStateContext<'_>) -> SubtrActorResult<()> {
196        AnalysisNode::evaluate(self, ctx)
197    }
198
199    fn finish(&mut self, ctx: &AnalysisStateContext<'_>) -> SubtrActorResult<()> {
200        AnalysisNode::finish(self, ctx)
201    }
202
203    fn state_any(&self) -> &dyn Any {
204        self.state()
205    }
206}
207
208#[derive(Default)]
209pub struct AnalysisGraph {
210    nodes: Vec<Box<dyn AnalysisNodeDyn>>,
211    evaluation_order: Vec<usize>,
212    declared_root_states: HashMap<TypeId, &'static str>,
213    declared_input_states: HashMap<TypeId, &'static str>,
214    root_states: HashMap<TypeId, Box<dyn Any>>,
215    resolved: bool,
216}
217
218impl AnalysisGraph {
219    pub fn new() -> Self {
220        Self::default()
221    }
222
223    pub fn with_root_state_type<T: 'static>(mut self) -> Self {
224        self.register_root_state::<T>();
225        self
226    }
227
228    pub fn register_root_state<T: 'static>(&mut self) {
229        self.declared_root_states
230            .insert(TypeId::of::<T>(), type_name::<T>());
231    }
232
233    pub fn with_input_state_type<T: 'static>(mut self) -> Self {
234        self.register_input_state::<T>();
235        self
236    }
237
238    pub fn register_input_state<T: 'static>(&mut self) {
239        self.declared_input_states
240            .insert(TypeId::of::<T>(), type_name::<T>());
241    }
242
243    pub fn set_root_state<T: 'static>(&mut self, value: T) {
244        self.register_root_state::<T>();
245        self.root_states.insert(TypeId::of::<T>(), Box::new(value));
246    }
247
248    pub fn with_node<N>(mut self, node: N) -> Self
249    where
250        N: AnalysisNode,
251    {
252        self.push_node(node);
253        self
254    }
255
256    pub fn with_boxed_node(mut self, node: Box<dyn AnalysisNodeDyn>) -> Self {
257        self.push_boxed_node(node);
258        self
259    }
260
261    pub fn push_node<N>(&mut self, node: N)
262    where
263        N: AnalysisNode,
264    {
265        self.push_boxed_node(Box::new(node));
266    }
267
268    pub fn push_boxed_node(&mut self, node: Box<dyn AnalysisNodeDyn>) {
269        self.nodes.push(node);
270        self.resolved = false;
271    }
272
273    pub fn resolve(&mut self) -> SubtrActorResult<()> {
274        if self.resolved {
275            return Ok(());
276        }
277
278        loop {
279            let providers = self.provider_index_by_type()?;
280            let mut additions = Vec::new();
281            let mut queued_types = HashSet::new();
282
283            for node in &self.nodes {
284                for dependency in node.dependencies() {
285                    if providers.contains_key(&dependency.state_type_id())
286                        || self
287                            .declared_root_states
288                            .contains_key(&dependency.state_type_id())
289                        || self
290                            .declared_input_states
291                            .contains_key(&dependency.state_type_id())
292                    {
293                        continue;
294                    }
295                    if dependency.is_external() {
296                        return Err(analysis_node_graph_error(format!(
297                            "Node '{}' requires state {} with no provider",
298                            node.name(),
299                            dependency.state_type_name(),
300                        )));
301                    }
302                    let default_factory = dependency.default_factory();
303                    if queued_types.insert(dependency.state_type_id()) {
304                        additions.push(default_factory());
305                    }
306                }
307            }
308
309            if additions.is_empty() {
310                break;
311            }
312
313            self.nodes.extend(additions);
314        }
315
316        let providers = self.provider_index_by_type()?;
317        let mut visiting = HashSet::new();
318        let mut visited = HashSet::new();
319        let mut order = Vec::with_capacity(self.nodes.len());
320
321        for index in 0..self.nodes.len() {
322            self.visit_node(
323                index,
324                &providers,
325                &mut visiting,
326                &mut visited,
327                &mut order,
328                &mut Vec::new(),
329            )?;
330        }
331
332        let mut ordered_nodes = Vec::with_capacity(self.nodes.len());
333        let mut original_nodes: Vec<Option<Box<dyn AnalysisNodeDyn>>> =
334            std::mem::take(&mut self.nodes)
335                .into_iter()
336                .map(Some)
337                .collect();
338        for index in order {
339            ordered_nodes.push(
340                original_nodes[index]
341                    .take()
342                    .expect("topological order should only reference each node once"),
343            );
344        }
345
346        self.nodes = ordered_nodes;
347        self.evaluation_order = (0..self.nodes.len()).collect();
348        self.resolved = true;
349        Ok(())
350    }
351
352    pub fn on_replay_meta(&mut self, meta: &ReplayMeta) -> SubtrActorResult<()> {
353        self.resolve()?;
354        for node in &mut self.nodes {
355            node.on_replay_meta(meta)?;
356        }
357        Ok(())
358    }
359
360    pub fn evaluate(&mut self) -> SubtrActorResult<()> {
361        self.evaluate_with_states(&[])
362    }
363
364    pub fn evaluate_with_state<T: 'static>(&mut self, value: &T) -> SubtrActorResult<()> {
365        self.evaluate_with_states(&[AnalysisStateRef::of(value)])
366    }
367
368    pub fn evaluate_with_states<'a>(
369        &mut self,
370        input_states: &'a [AnalysisStateRef<'a>],
371    ) -> SubtrActorResult<()> {
372        self.resolve()?;
373
374        for (type_id, type_name) in &self.declared_root_states {
375            if !self.root_states.contains_key(type_id) {
376                return Err(analysis_node_graph_error(format!(
377                    "Missing root state {type_name} for evaluation"
378                )));
379            }
380        }
381
382        let mut provided_input_types = HashMap::with_capacity(input_states.len());
383        for input_state in input_states {
384            if let Some(existing) =
385                provided_input_types.insert(input_state.type_id(), input_state.type_name())
386            {
387                return Err(analysis_node_graph_error(format!(
388                    "Duplicate input states for {}: {} and {}",
389                    input_state.type_name(),
390                    existing,
391                    input_state.type_name(),
392                )));
393            }
394        }
395        for (type_id, type_name) in self.required_input_states() {
396            if !provided_input_types.contains_key(&type_id) {
397                return Err(analysis_node_graph_error(format!(
398                    "Missing input state {type_name} for evaluation"
399                )));
400            }
401        }
402
403        for node_index in self.evaluation_order.clone() {
404            let (before, current_and_after) = self.nodes.split_at_mut(node_index);
405            let (current, _) = current_and_after
406                .split_first_mut()
407                .expect("evaluation order should contain valid indexes");
408            let ctx = AnalysisStateContext::from_parts(&self.root_states, input_states, before);
409            current.evaluate(&ctx)?;
410        }
411
412        Ok(())
413    }
414
415    pub fn finish(&mut self) -> SubtrActorResult<()> {
416        self.resolve()?;
417        for node_index in self.evaluation_order.clone() {
418            let (before, current_and_after) = self.nodes.split_at_mut(node_index);
419            let (current, _) = current_and_after
420                .split_first_mut()
421                .expect("evaluation order should contain valid indexes");
422            let ctx = AnalysisStateContext::from_parts(&self.root_states, &[], before);
423            current.finish(&ctx)?;
424        }
425        Ok(())
426    }
427
428    pub fn state<T: 'static>(&self) -> Option<&T> {
429        let target = TypeId::of::<T>();
430        self.root_states
431            .get(&target)
432            .and_then(|state| state.downcast_ref::<T>())
433            .or_else(|| {
434                self.nodes
435                    .iter()
436                    .find(|node| node.provides_state_type_id() == target)
437                    .and_then(|node| node.state_any().downcast_ref::<T>())
438            })
439    }
440
441    pub fn node_names(&self) -> impl Iterator<Item = &'static str> + '_ {
442        self.nodes.iter().map(|node| node.name())
443    }
444
445    fn provider_index_by_type(&self) -> SubtrActorResult<HashMap<TypeId, usize>> {
446        let mut providers = HashMap::new();
447        for (index, node) in self.nodes.iter().enumerate() {
448            if self
449                .declared_root_states
450                .contains_key(&node.provides_state_type_id())
451            {
452                return SubtrActorError::new_result(
453                    SubtrActorErrorVariant::CallbackError(format!(
454                        "analysis node graph error: Duplicate providers for root state {}: root and '{}'",
455                        node.provides_state_type_name(),
456                        node.name(),
457                    )),
458                );
459            }
460            if self
461                .declared_input_states
462                .contains_key(&node.provides_state_type_id())
463            {
464                return SubtrActorError::new_result(
465                    SubtrActorErrorVariant::CallbackError(format!(
466                        "analysis node graph error: Duplicate providers for input state {}: input and '{}'",
467                        node.provides_state_type_name(),
468                        node.name(),
469                    )),
470                );
471            }
472            if let Some(existing) = providers.insert(node.provides_state_type_id(), index) {
473                return SubtrActorError::new_result(
474                    SubtrActorErrorVariant::CallbackError(format!(
475                        "analysis node graph error: Duplicate providers for state {}: '{}' and '{}'",
476                        node.provides_state_type_name(),
477                        self.nodes[existing].name(),
478                        node.name(),
479                    )),
480                );
481            }
482        }
483        Ok(providers)
484    }
485
486    fn required_input_states(&self) -> HashMap<TypeId, &'static str> {
487        let mut required = HashMap::new();
488        for node in &self.nodes {
489            for dependency in node.dependencies() {
490                let type_id = dependency.state_type_id();
491                if self.declared_input_states.contains_key(&type_id)
492                    && !self.root_states.contains_key(&type_id)
493                {
494                    required.insert(type_id, dependency.state_type_name());
495                }
496            }
497        }
498        required
499    }
500
501    fn visit_node(
502        &self,
503        index: usize,
504        providers: &HashMap<TypeId, usize>,
505        visiting: &mut HashSet<usize>,
506        visited: &mut HashSet<usize>,
507        order: &mut Vec<usize>,
508        stack: &mut Vec<&'static str>,
509    ) -> SubtrActorResult<()> {
510        if visited.contains(&index) {
511            return Ok(());
512        }
513        if !visiting.insert(index) {
514            stack.push(self.nodes[index].name());
515            let cycle = stack.join(" -> ");
516            stack.pop();
517            return Err(analysis_node_graph_error(format!(
518                "Cycle detected in analysis node graph: {cycle}"
519            )));
520        }
521
522        stack.push(self.nodes[index].name());
523        for dependency in self.nodes[index].dependencies() {
524            if self
525                .declared_root_states
526                .contains_key(&dependency.state_type_id())
527                || self
528                    .declared_input_states
529                    .contains_key(&dependency.state_type_id())
530            {
531                continue;
532            }
533
534            let Some(dependency_index) = providers.get(&dependency.state_type_id()).copied() else {
535                stack.pop();
536                return Err(analysis_node_graph_error(format!(
537                    "Node '{}' depends on missing state {}",
538                    self.nodes[index].name(),
539                    dependency.state_type_name(),
540                )));
541            };
542            self.visit_node(dependency_index, providers, visiting, visited, order, stack)?;
543        }
544        stack.pop();
545
546        visiting.remove(&index);
547        visited.insert(index);
548        order.push(index);
549        Ok(())
550    }
551}
552
553fn analysis_node_graph_error(message: String) -> SubtrActorError {
554    SubtrActorError::new(SubtrActorErrorVariant::CallbackError(format!(
555        "analysis node graph error: {message}"
556    )))
557}
558
559#[cfg(test)]
560#[path = "graph_tests.rs"]
561mod tests;