vegafusion_core/task_graph/
graph.rs1use crate::error::{Result, ResultWithContext, VegaFusionError};
2use crate::proto::gen::tasks::{
3 IncomingEdge, NodeValueIndex, OutgoingEdge, Task, TaskGraph, TaskNode, Variable,
4};
5use crate::task_graph::scope::TaskScope;
6use petgraph::algo::toposort;
7use petgraph::graph::NodeIndex;
8use petgraph::prelude::EdgeRef;
9use petgraph::Direction;
10use std::collections::HashMap;
11
12use crate::task_graph::task_value::TaskValue;
13
14use crate::proto::gen::tasks::task::TaskKind;
15use crate::proto::gen::tasks::task_value::Data;
16use crate::proto::gen::tasks::TaskValue as ProtoTaskValue;
17use std::convert::TryFrom;
18use std::hash::{BuildHasher, Hash, Hasher};
19
20struct PetgraphEdge {
21 output_var: Option<Variable>,
22}
23
24pub type ScopedVariable = (Variable, Vec<u32>);
25
26impl TaskGraph {
27 pub fn new(tasks: Vec<Task>, task_scope: &TaskScope) -> Result<Self> {
28 let mut graph: petgraph::graph::DiGraph<ScopedVariable, PetgraphEdge> =
29 petgraph::graph::DiGraph::new();
30 let mut tasks_map: HashMap<ScopedVariable, (NodeIndex, Task)> = HashMap::new();
31
32 for task in tasks {
34 let scoped_var = (task.variable().clone(), task.scope.clone());
36 let node_index = graph.add_node(scoped_var.clone());
37 tasks_map.insert(scoped_var, (node_index, task));
38 }
39
40 for (node_index, task) in tasks_map.values() {
42 let usage_scope = task.scope();
43 let input_vars = task.input_vars();
44 for input_var in input_vars {
45 let resolved = task_scope.resolve_scope(&input_var.var, usage_scope)?;
46 let input_scoped_var = (resolved.var.clone(), resolved.scope.clone());
47 let (input_node_index, _) =
48 tasks_map.get(&input_scoped_var).with_context(|| {
49 format!(
50 "No variable {:?} with scope {:?}",
51 input_scoped_var.0, input_scoped_var.1
52 )
53 })?;
54
55 if input_node_index != node_index {
57 graph.add_edge(
60 *input_node_index,
61 *node_index,
62 PetgraphEdge {
63 output_var: resolved.output_var.clone(),
64 },
65 );
66 }
67 }
68 }
69
70 let toposorted: Vec<NodeIndex> = match toposort(&graph, None) {
72 Err(err) => {
73 return Err(VegaFusionError::internal(format!(
74 "failed to sort dependency graph topologically: {err:?}"
75 )))
76 }
77 Ok(toposorted) => toposorted,
78 };
79
80 let toposorted_node_indexes: HashMap<NodeIndex, usize> = toposorted
81 .iter()
82 .enumerate()
83 .map(|(sorted_index, node_index)| (*node_index, sorted_index))
84 .collect();
85
86 let task_nodes = toposorted
88 .iter()
89 .map(|node_index| {
90 let scoped_var = graph.node_weight(*node_index).unwrap();
91 let (_, task) = tasks_map.get(scoped_var).unwrap();
92
93 let outgoing_node_ids: Vec<_> = graph
95 .edges_directed(*node_index, Direction::Outgoing)
96 .map(|edge| edge.target())
97 .collect();
98
99 let outgoing: Vec<_> = outgoing_node_ids
100 .iter()
101 .map(|node_index| {
102 let sorted_index = *toposorted_node_indexes.get(node_index).unwrap() as u32;
103 OutgoingEdge {
104 target: sorted_index,
105 propagate: true,
106 }
107 })
108 .collect();
109
110 let incoming_node_ids: Vec<_> = graph
112 .edges_directed(*node_index, Direction::Incoming)
113 .map(|edge| (edge.source(), &edge.weight().output_var))
114 .collect();
115
116 let incoming_vars: HashMap<_, _> = incoming_node_ids
118 .iter()
119 .map(|(node_index, output_var)| {
120 let var = graph.node_weight(*node_index).unwrap().0.clone();
121 ((var, (*output_var).clone()), node_index)
122 })
123 .collect();
124
125 let incoming: Vec<_> = task
126 .input_vars()
127 .iter()
128 .filter_map(|var| {
129 let resolved = task_scope
130 .resolve_scope(&var.var, scoped_var.1.as_slice())
131 .unwrap();
132 let output_var = resolved.output_var.clone();
133 let resolved = (resolved.var, resolved.output_var);
134
135 let node_index = *incoming_vars.get(&resolved)?;
136 let sorted_index = *toposorted_node_indexes.get(node_index).unwrap() as u32;
137
138 if let Some(output_var) = output_var {
139 let weight = graph.node_weight(*node_index).unwrap();
140 let (_, input_task) = tasks_map.get(weight).unwrap();
141
142 let output_index = match input_task
143 .output_vars()
144 .iter()
145 .position(|v| v == &output_var)
146 {
147 Some(output_index) => output_index,
148 None => {
149 return Some(Err(VegaFusionError::internal(
150 "Failed to find output variable",
151 )))
152 }
153 };
154
155 Some(Ok(IncomingEdge {
156 source: sorted_index,
157 output: Some(output_index as u32),
158 }))
159 } else {
160 Some(Ok(IncomingEdge {
161 source: sorted_index,
162 output: None,
163 }))
164 }
165 })
166 .collect::<Result<Vec<_>>>()?;
167
168 Ok(TaskNode {
169 task: Some(task.clone()),
170 incoming,
171 outgoing,
172 id_fingerprint: 0,
173 state_fingerprint: 0,
174 })
175 })
176 .collect::<Result<Vec<_>>>()?;
177
178 let mut this = Self { nodes: task_nodes };
179
180 this.init_identity_fingerprints()?;
181 this.update_state_fingerprints()?;
182
183 Ok(this)
184 }
185
186 pub fn build_mapping(&self) -> HashMap<ScopedVariable, NodeValueIndex> {
187 let mut mapping: HashMap<ScopedVariable, NodeValueIndex> = Default::default();
188 for (node_index, node) in self.nodes.iter().enumerate() {
189 let task = node.task();
190 let _scope = task.scope.clone();
191 let scoped_var = (task.variable().clone(), task.scope.clone());
192 mapping.insert(scoped_var, NodeValueIndex::new(node_index as u32, None));
193
194 for (output_index, output_var) in task.output_vars().into_iter().enumerate() {
195 let scope_output_var = (output_var, task.scope.clone());
196 mapping.insert(
197 scope_output_var,
198 NodeValueIndex::new(node_index as u32, Some(output_index as u32)),
199 );
200 }
201 }
202 mapping
203 }
204
205 fn init_identity_fingerprints(&mut self) -> Result<()> {
206 let mut id_fingerprints: Vec<u64> = Vec::with_capacity(self.nodes.len());
208 for (i, node) in self.nodes.iter().enumerate() {
209 let task = node.task();
210 let mut hasher = ahash::RandomState::with_seed(123).build_hasher();
211 if let TaskKind::Value(value) = task.task_kind() {
212 task.variable().hash(&mut hasher);
215 task.scope.hash(&mut hasher);
216 match value.data.as_ref().unwrap() {
217 Data::Scalar(_) => "scalar".hash(&mut hasher),
218 Data::Table(_) => "data".hash(&mut hasher),
219 }
220 } else {
221 for parent_index in self.parent_indices(i)? {
223 id_fingerprints[parent_index].hash(&mut hasher);
224 }
225
226 task.hash(&mut hasher)
228 }
229
230 id_fingerprints.push(hasher.finish());
231 }
232
233 self.nodes
235 .iter_mut()
236 .zip(id_fingerprints)
237 .for_each(|(node, fingerprint)| {
238 node.id_fingerprint = fingerprint;
239 });
240
241 Ok(())
242 }
243
244 pub fn update_state_fingerprints(&mut self) -> Result<Vec<usize>> {
246 let mut state_fingerprints: Vec<u64> = Vec::with_capacity(self.nodes.len());
248 for (i, node) in self.nodes.iter().enumerate() {
249 let task = node.task();
250 let mut hasher = ahash::RandomState::with_seed(123).build_hasher();
251
252 if matches!(task.task_kind(), TaskKind::Value(_)) {
253 task.hash(&mut hasher);
255 } else {
256 for parent_index in self.parent_indices(i)? {
258 state_fingerprints[parent_index].hash(&mut hasher);
259 }
260
261 node.id_fingerprint.hash(&mut hasher);
263 }
264
265 state_fingerprints.push(hasher.finish());
266 }
267
268 let updated: Vec<_> = self
270 .nodes
271 .iter_mut()
272 .zip(state_fingerprints)
273 .enumerate()
274 .filter_map(|(node_index, (node, fingerprint))| {
275 if node.state_fingerprint != fingerprint {
276 node.state_fingerprint = fingerprint;
277 Some(node_index)
278 } else {
279 None
280 }
281 })
282 .collect();
283
284 Ok(updated)
285 }
286
287 pub fn update_value(
288 &mut self,
289 node_index: usize,
290 value: TaskValue,
291 ) -> Result<Vec<NodeValueIndex>> {
292 let node = self
293 .nodes
294 .get_mut(node_index)
295 .ok_or_else(|| VegaFusionError::internal("Missing node"))?;
296 if !matches!(node.task().task_kind(), TaskKind::Value(_)) {
297 return Err(VegaFusionError::internal(
298 "Task with index {} is not a Value",
299 ));
300 }
301
302 node.task = Some(Task {
303 variable: node.task().variable.clone(),
304 scope: node.task().scope.clone(),
305 task_kind: Some(TaskKind::Value(ProtoTaskValue::try_from(&value)?)),
306 tz_config: None,
307 });
308
309 let mut node_value_indexes = Vec::new();
310 for node_index in self.update_state_fingerprints()? {
311 node_value_indexes.push(NodeValueIndex::new(node_index as u32, None));
312
313 for output_index in 0..self
314 .nodes
315 .get(node_index)
316 .unwrap()
317 .task()
318 .output_vars()
319 .len()
320 {
321 node_value_indexes.push(NodeValueIndex::new(
322 node_index as u32,
323 Some(output_index as u32),
324 ));
325 }
326 }
327 Ok(node_value_indexes)
328 }
329
330 pub fn parent_nodes(&self, node_index: usize) -> Result<Vec<&TaskNode>> {
331 let node = self
332 .nodes
333 .get(node_index)
334 .with_context(|| format!("Node index {node_index} out of bounds"))?;
335 Ok(node
336 .incoming
337 .iter()
338 .map(|edge| self.nodes.get(edge.source as usize).unwrap())
339 .collect())
340 }
341
342 pub fn parent_indices(&self, node_index: usize) -> Result<Vec<usize>> {
343 let node = self
344 .nodes
345 .get(node_index)
346 .with_context(|| format!("Node index {node_index} out of bounds"))?;
347 Ok(node
348 .incoming
349 .iter()
350 .map(|edge| edge.source as usize)
351 .collect())
352 }
353
354 pub fn child_nodes(&self, node_index: usize) -> Result<Vec<&TaskNode>> {
355 let node = self
356 .nodes
357 .get(node_index)
358 .with_context(|| format!("Node index {node_index} out of bounds"))?;
359 Ok(node
360 .outgoing
361 .iter()
362 .map(|edge| self.nodes.get(edge.target as usize).unwrap())
363 .collect())
364 }
365
366 pub fn child_indices(&self, node_index: usize) -> Result<Vec<usize>> {
367 let node = self
368 .nodes
369 .get(node_index)
370 .with_context(|| format!("Node index {node_index} out of bounds"))?;
371 Ok(node
372 .outgoing
373 .iter()
374 .map(|edge| edge.target as usize)
375 .collect())
376 }
377
378 pub fn node(&self, node_index: usize) -> Result<&TaskNode> {
379 self.nodes
380 .get(node_index)
381 .with_context(|| format!("Node index {node_index} out of bounds"))
382 }
383}
384
385impl NodeValueIndex {
386 pub fn new(node_index: u32, output_index: Option<u32>) -> Self {
387 Self {
388 node_index,
389 output_index,
390 }
391 }
392}
393
394impl TaskNode {
395 pub fn task(&self) -> &Task {
396 self.task.as_ref().unwrap()
397 }
398}