1use crate::{
2 data::dataset::VegaFusionDataset,
3 planning::{
4 apply_pre_transform::apply_pre_transform_datasets,
5 plan::SpecPlan,
6 stitch::CommPlan,
7 watch::{ExportUpdateArrow, ExportUpdateJSON, ExportUpdateNamespace},
8 },
9 proto::gen::{
10 pretransform::PreTransformSpecWarning,
11 tasks::{NodeValueIndex, TaskGraph, TzConfig, Variable, VariableNamespace},
12 },
13 runtime::VegaFusionRuntimeTrait,
14 spec::chart::ChartSpec,
15 task_graph::{graph::ScopedVariable, task_value::TaskValue},
16};
17use datafusion_common::ScalarValue;
18use std::{
19 collections::{HashMap, HashSet},
20 sync::{Arc, Mutex},
21};
22use vegafusion_common::{
23 data::{scalar::ScalarValueHelpers, table::VegaFusionTable},
24 error::{Result, ResultWithContext, VegaFusionError},
25};
26
27#[derive(Clone, Debug)]
28pub struct ChartStateOpts {
29 pub tz_config: TzConfig,
30 pub row_limit: Option<u32>,
31}
32
33impl Default for ChartStateOpts {
34 fn default() -> Self {
35 Self {
36 tz_config: TzConfig {
37 local_tz: "UTC".to_string(),
38 default_input_tz: None,
39 },
40 row_limit: None,
41 }
42 }
43}
44
45#[derive(Clone)]
46pub struct ChartState {
47 input_spec: ChartSpec,
48 transformed_spec: ChartSpec,
49 plan: SpecPlan,
50 inline_datasets: HashMap<String, VegaFusionDataset>,
51 task_graph: Arc<Mutex<TaskGraph>>,
52 task_graph_mapping: Arc<HashMap<ScopedVariable, NodeValueIndex>>,
53 server_to_client_value_indices: Arc<HashSet<NodeValueIndex>>,
54 warnings: Vec<PreTransformSpecWarning>,
55}
56
57impl ChartState {
58 pub async fn try_new(
59 runtime: &dyn VegaFusionRuntimeTrait,
60 spec: ChartSpec,
61 inline_datasets: HashMap<String, VegaFusionDataset>,
62 opts: ChartStateOpts,
63 ) -> Result<Self> {
64 let dataset_fingerprints = inline_datasets
65 .iter()
66 .map(|(k, ds)| (k.clone(), ds.fingerprint()))
67 .collect::<HashMap<_, _>>();
68
69 let plan = SpecPlan::try_new(&spec, &Default::default())?;
70
71 let task_scope = plan
72 .server_spec
73 .to_task_scope()
74 .with_context(|| "Failed to create task scope for server spec")?;
75 let tasks = plan
76 .server_spec
77 .to_tasks(&opts.tz_config, &dataset_fingerprints)
78 .unwrap();
79 let task_graph = TaskGraph::new(tasks, &task_scope).unwrap();
80 let task_graph_mapping = task_graph.build_mapping();
81 let server_to_client_value_indices: Arc<HashSet<_>> = Arc::new(
82 plan.comm_plan
83 .server_to_client
84 .iter()
85 .map(|scoped_var| *task_graph_mapping.get(scoped_var).unwrap())
86 .collect(),
87 );
88
89 let indices: Vec<NodeValueIndex> = plan
91 .comm_plan
92 .server_to_client
93 .iter()
94 .map(|var| *task_graph_mapping.get(var).unwrap())
95 .collect();
96
97 let response_task_values = runtime
98 .query_request(Arc::new(task_graph.clone()), &indices, &inline_datasets)
99 .await?;
100
101 let mut init = Vec::new();
102 for response_value in response_task_values {
103 let variable = response_value.variable;
104
105 let scope = response_value.scope;
106 let value = response_value.value;
107
108 init.push(ExportUpdateArrow {
109 namespace: ExportUpdateNamespace::try_from(variable.ns()).unwrap(),
110 name: variable.name.clone(),
111 scope,
112 value,
113 });
114 }
115
116 let (transformed_spec, warnings) =
117 apply_pre_transform_datasets(&spec, &plan, init, opts.row_limit)?;
118
119 Ok(Self {
120 input_spec: spec,
121 transformed_spec,
122 plan,
123 inline_datasets,
124 task_graph: Arc::new(Mutex::new(task_graph)),
125 task_graph_mapping: Arc::new(task_graph_mapping),
126 server_to_client_value_indices,
127 warnings,
128 })
129 }
130
131 pub async fn update(
132 &self,
133 runtime: &dyn VegaFusionRuntimeTrait,
134 updates: Vec<ExportUpdateJSON>,
135 ) -> Result<Vec<ExportUpdateJSON>> {
136 let (indices, cloned_task_graph) = {
138 let mut task_graph = self.task_graph.lock().map_err(|err| {
139 VegaFusionError::internal(format!("Failed to acquire task graph lock: {:?}", err))
140 })?;
141 let server_to_client = self.server_to_client_value_indices.clone();
142 let mut indices: Vec<NodeValueIndex> = Vec::new();
143
144 for export_update in &updates {
145 let var = match export_update.namespace {
146 ExportUpdateNamespace::Signal => Variable::new_signal(&export_update.name),
147 ExportUpdateNamespace::Data => Variable::new_data(&export_update.name),
148 };
149 let scoped_var: ScopedVariable = (var, export_update.scope.clone());
150 let node_value_index = *self
151 .task_graph_mapping
152 .get(&scoped_var)
153 .with_context(|| format!("No task graph node found for {scoped_var:?}"))?;
154
155 let value = match export_update.namespace {
156 ExportUpdateNamespace::Signal => {
157 TaskValue::Scalar(ScalarValue::from_json(&export_update.value)?)
158 }
159 ExportUpdateNamespace::Data => {
160 TaskValue::Table(VegaFusionTable::from_json(&export_update.value)?)
161 }
162 };
163
164 indices
165 .extend(task_graph.update_value(node_value_index.node_index as usize, value)?);
166 }
167
168 let indices: Vec<_> = indices
170 .iter()
171 .filter(|&node| server_to_client.contains(node))
172 .cloned()
173 .collect();
174
175 let cloned_task_graph = task_graph.clone();
177
178 (indices, cloned_task_graph)
180 }; let response_task_values = runtime
184 .query_request(
185 Arc::new(cloned_task_graph),
186 indices.as_slice(),
187 &self.inline_datasets,
188 )
189 .await?;
190
191 let mut response_updates = response_task_values
192 .into_iter()
193 .map(|response_value| {
194 let variable = response_value.variable;
195 let scope = response_value.scope;
196 let value = response_value.value;
197
198 Ok(ExportUpdateJSON {
199 namespace: match variable.ns() {
200 VariableNamespace::Signal => ExportUpdateNamespace::Signal,
201 VariableNamespace::Data => ExportUpdateNamespace::Data,
202 VariableNamespace::Scale => {
203 return Err(VegaFusionError::internal("Unexpected scale variable"))
204 }
205 },
206 name: variable.name.clone(),
207 scope: scope.clone(),
208 value: value.to_json()?,
209 })
210 })
211 .collect::<Result<Vec<_>>>()?;
212
213 response_updates.sort_by_key(|update| update.name.clone());
215
216 Ok(response_updates)
217 }
218
219 pub fn get_input_spec(&self) -> &ChartSpec {
220 &self.input_spec
221 }
222
223 pub fn get_server_spec(&self) -> &ChartSpec {
224 &self.plan.server_spec
225 }
226
227 pub fn get_client_spec(&self) -> &ChartSpec {
228 &self.plan.client_spec
229 }
230
231 pub fn get_transformed_spec(&self) -> &ChartSpec {
232 &self.transformed_spec
233 }
234
235 pub fn get_comm_plan(&self) -> &CommPlan {
236 &self.plan.comm_plan
237 }
238
239 pub fn get_warnings(&self) -> &Vec<PreTransformSpecWarning> {
240 &self.warnings
241 }
242}