vegafusion_runtime/task_graph/
runtime.rs

1use crate::datafusion::context::make_datafusion_context;
2use crate::task_graph::cache::VegaFusionCache;
3use crate::task_graph::task::TaskCall;
4use crate::task_graph::timezone::RuntimeTzConfig;
5use async_recursion::async_recursion;
6use cfg_if::cfg_if;
7use datafusion::prelude::SessionContext;
8use futures_util::{future, FutureExt};
9use std::any::Any;
10use std::collections::HashMap;
11use std::convert::TryInto;
12use std::panic::AssertUnwindSafe;
13use std::sync::Arc;
14use vegafusion_core::data::dataset::VegaFusionDataset;
15use vegafusion_core::error::{Result, ResultWithContext, VegaFusionError};
16use vegafusion_core::proto::gen::tasks::inline_dataset::Dataset;
17use vegafusion_core::proto::gen::tasks::{
18    task::TaskKind, InlineDataset, InlineDatasetTable, NodeValueIndex, TaskGraph,
19};
20use vegafusion_core::runtime::VegaFusionRuntimeTrait;
21use vegafusion_core::task_graph::task_value::{NamedTaskValue, TaskValue};
22
23#[cfg(feature = "proto")]
24use {
25    datafusion_proto::bytes::{logical_plan_from_bytes, logical_plan_to_bytes},
26    vegafusion_core::proto::gen::tasks::InlineDatasetPlan,
27};
28
29type CacheValue = (TaskValue, Vec<TaskValue>);
30
31#[derive(Clone)]
32pub struct VegaFusionRuntime {
33    pub cache: VegaFusionCache,
34    pub ctx: Arc<SessionContext>,
35}
36
37impl VegaFusionRuntime {
38    pub fn new(cache: Option<VegaFusionCache>) -> Self {
39        Self {
40            cache: cache.unwrap_or_else(|| VegaFusionCache::new(Some(32), None)),
41            ctx: Arc::new(make_datafusion_context()),
42        }
43    }
44
45    pub async fn get_node_value(
46        &self,
47        task_graph: Arc<TaskGraph>,
48        node_value_index: &NodeValueIndex,
49        inline_datasets: HashMap<String, VegaFusionDataset>,
50    ) -> Result<TaskValue> {
51        // We shouldn't panic inside get_or_compute_node_value, but since this may be used
52        // in a server context, wrap in catch_unwind just in case.
53        let node_value = AssertUnwindSafe(get_or_compute_node_value(
54            task_graph,
55            node_value_index.node_index as usize,
56            self.cache.clone(),
57            inline_datasets,
58            self.ctx.clone(),
59        ))
60        .catch_unwind()
61        .await;
62
63        let mut node_value = node_value
64            .ok()
65            .with_context(|| "Unknown panic".to_string())??;
66
67        Ok(match node_value_index.output_index {
68            None => node_value.0,
69            Some(output_index) => node_value.1.remove(output_index as usize),
70        })
71    }
72
73    pub async fn clear_cache(&self) {
74        self.cache.clear().await;
75    }
76}
77
78#[async_trait::async_trait]
79impl VegaFusionRuntimeTrait for VegaFusionRuntime {
80    fn as_any(&self) -> &dyn Any {
81        self
82    }
83
84    async fn query_request(
85        &self,
86        task_graph: Arc<TaskGraph>,
87        indices: &[NodeValueIndex],
88        inline_datasets: &HashMap<String, VegaFusionDataset>,
89    ) -> Result<Vec<NamedTaskValue>> {
90        // Clone task_graph and task_graph_runtime for use in closure
91        let task_graph_runtime = self.clone();
92        let response_value_futures: Vec<_> = indices
93            .iter()
94            .map(|node_value_index| {
95                let node = task_graph
96                    .nodes
97                    .get(node_value_index.node_index as usize)
98                    .with_context(|| {
99                        format!(
100                            "Node index {} out of bounds for graph with size {}",
101                            node_value_index.node_index,
102                            task_graph.nodes.len()
103                        )
104                    })?;
105                let task = node.task();
106                let variable = match node_value_index.output_index {
107                    None => task.variable().clone(),
108                    Some(output_index) => task.output_vars()[output_index as usize].clone(),
109                };
110
111                let scope = node.task().scope.clone();
112
113                // Clone task_graph and task_graph_runtime for use in closure
114                let task_graph_runtime = task_graph_runtime.clone();
115                let task_graph = task_graph.clone();
116
117                Ok(async move {
118                    let value = task_graph_runtime
119                        .clone()
120                        .get_node_value(task_graph, node_value_index, inline_datasets.clone())
121                        .await?;
122
123                    Ok::<_, VegaFusionError>(NamedTaskValue {
124                        variable,
125                        scope,
126                        value,
127                    })
128                })
129            })
130            .collect::<Result<Vec<_>>>()?;
131
132        future::try_join_all(response_value_futures).await
133    }
134}
135
136#[async_recursion]
137async fn get_or_compute_node_value(
138    task_graph: Arc<TaskGraph>,
139    node_index: usize,
140    cache: VegaFusionCache,
141    inline_datasets: HashMap<String, VegaFusionDataset>,
142    ctx: Arc<SessionContext>,
143) -> Result<CacheValue> {
144    // Get the cache key for requested node
145    let node = task_graph.node(node_index).unwrap();
146    let task = node.task();
147
148    if let TaskKind::Value(value) = task.task_kind() {
149        // Root nodes are stored in the graph, so we don't add them to the cache
150        Ok((value.try_into().unwrap(), Vec::new()))
151    } else {
152        // Collect input node indices
153        let input_node_indexes = task_graph.parent_indices(node_index).unwrap();
154        let input_edges = node.incoming.clone();
155
156        // Clone task so we can move it to async block
157        let task = task.clone();
158        let tz_config = task.tz_config.clone().and_then(|tz_config| {
159            RuntimeTzConfig::try_new(&tz_config.local_tz, &tz_config.default_input_tz).ok()
160        });
161
162        let cache_key = node.state_fingerprint;
163        let cloned_cache = cache.clone();
164
165        let fut = async move {
166            // Create future to compute node value (will only be executed if not present in cache)
167            let mut inputs_futures = Vec::new();
168            for input_node_index in input_node_indexes {
169                let node_fut = get_or_compute_node_value(
170                    task_graph.clone(),
171                    input_node_index,
172                    cloned_cache.clone(),
173                    inline_datasets.clone(),
174                    ctx.clone(),
175                );
176
177                cfg_if! {
178                    if #[cfg(target_arch = "wasm32")] {
179                        // Add future directly
180                        inputs_futures.push(node_fut);
181                    } else {
182                        // In non-wasm environment, use tokio::spawn for multi-threading
183                        inputs_futures.push(tokio::spawn(node_fut));
184                    }
185                }
186            }
187
188            let input_values = futures::future::join_all(inputs_futures).await;
189
190            // Extract the appropriate value from
191            let input_values = input_values
192                .into_iter()
193                .zip(input_edges)
194                .map(|(value, edge)| {
195                    cfg_if! {
196                        if #[cfg(target_arch = "wasm32")] {
197                            let mut value = match value {
198                                Ok(value) => value,
199                                Err(join_err) => {
200                                    return Err(join_err)
201                                }
202                            };
203                        } else {
204                            // Convert outer JoinHandle error to internal VegaFusionError so we can propagate it.
205                            let mut value = match value {
206                                Ok(value) => value?,
207                                Err(join_err) => {
208                                    return Err(VegaFusionError::internal(join_err.to_string()))
209                                }
210                            };
211                        }
212                    }
213
214                    let value = match edge.output {
215                        None => value.0,
216                        Some(output_index) => value.1.remove(output_index as usize),
217                    };
218                    Ok(value)
219                })
220                .collect::<Result<Vec<_>>>()?;
221
222            task.eval(&input_values, &tz_config, inline_datasets, ctx)
223                .await
224        };
225
226        // get or construct from cache
227        cache.get_or_try_insert_with(cache_key, fut).await
228    }
229}
230
231pub async fn decode_inline_datasets(
232    inline_pretransform_datasets: Vec<InlineDataset>,
233    ctx: &SessionContext,
234) -> Result<HashMap<String, VegaFusionDataset>> {
235    let mut inline_datasets = HashMap::new();
236    for inline_dataset in inline_pretransform_datasets {
237        let (name, dataset) = match inline_dataset.dataset.as_ref().unwrap() {
238            Dataset::Table(table) => {
239                let dataset = VegaFusionDataset::from_table_ipc_bytes(&table.table)?;
240                (table.name.clone(), dataset)
241            }
242            #[cfg(feature = "proto")]
243            Dataset::Plan(plan) => {
244                let logical_plan = logical_plan_from_bytes(&plan.plan, ctx)?;
245                let dataset = VegaFusionDataset::from_plan(logical_plan);
246                (plan.name.clone(), dataset)
247            }
248            #[cfg(not(feature = "proto"))]
249            Dataset::Plan(_plan) => {
250                return Err(VegaFusionError::internal("proto feature is not enabled"))
251            }
252        };
253        inline_datasets.insert(name, dataset);
254    }
255    Ok(inline_datasets)
256}
257
258pub fn encode_inline_datasets(
259    datasets: &HashMap<String, VegaFusionDataset>,
260) -> Result<Vec<InlineDataset>> {
261    datasets
262        .iter()
263        .map(|(name, dataset)| {
264            let encoded_dataset = match dataset {
265                VegaFusionDataset::Table { table, .. } => InlineDataset {
266                    dataset: Some(Dataset::Table(InlineDatasetTable {
267                        name: name.clone(),
268                        table: table.to_ipc_bytes()?,
269                    })),
270                },
271                #[cfg(feature = "proto")]
272                VegaFusionDataset::Plan { plan } => InlineDataset {
273                    dataset: Some(Dataset::Plan(InlineDatasetPlan {
274                        name: name.clone(),
275                        plan: logical_plan_to_bytes(plan)?.to_vec(),
276                    })),
277                },
278                #[cfg(not(feature = "proto"))]
279                VegaFusionDataset::Plan { .. } => {
280                    return Err(VegaFusionError::internal("proto feature is not enabled"))
281                }
282            };
283            Ok(encoded_dataset)
284        })
285        .collect::<Result<Vec<InlineDataset>>>()
286}