1use serde::{Deserialize, Serialize};
7use std::fmt;
8
9#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
11#[serde(tag = "type", content = "data")]
12#[non_exhaustive]
13pub enum Value {
14 Tensor { values: Vec<f64>, shape: Vec<usize> },
16
17 Json(serde_json::Value),
19
20 Bytes(Vec<u8>),
22
23 Empty,
25}
26
27impl Value {
28 pub fn tensor(values: Vec<f64>, shape: Vec<usize>) -> Self {
29 Self::Tensor { values, shape }
30 }
31
32 pub fn json(val: serde_json::Value) -> Self {
33 Self::Json(val)
34 }
35
36 pub fn bytes(data: Vec<u8>) -> Self {
37 Self::Bytes(data)
38 }
39
40 pub fn is_empty(&self) -> bool {
41 matches!(self, Self::Empty)
42 }
43
44 pub fn as_tensor(&self) -> Option<(&[f64], &[usize])> {
46 match self {
47 Self::Tensor { values, shape } => Some((values, shape)),
48 _ => None,
49 }
50 }
51
52 pub fn as_json(&self) -> Option<&serde_json::Value> {
54 match self {
55 Self::Json(v) => Some(v),
56 _ => None,
57 }
58 }
59
60 pub fn size(&self) -> usize {
62 match self {
63 Self::Tensor { values, .. } => values.len(),
64 Self::Json(v) => v.to_string().len(),
65 Self::Bytes(b) => b.len(),
66 Self::Empty => 0,
67 }
68 }
69}
70
71impl fmt::Display for Value {
72 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
73 match self {
74 Self::Tensor { shape, values } => {
75 write!(f, "Tensor(shape={shape:?}, len={})", values.len())
76 }
77 Self::Json(v) => write!(f, "Json({v})"),
78 Self::Bytes(b) => write!(f, "Bytes(len={})", b.len()),
79 Self::Empty => write!(f, "Empty"),
80 }
81 }
82}
83
84impl From<Vec<f64>> for Value {
85 fn from(values: Vec<f64>) -> Self {
86 let len = values.len();
87 Self::Tensor {
88 values,
89 shape: vec![len],
90 }
91 }
92}
93
94impl From<serde_json::Value> for Value {
95 fn from(v: serde_json::Value) -> Self {
96 Self::Json(v)
97 }
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103 use serde_json::json;
104
105 #[test]
106 fn tensor_creation_and_access() {
107 let v = Value::tensor(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
108 let (data, shape) = v.as_tensor().unwrap();
109 assert_eq!(data, &[1.0, 2.0, 3.0, 4.0]);
110 assert_eq!(shape, &[2, 2]);
111 }
112
113 #[test]
114 fn json_value() {
115 let v = Value::json(json!({"key": "value"}));
116 let j = v.as_json().unwrap();
117 assert_eq!(j["key"], "value");
118 }
119
120 #[test]
121 fn empty_value() {
122 let v = Value::Empty;
123 assert!(v.is_empty());
124 assert_eq!(v.size(), 0);
125 }
126
127 #[test]
128 fn from_vec_f64() {
129 let v: Value = vec![1.0, 2.0, 3.0].into();
130 let (data, shape) = v.as_tensor().unwrap();
131 assert_eq!(data, &[1.0, 2.0, 3.0]);
132 assert_eq!(shape, &[3]);
133 }
134
135 #[test]
136 fn display_formatting() {
137 let t = Value::tensor(vec![1.0, 2.0], vec![2]);
138 assert_eq!(t.to_string(), "Tensor(shape=[2], len=2)");
139
140 let e = Value::Empty;
141 assert_eq!(e.to_string(), "Empty");
142 }
143
144 #[test]
145 fn serde_roundtrip() {
146 let values = vec![
147 Value::tensor(vec![1.0, 2.0, 3.0], vec![3]),
148 Value::json(json!({"a": 1})),
149 Value::bytes(vec![0xDE, 0xAD]),
150 Value::Empty,
151 ];
152
153 for v in values {
154 let serialized = serde_json::to_string(&v).unwrap();
155 let deserialized: Value = serde_json::from_str(&serialized).unwrap();
156 assert_eq!(v, deserialized);
157 }
158 }
159
160 #[test]
161 fn size_returns_correct_values() {
162 assert_eq!(Value::tensor(vec![1.0; 100], vec![10, 10]).size(), 100);
163 assert_eq!(Value::bytes(vec![0; 50]).size(), 50);
164 assert!(Value::json(json!({"key": "val"})).size() > 0);
165 }
166}