Skip to main content

ruvector_graph/
schema.rs

1//! Optional, schema-first type layer for the graph (HelixDB-inspired, ADR-252 P1/P2).
2//!
3//! RuVector's graph is schemaless by default and its Cypher engine is interpreted
4//! at runtime. This module adds an **opt-in** schema that catches type errors
5//! *before* execution — declared node labels, typed edges with `from`/`to` label
6//! constraints, indexed properties, and **vector types bound to a node label +
7//! property** (so a vector hit can be traversed back into the graph as a
8//! first-class, validated relationship rather than a runtime string + property
9//! name).
10//!
11//! The module is pure-Rust with no storage/HNSW dependency, so it compiles for
12//! WASM. It coexists with schemaless mode: only declared labels/edges are checked,
13//! and undeclared ones pass through untouched.
14
15use crate::edge::Edge;
16use crate::error::{GraphError, Result};
17use crate::node::Node;
18use crate::types::PropertyValue;
19use serde::{Deserialize, Serialize};
20use std::collections::HashMap;
21
22/// Declared type of a node/edge property.
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
24pub enum PropertyType {
25    Boolean,
26    Integer,
27    /// Accepts `Float` and (widening) `Integer`.
28    Float,
29    String,
30    /// Dense embedding (`FloatArray`, or a homogeneous numeric `Array`/`List`).
31    Vector,
32    /// Heterogeneous list.
33    Array,
34    Map,
35    /// Accepts any value (escape hatch).
36    Any,
37}
38
39impl PropertyType {
40    /// Does `value` satisfy this declared type?
41    pub fn accepts(&self, value: &PropertyValue) -> bool {
42        match self {
43            PropertyType::Any => true,
44            PropertyType::Boolean => matches!(value, PropertyValue::Boolean(_)),
45            PropertyType::Integer => matches!(value, PropertyValue::Integer(_)),
46            // Float is permissive: an integer literal is a valid float.
47            PropertyType::Float => {
48                matches!(value, PropertyValue::Float(_) | PropertyValue::Integer(_))
49            }
50            PropertyType::String => matches!(value, PropertyValue::String(_)),
51            PropertyType::Vector => extract_vector(value).is_some(),
52            PropertyType::Array => {
53                matches!(value, PropertyValue::Array(_) | PropertyValue::List(_))
54            }
55            PropertyType::Map => matches!(value, PropertyValue::Map(_)),
56        }
57    }
58}
59
60/// Distance metric for a vector type. Search always ranks by a *higher-is-better*
61/// score, so `Euclidean` is surfaced as the negated distance.
62#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
63pub enum DistanceMetric {
64    Cosine,
65    DotProduct,
66    Euclidean,
67}
68
69impl DistanceMetric {
70    /// Higher score == more similar, for any metric. Convenience wrapper that
71    /// computes the query's norm inline; prefer [`DistanceMetric::query_norm`] +
72    /// [`DistanceMetric::score_pre`] in a scan loop to amortize the query norm.
73    pub fn score(&self, a: &[f32], b: &[f32]) -> f32 {
74        self.score_pre(a, b, self.query_norm(a))
75    }
76
77    /// Precompute the query-side norm once per query. Only `Cosine` needs it;
78    /// the others return `1.0`.
79    #[inline]
80    pub fn query_norm(&self, q: &[f32]) -> f32 {
81        match self {
82            DistanceMetric::Cosine => dot(q, q).sqrt(),
83            _ => 1.0,
84        }
85    }
86
87    /// Score `candidate` against `query`, reusing a precomputed `query_norm`.
88    /// Hoists the query norm out of the per-candidate hot loop.
89    #[inline]
90    pub fn score_pre(&self, query: &[f32], candidate: &[f32], query_norm: f32) -> f32 {
91        match self {
92            DistanceMetric::DotProduct => dot(query, candidate),
93            DistanceMetric::Cosine => {
94                // Single fused pass: accumulate q·c and c·c together so the
95                // candidate slice is read once (half the memory traffic of two
96                // separate `dot` calls).
97                let n = query.len().min(candidate.len());
98                let mut qc = 0.0f32;
99                let mut cc = 0.0f32;
100                for i in 0..n {
101                    let c = candidate[i];
102                    qc += query[i] * c;
103                    cc += c * c;
104                }
105                let cn = cc.sqrt();
106                if query_norm == 0.0 || cn == 0.0 {
107                    0.0
108                } else {
109                    qc / (query_norm * cn)
110                }
111            }
112            DistanceMetric::Euclidean => {
113                let n = query.len().min(candidate.len());
114                let mut sum = 0.0f32;
115                for i in 0..n {
116                    let d = query[i] - candidate[i];
117                    sum += d * d;
118                }
119                -sum.sqrt()
120            }
121        }
122    }
123}
124
125/// Score a vector-shaped property against a query without allocating in the
126/// common `FloatArray` case (zero-copy slice scoring). Returns `None` if the
127/// property is not vector-shaped or its dimension does not match the query.
128#[inline]
129pub fn score_property(
130    metric: DistanceMetric,
131    query: &[f32],
132    query_norm: f32,
133    value: &PropertyValue,
134) -> Option<f32> {
135    match value {
136        // Fast path: borrow the stored slice directly, no clone.
137        PropertyValue::FloatArray(v) => {
138            if v.len() == query.len() {
139                Some(metric.score_pre(query, v, query_norm))
140            } else {
141                None
142            }
143        }
144        // Slow path: heterogeneous numeric list must be materialized.
145        PropertyValue::Array(_) | PropertyValue::List(_) => {
146            let v = extract_vector(value)?;
147            if v.len() == query.len() {
148                Some(metric.score_pre(query, &v, query_norm))
149            } else {
150                None
151            }
152        }
153        _ => None,
154    }
155}
156
157#[inline]
158fn dot(a: &[f32], b: &[f32]) -> f32 {
159    // Iterator form so LLVM auto-vectorizes (SSE/AVX/NEON) without bounds checks.
160    // SIMD via `simsimd`/ruvector-core is a follow-up (ADR-252 P5) but is
161    // deliberately not a hard dependency here so the schema layer stays WASM- and
162    // no-feature-build-safe.
163    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
164}
165
166/// Coerce a property value into a dense `Vec<f32>` if it is vector-shaped.
167pub fn extract_vector(value: &PropertyValue) -> Option<Vec<f32>> {
168    match value {
169        PropertyValue::FloatArray(v) => Some(v.clone()),
170        PropertyValue::Array(items) | PropertyValue::List(items) => {
171            let mut out = Vec::with_capacity(items.len());
172            for it in items {
173                match it {
174                    PropertyValue::Float(f) => out.push(*f as f32),
175                    PropertyValue::Integer(i) => out.push(*i as f32),
176                    _ => return None,
177                }
178            }
179            if out.is_empty() {
180                None
181            } else {
182                Some(out)
183            }
184        }
185        _ => None,
186    }
187}
188
189/// Declaration for a single property.
190#[derive(Debug, Clone, Serialize, Deserialize)]
191pub struct PropertySchema {
192    pub name: String,
193    pub ptype: PropertyType,
194    /// Must be present on every instance.
195    pub required: bool,
196    /// Hint that this property is secondary-indexed (HelixQL `INDEX`).
197    pub indexed: bool,
198}
199
200impl PropertySchema {
201    pub fn new(name: impl Into<String>, ptype: PropertyType) -> Self {
202        Self {
203            name: name.into(),
204            ptype,
205            required: false,
206            indexed: false,
207        }
208    }
209    pub fn required(mut self) -> Self {
210        self.required = true;
211        self
212    }
213    pub fn indexed(mut self) -> Self {
214        self.indexed = true;
215        self
216    }
217}
218
219/// Schema for a node label (`N::` in HelixQL).
220#[derive(Debug, Clone, Serialize, Deserialize)]
221pub struct NodeSchema {
222    pub label: String,
223    pub properties: Vec<PropertySchema>,
224    /// If true, properties not declared here are rejected.
225    pub strict: bool,
226}
227
228impl NodeSchema {
229    pub fn new(label: impl Into<String>) -> Self {
230        Self {
231            label: label.into(),
232            properties: Vec::new(),
233            strict: false,
234        }
235    }
236    pub fn property(mut self, p: PropertySchema) -> Self {
237        self.properties.push(p);
238        self
239    }
240    pub fn strict(mut self) -> Self {
241        self.strict = true;
242        self
243    }
244}
245
246/// Schema for an edge type (`E::` in HelixQL) with `from`/`to` label constraints.
247#[derive(Debug, Clone, Serialize, Deserialize)]
248pub struct EdgeSchema {
249    pub edge_type: String,
250    pub from_label: String,
251    pub to_label: String,
252    pub properties: Vec<PropertySchema>,
253}
254
255impl EdgeSchema {
256    pub fn new(
257        edge_type: impl Into<String>,
258        from_label: impl Into<String>,
259        to_label: impl Into<String>,
260    ) -> Self {
261        Self {
262            edge_type: edge_type.into(),
263            from_label: from_label.into(),
264            to_label: to_label.into(),
265            properties: Vec::new(),
266        }
267    }
268    pub fn property(mut self, p: PropertySchema) -> Self {
269        self.properties.push(p);
270        self
271    }
272}
273
274/// Schema for a vector type (`V::` in HelixQL), bound to a node label + property.
275#[derive(Debug, Clone, Serialize, Deserialize)]
276pub struct VectorSchema {
277    /// Vector type name (referenced by `search_then_traverse`).
278    pub name: String,
279    /// Node label whose instances carry this embedding.
280    pub label: String,
281    /// Property holding the embedding.
282    pub property: String,
283    pub dimensions: usize,
284    pub metric: DistanceMetric,
285}
286
287impl VectorSchema {
288    pub fn new(
289        name: impl Into<String>,
290        label: impl Into<String>,
291        property: impl Into<String>,
292        dimensions: usize,
293        metric: DistanceMetric,
294    ) -> Self {
295        Self {
296            name: name.into(),
297            label: label.into(),
298            property: property.into(),
299            dimensions,
300            metric,
301        }
302    }
303}
304
305/// A complete, optional graph schema.
306#[derive(Debug, Clone, Default, Serialize, Deserialize)]
307pub struct GraphSchema {
308    nodes: HashMap<String, NodeSchema>,
309    edges: HashMap<String, EdgeSchema>,
310    vectors: HashMap<String, VectorSchema>,
311}
312
313impl GraphSchema {
314    pub fn new() -> Self {
315        Self::default()
316    }
317
318    pub fn add_node(&mut self, schema: NodeSchema) -> &mut Self {
319        self.nodes.insert(schema.label.clone(), schema);
320        self
321    }
322    pub fn add_edge(&mut self, schema: EdgeSchema) -> &mut Self {
323        self.edges.insert(schema.edge_type.clone(), schema);
324        self
325    }
326    pub fn add_vector(&mut self, schema: VectorSchema) -> &mut Self {
327        self.vectors.insert(schema.name.clone(), schema);
328        self
329    }
330
331    pub fn node(&self, label: &str) -> Option<&NodeSchema> {
332        self.nodes.get(label)
333    }
334    pub fn edge(&self, edge_type: &str) -> Option<&EdgeSchema> {
335        self.edges.get(edge_type)
336    }
337    pub fn vector(&self, name: &str) -> Option<&VectorSchema> {
338        self.vectors.get(name)
339    }
340
341    /// Node schemas sorted by label (deterministic — for codegen).
342    pub fn node_schemas_sorted(&self) -> Vec<&NodeSchema> {
343        let mut v: Vec<&NodeSchema> = self.nodes.values().collect();
344        v.sort_by(|a, b| a.label.cmp(&b.label));
345        v
346    }
347    /// Edge schemas sorted by edge type (deterministic — for codegen).
348    pub fn edge_schemas_sorted(&self) -> Vec<&EdgeSchema> {
349        let mut v: Vec<&EdgeSchema> = self.edges.values().collect();
350        v.sort_by(|a, b| a.edge_type.cmp(&b.edge_type));
351        v
352    }
353    /// Vector schemas sorted by name (deterministic — for codegen).
354    pub fn vector_schemas_sorted(&self) -> Vec<&VectorSchema> {
355        let mut v: Vec<&VectorSchema> = self.vectors.values().collect();
356        v.sort_by(|a, b| a.name.cmp(&b.name));
357        v
358    }
359
360    /// Validate the schema's own internal consistency: every edge's `from`/`to`
361    /// label and every vector's bound label must reference a declared node. Run
362    /// this once after building the schema (HelixQL's compile-time check).
363    pub fn validate_self(&self) -> Result<()> {
364        for e in self.edges.values() {
365            if !self.nodes.contains_key(&e.from_label) {
366                return Err(GraphError::SchemaViolation(format!(
367                    "edge '{}' references undeclared from-label '{}'",
368                    e.edge_type, e.from_label
369                )));
370            }
371            if !self.nodes.contains_key(&e.to_label) {
372                return Err(GraphError::SchemaViolation(format!(
373                    "edge '{}' references undeclared to-label '{}'",
374                    e.edge_type, e.to_label
375                )));
376            }
377        }
378        for v in self.vectors.values() {
379            if !self.nodes.contains_key(&v.label) {
380                return Err(GraphError::SchemaViolation(format!(
381                    "vector '{}' bound to undeclared label '{}'",
382                    v.name, v.label
383                )));
384            }
385        }
386        Ok(())
387    }
388
389    /// Validate a node against any declared schema for its labels. Labels with no
390    /// schema pass through (schemaless coexistence).
391    pub fn validate_node(&self, node: &Node) -> Result<()> {
392        // Collect every property allowed by any matching (declared) label.
393        let mut allowed: Vec<&str> = Vec::new();
394        let mut any_strict = false;
395        let mut matched_any = false;
396
397        for label in &node.labels {
398            let Some(ns) = self.nodes.get(&label.name) else {
399                continue;
400            };
401            matched_any = true;
402            any_strict |= ns.strict;
403            for p in &ns.properties {
404                allowed.push(p.name.as_str());
405                match node.properties.get(&p.name) {
406                    None if p.required => {
407                        return Err(GraphError::SchemaViolation(format!(
408                            "node '{}' (:{}) missing required property '{}'",
409                            node.id, label.name, p.name
410                        )));
411                    }
412                    Some(v) if !p.ptype.accepts(v) => {
413                        return Err(GraphError::SchemaViolation(format!(
414                            "node '{}' (:{}) property '{}' has wrong type (expected {:?})",
415                            node.id, label.name, p.name, p.ptype
416                        )));
417                    }
418                    _ => {}
419                }
420            }
421        }
422
423        if matched_any && any_strict {
424            for key in node.properties.keys() {
425                if !allowed.iter().any(|a| a == key) {
426                    return Err(GraphError::SchemaViolation(format!(
427                        "node '{}' has undeclared property '{}' (strict schema)",
428                        node.id, key
429                    )));
430                }
431            }
432        }
433        Ok(())
434    }
435
436    /// Validate an edge given the labels of its endpoints. Undeclared edge types
437    /// pass through. Pass the actual from/to node labels so direction + endpoint
438    /// types are checked.
439    pub fn validate_edge(&self, edge: &Edge, from_labels: &[String], to_labels: &[String]) -> Result<()> {
440        let Some(es) = self.edges.get(&edge.edge_type) else {
441            return Ok(());
442        };
443        if !from_labels.iter().any(|l| l == &es.from_label) {
444            return Err(GraphError::SchemaViolation(format!(
445                "edge '{}' requires from-label '{}', got {:?}",
446                edge.edge_type, es.from_label, from_labels
447            )));
448        }
449        if !to_labels.iter().any(|l| l == &es.to_label) {
450            return Err(GraphError::SchemaViolation(format!(
451                "edge '{}' requires to-label '{}', got {:?}",
452                edge.edge_type, es.to_label, to_labels
453            )));
454        }
455        for p in &es.properties {
456            match edge.properties.get(&p.name) {
457                None if p.required => {
458                    return Err(GraphError::SchemaViolation(format!(
459                        "edge '{}' missing required property '{}'",
460                        edge.edge_type, p.name
461                    )));
462                }
463                Some(v) if !p.ptype.accepts(v) => {
464                    return Err(GraphError::SchemaViolation(format!(
465                        "edge '{}' property '{}' has wrong type (expected {:?})",
466                        edge.edge_type, p.name, p.ptype
467                    )));
468                }
469                _ => {}
470            }
471        }
472        Ok(())
473    }
474
475    /// Validate that a query vector matches a declared vector type's dimension.
476    pub fn validate_vector_dims(&self, vector_type: &str, query: &[f32]) -> Result<&VectorSchema> {
477        let vs = self.vectors.get(vector_type).ok_or_else(|| {
478            GraphError::SchemaViolation(format!("unknown vector type '{}'", vector_type))
479        })?;
480        if query.len() != vs.dimensions {
481            return Err(GraphError::SchemaViolation(format!(
482                "vector type '{}' expects dimension {}, got {}",
483                vector_type,
484                vs.dimensions,
485                query.len()
486            )));
487        }
488        Ok(vs)
489    }
490}
491
492/// Reciprocal Rank Fusion over several ranked id lists (ADR-252 P4 core).
493///
494/// `score(id) = Σ 1 / (k_const + rank)` with `rank` 1-based per list. The common
495/// default for `k_const` is 60. Returns ids sorted by fused score, descending.
496pub fn reciprocal_rank_fusion(rankings: &[Vec<String>], k_const: f32) -> Vec<(String, f32)> {
497    let mut scores: HashMap<String, f32> = HashMap::new();
498    for ranking in rankings {
499        for (rank, id) in ranking.iter().enumerate() {
500            let contribution = 1.0 / (k_const + (rank as f32 + 1.0));
501            *scores.entry(id.clone()).or_insert(0.0) += contribution;
502        }
503    }
504    let mut fused: Vec<(String, f32)> = scores.into_iter().collect();
505    fused.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
506    fused
507}
508
509#[cfg(test)]
510mod tests {
511    use super::*;
512    use crate::node::NodeBuilder;
513    use crate::types::Label;
514
515    fn person_schema() -> GraphSchema {
516        let mut s = GraphSchema::new();
517        s.add_node(
518            NodeSchema::new("Person")
519                .property(PropertySchema::new("name", PropertyType::String).required().indexed())
520                .property(PropertySchema::new("age", PropertyType::Integer))
521                .property(PropertySchema::new("embedding", PropertyType::Vector)),
522        );
523        s.add_node(NodeSchema::new("Company"));
524        s.add_edge(EdgeSchema::new("WORKS_AT", "Person", "Company"));
525        s.add_vector(VectorSchema::new("PersonEmb", "Person", "embedding", 3, DistanceMetric::Cosine));
526        s
527    }
528
529    #[test]
530    fn self_validation_catches_dangling_refs() {
531        let mut s = GraphSchema::new();
532        s.add_edge(EdgeSchema::new("KNOWS", "Person", "Person"));
533        assert!(s.validate_self().is_err());
534        s.add_node(NodeSchema::new("Person"));
535        assert!(s.validate_self().is_ok());
536    }
537
538    #[test]
539    fn node_validation_required_and_types() {
540        let s = person_schema();
541        // Valid.
542        let ok = NodeBuilder::new().label("Person").property("name", "Alice").property("age", 30i64).build();
543        assert!(s.validate_node(&ok).is_ok());
544        // Missing required `name`.
545        let missing = NodeBuilder::new().label("Person").property("age", 30i64).build();
546        assert!(s.validate_node(&missing).is_err());
547        // Wrong type for `age` (string where integer expected).
548        let wrong = NodeBuilder::new().label("Person").property("name", "Bob").property("age", "old").build();
549        assert!(s.validate_node(&wrong).is_err());
550        // Undeclared label passes through (schemaless coexistence).
551        let other = NodeBuilder::new().label("Alien").property("planet", "Mars").build();
552        assert!(s.validate_node(&other).is_ok());
553    }
554
555    #[test]
556    fn strict_node_rejects_undeclared_props() {
557        let mut s = GraphSchema::new();
558        s.add_node(NodeSchema::new("Tag").property(PropertySchema::new("name", PropertyType::String)).strict());
559        let bad = NodeBuilder::new().label("Tag").property("name", "x").property("extra", 1i64).build();
560        assert!(s.validate_node(&bad).is_err());
561    }
562
563    #[test]
564    fn edge_validation_checks_endpoint_labels() {
565        let s = person_schema();
566        let e = Edge::create("p1".into(), "c1".into(), "WORKS_AT");
567        assert!(s
568            .validate_edge(&e, &["Person".into()], &["Company".into()])
569            .is_ok());
570        // Wrong from-label.
571        assert!(s
572            .validate_edge(&e, &["Company".into()], &["Company".into()])
573            .is_err());
574        // Undeclared edge type passes through.
575        let e2 = Edge::create("p1".into(), "p2".into(), "LIKES");
576        assert!(s.validate_edge(&e2, &["Person".into()], &["Person".into()]).is_ok());
577    }
578
579    #[test]
580    fn vector_dim_validation() {
581        let s = person_schema();
582        assert!(s.validate_vector_dims("PersonEmb", &[1.0, 2.0, 3.0]).is_ok());
583        assert!(s.validate_vector_dims("PersonEmb", &[1.0, 2.0]).is_err());
584        assert!(s.validate_vector_dims("Missing", &[1.0, 2.0, 3.0]).is_err());
585    }
586
587    #[test]
588    fn distance_metrics_rank_higher_is_better() {
589        let q = [1.0f32, 0.0, 0.0];
590        let near = [0.9f32, 0.1, 0.0];
591        let far = [0.0f32, 1.0, 0.0];
592        for m in [DistanceMetric::Cosine, DistanceMetric::DotProduct, DistanceMetric::Euclidean] {
593            assert!(m.score(&q, &near) > m.score(&q, &far), "{:?}", m);
594        }
595    }
596
597    #[test]
598    fn extract_vector_handles_shapes() {
599        assert_eq!(extract_vector(&PropertyValue::FloatArray(vec![1.0, 2.0])), Some(vec![1.0, 2.0]));
600        assert_eq!(
601            extract_vector(&PropertyValue::Array(vec![PropertyValue::Integer(1), PropertyValue::Float(2.0)])),
602            Some(vec![1.0, 2.0])
603        );
604        assert_eq!(extract_vector(&PropertyValue::String("x".into())), None);
605    }
606
607    #[test]
608    fn rrf_fuses_and_ranks() {
609        let a = vec!["x".to_string(), "y".to_string(), "z".to_string()];
610        let b = vec!["y".to_string(), "x".to_string()];
611        let fused = reciprocal_rank_fusion(&[a, b], 60.0);
612        // `y`: 1/62 + 1/61; `x`: 1/61 + 1/62 — tie; `z`: 1/63. x & y lead z.
613        assert_eq!(fused.len(), 3);
614        assert_eq!(fused[2].0, "z");
615    }
616
617    #[test]
618    fn multi_label_node_validation() {
619        let mut s = GraphSchema::new();
620        s.add_node(NodeSchema::new("A").property(PropertySchema::new("a", PropertyType::Integer).required()));
621        s.add_node(NodeSchema::new("B").property(PropertySchema::new("b", PropertyType::String).required()));
622        let n = Node::new(
623            "n1".into(),
624            vec![Label::new("A"), Label::new("B")],
625            [
626                ("a".to_string(), PropertyValue::Integer(1)),
627                ("b".to_string(), PropertyValue::String("x".into())),
628            ]
629            .into_iter()
630            .collect(),
631        );
632        assert!(s.validate_node(&n).is_ok());
633    }
634}