1use crate::input::{StoredInput, boxed_input, downcast_input};
2use crate::{DerivedNode, Graph, GraphError, GraphResult, InputNode, NodeId, NodeKind};
3use std::collections::BTreeSet;
4use std::sync::Arc;
5
6type ComputeFn<C, O> =
7 dyn for<'ctx> Fn(&DeriveContext<'ctx, C, O>) -> Result<Box<dyn StoredInput>, DeriveError>;
8
9pub(crate) struct DerivedSpec<C, O> {
10 compute: Arc<ComputeFn<C, O>>,
11}
12
13impl<C, O> Clone for DerivedSpec<C, O> {
14 fn clone(&self) -> Self {
15 Self {
16 compute: Arc::clone(&self.compute),
17 }
18 }
19}
20
21impl<C, O> DerivedSpec<C, O> {
22 pub(crate) fn new<T, F>(derive: F) -> Self
23 where
24 T: Clone + PartialEq + 'static,
25 F: for<'ctx> Fn(&DeriveContext<'ctx, C, O>) -> Result<T, DeriveError> + 'static,
26 {
27 Self {
28 compute: Arc::new(move |ctx| derive(ctx).map(boxed_input)),
29 }
30 }
31
32 pub(crate) fn compute(
33 &self,
34 ctx: &DeriveContext<'_, C, O>,
35 ) -> Result<Box<dyn StoredInput>, DeriveError> {
36 (self.compute)(ctx)
37 }
38}
39
40pub struct DeriveContext<'graph, C = (), O = ()> {
42 graph: &'graph Graph<C, O>,
43 declared_dependencies: &'graph [NodeId],
44}
45
46impl<'graph, C, O> DeriveContext<'graph, C, O> {
47 pub(crate) fn new(graph: &'graph Graph<C, O>, declared_dependencies: &'graph [NodeId]) -> Self {
48 Self {
49 graph,
50 declared_dependencies,
51 }
52 }
53
54 pub fn input<T>(&self, input: InputNode<T>) -> Result<&'graph T, DeriveError>
56 where
57 T: Clone + PartialEq + 'static,
58 {
59 let node = input.id();
60 self.require_declared(node)?;
61 self.graph
62 .input_values
63 .get(&node)
64 .and_then(|value| downcast_input::<T>(value.as_ref()))
65 .ok_or(DeriveError::MissingValue(node))
66 }
67
68 pub fn derived<T>(&self, derived: DerivedNode<T>) -> Result<&'graph T, DeriveError>
70 where
71 T: Clone + PartialEq + 'static,
72 {
73 let node = derived.id();
74 self.require_declared(node)?;
75 self.graph
76 .derived_values
77 .get(&node)
78 .and_then(|value| downcast_input::<T>(value.as_ref()))
79 .ok_or(DeriveError::MissingValue(node))
80 }
81
82 fn require_declared(&self, node: NodeId) -> Result<(), DeriveError> {
83 if self.declared_dependencies.contains(&node) {
84 Ok(())
85 } else {
86 Err(DeriveError::UndeclaredDependency(node))
87 }
88 }
89}
90
91#[derive(Clone, Debug, Eq, PartialEq)]
93pub enum DeriveError {
94 UndeclaredDependency(NodeId),
96 MissingValue(NodeId),
98 WrongCollectionType(NodeId),
100 Message(String),
102}
103
104impl DeriveError {
105 pub fn message(message: impl Into<String>) -> Self {
107 Self::Message(message.into())
108 }
109}
110
111impl<C, O> Graph<C, O> {
112 pub(crate) fn recompute_dirty_derived(
113 &mut self,
114 initial_changed: &[NodeId],
115 ) -> GraphResult<RecomputeTrace> {
116 let order = self.derived_topological_order()?;
117 let mut changed: BTreeSet<NodeId> = initial_changed.iter().copied().collect();
118 let mut changed_derived = Vec::new();
119 let mut recomputed = Vec::new();
120
121 for node in order {
122 let dependencies = self
123 .nodes
124 .get(&node)
125 .expect("derived node metadata exists")
126 .dependencies()
127 .clone();
128 let is_dirty = changed.contains(&node)
129 || dependencies
130 .as_slice()
131 .iter()
132 .any(|dependency| changed.contains(dependency));
133
134 if !is_dirty {
135 continue;
136 }
137
138 recomputed.push(node);
139 let next_value = self.compute_derived(node, dependencies.as_slice())?;
140 let changed_value = self
141 .derived_values
142 .get(&node)
143 .is_none_or(|current| !current.equals(next_value.as_ref()));
144
145 if changed_value {
146 self.derived_values.insert(node, next_value);
147 changed.insert(node);
148 changed_derived.push(node);
149 }
150 }
151
152 Ok(RecomputeTrace {
153 recomputed,
154 changed: changed_derived,
155 })
156 }
157
158 pub(crate) fn compute_derived(
159 &self,
160 node: NodeId,
161 dependencies: &[NodeId],
162 ) -> GraphResult<Box<dyn StoredInput>> {
163 let spec = self
164 .derived_specs
165 .get(&node)
166 .ok_or(GraphError::UnknownNode(node))?;
167 let ctx = DeriveContext::new(self, dependencies);
168 spec.compute(&ctx)
169 .map_err(|error| GraphError::DeriveFailed(node, error))
170 }
171
172 pub(crate) fn derived_topological_order(&self) -> GraphResult<Vec<NodeId>> {
173 let mut order = Vec::new();
174 let mut temporary = BTreeSet::new();
175 let mut permanent = BTreeSet::new();
176
177 for node in self.nodes.keys().copied() {
178 if self
179 .nodes
180 .get(&node)
181 .is_some_and(|meta| meta.kind() == NodeKind::Derived)
182 {
183 self.visit_derived(node, &mut temporary, &mut permanent, &mut order)?;
184 }
185 }
186
187 Ok(order)
188 }
189
190 fn visit_derived(
191 &self,
192 node: NodeId,
193 temporary: &mut BTreeSet<NodeId>,
194 permanent: &mut BTreeSet<NodeId>,
195 order: &mut Vec<NodeId>,
196 ) -> GraphResult<()> {
197 if permanent.contains(&node) {
198 return Ok(());
199 }
200 if !temporary.insert(node) {
201 return Err(GraphError::CycleDetected(node));
202 }
203
204 let dependencies = self
205 .nodes
206 .get(&node)
207 .expect("derived node metadata exists")
208 .dependencies();
209 for dependency in dependencies.as_slice() {
210 if self
211 .nodes
212 .get(dependency)
213 .is_some_and(|meta| meta.kind() == NodeKind::Derived)
214 {
215 self.visit_derived(*dependency, temporary, permanent, order)?;
216 }
217 }
218
219 temporary.remove(&node);
220 permanent.insert(node);
221 order.push(node);
222 Ok(())
223 }
224}
225
226pub(crate) struct RecomputeTrace {
227 pub(crate) recomputed: Vec<NodeId>,
228 pub(crate) changed: Vec<NodeId>,
229}