1use crate::input::{StoredInput, boxed_input, downcast_input};
2use crate::{DerivedNode, Graph, GraphError, GraphResult, InputNode, NodeId, NodeKind};
3use std::collections::BTreeSet;
4use std::sync::Arc;
5
6type ComputeFn<C, O> = dyn for<'ctx> Fn(&DeriveContext<'ctx, C, O>) -> Result<Box<dyn StoredInput>, DeriveError>
7 + Send
8 + Sync;
9
10pub(crate) struct DerivedSpec<C, O> {
11 compute: Arc<ComputeFn<C, O>>,
12}
13
14impl<C, O> Clone for DerivedSpec<C, O> {
15 fn clone(&self) -> Self {
16 Self {
17 compute: Arc::clone(&self.compute),
18 }
19 }
20}
21
22impl<C, O> DerivedSpec<C, O> {
23 pub(crate) fn new<T, F>(derive: F) -> Self
24 where
25 T: Clone + PartialEq + Send + Sync + 'static,
26 F: for<'ctx> Fn(&DeriveContext<'ctx, C, O>) -> Result<T, DeriveError>
27 + Send
28 + Sync
29 + 'static,
30 {
31 Self {
32 compute: Arc::new(move |ctx| derive(ctx).map(boxed_input)),
33 }
34 }
35
36 pub(crate) fn compute(
37 &self,
38 ctx: &DeriveContext<'_, C, O>,
39 ) -> Result<Box<dyn StoredInput>, DeriveError> {
40 (self.compute)(ctx)
41 }
42}
43
44pub struct DeriveContext<'graph, C = (), O = ()> {
46 graph: &'graph Graph<C, O>,
47 declared_dependencies: &'graph [NodeId],
48}
49
50impl<'graph, C, O> DeriveContext<'graph, C, O> {
51 pub(crate) fn new(graph: &'graph Graph<C, O>, declared_dependencies: &'graph [NodeId]) -> Self {
52 Self {
53 graph,
54 declared_dependencies,
55 }
56 }
57
58 pub fn input<T>(&self, input: InputNode<T>) -> Result<&'graph T, DeriveError>
60 where
61 T: Clone + PartialEq + Send + Sync + 'static,
62 {
63 let node = input.id();
64 self.require_declared(node)?;
65 self.graph
66 .input_values
67 .get(&node)
68 .and_then(|value| downcast_input::<T>(value.as_ref()))
69 .ok_or(DeriveError::MissingValue(node))
70 }
71
72 pub fn derived<T>(&self, derived: DerivedNode<T>) -> Result<&'graph T, DeriveError>
74 where
75 T: Clone + PartialEq + Send + Sync + 'static,
76 {
77 let node = derived.id();
78 self.require_declared(node)?;
79 self.graph
80 .derived_values
81 .get(&node)
82 .and_then(|value| downcast_input::<T>(value.as_ref()))
83 .ok_or(DeriveError::MissingValue(node))
84 }
85
86 fn require_declared(&self, node: NodeId) -> Result<(), DeriveError> {
87 if self.declared_dependencies.contains(&node) {
88 Ok(())
89 } else {
90 Err(DeriveError::UndeclaredDependency(node))
91 }
92 }
93}
94
95#[derive(Clone, Debug, Eq, PartialEq)]
97pub enum DeriveError {
98 UndeclaredDependency(NodeId),
100 MissingValue(NodeId),
102 WrongCollectionType(NodeId),
104 Message(String),
106}
107
108impl DeriveError {
109 pub fn message(message: impl Into<String>) -> Self {
111 Self::Message(message.into())
112 }
113}
114
115impl<C, O> Graph<C, O> {
116 pub(crate) fn recompute_dirty_derived(
117 &mut self,
118 initial_changed: &[NodeId],
119 ) -> GraphResult<RecomputeTrace> {
120 let order = self.derived_topological_order()?;
121 let mut changed: BTreeSet<NodeId> = initial_changed.iter().copied().collect();
122 let mut changed_derived = Vec::new();
123 let mut recomputed = Vec::new();
124
125 for node in order {
126 let dependencies = self
127 .nodes
128 .get(&node)
129 .expect("derived node metadata exists")
130 .dependencies()
131 .clone();
132 let is_dirty = changed.contains(&node)
133 || dependencies
134 .as_slice()
135 .iter()
136 .any(|dependency| changed.contains(dependency));
137
138 if !is_dirty {
139 continue;
140 }
141
142 recomputed.push(node);
143 let next_value = self.compute_derived(node, dependencies.as_slice())?;
144 let changed_value = self
145 .derived_values
146 .get(&node)
147 .is_none_or(|current| !current.equals(next_value.as_ref()));
148
149 if changed_value {
150 self.derived_values.insert(node, next_value);
151 changed.insert(node);
152 changed_derived.push(node);
153 }
154 }
155
156 Ok(RecomputeTrace {
157 recomputed,
158 changed: changed_derived,
159 })
160 }
161
162 pub(crate) fn compute_derived(
163 &self,
164 node: NodeId,
165 dependencies: &[NodeId],
166 ) -> GraphResult<Box<dyn StoredInput>> {
167 let spec = self
168 .derived_specs
169 .get(&node)
170 .ok_or(GraphError::UnknownNode(node))?;
171 let ctx = DeriveContext::new(self, dependencies);
172 spec.compute(&ctx)
173 .map_err(|error| GraphError::DeriveFailed(node, error))
174 }
175
176 pub(crate) fn derived_topological_order(&self) -> GraphResult<Vec<NodeId>> {
177 let mut order = Vec::new();
178 let mut temporary = BTreeSet::new();
179 let mut permanent = BTreeSet::new();
180
181 for node in self.nodes.keys().copied() {
182 if self
183 .nodes
184 .get(&node)
185 .is_some_and(|meta| meta.kind() == NodeKind::Derived)
186 {
187 self.visit_derived(node, &mut temporary, &mut permanent, &mut order)?;
188 }
189 }
190
191 Ok(order)
192 }
193
194 fn visit_derived(
195 &self,
196 node: NodeId,
197 temporary: &mut BTreeSet<NodeId>,
198 permanent: &mut BTreeSet<NodeId>,
199 order: &mut Vec<NodeId>,
200 ) -> GraphResult<()> {
201 if permanent.contains(&node) {
202 return Ok(());
203 }
204 if !temporary.insert(node) {
205 return Err(GraphError::CycleDetected(node));
206 }
207
208 let dependencies = self
209 .nodes
210 .get(&node)
211 .expect("derived node metadata exists")
212 .dependencies();
213 for dependency in dependencies.as_slice() {
214 if self
215 .nodes
216 .get(dependency)
217 .is_some_and(|meta| meta.kind() == NodeKind::Derived)
218 {
219 self.visit_derived(*dependency, temporary, permanent, order)?;
220 }
221 }
222
223 temporary.remove(&node);
224 permanent.insert(node);
225 order.push(node);
226 Ok(())
227 }
228}
229
230pub(crate) struct RecomputeTrace {
231 pub(crate) recomputed: Vec<NodeId>,
232 pub(crate) changed: Vec<NodeId>,
233}