spec_ai_knowledge_graph/
types.rs

1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3use serde_json::Value as JsonValue;
4use std::collections::HashMap;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct GraphNode {
8    pub id: i64,
9    pub session_id: String,
10    pub node_type: NodeType,
11    pub label: String,
12    pub properties: JsonValue,
13    pub embedding_id: Option<i64>,
14    pub created_at: DateTime<Utc>,
15    pub updated_at: DateTime<Utc>,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
19pub enum NodeType {
20    Entity,
21    Concept,
22    Fact,
23    Message,
24    ToolResult,
25    Event,
26    Goal,
27}
28
29impl NodeType {
30    pub fn as_str(&self) -> &'static str {
31        match self {
32            NodeType::Entity => "entity",
33            NodeType::Concept => "concept",
34            NodeType::Fact => "fact",
35            NodeType::Message => "message",
36            NodeType::ToolResult => "tool_result",
37            NodeType::Event => "event",
38            NodeType::Goal => "goal",
39        }
40    }
41
42    pub fn from_str(value: &str) -> Self {
43        match value.to_ascii_lowercase().as_str() {
44            "entity" => NodeType::Entity,
45            "concept" => NodeType::Concept,
46            "fact" => NodeType::Fact,
47            "message" => NodeType::Message,
48            "tool_result" => NodeType::ToolResult,
49            "event" => NodeType::Event,
50            "goal" => NodeType::Goal,
51            _ => NodeType::Entity,
52        }
53    }
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct GraphEdge {
58    pub id: i64,
59    pub session_id: String,
60    pub source_id: i64,
61    pub target_id: i64,
62    pub edge_type: EdgeType,
63    pub predicate: Option<String>,
64    pub properties: Option<JsonValue>,
65    pub weight: f32,
66    pub temporal_start: Option<DateTime<Utc>>,
67    pub temporal_end: Option<DateTime<Utc>>,
68    pub created_at: DateTime<Utc>,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
72pub enum EdgeType {
73    RelatesTo,
74    CausedBy,
75    PartOf,
76    Mentions,
77    FollowsFrom,
78    Uses,
79    Produces,
80    DependsOn,
81    Custom(String),
82}
83
84impl EdgeType {
85    pub fn as_str(&self) -> String {
86        match self {
87            EdgeType::RelatesTo => "RELATES_TO".to_string(),
88            EdgeType::CausedBy => "CAUSED_BY".to_string(),
89            EdgeType::PartOf => "PART_OF".to_string(),
90            EdgeType::Mentions => "MENTIONS".to_string(),
91            EdgeType::FollowsFrom => "FOLLOWS_FROM".to_string(),
92            EdgeType::Uses => "USES".to_string(),
93            EdgeType::Produces => "PRODUCES".to_string(),
94            EdgeType::DependsOn => "DEPENDS_ON".to_string(),
95            EdgeType::Custom(value) => value.clone(),
96        }
97    }
98
99    pub fn from_str(value: &str) -> Self {
100        match value.to_uppercase().as_str() {
101            "RELATES_TO" => EdgeType::RelatesTo,
102            "CAUSED_BY" => EdgeType::CausedBy,
103            "PART_OF" => EdgeType::PartOf,
104            "MENTIONS" => EdgeType::Mentions,
105            "FOLLOWS_FROM" => EdgeType::FollowsFrom,
106            "USES" => EdgeType::Uses,
107            "PRODUCES" => EdgeType::Produces,
108            "DEPENDS_ON" => EdgeType::DependsOn,
109            custom => EdgeType::Custom(custom.to_string()),
110        }
111    }
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct GraphQuery {
116    pub pattern: String,
117    pub parameters: HashMap<String, JsonValue>,
118    pub limit: Option<usize>,
119    pub return_type: GraphQueryReturnType,
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
123pub enum GraphQueryReturnType {
124    Nodes,
125    Edges,
126    Paths,
127    Count,
128}
129
130#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct GraphPath {
132    pub nodes: Vec<GraphNode>,
133    pub edges: Vec<GraphEdge>,
134    pub length: usize,
135    pub weight: f32,
136}
137
138#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct GraphQueryResult {
140    pub nodes: Vec<GraphNode>,
141    pub edges: Vec<GraphEdge>,
142    pub paths: Vec<GraphPath>,
143    pub count: Option<usize>,
144}
145
146#[derive(Debug, Clone, Copy, PartialEq, Eq)]
147pub enum TraversalDirection {
148    Outgoing,
149    Incoming,
150    Both,
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156
157    #[test]
158    fn node_type_round_trips() {
159        let variants = [
160            (NodeType::Entity, "entity"),
161            (NodeType::Concept, "concept"),
162            (NodeType::Fact, "fact"),
163            (NodeType::Message, "message"),
164            (NodeType::ToolResult, "tool_result"),
165            (NodeType::Event, "event"),
166            (NodeType::Goal, "goal"),
167        ];
168
169        for (variant, label) in variants {
170            assert_eq!(variant.as_str(), label);
171            assert_eq!(NodeType::from_str(label), variant);
172        }
173
174        // Unknown strings should fall back to Entity
175        assert_eq!(NodeType::from_str("unknown"), NodeType::Entity);
176    }
177
178    #[test]
179    fn edge_type_round_trips() {
180        let variants = [
181            (EdgeType::RelatesTo, "RELATES_TO"),
182            (EdgeType::CausedBy, "CAUSED_BY"),
183            (EdgeType::PartOf, "PART_OF"),
184            (EdgeType::Mentions, "MENTIONS"),
185            (EdgeType::FollowsFrom, "FOLLOWS_FROM"),
186            (EdgeType::Uses, "USES"),
187            (EdgeType::Produces, "PRODUCES"),
188            (EdgeType::DependsOn, "DEPENDS_ON"),
189        ];
190
191        for (variant, label) in variants {
192            assert_eq!(variant.as_str(), label);
193            assert_eq!(EdgeType::from_str(label), variant);
194        }
195
196        // Custom types should be preserved verbatim
197        if let EdgeType::Custom(value) = EdgeType::from_str("MY_EDGE") {
198            assert_eq!(value, "MY_EDGE");
199        } else {
200            panic!("expected custom edge type");
201        }
202    }
203}