vegafusion_core/task_graph/
task_value.rs1use crate::proto::gen::tasks::task_value::Data;
2use crate::proto::gen::tasks::ResponseTaskValue;
3use crate::proto::gen::tasks::{TaskGraphValueResponse, TaskValue as ProtoTaskValue, Variable};
4use crate::task_graph::memory::{inner_size_of_scalar, inner_size_of_table};
5use datafusion_common::ScalarValue;
6use serde_json::Value;
7use std::convert::TryFrom;
8use vegafusion_common::arrow::record_batch::RecordBatch;
9use vegafusion_common::data::scalar::ScalarValueHelpers;
10use vegafusion_common::data::table::VegaFusionTable;
11use vegafusion_common::error::{Result, ResultWithContext, VegaFusionError};
12
13#[derive(Debug, Clone)]
14pub enum TaskValue {
15 Scalar(ScalarValue),
16 Table(VegaFusionTable),
17}
18
19impl TaskValue {
20 pub fn as_scalar(&self) -> Result<&ScalarValue> {
21 match self {
22 TaskValue::Scalar(value) => Ok(value),
23 _ => Err(VegaFusionError::internal("Value is not a scalar")),
24 }
25 }
26
27 pub fn as_table(&self) -> Result<&VegaFusionTable> {
28 match self {
29 TaskValue::Table(value) => Ok(value),
30 _ => Err(VegaFusionError::internal("Value is not a table")),
31 }
32 }
33
34 pub fn to_json(&self) -> Result<Value> {
35 match self {
36 TaskValue::Scalar(value) => value.to_json(),
37 TaskValue::Table(value) => Ok(value.to_json()?),
38 }
39 }
40
41 pub fn size_of(&self) -> usize {
42 let inner_size = match self {
43 TaskValue::Scalar(scalar) => inner_size_of_scalar(scalar),
44 TaskValue::Table(table) => inner_size_of_table(table),
45 };
46
47 std::mem::size_of::<Self>() + inner_size
48 }
49}
50
51impl TryFrom<&ProtoTaskValue> for TaskValue {
52 type Error = VegaFusionError;
53
54 fn try_from(value: &ProtoTaskValue) -> std::result::Result<Self, Self::Error> {
55 match value.data.as_ref().unwrap() {
56 Data::Table(value) => Ok(Self::Table(VegaFusionTable::from_ipc_bytes(value)?)),
57 Data::Scalar(value) => {
58 let scalar_table = VegaFusionTable::from_ipc_bytes(value)?;
59 let scalar_rb = scalar_table.to_record_batch()?;
60 let scalar_array = scalar_rb.column(0);
61 let scalar = ScalarValue::try_from_array(scalar_array, 0)?;
62 Ok(Self::Scalar(scalar))
63 }
64 }
65 }
66}
67
68impl TryFrom<&TaskValue> for ProtoTaskValue {
69 type Error = VegaFusionError;
70
71 fn try_from(value: &TaskValue) -> std::result::Result<Self, Self::Error> {
72 match value {
73 TaskValue::Scalar(scalar) => {
74 let scalar_array = scalar.to_array()?;
75 let scalar_rb = RecordBatch::try_from_iter(vec![("value", scalar_array)])?;
76 let ipc_bytes = VegaFusionTable::from(scalar_rb).to_ipc_bytes()?;
77 Ok(Self {
78 data: Some(Data::Scalar(ipc_bytes)),
79 })
80 }
81 TaskValue::Table(table) => Ok(Self {
82 data: Some(Data::Table(table.to_ipc_bytes()?)),
83 }),
84 }
85 }
86}
87
88impl TaskGraphValueResponse {
89 pub fn deserialize(self) -> Result<Vec<(Variable, Vec<u32>, TaskValue)>> {
90 self.response_values
91 .into_iter()
92 .map(|response_value| {
93 let variable = response_value
94 .variable
95 .with_context(|| "Unwrap failed for variable of response value".to_string())?;
96
97 let scope = response_value.scope;
98 let proto_value = response_value.value.with_context(|| {
99 "Unwrap failed for value of response value: {:?}".to_string()
100 })?;
101
102 let value = TaskValue::try_from(&proto_value).with_context(|| {
103 "Deserialization failed for value of response value: {:?}".to_string()
104 })?;
105
106 Ok((variable, scope, value))
107 })
108 .collect::<Result<Vec<_>>>()
109 }
110}
111
112#[derive(Debug, Clone)]
113pub struct NamedTaskValue {
114 pub variable: Variable,
115 pub scope: Vec<u32>,
116 pub value: TaskValue,
117}
118
119impl From<NamedTaskValue> for ResponseTaskValue {
120 fn from(value: NamedTaskValue) -> Self {
121 ResponseTaskValue {
122 variable: Some(value.variable),
123 scope: value.scope,
124 value: Some(ProtoTaskValue::try_from(&value.value).unwrap()),
125 }
126 }
127}
128
129impl From<ResponseTaskValue> for NamedTaskValue {
130 fn from(value: ResponseTaskValue) -> Self {
131 NamedTaskValue {
132 variable: value.variable.unwrap(),
133 scope: value.scope,
134 value: TaskValue::try_from(&value.value.unwrap()).unwrap(),
135 }
136 }
137}