1use crate::edge::Edge;
16use crate::error::{GraphError, Result};
17use crate::node::Node;
18use crate::types::PropertyValue;
19use serde::{Deserialize, Serialize};
20use std::collections::HashMap;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
24pub enum PropertyType {
25 Boolean,
26 Integer,
27 Float,
29 String,
30 Vector,
32 Array,
34 Map,
35 Any,
37}
38
39impl PropertyType {
40 pub fn accepts(&self, value: &PropertyValue) -> bool {
42 match self {
43 PropertyType::Any => true,
44 PropertyType::Boolean => matches!(value, PropertyValue::Boolean(_)),
45 PropertyType::Integer => matches!(value, PropertyValue::Integer(_)),
46 PropertyType::Float => {
48 matches!(value, PropertyValue::Float(_) | PropertyValue::Integer(_))
49 }
50 PropertyType::String => matches!(value, PropertyValue::String(_)),
51 PropertyType::Vector => extract_vector(value).is_some(),
52 PropertyType::Array => {
53 matches!(value, PropertyValue::Array(_) | PropertyValue::List(_))
54 }
55 PropertyType::Map => matches!(value, PropertyValue::Map(_)),
56 }
57 }
58}
59
60#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
63pub enum DistanceMetric {
64 Cosine,
65 DotProduct,
66 Euclidean,
67}
68
69impl DistanceMetric {
70 pub fn score(&self, a: &[f32], b: &[f32]) -> f32 {
74 self.score_pre(a, b, self.query_norm(a))
75 }
76
77 #[inline]
80 pub fn query_norm(&self, q: &[f32]) -> f32 {
81 match self {
82 DistanceMetric::Cosine => dot(q, q).sqrt(),
83 _ => 1.0,
84 }
85 }
86
87 #[inline]
90 pub fn score_pre(&self, query: &[f32], candidate: &[f32], query_norm: f32) -> f32 {
91 match self {
92 DistanceMetric::DotProduct => dot(query, candidate),
93 DistanceMetric::Cosine => {
94 let n = query.len().min(candidate.len());
98 let mut qc = 0.0f32;
99 let mut cc = 0.0f32;
100 for i in 0..n {
101 let c = candidate[i];
102 qc += query[i] * c;
103 cc += c * c;
104 }
105 let cn = cc.sqrt();
106 if query_norm == 0.0 || cn == 0.0 {
107 0.0
108 } else {
109 qc / (query_norm * cn)
110 }
111 }
112 DistanceMetric::Euclidean => {
113 let n = query.len().min(candidate.len());
114 let mut sum = 0.0f32;
115 for i in 0..n {
116 let d = query[i] - candidate[i];
117 sum += d * d;
118 }
119 -sum.sqrt()
120 }
121 }
122 }
123}
124
125#[inline]
129pub fn score_property(
130 metric: DistanceMetric,
131 query: &[f32],
132 query_norm: f32,
133 value: &PropertyValue,
134) -> Option<f32> {
135 match value {
136 PropertyValue::FloatArray(v) => {
138 if v.len() == query.len() {
139 Some(metric.score_pre(query, v, query_norm))
140 } else {
141 None
142 }
143 }
144 PropertyValue::Array(_) | PropertyValue::List(_) => {
146 let v = extract_vector(value)?;
147 if v.len() == query.len() {
148 Some(metric.score_pre(query, &v, query_norm))
149 } else {
150 None
151 }
152 }
153 _ => None,
154 }
155}
156
157#[inline]
158fn dot(a: &[f32], b: &[f32]) -> f32 {
159 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
164}
165
166pub fn extract_vector(value: &PropertyValue) -> Option<Vec<f32>> {
168 match value {
169 PropertyValue::FloatArray(v) => Some(v.clone()),
170 PropertyValue::Array(items) | PropertyValue::List(items) => {
171 let mut out = Vec::with_capacity(items.len());
172 for it in items {
173 match it {
174 PropertyValue::Float(f) => out.push(*f as f32),
175 PropertyValue::Integer(i) => out.push(*i as f32),
176 _ => return None,
177 }
178 }
179 if out.is_empty() {
180 None
181 } else {
182 Some(out)
183 }
184 }
185 _ => None,
186 }
187}
188
189#[derive(Debug, Clone, Serialize, Deserialize)]
191pub struct PropertySchema {
192 pub name: String,
193 pub ptype: PropertyType,
194 pub required: bool,
196 pub indexed: bool,
198}
199
200impl PropertySchema {
201 pub fn new(name: impl Into<String>, ptype: PropertyType) -> Self {
202 Self {
203 name: name.into(),
204 ptype,
205 required: false,
206 indexed: false,
207 }
208 }
209 pub fn required(mut self) -> Self {
210 self.required = true;
211 self
212 }
213 pub fn indexed(mut self) -> Self {
214 self.indexed = true;
215 self
216 }
217}
218
219#[derive(Debug, Clone, Serialize, Deserialize)]
221pub struct NodeSchema {
222 pub label: String,
223 pub properties: Vec<PropertySchema>,
224 pub strict: bool,
226}
227
228impl NodeSchema {
229 pub fn new(label: impl Into<String>) -> Self {
230 Self {
231 label: label.into(),
232 properties: Vec::new(),
233 strict: false,
234 }
235 }
236 pub fn property(mut self, p: PropertySchema) -> Self {
237 self.properties.push(p);
238 self
239 }
240 pub fn strict(mut self) -> Self {
241 self.strict = true;
242 self
243 }
244}
245
246#[derive(Debug, Clone, Serialize, Deserialize)]
248pub struct EdgeSchema {
249 pub edge_type: String,
250 pub from_label: String,
251 pub to_label: String,
252 pub properties: Vec<PropertySchema>,
253}
254
255impl EdgeSchema {
256 pub fn new(
257 edge_type: impl Into<String>,
258 from_label: impl Into<String>,
259 to_label: impl Into<String>,
260 ) -> Self {
261 Self {
262 edge_type: edge_type.into(),
263 from_label: from_label.into(),
264 to_label: to_label.into(),
265 properties: Vec::new(),
266 }
267 }
268 pub fn property(mut self, p: PropertySchema) -> Self {
269 self.properties.push(p);
270 self
271 }
272}
273
274#[derive(Debug, Clone, Serialize, Deserialize)]
276pub struct VectorSchema {
277 pub name: String,
279 pub label: String,
281 pub property: String,
283 pub dimensions: usize,
284 pub metric: DistanceMetric,
285}
286
287impl VectorSchema {
288 pub fn new(
289 name: impl Into<String>,
290 label: impl Into<String>,
291 property: impl Into<String>,
292 dimensions: usize,
293 metric: DistanceMetric,
294 ) -> Self {
295 Self {
296 name: name.into(),
297 label: label.into(),
298 property: property.into(),
299 dimensions,
300 metric,
301 }
302 }
303}
304
305#[derive(Debug, Clone, Default, Serialize, Deserialize)]
307pub struct GraphSchema {
308 nodes: HashMap<String, NodeSchema>,
309 edges: HashMap<String, EdgeSchema>,
310 vectors: HashMap<String, VectorSchema>,
311}
312
313impl GraphSchema {
314 pub fn new() -> Self {
315 Self::default()
316 }
317
318 pub fn add_node(&mut self, schema: NodeSchema) -> &mut Self {
319 self.nodes.insert(schema.label.clone(), schema);
320 self
321 }
322 pub fn add_edge(&mut self, schema: EdgeSchema) -> &mut Self {
323 self.edges.insert(schema.edge_type.clone(), schema);
324 self
325 }
326 pub fn add_vector(&mut self, schema: VectorSchema) -> &mut Self {
327 self.vectors.insert(schema.name.clone(), schema);
328 self
329 }
330
331 pub fn node(&self, label: &str) -> Option<&NodeSchema> {
332 self.nodes.get(label)
333 }
334 pub fn edge(&self, edge_type: &str) -> Option<&EdgeSchema> {
335 self.edges.get(edge_type)
336 }
337 pub fn vector(&self, name: &str) -> Option<&VectorSchema> {
338 self.vectors.get(name)
339 }
340
341 pub fn node_schemas_sorted(&self) -> Vec<&NodeSchema> {
343 let mut v: Vec<&NodeSchema> = self.nodes.values().collect();
344 v.sort_by(|a, b| a.label.cmp(&b.label));
345 v
346 }
347 pub fn edge_schemas_sorted(&self) -> Vec<&EdgeSchema> {
349 let mut v: Vec<&EdgeSchema> = self.edges.values().collect();
350 v.sort_by(|a, b| a.edge_type.cmp(&b.edge_type));
351 v
352 }
353 pub fn vector_schemas_sorted(&self) -> Vec<&VectorSchema> {
355 let mut v: Vec<&VectorSchema> = self.vectors.values().collect();
356 v.sort_by(|a, b| a.name.cmp(&b.name));
357 v
358 }
359
360 pub fn validate_self(&self) -> Result<()> {
364 for e in self.edges.values() {
365 if !self.nodes.contains_key(&e.from_label) {
366 return Err(GraphError::SchemaViolation(format!(
367 "edge '{}' references undeclared from-label '{}'",
368 e.edge_type, e.from_label
369 )));
370 }
371 if !self.nodes.contains_key(&e.to_label) {
372 return Err(GraphError::SchemaViolation(format!(
373 "edge '{}' references undeclared to-label '{}'",
374 e.edge_type, e.to_label
375 )));
376 }
377 }
378 for v in self.vectors.values() {
379 if !self.nodes.contains_key(&v.label) {
380 return Err(GraphError::SchemaViolation(format!(
381 "vector '{}' bound to undeclared label '{}'",
382 v.name, v.label
383 )));
384 }
385 }
386 Ok(())
387 }
388
389 pub fn validate_node(&self, node: &Node) -> Result<()> {
392 let mut allowed: Vec<&str> = Vec::new();
394 let mut any_strict = false;
395 let mut matched_any = false;
396
397 for label in &node.labels {
398 let Some(ns) = self.nodes.get(&label.name) else {
399 continue;
400 };
401 matched_any = true;
402 any_strict |= ns.strict;
403 for p in &ns.properties {
404 allowed.push(p.name.as_str());
405 match node.properties.get(&p.name) {
406 None if p.required => {
407 return Err(GraphError::SchemaViolation(format!(
408 "node '{}' (:{}) missing required property '{}'",
409 node.id, label.name, p.name
410 )));
411 }
412 Some(v) if !p.ptype.accepts(v) => {
413 return Err(GraphError::SchemaViolation(format!(
414 "node '{}' (:{}) property '{}' has wrong type (expected {:?})",
415 node.id, label.name, p.name, p.ptype
416 )));
417 }
418 _ => {}
419 }
420 }
421 }
422
423 if matched_any && any_strict {
424 for key in node.properties.keys() {
425 if !allowed.iter().any(|a| a == key) {
426 return Err(GraphError::SchemaViolation(format!(
427 "node '{}' has undeclared property '{}' (strict schema)",
428 node.id, key
429 )));
430 }
431 }
432 }
433 Ok(())
434 }
435
436 pub fn validate_edge(&self, edge: &Edge, from_labels: &[String], to_labels: &[String]) -> Result<()> {
440 let Some(es) = self.edges.get(&edge.edge_type) else {
441 return Ok(());
442 };
443 if !from_labels.iter().any(|l| l == &es.from_label) {
444 return Err(GraphError::SchemaViolation(format!(
445 "edge '{}' requires from-label '{}', got {:?}",
446 edge.edge_type, es.from_label, from_labels
447 )));
448 }
449 if !to_labels.iter().any(|l| l == &es.to_label) {
450 return Err(GraphError::SchemaViolation(format!(
451 "edge '{}' requires to-label '{}', got {:?}",
452 edge.edge_type, es.to_label, to_labels
453 )));
454 }
455 for p in &es.properties {
456 match edge.properties.get(&p.name) {
457 None if p.required => {
458 return Err(GraphError::SchemaViolation(format!(
459 "edge '{}' missing required property '{}'",
460 edge.edge_type, p.name
461 )));
462 }
463 Some(v) if !p.ptype.accepts(v) => {
464 return Err(GraphError::SchemaViolation(format!(
465 "edge '{}' property '{}' has wrong type (expected {:?})",
466 edge.edge_type, p.name, p.ptype
467 )));
468 }
469 _ => {}
470 }
471 }
472 Ok(())
473 }
474
475 pub fn validate_vector_dims(&self, vector_type: &str, query: &[f32]) -> Result<&VectorSchema> {
477 let vs = self.vectors.get(vector_type).ok_or_else(|| {
478 GraphError::SchemaViolation(format!("unknown vector type '{}'", vector_type))
479 })?;
480 if query.len() != vs.dimensions {
481 return Err(GraphError::SchemaViolation(format!(
482 "vector type '{}' expects dimension {}, got {}",
483 vector_type,
484 vs.dimensions,
485 query.len()
486 )));
487 }
488 Ok(vs)
489 }
490}
491
492pub fn reciprocal_rank_fusion(rankings: &[Vec<String>], k_const: f32) -> Vec<(String, f32)> {
497 let mut scores: HashMap<String, f32> = HashMap::new();
498 for ranking in rankings {
499 for (rank, id) in ranking.iter().enumerate() {
500 let contribution = 1.0 / (k_const + (rank as f32 + 1.0));
501 *scores.entry(id.clone()).or_insert(0.0) += contribution;
502 }
503 }
504 let mut fused: Vec<(String, f32)> = scores.into_iter().collect();
505 fused.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
506 fused
507}
508
509#[cfg(test)]
510mod tests {
511 use super::*;
512 use crate::node::NodeBuilder;
513 use crate::types::Label;
514
515 fn person_schema() -> GraphSchema {
516 let mut s = GraphSchema::new();
517 s.add_node(
518 NodeSchema::new("Person")
519 .property(PropertySchema::new("name", PropertyType::String).required().indexed())
520 .property(PropertySchema::new("age", PropertyType::Integer))
521 .property(PropertySchema::new("embedding", PropertyType::Vector)),
522 );
523 s.add_node(NodeSchema::new("Company"));
524 s.add_edge(EdgeSchema::new("WORKS_AT", "Person", "Company"));
525 s.add_vector(VectorSchema::new("PersonEmb", "Person", "embedding", 3, DistanceMetric::Cosine));
526 s
527 }
528
529 #[test]
530 fn self_validation_catches_dangling_refs() {
531 let mut s = GraphSchema::new();
532 s.add_edge(EdgeSchema::new("KNOWS", "Person", "Person"));
533 assert!(s.validate_self().is_err());
534 s.add_node(NodeSchema::new("Person"));
535 assert!(s.validate_self().is_ok());
536 }
537
538 #[test]
539 fn node_validation_required_and_types() {
540 let s = person_schema();
541 let ok = NodeBuilder::new().label("Person").property("name", "Alice").property("age", 30i64).build();
543 assert!(s.validate_node(&ok).is_ok());
544 let missing = NodeBuilder::new().label("Person").property("age", 30i64).build();
546 assert!(s.validate_node(&missing).is_err());
547 let wrong = NodeBuilder::new().label("Person").property("name", "Bob").property("age", "old").build();
549 assert!(s.validate_node(&wrong).is_err());
550 let other = NodeBuilder::new().label("Alien").property("planet", "Mars").build();
552 assert!(s.validate_node(&other).is_ok());
553 }
554
555 #[test]
556 fn strict_node_rejects_undeclared_props() {
557 let mut s = GraphSchema::new();
558 s.add_node(NodeSchema::new("Tag").property(PropertySchema::new("name", PropertyType::String)).strict());
559 let bad = NodeBuilder::new().label("Tag").property("name", "x").property("extra", 1i64).build();
560 assert!(s.validate_node(&bad).is_err());
561 }
562
563 #[test]
564 fn edge_validation_checks_endpoint_labels() {
565 let s = person_schema();
566 let e = Edge::create("p1".into(), "c1".into(), "WORKS_AT");
567 assert!(s
568 .validate_edge(&e, &["Person".into()], &["Company".into()])
569 .is_ok());
570 assert!(s
572 .validate_edge(&e, &["Company".into()], &["Company".into()])
573 .is_err());
574 let e2 = Edge::create("p1".into(), "p2".into(), "LIKES");
576 assert!(s.validate_edge(&e2, &["Person".into()], &["Person".into()]).is_ok());
577 }
578
579 #[test]
580 fn vector_dim_validation() {
581 let s = person_schema();
582 assert!(s.validate_vector_dims("PersonEmb", &[1.0, 2.0, 3.0]).is_ok());
583 assert!(s.validate_vector_dims("PersonEmb", &[1.0, 2.0]).is_err());
584 assert!(s.validate_vector_dims("Missing", &[1.0, 2.0, 3.0]).is_err());
585 }
586
587 #[test]
588 fn distance_metrics_rank_higher_is_better() {
589 let q = [1.0f32, 0.0, 0.0];
590 let near = [0.9f32, 0.1, 0.0];
591 let far = [0.0f32, 1.0, 0.0];
592 for m in [DistanceMetric::Cosine, DistanceMetric::DotProduct, DistanceMetric::Euclidean] {
593 assert!(m.score(&q, &near) > m.score(&q, &far), "{:?}", m);
594 }
595 }
596
597 #[test]
598 fn extract_vector_handles_shapes() {
599 assert_eq!(extract_vector(&PropertyValue::FloatArray(vec![1.0, 2.0])), Some(vec![1.0, 2.0]));
600 assert_eq!(
601 extract_vector(&PropertyValue::Array(vec![PropertyValue::Integer(1), PropertyValue::Float(2.0)])),
602 Some(vec![1.0, 2.0])
603 );
604 assert_eq!(extract_vector(&PropertyValue::String("x".into())), None);
605 }
606
607 #[test]
608 fn rrf_fuses_and_ranks() {
609 let a = vec!["x".to_string(), "y".to_string(), "z".to_string()];
610 let b = vec!["y".to_string(), "x".to_string()];
611 let fused = reciprocal_rank_fusion(&[a, b], 60.0);
612 assert_eq!(fused.len(), 3);
614 assert_eq!(fused[2].0, "z");
615 }
616
617 #[test]
618 fn multi_label_node_validation() {
619 let mut s = GraphSchema::new();
620 s.add_node(NodeSchema::new("A").property(PropertySchema::new("a", PropertyType::Integer).required()));
621 s.add_node(NodeSchema::new("B").property(PropertySchema::new("b", PropertyType::String).required()));
622 let n = Node::new(
623 "n1".into(),
624 vec![Label::new("A"), Label::new("B")],
625 [
626 ("a".to_string(), PropertyValue::Integer(1)),
627 ("b".to_string(), PropertyValue::String("x".into())),
628 ]
629 .into_iter()
630 .collect(),
631 );
632 assert!(s.validate_node(&n).is_ok());
633 }
634}