Skip to main content

trellis_core/
derive.rs

1use crate::input::{StoredInput, boxed_input, downcast_input};
2use crate::{DerivedNode, Graph, GraphError, GraphResult, InputNode, NodeId, NodeKind};
3use std::collections::BTreeSet;
4use std::sync::Arc;
5
6type ComputeFn<C> = dyn for<'ctx> Fn(&DeriveContext<'ctx, C>) -> Result<Box<dyn StoredInput>, DeriveError>
7    + Send
8    + Sync;
9
10pub(crate) struct DerivedSpec<C> {
11    compute: Arc<ComputeFn<C>>,
12}
13
14impl<C> Clone for DerivedSpec<C> {
15    fn clone(&self) -> Self {
16        Self {
17            compute: Arc::clone(&self.compute),
18        }
19    }
20}
21
22impl<C> DerivedSpec<C> {
23    pub(crate) fn new<T, F>(derive: F) -> Self
24    where
25        T: Clone + PartialEq + Send + Sync + 'static,
26        F: for<'ctx> Fn(&DeriveContext<'ctx, C>) -> Result<T, DeriveError> + Send + Sync + 'static,
27    {
28        Self {
29            compute: Arc::new(move |ctx| derive(ctx).map(boxed_input)),
30        }
31    }
32
33    pub(crate) fn compute(
34        &self,
35        ctx: &DeriveContext<'_, C>,
36    ) -> Result<Box<dyn StoredInput>, DeriveError> {
37        (self.compute)(ctx)
38    }
39}
40
41/// Read-only context passed to pure derived node computations.
42pub struct DeriveContext<'graph, C = ()> {
43    graph: &'graph Graph<C>,
44    declared_dependencies: &'graph [NodeId],
45}
46
47impl<'graph, C> DeriveContext<'graph, C> {
48    pub(crate) fn new(graph: &'graph Graph<C>, declared_dependencies: &'graph [NodeId]) -> Self {
49        Self {
50            graph,
51            declared_dependencies,
52        }
53    }
54
55    /// Reads a declared input dependency.
56    pub fn input<T>(&self, input: InputNode<T>) -> Result<&'graph T, DeriveError>
57    where
58        T: Clone + PartialEq + Send + Sync + 'static,
59    {
60        let node = input.id();
61        self.require_declared(node)?;
62        self.graph
63            .input_values
64            .get(&node)
65            .and_then(|value| downcast_input::<T>(value.as_ref()))
66            .ok_or(DeriveError::MissingValue(node))
67    }
68
69    /// Reads a declared derived dependency.
70    pub fn derived<T>(&self, derived: DerivedNode<T>) -> Result<&'graph T, DeriveError>
71    where
72        T: Clone + PartialEq + Send + Sync + 'static,
73    {
74        let node = derived.id();
75        self.require_declared(node)?;
76        self.graph
77            .derived_values
78            .get(&node)
79            .and_then(|value| downcast_input::<T>(value.as_ref()))
80            .ok_or(DeriveError::MissingValue(node))
81    }
82
83    fn require_declared(&self, node: NodeId) -> Result<(), DeriveError> {
84        if self.declared_dependencies.contains(&node) {
85            Ok(())
86        } else {
87            Err(DeriveError::UndeclaredDependency(node))
88        }
89    }
90}
91
92/// Error returned by a pure derived node computation.
93#[derive(Clone, Debug, Eq, PartialEq)]
94pub enum DeriveError {
95    /// A derive function tried to read a node it did not declare.
96    UndeclaredDependency(NodeId),
97    /// A dependency had no committed value.
98    MissingValue(NodeId),
99    /// A collection dependency was read with the wrong set/map shape or value type.
100    WrongCollectionType(NodeId),
101    /// User-defined derivation failed.
102    Message(String),
103}
104
105impl DeriveError {
106    /// Creates a user-defined derive error.
107    pub fn message(message: impl Into<String>) -> Self {
108        Self::Message(message.into())
109    }
110}
111
112impl<C> Graph<C> {
113    pub(crate) fn recompute_dirty_derived(
114        &mut self,
115        initial_changed: &[NodeId],
116    ) -> GraphResult<RecomputeTrace> {
117        let order = self.derived_topological_order()?;
118        let mut changed: BTreeSet<NodeId> = initial_changed.iter().copied().collect();
119        let mut changed_derived = Vec::new();
120        let mut recomputed = Vec::new();
121
122        for node in order {
123            let dependencies = self
124                .nodes
125                .get(&node)
126                .expect("derived node metadata exists")
127                .dependencies();
128            let is_dirty = changed.contains(&node)
129                || dependencies
130                    .as_slice()
131                    .iter()
132                    .any(|dependency| changed.contains(dependency));
133
134            if !is_dirty {
135                continue;
136            }
137
138            recomputed.push(node);
139            let next_value = self.compute_derived(node, dependencies.as_slice())?;
140            let changed_value = self
141                .derived_values
142                .get(&node)
143                .is_none_or(|current| !current.equals(next_value.as_ref()));
144
145            if changed_value {
146                self.derived_values.insert(node, next_value);
147                changed.insert(node);
148                changed_derived.push(node);
149            }
150        }
151
152        Ok(RecomputeTrace {
153            recomputed,
154            changed: changed_derived,
155        })
156    }
157
158    pub(crate) fn compute_derived(
159        &self,
160        node: NodeId,
161        dependencies: &[NodeId],
162    ) -> GraphResult<Box<dyn StoredInput>> {
163        let spec = self
164            .derived_specs
165            .get(&node)
166            .ok_or(GraphError::UnknownNode(node))?;
167        let ctx = DeriveContext::new(self, dependencies);
168        spec.compute(&ctx)
169            .map_err(|error| GraphError::DeriveFailed(node, error))
170    }
171
172    pub(crate) fn derived_topological_order(&mut self) -> GraphResult<Vec<NodeId>> {
173        self.topological_order_for_kind(NodeKind::Derived)
174    }
175}
176
177pub(crate) struct RecomputeTrace {
178    pub(crate) recomputed: Vec<NodeId>,
179    pub(crate) changed: Vec<NodeId>,
180}