1use std::sync::Arc;
9
10use serde::{Deserialize, Deserializer, Serialize};
11use smallvec::SmallVec;
12
13use crate::db_string::DbString;
14use crate::error::{CoreError, CoreResult};
15use crate::extension_type_ids::ExtensionTypeId;
16use crate::identity::{BindingTableId, EdgeId, GraphId, NodeId, RecordTypeId};
17use crate::json_value::JsonValue;
18
19pub const MAX_VECTOR_DIMENSION: usize = u16::MAX as usize;
24
25#[derive(Clone, Debug, Deserialize, Serialize)]
38#[non_exhaustive]
39pub enum Value {
40 Bool(bool),
42 Int(i64),
44 Uint(u64),
46 Int128(#[serde(with = "serde_i128_le")] i128),
48 Uint128(#[serde(with = "serde_u128_le")] u128),
50 Float(f64),
52 Float32(f32),
54 Decimal(#[serde(with = "serde_decimal_str")] rust_decimal::Decimal),
56 String(DbString),
58 Bytes(Arc<[u8]>),
60 List(Vec<Value>),
62 Record(Box<Record>),
64 RecordTyped(Box<RecordTyped>),
66 Path(Box<Path>),
72 NodeRef(NodeId),
74 EdgeRef(EdgeId),
76 GraphRef(GraphId),
78 TableRef(BindingTableId),
80 ZonedDateTime(Box<jiff::Zoned>),
85 LocalDateTime(jiff::civil::DateTime),
87 Date(jiff::civil::Date),
89 ZonedTime(Box<jiff::Zoned>),
96 LocalTime(jiff::civil::Time),
98 Duration(Box<jiff::Span>),
103 Extended {
105 type_id: ExtensionTypeId,
107 payload: Arc<[u8]>,
109 },
110 Null,
112 Uuid(uuid::Uuid),
114 Vector(VectorValue),
116 Json(JsonValue),
118}
119
120const _: () = assert!(core::mem::size_of::<Value>() <= 32);
134
135impl Value {
136 pub const ALL: &[fn() -> Self] = &[
142 || Self::Bool(false),
143 || Self::Int(0),
144 || Self::Uint(0),
145 || Self::Int128(0),
146 || Self::Uint128(0),
147 || Self::Float(0.0),
148 || Self::Float32(0.0),
149 || Self::Decimal(rust_decimal::Decimal::ZERO),
150 || Self::String(value_variant_string("value.all.string")),
151 || Self::Bytes(Arc::from([0_u8])),
152 || Self::List(Vec::new()),
153 || Self::Record(Box::new(Record::Open(SmallVec::new()))),
154 || {
155 Self::RecordTyped(Box::new(RecordTyped {
156 type_id: RecordTypeId::new(1),
157 values: SmallVec::new(),
158 }))
159 },
160 || {
161 Self::Path(Box::new(Path {
162 graph: GraphId::new(1),
163 start: NodeId::new(1),
164 segments: SmallVec::new(),
165 }))
166 },
167 || Self::NodeRef(NodeId::new(1)),
168 || Self::EdgeRef(EdgeId::new(1)),
169 || Self::GraphRef(GraphId::new(1)),
170 || Self::TableRef(BindingTableId::new(1)),
171 || Self::ZonedDateTime(Box::new(value_variant_zoned())),
172 || Self::LocalDateTime("2024-01-01T00:00:00".parse().unwrap()),
173 || Self::Date("2024-01-01".parse().unwrap()),
174 || Self::ZonedTime(Box::new(value_variant_zoned())),
175 || Self::LocalTime("00:00:00".parse().unwrap()),
176 || Self::Duration(Box::new("PT1S".parse().unwrap())),
177 || Self::Extended {
178 type_id: ExtensionTypeId::FIRST_PARTY_MIN,
179 payload: Arc::from([0_u8]),
180 },
181 || Self::Null,
182 || Self::Uuid(uuid::Uuid::nil()),
183 || Self::Vector(VectorValue::new(vec![0.0]).expect("fixture vector is valid")),
184 || Self::Json(JsonValue::new(serde_json::json!({"fixture": true})).unwrap()),
185 ];
186
187 pub const VARIANT_COUNT: usize = Self::ALL.len();
189
190 #[must_use]
195 pub fn variant_name(&self) -> &'static str {
196 match self {
197 Self::Bool(_) => "Bool",
198 Self::Int(_) => "Int",
199 Self::Uint(_) => "Uint",
200 Self::Int128(_) => "Int128",
201 Self::Uint128(_) => "Uint128",
202 Self::Float(_) => "Float",
203 Self::Float32(_) => "Float32",
204 Self::Decimal(_) => "Decimal",
205 Self::String(_) => "String",
206 Self::Bytes(_) => "Bytes",
207 Self::List(_) => "List",
208 Self::Record(_) => "Record",
209 Self::RecordTyped(_) => "RecordTyped",
210 Self::Path(_) => "Path",
211 Self::NodeRef(_) => "NodeRef",
212 Self::EdgeRef(_) => "EdgeRef",
213 Self::GraphRef(_) => "GraphRef",
214 Self::TableRef(_) => "TableRef",
215 Self::ZonedDateTime(_) => "ZonedDateTime",
216 Self::LocalDateTime(_) => "LocalDateTime",
217 Self::Date(_) => "Date",
218 Self::ZonedTime(_) => "ZonedTime",
219 Self::LocalTime(_) => "LocalTime",
220 Self::Duration(_) => "Duration",
221 Self::Extended { .. } => "Extended",
222 Self::Null => "Null",
223 Self::Uuid(_) => "Uuid",
224 Self::Vector(_) => "Vector",
225 Self::Json(_) => "Json",
226 }
227 }
228}
229
230fn value_variant_string(name: &str) -> DbString {
231 crate::db_string(name).expect("Value::ALL fixture strings fit DB string cap")
232}
233
234fn value_variant_zoned() -> jiff::Zoned {
235 jiff::Timestamp::new(0, 0)
236 .expect("Value::ALL timestamp fixture is in range")
237 .to_zoned(jiff::tz::TimeZone::UTC)
238}
239
240impl PartialEq for Value {
241 fn eq(&self, rhs: &Self) -> bool {
242 match (self, rhs) {
243 (Self::Bool(lhs), Self::Bool(rhs)) => lhs == rhs,
244 (Self::Int(lhs), Self::Int(rhs)) => lhs == rhs,
245 (Self::Uint(lhs), Self::Uint(rhs)) => lhs == rhs,
246 (Self::Int128(lhs), Self::Int128(rhs)) => lhs == rhs,
247 (Self::Uint128(lhs), Self::Uint128(rhs)) => lhs == rhs,
248 (Self::Float(lhs), Self::Float(rhs)) => lhs == rhs || (lhs.is_nan() && rhs.is_nan()),
249 (Self::Float32(lhs), Self::Float32(rhs)) => {
250 lhs == rhs || (lhs.is_nan() && rhs.is_nan())
251 }
252 (Self::Decimal(lhs), Self::Decimal(rhs)) => lhs == rhs,
253 (Self::String(lhs), Self::String(rhs)) => lhs == rhs,
254 (Self::Bytes(lhs), Self::Bytes(rhs)) => lhs == rhs,
255 (Self::List(lhs), Self::List(rhs)) => lhs == rhs,
256 (Self::Record(lhs), Self::Record(rhs)) => lhs == rhs,
257 (Self::RecordTyped(lhs), Self::RecordTyped(rhs)) => lhs == rhs,
258 (Self::Path(lhs), Self::Path(rhs)) => lhs == rhs,
259 (Self::NodeRef(lhs), Self::NodeRef(rhs)) => lhs == rhs,
260 (Self::EdgeRef(lhs), Self::EdgeRef(rhs)) => lhs == rhs,
261 (Self::GraphRef(lhs), Self::GraphRef(rhs)) => lhs == rhs,
262 (Self::TableRef(lhs), Self::TableRef(rhs)) => lhs == rhs,
263 (Self::ZonedDateTime(lhs), Self::ZonedDateTime(rhs)) => lhs == rhs,
264 (Self::LocalDateTime(lhs), Self::LocalDateTime(rhs)) => lhs == rhs,
265 (Self::Date(lhs), Self::Date(rhs)) => lhs == rhs,
266 (Self::ZonedTime(lhs), Self::ZonedTime(rhs)) => lhs == rhs,
267 (Self::LocalTime(lhs), Self::LocalTime(rhs)) => lhs == rhs,
268 (Self::Duration(lhs), Self::Duration(rhs)) => lhs.fieldwise() == rhs.fieldwise(),
269 (
270 Self::Extended {
271 type_id: lhs_type_id,
272 payload: lhs_payload,
273 },
274 Self::Extended {
275 type_id: rhs_type_id,
276 payload: rhs_payload,
277 },
278 ) => lhs_type_id == rhs_type_id && lhs_payload == rhs_payload,
279 (Self::Null, Self::Null) => true,
280 (Self::Uuid(lhs), Self::Uuid(rhs)) => lhs == rhs,
281 (Self::Vector(lhs), Self::Vector(rhs)) => lhs == rhs,
282 (Self::Json(lhs), Self::Json(rhs)) => lhs == rhs,
283 _ => false,
284 }
285 }
286}
287
288#[derive(Clone, Debug, PartialEq, Serialize)]
295#[serde(transparent)]
296pub struct VectorValue {
297 components: Arc<[f32]>,
298}
299
300impl VectorValue {
301 pub fn new(components: impl Into<Vec<f32>>) -> CoreResult<Self> {
303 let components = components.into();
304 if components.is_empty() {
305 return Err(CoreError::VectorEmpty);
306 }
307 if components.len() > MAX_VECTOR_DIMENSION {
308 return Err(CoreError::VectorTooLarge {
309 got: components.len(),
310 max: MAX_VECTOR_DIMENSION,
311 });
312 }
313 for (index, value) in components.iter().copied().enumerate() {
314 if !value.is_finite() {
315 return Err(CoreError::VectorComponentNotFinite { index, value });
316 }
317 }
318 Ok(Self {
319 components: Arc::from(components),
320 })
321 }
322
323 #[must_use]
325 pub fn dimension(&self) -> usize {
326 self.components.len()
327 }
328
329 #[must_use]
331 pub fn as_slice(&self) -> &[f32] {
332 &self.components
333 }
334
335 #[must_use]
337 pub fn as_arc(&self) -> Arc<[f32]> {
338 Arc::clone(&self.components)
339 }
340}
341
342impl TryFrom<Vec<f32>> for VectorValue {
343 type Error = CoreError;
344
345 fn try_from(value: Vec<f32>) -> Result<Self, Self::Error> {
346 Self::new(value)
347 }
348}
349
350impl From<VectorValue> for Vec<f32> {
351 fn from(value: VectorValue) -> Self {
352 value.components.as_ref().to_vec()
353 }
354}
355
356impl<'de> Deserialize<'de> for VectorValue {
357 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
358 where
359 D: Deserializer<'de>,
360 {
361 Vec::<f32>::deserialize(deserializer)
362 .and_then(|components| Self::new(components).map_err(serde::de::Error::custom))
363 }
364}
365
366#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
368#[non_exhaustive]
369pub enum Record {
370 Open(SmallVec<[(DbString, Value); 4]>),
372}
373
374#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
376pub struct RecordTyped {
377 pub type_id: RecordTypeId,
379 pub values: SmallVec<[Option<Value>; 4]>,
381}
382
383#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
385pub struct Path {
386 pub graph: GraphId,
388 pub start: NodeId,
390 pub segments: SmallVec<[PathSegment; 4]>,
392}
393
394#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
396pub struct PathSegment {
397 pub edge: EdgeId,
399 pub direction: EdgeDirection,
401 pub node: NodeId,
403}
404
405#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
407pub enum EdgeDirection {
408 Outgoing,
410 Incoming,
412 Undirected,
414}
415
416mod serde_i128_le {
417 use serde::{Deserialize, Deserializer, Serialize, Serializer};
418
419 pub(super) fn serialize<S>(value: &i128, serializer: S) -> Result<S::Ok, S::Error>
420 where
421 S: Serializer,
422 {
423 value.to_le_bytes().serialize(serializer)
424 }
425
426 pub(super) fn deserialize<'de, D>(deserializer: D) -> Result<i128, D::Error>
427 where
428 D: Deserializer<'de>,
429 {
430 <[u8; 16]>::deserialize(deserializer).map(i128::from_le_bytes)
431 }
432}
433
434mod serde_u128_le {
435 use serde::{Deserialize, Deserializer, Serialize, Serializer};
436
437 pub(super) fn serialize<S>(value: &u128, serializer: S) -> Result<S::Ok, S::Error>
438 where
439 S: Serializer,
440 {
441 value.to_le_bytes().serialize(serializer)
442 }
443
444 pub(super) fn deserialize<'de, D>(deserializer: D) -> Result<u128, D::Error>
445 where
446 D: Deserializer<'de>,
447 {
448 <[u8; 16]>::deserialize(deserializer).map(u128::from_le_bytes)
449 }
450}
451
452mod serde_decimal_str {
453 use std::str::FromStr;
454
455 use serde::{Deserialize, Deserializer, Serializer};
456
457 pub(super) fn serialize<S>(
458 value: &rust_decimal::Decimal,
459 serializer: S,
460 ) -> Result<S::Ok, S::Error>
461 where
462 S: Serializer,
463 {
464 serializer.serialize_str(&value.to_string())
465 }
466
467 pub(super) fn deserialize<'de, D>(deserializer: D) -> Result<rust_decimal::Decimal, D::Error>
468 where
469 D: Deserializer<'de>,
470 {
471 let value = String::deserialize(deserializer)?;
472 rust_decimal::Decimal::from_str(&value).map_err(serde::de::Error::custom)
473 }
474}
475
476#[cfg(test)]
477mod tests;