1use crate::input::{StoredInput, boxed_input, downcast_input};
2use crate::{DerivedNode, Graph, GraphError, GraphResult, InputNode, NodeId, NodeKind};
3use std::collections::BTreeSet;
4use std::sync::Arc;
5
6type ComputeFn<C> = dyn for<'ctx> Fn(&DeriveContext<'ctx, C>) -> Result<Box<dyn StoredInput>, DeriveError>
7 + Send
8 + Sync;
9
10pub(crate) struct DerivedSpec<C> {
11 compute: Arc<ComputeFn<C>>,
12}
13
14impl<C> Clone for DerivedSpec<C> {
15 fn clone(&self) -> Self {
16 Self {
17 compute: Arc::clone(&self.compute),
18 }
19 }
20}
21
22impl<C> DerivedSpec<C> {
23 pub(crate) fn new<T, F>(derive: F) -> Self
24 where
25 T: Clone + PartialEq + Send + Sync + 'static,
26 F: for<'ctx> Fn(&DeriveContext<'ctx, C>) -> Result<T, DeriveError> + Send + Sync + 'static,
27 {
28 Self {
29 compute: Arc::new(move |ctx| derive(ctx).map(boxed_input)),
30 }
31 }
32
33 pub(crate) fn compute(
34 &self,
35 ctx: &DeriveContext<'_, C>,
36 ) -> Result<Box<dyn StoredInput>, DeriveError> {
37 (self.compute)(ctx)
38 }
39}
40
41pub struct DeriveContext<'graph, C = ()> {
43 graph: &'graph Graph<C>,
44 declared_dependencies: &'graph [NodeId],
45}
46
47impl<'graph, C> DeriveContext<'graph, C> {
48 pub(crate) fn new(graph: &'graph Graph<C>, declared_dependencies: &'graph [NodeId]) -> Self {
49 Self {
50 graph,
51 declared_dependencies,
52 }
53 }
54
55 pub fn input<T>(&self, input: InputNode<T>) -> Result<&'graph T, DeriveError>
57 where
58 T: Clone + PartialEq + Send + Sync + 'static,
59 {
60 let node = input.id();
61 self.require_declared(node)?;
62 self.graph
63 .input_values
64 .get(&node)
65 .and_then(|value| downcast_input::<T>(value.as_ref()))
66 .ok_or(DeriveError::MissingValue(node))
67 }
68
69 pub fn derived<T>(&self, derived: DerivedNode<T>) -> Result<&'graph T, DeriveError>
71 where
72 T: Clone + PartialEq + Send + Sync + 'static,
73 {
74 let node = derived.id();
75 self.require_declared(node)?;
76 self.graph
77 .derived_values
78 .get(&node)
79 .and_then(|value| downcast_input::<T>(value.as_ref()))
80 .ok_or(DeriveError::MissingValue(node))
81 }
82
83 fn require_declared(&self, node: NodeId) -> Result<(), DeriveError> {
84 if self.declared_dependencies.contains(&node) {
85 Ok(())
86 } else {
87 Err(DeriveError::UndeclaredDependency(node))
88 }
89 }
90}
91
92#[derive(Clone, Debug, Eq, PartialEq)]
94pub enum DeriveError {
95 UndeclaredDependency(NodeId),
97 MissingValue(NodeId),
99 WrongCollectionType(NodeId),
101 Message(String),
103}
104
105impl DeriveError {
106 pub fn message(message: impl Into<String>) -> Self {
108 Self::Message(message.into())
109 }
110}
111
112impl<C> Graph<C> {
113 pub(crate) fn recompute_dirty_derived(
114 &mut self,
115 initial_changed: &[NodeId],
116 ) -> GraphResult<RecomputeTrace> {
117 let order = self.derived_topological_order()?;
118 let mut changed: BTreeSet<NodeId> = initial_changed.iter().copied().collect();
119 let mut changed_derived = Vec::new();
120 let mut recomputed = Vec::new();
121
122 for node in order {
123 let dependencies = self
124 .nodes
125 .get(&node)
126 .expect("derived node metadata exists")
127 .dependencies();
128 let is_dirty = changed.contains(&node)
129 || dependencies
130 .as_slice()
131 .iter()
132 .any(|dependency| changed.contains(dependency));
133
134 if !is_dirty {
135 continue;
136 }
137
138 recomputed.push(node);
139 let next_value = self.compute_derived(node, dependencies.as_slice())?;
140 let changed_value = self
141 .derived_values
142 .get(&node)
143 .is_none_or(|current| !current.equals(next_value.as_ref()));
144
145 if changed_value {
146 self.derived_values.insert(node, next_value);
147 changed.insert(node);
148 changed_derived.push(node);
149 }
150 }
151
152 Ok(RecomputeTrace {
153 recomputed,
154 changed: changed_derived,
155 })
156 }
157
158 pub(crate) fn compute_derived(
159 &self,
160 node: NodeId,
161 dependencies: &[NodeId],
162 ) -> GraphResult<Box<dyn StoredInput>> {
163 let spec = self
164 .derived_specs
165 .get(&node)
166 .ok_or(GraphError::UnknownNode(node))?;
167 let ctx = DeriveContext::new(self, dependencies);
168 spec.compute(&ctx)
169 .map_err(|error| GraphError::DeriveFailed(node, error))
170 }
171
172 pub(crate) fn derived_topological_order(&mut self) -> GraphResult<Vec<NodeId>> {
173 self.topological_order_for_kind(NodeKind::Derived)
174 }
175}
176
177pub(crate) struct RecomputeTrace {
178 pub(crate) recomputed: Vec<NodeId>,
179 pub(crate) changed: Vec<NodeId>,
180}