vegafusion_core/task_graph/
graph.rs

1use crate::error::{Result, ResultWithContext, VegaFusionError};
2use crate::proto::gen::tasks::{
3    IncomingEdge, NodeValueIndex, OutgoingEdge, Task, TaskGraph, TaskNode, Variable,
4};
5use crate::task_graph::scope::TaskScope;
6use petgraph::algo::toposort;
7use petgraph::graph::NodeIndex;
8use petgraph::prelude::EdgeRef;
9use petgraph::Direction;
10use std::collections::HashMap;
11
12use crate::task_graph::task_value::TaskValue;
13
14use crate::proto::gen::tasks::task::TaskKind;
15use crate::proto::gen::tasks::task_value::Data;
16use crate::proto::gen::tasks::TaskValue as ProtoTaskValue;
17use std::convert::TryFrom;
18use std::hash::{BuildHasher, Hash, Hasher};
19
20struct PetgraphEdge {
21    output_var: Option<Variable>,
22}
23
24pub type ScopedVariable = (Variable, Vec<u32>);
25
26impl TaskGraph {
27    pub fn new(tasks: Vec<Task>, task_scope: &TaskScope) -> Result<Self> {
28        let mut graph: petgraph::graph::DiGraph<ScopedVariable, PetgraphEdge> =
29            petgraph::graph::DiGraph::new();
30        let mut tasks_map: HashMap<ScopedVariable, (NodeIndex, Task)> = HashMap::new();
31
32        // Add graph nodes
33        for task in tasks {
34            // Add scope variable
35            let scoped_var = (task.variable().clone(), task.scope.clone());
36            let node_index = graph.add_node(scoped_var.clone());
37            tasks_map.insert(scoped_var, (node_index, task));
38        }
39
40        // Resolve and add edges
41        for (node_index, task) in tasks_map.values() {
42            let usage_scope = task.scope();
43            let input_vars = task.input_vars();
44            for input_var in input_vars {
45                let resolved = task_scope.resolve_scope(&input_var.var, usage_scope)?;
46                let input_scoped_var = (resolved.var.clone(), resolved.scope.clone());
47                let (input_node_index, _) =
48                    tasks_map.get(&input_scoped_var).with_context(|| {
49                        format!(
50                            "No variable {:?} with scope {:?}",
51                            input_scoped_var.0, input_scoped_var.1
52                        )
53                    })?;
54
55                // Add graph edge
56                if input_node_index != node_index {
57                    // If a task depends on information generated by the task,that will be handled
58                    // internally to the task. So we avoid making a cycle
59                    graph.add_edge(
60                        *input_node_index,
61                        *node_index,
62                        PetgraphEdge {
63                            output_var: resolved.output_var.clone(),
64                        },
65                    );
66                }
67            }
68        }
69
70        // Create mapping from toposorted node_index to the final linear node index
71        let toposorted: Vec<NodeIndex> = match toposort(&graph, None) {
72            Err(err) => {
73                return Err(VegaFusionError::internal(format!(
74                    "failed to sort dependency graph topologically: {err:?}"
75                )))
76            }
77            Ok(toposorted) => toposorted,
78        };
79
80        let toposorted_node_indexes: HashMap<NodeIndex, usize> = toposorted
81            .iter()
82            .enumerate()
83            .map(|(sorted_index, node_index)| (*node_index, sorted_index))
84            .collect();
85
86        // Create linear vec of TaskNodes, with edges as sorted index references to nodes
87        let task_nodes = toposorted
88            .iter()
89            .map(|node_index| {
90                let scoped_var = graph.node_weight(*node_index).unwrap();
91                let (_, task) = tasks_map.get(scoped_var).unwrap();
92
93                // Collect outgoing node indexes
94                let outgoing_node_ids: Vec<_> = graph
95                    .edges_directed(*node_index, Direction::Outgoing)
96                    .map(|edge| edge.target())
97                    .collect();
98
99                let outgoing: Vec<_> = outgoing_node_ids
100                    .iter()
101                    .map(|node_index| {
102                        let sorted_index = *toposorted_node_indexes.get(node_index).unwrap() as u32;
103                        OutgoingEdge {
104                            target: sorted_index,
105                            propagate: true,
106                        }
107                    })
108                    .collect();
109
110                // Collect incoming node indexes
111                let incoming_node_ids: Vec<_> = graph
112                    .edges_directed(*node_index, Direction::Incoming)
113                    .map(|edge| (edge.source(), &edge.weight().output_var))
114                    .collect();
115
116                // Sort incoming nodes to match order expected by the task
117                let incoming_vars: HashMap<_, _> = incoming_node_ids
118                    .iter()
119                    .map(|(node_index, output_var)| {
120                        let var = graph.node_weight(*node_index).unwrap().0.clone();
121                        ((var, (*output_var).clone()), node_index)
122                    })
123                    .collect();
124
125                let incoming: Vec<_> = task
126                    .input_vars()
127                    .iter()
128                    .filter_map(|var| {
129                        let resolved = task_scope
130                            .resolve_scope(&var.var, scoped_var.1.as_slice())
131                            .unwrap();
132                        let output_var = resolved.output_var.clone();
133                        let resolved = (resolved.var, resolved.output_var);
134
135                        let node_index = *incoming_vars.get(&resolved)?;
136                        let sorted_index = *toposorted_node_indexes.get(node_index).unwrap() as u32;
137
138                        if let Some(output_var) = output_var {
139                            let weight = graph.node_weight(*node_index).unwrap();
140                            let (_, input_task) = tasks_map.get(weight).unwrap();
141
142                            let output_index = match input_task
143                                .output_vars()
144                                .iter()
145                                .position(|v| v == &output_var)
146                            {
147                                Some(output_index) => output_index,
148                                None => {
149                                    return Some(Err(VegaFusionError::internal(
150                                        "Failed to find output variable",
151                                    )))
152                                }
153                            };
154
155                            Some(Ok(IncomingEdge {
156                                source: sorted_index,
157                                output: Some(output_index as u32),
158                            }))
159                        } else {
160                            Some(Ok(IncomingEdge {
161                                source: sorted_index,
162                                output: None,
163                            }))
164                        }
165                    })
166                    .collect::<Result<Vec<_>>>()?;
167
168                Ok(TaskNode {
169                    task: Some(task.clone()),
170                    incoming,
171                    outgoing,
172                    id_fingerprint: 0,
173                    state_fingerprint: 0,
174                })
175            })
176            .collect::<Result<Vec<_>>>()?;
177
178        let mut this = Self { nodes: task_nodes };
179
180        this.init_identity_fingerprints()?;
181        this.update_state_fingerprints()?;
182
183        Ok(this)
184    }
185
186    pub fn build_mapping(&self) -> HashMap<ScopedVariable, NodeValueIndex> {
187        let mut mapping: HashMap<ScopedVariable, NodeValueIndex> = Default::default();
188        for (node_index, node) in self.nodes.iter().enumerate() {
189            let task = node.task();
190            let _scope = task.scope.clone();
191            let scoped_var = (task.variable().clone(), task.scope.clone());
192            mapping.insert(scoped_var, NodeValueIndex::new(node_index as u32, None));
193
194            for (output_index, output_var) in task.output_vars().into_iter().enumerate() {
195                let scope_output_var = (output_var, task.scope.clone());
196                mapping.insert(
197                    scope_output_var,
198                    NodeValueIndex::new(node_index as u32, Some(output_index as u32)),
199                );
200            }
201        }
202        mapping
203    }
204
205    fn init_identity_fingerprints(&mut self) -> Result<()> {
206        // Compute new identity fingerprints
207        let mut id_fingerprints: Vec<u64> = Vec::with_capacity(self.nodes.len());
208        for (i, node) in self.nodes.iter().enumerate() {
209            let task = node.task();
210            let mut hasher = ahash::RandomState::with_seed(123).build_hasher();
211            if let TaskKind::Value(value) = task.task_kind() {
212                // Only hash the distinction between Scalar and Table, not the value itself.
213                // The state fingerprint takes the value into account.
214                task.variable().hash(&mut hasher);
215                task.scope.hash(&mut hasher);
216                match value.data.as_ref().unwrap() {
217                    Data::Scalar(_) => "scalar".hash(&mut hasher),
218                    Data::Table(_) => "data".hash(&mut hasher),
219                }
220            } else {
221                // Include id_fingerprint of parents in the hash
222                for parent_index in self.parent_indices(i)? {
223                    id_fingerprints[parent_index].hash(&mut hasher);
224                }
225
226                // Include current task in hash
227                task.hash(&mut hasher)
228            }
229
230            id_fingerprints.push(hasher.finish());
231        }
232
233        // Apply fingerprints
234        self.nodes
235            .iter_mut()
236            .zip(id_fingerprints)
237            .for_each(|(node, fingerprint)| {
238                node.id_fingerprint = fingerprint;
239            });
240
241        Ok(())
242    }
243
244    /// Update state finger prints of nodes, and return indices of nodes that were updated
245    pub fn update_state_fingerprints(&mut self) -> Result<Vec<usize>> {
246        // Compute new identity fingerprints
247        let mut state_fingerprints: Vec<u64> = Vec::with_capacity(self.nodes.len());
248        for (i, node) in self.nodes.iter().enumerate() {
249            let task = node.task();
250            let mut hasher = ahash::RandomState::with_seed(123).build_hasher();
251
252            if matches!(task.task_kind(), TaskKind::Value(_)) {
253                // Hash the task with inline TaskValue
254                task.hash(&mut hasher);
255            } else {
256                // Include state fingerprint of parents in the hash
257                for parent_index in self.parent_indices(i)? {
258                    state_fingerprints[parent_index].hash(&mut hasher);
259                }
260
261                // Include id fingerprint of current task
262                node.id_fingerprint.hash(&mut hasher);
263            }
264
265            state_fingerprints.push(hasher.finish());
266        }
267
268        // Apply fingerprints
269        let updated: Vec<_> = self
270            .nodes
271            .iter_mut()
272            .zip(state_fingerprints)
273            .enumerate()
274            .filter_map(|(node_index, (node, fingerprint))| {
275                if node.state_fingerprint != fingerprint {
276                    node.state_fingerprint = fingerprint;
277                    Some(node_index)
278                } else {
279                    None
280                }
281            })
282            .collect();
283
284        Ok(updated)
285    }
286
287    pub fn update_value(
288        &mut self,
289        node_index: usize,
290        value: TaskValue,
291    ) -> Result<Vec<NodeValueIndex>> {
292        let node = self
293            .nodes
294            .get_mut(node_index)
295            .ok_or_else(|| VegaFusionError::internal("Missing node"))?;
296        if !matches!(node.task().task_kind(), TaskKind::Value(_)) {
297            return Err(VegaFusionError::internal(
298                "Task with index {} is not a Value",
299            ));
300        }
301
302        node.task = Some(Task {
303            variable: node.task().variable.clone(),
304            scope: node.task().scope.clone(),
305            task_kind: Some(TaskKind::Value(ProtoTaskValue::try_from(&value)?)),
306            tz_config: None,
307        });
308
309        let mut node_value_indexes = Vec::new();
310        for node_index in self.update_state_fingerprints()? {
311            node_value_indexes.push(NodeValueIndex::new(node_index as u32, None));
312
313            for output_index in 0..self
314                .nodes
315                .get(node_index)
316                .unwrap()
317                .task()
318                .output_vars()
319                .len()
320            {
321                node_value_indexes.push(NodeValueIndex::new(
322                    node_index as u32,
323                    Some(output_index as u32),
324                ));
325            }
326        }
327        Ok(node_value_indexes)
328    }
329
330    pub fn parent_nodes(&self, node_index: usize) -> Result<Vec<&TaskNode>> {
331        let node = self
332            .nodes
333            .get(node_index)
334            .with_context(|| format!("Node index {node_index} out of bounds"))?;
335        Ok(node
336            .incoming
337            .iter()
338            .map(|edge| self.nodes.get(edge.source as usize).unwrap())
339            .collect())
340    }
341
342    pub fn parent_indices(&self, node_index: usize) -> Result<Vec<usize>> {
343        let node = self
344            .nodes
345            .get(node_index)
346            .with_context(|| format!("Node index {node_index} out of bounds"))?;
347        Ok(node
348            .incoming
349            .iter()
350            .map(|edge| edge.source as usize)
351            .collect())
352    }
353
354    pub fn child_nodes(&self, node_index: usize) -> Result<Vec<&TaskNode>> {
355        let node = self
356            .nodes
357            .get(node_index)
358            .with_context(|| format!("Node index {node_index} out of bounds"))?;
359        Ok(node
360            .outgoing
361            .iter()
362            .map(|edge| self.nodes.get(edge.target as usize).unwrap())
363            .collect())
364    }
365
366    pub fn child_indices(&self, node_index: usize) -> Result<Vec<usize>> {
367        let node = self
368            .nodes
369            .get(node_index)
370            .with_context(|| format!("Node index {node_index} out of bounds"))?;
371        Ok(node
372            .outgoing
373            .iter()
374            .map(|edge| edge.target as usize)
375            .collect())
376    }
377
378    pub fn node(&self, node_index: usize) -> Result<&TaskNode> {
379        self.nodes
380            .get(node_index)
381            .with_context(|| format!("Node index {node_index} out of bounds"))
382    }
383}
384
385impl NodeValueIndex {
386    pub fn new(node_index: u32, output_index: Option<u32>) -> Self {
387        Self {
388            node_index,
389            output_index,
390        }
391    }
392}
393
394impl TaskNode {
395    pub fn task(&self) -> &Task {
396        self.task.as_ref().unwrap()
397    }
398}