1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3use serde_json::Value as JsonValue;
4use std::collections::HashMap;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct GraphNode {
8 pub id: i64,
9 pub session_id: String,
10 pub node_type: NodeType,
11 pub label: String,
12 pub properties: JsonValue,
13 pub embedding_id: Option<i64>,
14 pub created_at: DateTime<Utc>,
15 pub updated_at: DateTime<Utc>,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
19pub enum NodeType {
20 Entity,
21 Concept,
22 Fact,
23 Message,
24 ToolResult,
25 Event,
26 Goal,
27}
28
29impl NodeType {
30 pub fn as_str(&self) -> &'static str {
31 match self {
32 NodeType::Entity => "entity",
33 NodeType::Concept => "concept",
34 NodeType::Fact => "fact",
35 NodeType::Message => "message",
36 NodeType::ToolResult => "tool_result",
37 NodeType::Event => "event",
38 NodeType::Goal => "goal",
39 }
40 }
41
42 pub fn from_str(value: &str) -> Self {
43 match value.to_ascii_lowercase().as_str() {
44 "entity" => NodeType::Entity,
45 "concept" => NodeType::Concept,
46 "fact" => NodeType::Fact,
47 "message" => NodeType::Message,
48 "tool_result" => NodeType::ToolResult,
49 "event" => NodeType::Event,
50 "goal" => NodeType::Goal,
51 _ => NodeType::Entity,
52 }
53 }
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct GraphEdge {
58 pub id: i64,
59 pub session_id: String,
60 pub source_id: i64,
61 pub target_id: i64,
62 pub edge_type: EdgeType,
63 pub predicate: Option<String>,
64 pub properties: Option<JsonValue>,
65 pub weight: f32,
66 pub temporal_start: Option<DateTime<Utc>>,
67 pub temporal_end: Option<DateTime<Utc>>,
68 pub created_at: DateTime<Utc>,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
72pub enum EdgeType {
73 RelatesTo,
74 CausedBy,
75 PartOf,
76 Mentions,
77 FollowsFrom,
78 Uses,
79 Produces,
80 DependsOn,
81 Custom(String),
82}
83
84impl EdgeType {
85 pub fn as_str(&self) -> String {
86 match self {
87 EdgeType::RelatesTo => "RELATES_TO".to_string(),
88 EdgeType::CausedBy => "CAUSED_BY".to_string(),
89 EdgeType::PartOf => "PART_OF".to_string(),
90 EdgeType::Mentions => "MENTIONS".to_string(),
91 EdgeType::FollowsFrom => "FOLLOWS_FROM".to_string(),
92 EdgeType::Uses => "USES".to_string(),
93 EdgeType::Produces => "PRODUCES".to_string(),
94 EdgeType::DependsOn => "DEPENDS_ON".to_string(),
95 EdgeType::Custom(value) => value.clone(),
96 }
97 }
98
99 pub fn from_str(value: &str) -> Self {
100 match value.to_uppercase().as_str() {
101 "RELATES_TO" => EdgeType::RelatesTo,
102 "CAUSED_BY" => EdgeType::CausedBy,
103 "PART_OF" => EdgeType::PartOf,
104 "MENTIONS" => EdgeType::Mentions,
105 "FOLLOWS_FROM" => EdgeType::FollowsFrom,
106 "USES" => EdgeType::Uses,
107 "PRODUCES" => EdgeType::Produces,
108 "DEPENDS_ON" => EdgeType::DependsOn,
109 custom => EdgeType::Custom(custom.to_string()),
110 }
111 }
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct GraphQuery {
116 pub pattern: String,
117 pub parameters: HashMap<String, JsonValue>,
118 pub limit: Option<usize>,
119 pub return_type: GraphQueryReturnType,
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
123pub enum GraphQueryReturnType {
124 Nodes,
125 Edges,
126 Paths,
127 Count,
128}
129
130#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct GraphPath {
132 pub nodes: Vec<GraphNode>,
133 pub edges: Vec<GraphEdge>,
134 pub length: usize,
135 pub weight: f32,
136}
137
138#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct GraphQueryResult {
140 pub nodes: Vec<GraphNode>,
141 pub edges: Vec<GraphEdge>,
142 pub paths: Vec<GraphPath>,
143 pub count: Option<usize>,
144}
145
146#[derive(Debug, Clone, Copy, PartialEq, Eq)]
147pub enum TraversalDirection {
148 Outgoing,
149 Incoming,
150 Both,
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156
157 #[test]
158 fn node_type_round_trips() {
159 let variants = [
160 (NodeType::Entity, "entity"),
161 (NodeType::Concept, "concept"),
162 (NodeType::Fact, "fact"),
163 (NodeType::Message, "message"),
164 (NodeType::ToolResult, "tool_result"),
165 (NodeType::Event, "event"),
166 (NodeType::Goal, "goal"),
167 ];
168
169 for (variant, label) in variants {
170 assert_eq!(variant.as_str(), label);
171 assert_eq!(NodeType::from_str(label), variant);
172 }
173
174 assert_eq!(NodeType::from_str("unknown"), NodeType::Entity);
176 }
177
178 #[test]
179 fn edge_type_round_trips() {
180 let variants = [
181 (EdgeType::RelatesTo, "RELATES_TO"),
182 (EdgeType::CausedBy, "CAUSED_BY"),
183 (EdgeType::PartOf, "PART_OF"),
184 (EdgeType::Mentions, "MENTIONS"),
185 (EdgeType::FollowsFrom, "FOLLOWS_FROM"),
186 (EdgeType::Uses, "USES"),
187 (EdgeType::Produces, "PRODUCES"),
188 (EdgeType::DependsOn, "DEPENDS_ON"),
189 ];
190
191 for (variant, label) in variants {
192 assert_eq!(variant.as_str(), label);
193 assert_eq!(EdgeType::from_str(label), variant);
194 }
195
196 if let EdgeType::Custom(value) = EdgeType::from_str("MY_EDGE") {
198 assert_eq!(value, "MY_EDGE");
199 } else {
200 panic!("expected custom edge type");
201 }
202 }
203}