vegafusion_core/
chart_state.rs

1use crate::{
2    data::dataset::VegaFusionDataset,
3    planning::{
4        apply_pre_transform::apply_pre_transform_datasets,
5        plan::SpecPlan,
6        stitch::CommPlan,
7        watch::{ExportUpdateArrow, ExportUpdateJSON, ExportUpdateNamespace},
8    },
9    proto::gen::{
10        pretransform::PreTransformSpecWarning,
11        tasks::{NodeValueIndex, TaskGraph, TzConfig, Variable, VariableNamespace},
12    },
13    runtime::VegaFusionRuntimeTrait,
14    spec::chart::ChartSpec,
15    task_graph::{graph::ScopedVariable, task_value::TaskValue},
16};
17use datafusion_common::ScalarValue;
18use std::{
19    collections::{HashMap, HashSet},
20    sync::{Arc, Mutex},
21};
22use vegafusion_common::{
23    data::{scalar::ScalarValueHelpers, table::VegaFusionTable},
24    error::{Result, ResultWithContext, VegaFusionError},
25};
26
27#[derive(Clone, Debug)]
28pub struct ChartStateOpts {
29    pub tz_config: TzConfig,
30    pub row_limit: Option<u32>,
31}
32
33impl Default for ChartStateOpts {
34    fn default() -> Self {
35        Self {
36            tz_config: TzConfig {
37                local_tz: "UTC".to_string(),
38                default_input_tz: None,
39            },
40            row_limit: None,
41        }
42    }
43}
44
45#[derive(Clone)]
46pub struct ChartState {
47    input_spec: ChartSpec,
48    transformed_spec: ChartSpec,
49    plan: SpecPlan,
50    inline_datasets: HashMap<String, VegaFusionDataset>,
51    task_graph: Arc<Mutex<TaskGraph>>,
52    task_graph_mapping: Arc<HashMap<ScopedVariable, NodeValueIndex>>,
53    server_to_client_value_indices: Arc<HashSet<NodeValueIndex>>,
54    warnings: Vec<PreTransformSpecWarning>,
55}
56
57impl ChartState {
58    pub async fn try_new(
59        runtime: &dyn VegaFusionRuntimeTrait,
60        spec: ChartSpec,
61        inline_datasets: HashMap<String, VegaFusionDataset>,
62        opts: ChartStateOpts,
63    ) -> Result<Self> {
64        let dataset_fingerprints = inline_datasets
65            .iter()
66            .map(|(k, ds)| (k.clone(), ds.fingerprint()))
67            .collect::<HashMap<_, _>>();
68
69        let plan = SpecPlan::try_new(&spec, &Default::default())?;
70
71        let task_scope = plan
72            .server_spec
73            .to_task_scope()
74            .with_context(|| "Failed to create task scope for server spec")?;
75        let tasks = plan
76            .server_spec
77            .to_tasks(&opts.tz_config, &dataset_fingerprints)
78            .unwrap();
79        let task_graph = TaskGraph::new(tasks, &task_scope).unwrap();
80        let task_graph_mapping = task_graph.build_mapping();
81        let server_to_client_value_indices: Arc<HashSet<_>> = Arc::new(
82            plan.comm_plan
83                .server_to_client
84                .iter()
85                .map(|scoped_var| *task_graph_mapping.get(scoped_var).unwrap())
86                .collect(),
87        );
88
89        // Gather values of server-to-client values using query_request
90        let indices: Vec<NodeValueIndex> = plan
91            .comm_plan
92            .server_to_client
93            .iter()
94            .map(|var| *task_graph_mapping.get(var).unwrap())
95            .collect();
96
97        let response_task_values = runtime
98            .query_request(Arc::new(task_graph.clone()), &indices, &inline_datasets)
99            .await?;
100
101        let mut init = Vec::new();
102        for response_value in response_task_values {
103            let variable = response_value.variable;
104
105            let scope = response_value.scope;
106            let value = response_value.value;
107
108            init.push(ExportUpdateArrow {
109                namespace: ExportUpdateNamespace::try_from(variable.ns()).unwrap(),
110                name: variable.name.clone(),
111                scope,
112                value,
113            });
114        }
115
116        let (transformed_spec, warnings) =
117            apply_pre_transform_datasets(&spec, &plan, init, opts.row_limit)?;
118
119        Ok(Self {
120            input_spec: spec,
121            transformed_spec,
122            plan,
123            inline_datasets,
124            task_graph: Arc::new(Mutex::new(task_graph)),
125            task_graph_mapping: Arc::new(task_graph_mapping),
126            server_to_client_value_indices,
127            warnings,
128        })
129    }
130
131    pub async fn update(
132        &self,
133        runtime: &dyn VegaFusionRuntimeTrait,
134        updates: Vec<ExportUpdateJSON>,
135    ) -> Result<Vec<ExportUpdateJSON>> {
136        // Scope the mutex guard to ensure it's dropped before the async call
137        let (indices, cloned_task_graph) = {
138            let mut task_graph = self.task_graph.lock().map_err(|err| {
139                VegaFusionError::internal(format!("Failed to acquire task graph lock: {:?}", err))
140            })?;
141            let server_to_client = self.server_to_client_value_indices.clone();
142            let mut indices: Vec<NodeValueIndex> = Vec::new();
143
144            for export_update in &updates {
145                let var = match export_update.namespace {
146                    ExportUpdateNamespace::Signal => Variable::new_signal(&export_update.name),
147                    ExportUpdateNamespace::Data => Variable::new_data(&export_update.name),
148                };
149                let scoped_var: ScopedVariable = (var, export_update.scope.clone());
150                let node_value_index = *self
151                    .task_graph_mapping
152                    .get(&scoped_var)
153                    .with_context(|| format!("No task graph node found for {scoped_var:?}"))?;
154
155                let value = match export_update.namespace {
156                    ExportUpdateNamespace::Signal => {
157                        TaskValue::Scalar(ScalarValue::from_json(&export_update.value)?)
158                    }
159                    ExportUpdateNamespace::Data => {
160                        TaskValue::Table(VegaFusionTable::from_json(&export_update.value)?)
161                    }
162                };
163
164                indices
165                    .extend(task_graph.update_value(node_value_index.node_index as usize, value)?);
166            }
167
168            // Filter to update nodes in the comm plan
169            let indices: Vec<_> = indices
170                .iter()
171                .filter(|&node| server_to_client.contains(node))
172                .cloned()
173                .collect();
174
175            // Clone the task graph while we still have the lock
176            let cloned_task_graph = task_graph.clone();
177
178            // Return both values we need
179            (indices, cloned_task_graph)
180        }; // MutexGuard is dropped here
181
182        // Now we can safely make the async call
183        let response_task_values = runtime
184            .query_request(
185                Arc::new(cloned_task_graph),
186                indices.as_slice(),
187                &self.inline_datasets,
188            )
189            .await?;
190
191        let mut response_updates = response_task_values
192            .into_iter()
193            .map(|response_value| {
194                let variable = response_value.variable;
195                let scope = response_value.scope;
196                let value = response_value.value;
197
198                Ok(ExportUpdateJSON {
199                    namespace: match variable.ns() {
200                        VariableNamespace::Signal => ExportUpdateNamespace::Signal,
201                        VariableNamespace::Data => ExportUpdateNamespace::Data,
202                        VariableNamespace::Scale => {
203                            return Err(VegaFusionError::internal("Unexpected scale variable"))
204                        }
205                    },
206                    name: variable.name.clone(),
207                    scope: scope.clone(),
208                    value: value.to_json()?,
209                })
210            })
211            .collect::<Result<Vec<_>>>()?;
212
213        // Sort for deterministic ordering
214        response_updates.sort_by_key(|update| update.name.clone());
215
216        Ok(response_updates)
217    }
218
219    pub fn get_input_spec(&self) -> &ChartSpec {
220        &self.input_spec
221    }
222
223    pub fn get_server_spec(&self) -> &ChartSpec {
224        &self.plan.server_spec
225    }
226
227    pub fn get_client_spec(&self) -> &ChartSpec {
228        &self.plan.client_spec
229    }
230
231    pub fn get_transformed_spec(&self) -> &ChartSpec {
232        &self.transformed_spec
233    }
234
235    pub fn get_comm_plan(&self) -> &CommPlan {
236        &self.plan.comm_plan
237    }
238
239    pub fn get_warnings(&self) -> &Vec<PreTransformSpecWarning> {
240        &self.warnings
241    }
242}