1use crate::input::{StoredInput, boxed_input, downcast_input};
2use crate::{DerivedNode, Graph, GraphError, GraphResult, InputNode, NodeId, NodeKind};
3use std::collections::BTreeSet;
4use std::sync::Arc;
5
6type ComputeFn<C, O> =
7 dyn for<'ctx> Fn(&DeriveContext<'ctx, C, O>) -> Result<Box<dyn StoredInput>, DeriveError>;
8
9pub(crate) struct DerivedSpec<C, O> {
10 compute: Arc<ComputeFn<C, O>>,
11}
12
13impl<C, O> Clone for DerivedSpec<C, O> {
14 fn clone(&self) -> Self {
15 Self {
16 compute: Arc::clone(&self.compute),
17 }
18 }
19}
20
21impl<C, O> DerivedSpec<C, O> {
22 pub(crate) fn new<T, F>(derive: F) -> Self
23 where
24 T: Clone + PartialEq + 'static,
25 F: for<'ctx> Fn(&DeriveContext<'ctx, C, O>) -> Result<T, DeriveError> + 'static,
26 {
27 Self {
28 compute: Arc::new(move |ctx| derive(ctx).map(boxed_input)),
29 }
30 }
31
32 pub(crate) fn compute(
33 &self,
34 ctx: &DeriveContext<'_, C, O>,
35 ) -> Result<Box<dyn StoredInput>, DeriveError> {
36 (self.compute)(ctx)
37 }
38}
39
40pub struct DeriveContext<'graph, C = (), O = ()> {
42 graph: &'graph Graph<C, O>,
43 declared_dependencies: &'graph [NodeId],
44}
45
46impl<'graph, C, O> DeriveContext<'graph, C, O> {
47 pub(crate) fn new(graph: &'graph Graph<C, O>, declared_dependencies: &'graph [NodeId]) -> Self {
48 Self {
49 graph,
50 declared_dependencies,
51 }
52 }
53
54 pub fn input<T>(&self, input: InputNode<T>) -> Result<&'graph T, DeriveError>
56 where
57 T: Clone + PartialEq + 'static,
58 {
59 let node = input.id();
60 self.require_declared(node)?;
61 self.graph
62 .input_values
63 .get(&node)
64 .and_then(|value| downcast_input::<T>(value.as_ref()))
65 .ok_or(DeriveError::MissingValue(node))
66 }
67
68 pub fn derived<T>(&self, derived: DerivedNode<T>) -> Result<&'graph T, DeriveError>
70 where
71 T: Clone + PartialEq + 'static,
72 {
73 let node = derived.id();
74 self.require_declared(node)?;
75 self.graph
76 .derived_values
77 .get(&node)
78 .and_then(|value| downcast_input::<T>(value.as_ref()))
79 .ok_or(DeriveError::MissingValue(node))
80 }
81
82 fn require_declared(&self, node: NodeId) -> Result<(), DeriveError> {
83 if self.declared_dependencies.contains(&node) {
84 Ok(())
85 } else {
86 Err(DeriveError::UndeclaredDependency(node))
87 }
88 }
89}
90
91#[derive(Clone, Debug, Eq, PartialEq)]
93pub enum DeriveError {
94 UndeclaredDependency(NodeId),
96 MissingValue(NodeId),
98 WrongCollectionType(NodeId),
100 Message(String),
102}
103
104impl DeriveError {
105 pub fn message(message: impl Into<String>) -> Self {
107 Self::Message(message.into())
108 }
109}
110
111impl<C, O> Graph<C, O> {
112 pub(crate) fn recompute_dirty_derived(
113 &mut self,
114 initial_changed: &[NodeId],
115 ) -> GraphResult<Vec<NodeId>> {
116 let order = self.derived_topological_order()?;
117 let mut changed: BTreeSet<NodeId> = initial_changed.iter().copied().collect();
118 let mut changed_derived = Vec::new();
119
120 for node in order {
121 let dependencies = self
122 .nodes
123 .get(&node)
124 .expect("derived node metadata exists")
125 .dependencies()
126 .clone();
127 let is_dirty = changed.contains(&node)
128 || dependencies
129 .as_slice()
130 .iter()
131 .any(|dependency| changed.contains(dependency));
132
133 if !is_dirty {
134 continue;
135 }
136
137 let next_value = self.compute_derived(node, dependencies.as_slice())?;
138 let changed_value = self
139 .derived_values
140 .get(&node)
141 .is_none_or(|current| !current.equals(next_value.as_ref()));
142
143 if changed_value {
144 self.derived_values.insert(node, next_value);
145 changed.insert(node);
146 changed_derived.push(node);
147 }
148 }
149
150 Ok(changed_derived)
151 }
152
153 pub(crate) fn compute_derived(
154 &self,
155 node: NodeId,
156 dependencies: &[NodeId],
157 ) -> GraphResult<Box<dyn StoredInput>> {
158 let spec = self
159 .derived_specs
160 .get(&node)
161 .ok_or(GraphError::UnknownNode(node))?;
162 let ctx = DeriveContext::new(self, dependencies);
163 spec.compute(&ctx)
164 .map_err(|error| GraphError::DeriveFailed(node, error))
165 }
166
167 pub(crate) fn derived_topological_order(&self) -> GraphResult<Vec<NodeId>> {
168 let mut order = Vec::new();
169 let mut temporary = BTreeSet::new();
170 let mut permanent = BTreeSet::new();
171
172 for node in self.nodes.keys().copied() {
173 if self
174 .nodes
175 .get(&node)
176 .is_some_and(|meta| meta.kind() == NodeKind::Derived)
177 {
178 self.visit_derived(node, &mut temporary, &mut permanent, &mut order)?;
179 }
180 }
181
182 Ok(order)
183 }
184
185 fn visit_derived(
186 &self,
187 node: NodeId,
188 temporary: &mut BTreeSet<NodeId>,
189 permanent: &mut BTreeSet<NodeId>,
190 order: &mut Vec<NodeId>,
191 ) -> GraphResult<()> {
192 if permanent.contains(&node) {
193 return Ok(());
194 }
195 if !temporary.insert(node) {
196 return Err(GraphError::CycleDetected(node));
197 }
198
199 let dependencies = self
200 .nodes
201 .get(&node)
202 .expect("derived node metadata exists")
203 .dependencies();
204 for dependency in dependencies.as_slice() {
205 if self
206 .nodes
207 .get(dependency)
208 .is_some_and(|meta| meta.kind() == NodeKind::Derived)
209 {
210 self.visit_derived(*dependency, temporary, permanent, order)?;
211 }
212 }
213
214 temporary.remove(&node);
215 permanent.insert(node);
216 order.push(node);
217 Ok(())
218 }
219}