Skip to main content

trellis_core/
derive.rs

1use crate::input::{StoredInput, boxed_input, downcast_input};
2use crate::{DerivedNode, Graph, GraphError, GraphResult, InputNode, NodeId, NodeKind};
3use std::collections::BTreeSet;
4use std::sync::Arc;
5
6type ComputeFn<C, O> = dyn for<'ctx> Fn(&DeriveContext<'ctx, C, O>) -> Result<Box<dyn StoredInput>, DeriveError>
7    + Send
8    + Sync;
9
10pub(crate) struct DerivedSpec<C, O> {
11    compute: Arc<ComputeFn<C, O>>,
12}
13
14impl<C, O> Clone for DerivedSpec<C, O> {
15    fn clone(&self) -> Self {
16        Self {
17            compute: Arc::clone(&self.compute),
18        }
19    }
20}
21
22impl<C, O> DerivedSpec<C, O> {
23    pub(crate) fn new<T, F>(derive: F) -> Self
24    where
25        T: Clone + PartialEq + Send + Sync + 'static,
26        F: for<'ctx> Fn(&DeriveContext<'ctx, C, O>) -> Result<T, DeriveError>
27            + Send
28            + Sync
29            + 'static,
30    {
31        Self {
32            compute: Arc::new(move |ctx| derive(ctx).map(boxed_input)),
33        }
34    }
35
36    pub(crate) fn compute(
37        &self,
38        ctx: &DeriveContext<'_, C, O>,
39    ) -> Result<Box<dyn StoredInput>, DeriveError> {
40        (self.compute)(ctx)
41    }
42}
43
44/// Read-only context passed to pure derived node computations.
45pub struct DeriveContext<'graph, C = (), O = ()> {
46    graph: &'graph Graph<C, O>,
47    declared_dependencies: &'graph [NodeId],
48}
49
50impl<'graph, C, O> DeriveContext<'graph, C, O> {
51    pub(crate) fn new(graph: &'graph Graph<C, O>, declared_dependencies: &'graph [NodeId]) -> Self {
52        Self {
53            graph,
54            declared_dependencies,
55        }
56    }
57
58    /// Reads a declared input dependency.
59    pub fn input<T>(&self, input: InputNode<T>) -> Result<&'graph T, DeriveError>
60    where
61        T: Clone + PartialEq + Send + Sync + 'static,
62    {
63        let node = input.id();
64        self.require_declared(node)?;
65        self.graph
66            .input_values
67            .get(&node)
68            .and_then(|value| downcast_input::<T>(value.as_ref()))
69            .ok_or(DeriveError::MissingValue(node))
70    }
71
72    /// Reads a declared derived dependency.
73    pub fn derived<T>(&self, derived: DerivedNode<T>) -> Result<&'graph T, DeriveError>
74    where
75        T: Clone + PartialEq + Send + Sync + 'static,
76    {
77        let node = derived.id();
78        self.require_declared(node)?;
79        self.graph
80            .derived_values
81            .get(&node)
82            .and_then(|value| downcast_input::<T>(value.as_ref()))
83            .ok_or(DeriveError::MissingValue(node))
84    }
85
86    fn require_declared(&self, node: NodeId) -> Result<(), DeriveError> {
87        if self.declared_dependencies.contains(&node) {
88            Ok(())
89        } else {
90            Err(DeriveError::UndeclaredDependency(node))
91        }
92    }
93}
94
95/// Error returned by a pure derived node computation.
96#[derive(Clone, Debug, Eq, PartialEq)]
97pub enum DeriveError {
98    /// A derive function tried to read a node it did not declare.
99    UndeclaredDependency(NodeId),
100    /// A dependency had no committed value.
101    MissingValue(NodeId),
102    /// A collection dependency was read with the wrong set/map shape or value type.
103    WrongCollectionType(NodeId),
104    /// User-defined derivation failed.
105    Message(String),
106}
107
108impl DeriveError {
109    /// Creates a user-defined derive error.
110    pub fn message(message: impl Into<String>) -> Self {
111        Self::Message(message.into())
112    }
113}
114
115impl<C, O> Graph<C, O> {
116    pub(crate) fn recompute_dirty_derived(
117        &mut self,
118        initial_changed: &[NodeId],
119    ) -> GraphResult<RecomputeTrace> {
120        let order = self.derived_topological_order()?;
121        let mut changed: BTreeSet<NodeId> = initial_changed.iter().copied().collect();
122        let mut changed_derived = Vec::new();
123        let mut recomputed = Vec::new();
124
125        for node in order {
126            let dependencies = self
127                .nodes
128                .get(&node)
129                .expect("derived node metadata exists")
130                .dependencies()
131                .clone();
132            let is_dirty = changed.contains(&node)
133                || dependencies
134                    .as_slice()
135                    .iter()
136                    .any(|dependency| changed.contains(dependency));
137
138            if !is_dirty {
139                continue;
140            }
141
142            recomputed.push(node);
143            let next_value = self.compute_derived(node, dependencies.as_slice())?;
144            let changed_value = self
145                .derived_values
146                .get(&node)
147                .is_none_or(|current| !current.equals(next_value.as_ref()));
148
149            if changed_value {
150                self.derived_values.insert(node, next_value);
151                changed.insert(node);
152                changed_derived.push(node);
153            }
154        }
155
156        Ok(RecomputeTrace {
157            recomputed,
158            changed: changed_derived,
159        })
160    }
161
162    pub(crate) fn compute_derived(
163        &self,
164        node: NodeId,
165        dependencies: &[NodeId],
166    ) -> GraphResult<Box<dyn StoredInput>> {
167        let spec = self
168            .derived_specs
169            .get(&node)
170            .ok_or(GraphError::UnknownNode(node))?;
171        let ctx = DeriveContext::new(self, dependencies);
172        spec.compute(&ctx)
173            .map_err(|error| GraphError::DeriveFailed(node, error))
174    }
175
176    pub(crate) fn derived_topological_order(&self) -> GraphResult<Vec<NodeId>> {
177        let mut order = Vec::new();
178        let mut temporary = BTreeSet::new();
179        let mut permanent = BTreeSet::new();
180
181        for node in self.nodes.keys().copied() {
182            if self
183                .nodes
184                .get(&node)
185                .is_some_and(|meta| meta.kind() == NodeKind::Derived)
186            {
187                self.visit_derived(node, &mut temporary, &mut permanent, &mut order)?;
188            }
189        }
190
191        Ok(order)
192    }
193
194    fn visit_derived(
195        &self,
196        node: NodeId,
197        temporary: &mut BTreeSet<NodeId>,
198        permanent: &mut BTreeSet<NodeId>,
199        order: &mut Vec<NodeId>,
200    ) -> GraphResult<()> {
201        if permanent.contains(&node) {
202            return Ok(());
203        }
204        if !temporary.insert(node) {
205            return Err(GraphError::CycleDetected(node));
206        }
207
208        let dependencies = self
209            .nodes
210            .get(&node)
211            .expect("derived node metadata exists")
212            .dependencies();
213        for dependency in dependencies.as_slice() {
214            if self
215                .nodes
216                .get(dependency)
217                .is_some_and(|meta| meta.kind() == NodeKind::Derived)
218            {
219                self.visit_derived(*dependency, temporary, permanent, order)?;
220            }
221        }
222
223        temporary.remove(&node);
224        permanent.insert(node);
225        order.push(node);
226        Ok(())
227    }
228}
229
230pub(crate) struct RecomputeTrace {
231    pub(crate) recomputed: Vec<NodeId>,
232    pub(crate) changed: Vec<NodeId>,
233}