Skip to main content

trellis_core/
derive.rs

1use crate::input::{StoredInput, boxed_input, downcast_input};
2use crate::{DerivedNode, Graph, GraphError, GraphResult, InputNode, NodeId, NodeKind};
3use std::collections::BTreeSet;
4use std::sync::Arc;
5
6type ComputeFn<C, O> =
7    dyn for<'ctx> Fn(&DeriveContext<'ctx, C, O>) -> Result<Box<dyn StoredInput>, DeriveError>;
8
9pub(crate) struct DerivedSpec<C, O> {
10    compute: Arc<ComputeFn<C, O>>,
11}
12
13impl<C, O> Clone for DerivedSpec<C, O> {
14    fn clone(&self) -> Self {
15        Self {
16            compute: Arc::clone(&self.compute),
17        }
18    }
19}
20
21impl<C, O> DerivedSpec<C, O> {
22    pub(crate) fn new<T, F>(derive: F) -> Self
23    where
24        T: Clone + PartialEq + 'static,
25        F: for<'ctx> Fn(&DeriveContext<'ctx, C, O>) -> Result<T, DeriveError> + 'static,
26    {
27        Self {
28            compute: Arc::new(move |ctx| derive(ctx).map(boxed_input)),
29        }
30    }
31
32    pub(crate) fn compute(
33        &self,
34        ctx: &DeriveContext<'_, C, O>,
35    ) -> Result<Box<dyn StoredInput>, DeriveError> {
36        (self.compute)(ctx)
37    }
38}
39
40/// Read-only context passed to pure derived node computations.
41pub struct DeriveContext<'graph, C = (), O = ()> {
42    graph: &'graph Graph<C, O>,
43    declared_dependencies: &'graph [NodeId],
44}
45
46impl<'graph, C, O> DeriveContext<'graph, C, O> {
47    pub(crate) fn new(graph: &'graph Graph<C, O>, declared_dependencies: &'graph [NodeId]) -> Self {
48        Self {
49            graph,
50            declared_dependencies,
51        }
52    }
53
54    /// Reads a declared input dependency.
55    pub fn input<T>(&self, input: InputNode<T>) -> Result<&'graph T, DeriveError>
56    where
57        T: Clone + PartialEq + 'static,
58    {
59        let node = input.id();
60        self.require_declared(node)?;
61        self.graph
62            .input_values
63            .get(&node)
64            .and_then(|value| downcast_input::<T>(value.as_ref()))
65            .ok_or(DeriveError::MissingValue(node))
66    }
67
68    /// Reads a declared derived dependency.
69    pub fn derived<T>(&self, derived: DerivedNode<T>) -> Result<&'graph T, DeriveError>
70    where
71        T: Clone + PartialEq + 'static,
72    {
73        let node = derived.id();
74        self.require_declared(node)?;
75        self.graph
76            .derived_values
77            .get(&node)
78            .and_then(|value| downcast_input::<T>(value.as_ref()))
79            .ok_or(DeriveError::MissingValue(node))
80    }
81
82    fn require_declared(&self, node: NodeId) -> Result<(), DeriveError> {
83        if self.declared_dependencies.contains(&node) {
84            Ok(())
85        } else {
86            Err(DeriveError::UndeclaredDependency(node))
87        }
88    }
89}
90
91/// Error returned by a pure derived node computation.
92#[derive(Clone, Debug, Eq, PartialEq)]
93pub enum DeriveError {
94    /// A derive function tried to read a node it did not declare.
95    UndeclaredDependency(NodeId),
96    /// A dependency had no committed value.
97    MissingValue(NodeId),
98    /// A collection dependency was read with the wrong set/map shape or value type.
99    WrongCollectionType(NodeId),
100    /// User-defined derivation failed.
101    Message(String),
102}
103
104impl DeriveError {
105    /// Creates a user-defined derive error.
106    pub fn message(message: impl Into<String>) -> Self {
107        Self::Message(message.into())
108    }
109}
110
111impl<C, O> Graph<C, O> {
112    pub(crate) fn recompute_dirty_derived(
113        &mut self,
114        initial_changed: &[NodeId],
115    ) -> GraphResult<RecomputeTrace> {
116        let order = self.derived_topological_order()?;
117        let mut changed: BTreeSet<NodeId> = initial_changed.iter().copied().collect();
118        let mut changed_derived = Vec::new();
119        let mut recomputed = Vec::new();
120
121        for node in order {
122            let dependencies = self
123                .nodes
124                .get(&node)
125                .expect("derived node metadata exists")
126                .dependencies()
127                .clone();
128            let is_dirty = changed.contains(&node)
129                || dependencies
130                    .as_slice()
131                    .iter()
132                    .any(|dependency| changed.contains(dependency));
133
134            if !is_dirty {
135                continue;
136            }
137
138            recomputed.push(node);
139            let next_value = self.compute_derived(node, dependencies.as_slice())?;
140            let changed_value = self
141                .derived_values
142                .get(&node)
143                .is_none_or(|current| !current.equals(next_value.as_ref()));
144
145            if changed_value {
146                self.derived_values.insert(node, next_value);
147                changed.insert(node);
148                changed_derived.push(node);
149            }
150        }
151
152        Ok(RecomputeTrace {
153            recomputed,
154            changed: changed_derived,
155        })
156    }
157
158    pub(crate) fn compute_derived(
159        &self,
160        node: NodeId,
161        dependencies: &[NodeId],
162    ) -> GraphResult<Box<dyn StoredInput>> {
163        let spec = self
164            .derived_specs
165            .get(&node)
166            .ok_or(GraphError::UnknownNode(node))?;
167        let ctx = DeriveContext::new(self, dependencies);
168        spec.compute(&ctx)
169            .map_err(|error| GraphError::DeriveFailed(node, error))
170    }
171
172    pub(crate) fn derived_topological_order(&self) -> GraphResult<Vec<NodeId>> {
173        let mut order = Vec::new();
174        let mut temporary = BTreeSet::new();
175        let mut permanent = BTreeSet::new();
176
177        for node in self.nodes.keys().copied() {
178            if self
179                .nodes
180                .get(&node)
181                .is_some_and(|meta| meta.kind() == NodeKind::Derived)
182            {
183                self.visit_derived(node, &mut temporary, &mut permanent, &mut order)?;
184            }
185        }
186
187        Ok(order)
188    }
189
190    fn visit_derived(
191        &self,
192        node: NodeId,
193        temporary: &mut BTreeSet<NodeId>,
194        permanent: &mut BTreeSet<NodeId>,
195        order: &mut Vec<NodeId>,
196    ) -> GraphResult<()> {
197        if permanent.contains(&node) {
198            return Ok(());
199        }
200        if !temporary.insert(node) {
201            return Err(GraphError::CycleDetected(node));
202        }
203
204        let dependencies = self
205            .nodes
206            .get(&node)
207            .expect("derived node metadata exists")
208            .dependencies();
209        for dependency in dependencies.as_slice() {
210            if self
211                .nodes
212                .get(dependency)
213                .is_some_and(|meta| meta.kind() == NodeKind::Derived)
214            {
215                self.visit_derived(*dependency, temporary, permanent, order)?;
216            }
217        }
218
219        temporary.remove(&node);
220        permanent.insert(node);
221        order.push(node);
222        Ok(())
223    }
224}
225
226pub(crate) struct RecomputeTrace {
227    pub(crate) recomputed: Vec<NodeId>,
228    pub(crate) changed: Vec<NodeId>,
229}