1use crate::{
4 lang::query::{
5 QueryCommonConstraint, QueryConstraintSortByKeyJson, QueryGraphConstraint,
6 QueryGraphConstraintJson, QueryGraphConstraintLimitJson, QueryGraphJson, QueryJson,
7 QueryResultJson, QueryVectorConstraint, QueryVectorConstraintJson, QueryVectorJson,
8 },
9 schema::{format_edge_table_name, format_node_table_name},
10};
11use sea_orm::{ConnectionTrait, DbConn, DbErr, FromQueryResult, Order};
12use sea_query::{Alias, Expr, SelectStatement};
13use serde::{Deserialize, Serialize};
14use serde_repr::{Deserialize_repr, Serialize_repr};
15use std::collections::{HashMap, HashSet};
16
17#[derive(Debug, Clone, Serialize, Deserialize, FromQueryResult)]
18pub struct QueryResultNode {
20 pub name: String,
22 pub weight: Option<f64>,
24 pub depth: Option<u64>,
28}
29
30#[derive(Debug, Clone, FromQueryResult)]
31struct NodeName {
33 name: String,
34}
35
36#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize, FromQueryResult)]
37#[serde(rename_all = "camelCase")]
38pub struct QueryResultEdge {
40 pub from_node: String,
42 pub to_node: String,
44}
45
46impl QueryResultEdge {
47 pub fn to_flipped(self) -> Self {
49 Self {
50 from_node: self.to_node,
51 to_node: self.from_node,
52 }
53 }
54}
55
56#[derive(Debug)]
57pub struct QueryGraphParams {
59 pub entity_name: Result<String, DbErr>,
61 pub relation_name: Result<String, DbErr>,
63 pub reverse_direction: bool,
65 pub root_node_names: Vec<String>,
68 pub max_depth: Option<u64>,
71 pub batch_sort_key: Option<String>,
75 pub batch_sort_asc: bool,
77 pub max_batch_size: Option<usize>,
80 pub max_total_size: Option<usize>,
83}
84
85impl Default for QueryGraphParams {
86 fn default() -> Self {
87 Self {
88 entity_name: Err(DbErr::Custom("Entity name is unspecified.".to_owned())),
89 relation_name: Err(DbErr::Custom("Relation name is unspecified.".to_owned())),
90 reverse_direction: false,
91 root_node_names: vec![],
92 max_depth: Some(6),
93 batch_sort_key: None,
94 batch_sort_asc: false,
95 max_batch_size: Some(6),
96 max_total_size: Some(10000),
97 }
98 }
99}
100
101impl QueryGraphParams {
102 pub fn from_query_graph_metadata(metadata: QueryGraphJson) -> Self {
104 let mut params = Self {
105 entity_name: Ok(metadata.of),
106 ..Default::default()
107 };
108
109 metadata
110 .constraints
111 .into_iter()
112 .for_each(|constraint| match constraint {
113 QueryGraphConstraintJson::Common(constraint) => {
114 params.handle_common_constraint(constraint)
115 }
116 QueryGraphConstraintJson::Exclusive(constraint) => {
117 params.handle_graph_constraint(constraint)
118 }
119 });
120
121 params
122 }
123
124 fn handle_common_constraint(&mut self, constraint: QueryCommonConstraint) {
125 match constraint {
126 QueryCommonConstraint::SortBy(sort_by) => {
127 self.batch_sort_key = match sort_by.key {
128 QueryConstraintSortByKeyJson::Connectivity { of, r#type } => {
129 Some(r#type.to_column_name(of))
130 }
131 };
132 self.batch_sort_asc = !sort_by.desc;
133 }
134 QueryCommonConstraint::Limit(limit) => self.max_total_size = Some(limit as usize),
135 }
136 }
137
138 fn handle_graph_constraint(&mut self, constraint: QueryGraphConstraint) {
139 match constraint {
140 QueryGraphConstraint::Edge { of, traversal } => {
141 self.relation_name = Ok(of);
142 self.reverse_direction = traversal.reverse_direction;
143 }
144 QueryGraphConstraint::RootNodes(root_node_names) => {
145 self.root_node_names = root_node_names;
146 }
147 QueryGraphConstraint::Limit(limit) => match limit {
148 QueryGraphConstraintLimitJson::Depth(depth) => self.max_depth = depth,
149 QueryGraphConstraintLimitJson::BatchSize(batch_size) => {
150 self.max_batch_size = batch_size
151 }
152 },
153 }
154 }
155}
156
157#[derive(Debug)]
159pub struct Query;
160
161impl Query {
162 pub async fn query(db: &DbConn, query_json: QueryJson) -> Result<QueryResultJson, DbErr> {
164 match query_json {
165 QueryJson::Vector(metadata) => Self::query_vector(db, metadata).await,
166 QueryJson::Graph(metadata) => Self::query_graph(db, metadata).await,
167 }
168 }
169
170 async fn query_vector(
171 db: &DbConn,
172 metadata: QueryVectorJson,
173 ) -> Result<QueryResultJson, DbErr> {
174 let mut stmt = sea_query::Query::select();
175
176 stmt.column(Alias::new("name"))
177 .expr_as(Expr::value(Option::<f64>::None), Alias::new("weight"))
178 .expr_as(Expr::val(Option::<u64>::None), Alias::new("depth"))
179 .from(Alias::new(&format_node_table_name(metadata.of)));
180
181 for constraint in metadata.constraints {
182 match constraint {
183 QueryVectorConstraintJson::Common(constraint) => {
184 Self::handle_common_constraint(&mut stmt, constraint)
185 }
186 QueryVectorConstraintJson::Exclusive(constraint) => {
187 Self::handle_vector_constraint(&mut stmt, constraint)
188 }
189 }
190 }
191
192 let builder = db.get_database_backend();
193
194 Ok(QueryResultJson::Vector(
195 QueryResultNode::find_by_statement(builder.build(&stmt))
196 .all(db)
197 .await?,
198 ))
199 }
200
201 fn handle_common_constraint(stmt: &mut SelectStatement, constraint: QueryCommonConstraint) {
202 match constraint {
203 QueryCommonConstraint::SortBy(sort_by) => {
204 let col_name = match sort_by.key {
205 QueryConstraintSortByKeyJson::Connectivity { of, r#type } => {
206 r#type.to_column_name(of)
207 }
208 };
209 stmt.expr_as(Expr::col(Alias::new(&col_name)), Alias::new("weight"))
210 .order_by(
211 Alias::new(&col_name),
212 if sort_by.desc {
213 Order::Desc
214 } else {
215 Order::Asc
216 },
217 );
218 }
219 QueryCommonConstraint::Limit(limit) => {
220 stmt.limit(limit);
221 }
222 }
223 }
224
225 fn handle_vector_constraint(_: &mut SelectStatement, constraint: QueryVectorConstraint) {
226 match constraint {
227 }
229 }
230
231 async fn query_graph(db: &DbConn, metadata: QueryGraphJson) -> Result<QueryResultJson, DbErr> {
232 let params = QueryGraphParams::from_query_graph_metadata(metadata);
233
234 println!("Querying a graph with params:\n{:?}", params);
235
236 Self::traverse_with_params(db, params).await
237 }
238
239 async fn traverse_with_params(
240 db: &DbConn,
241 params: QueryGraphParams,
242 ) -> Result<QueryResultJson, DbErr> {
243 let builder = db.get_database_backend();
244 let edge_table = &format_edge_table_name(params.relation_name?);
245 let node_table = &format_node_table_name(params.entity_name?);
246
247 let mut pending_nodes: Vec<String> = {
249 let root_node_set: HashSet<String> =
250 HashSet::from_iter(params.root_node_names.into_iter());
251
252 let root_node_stmt = sea_query::Query::select()
253 .column(Alias::new("name"))
254 .from(Alias::new(node_table))
255 .to_owned();
256
257 NodeName::find_by_statement(builder.build(&root_node_stmt))
258 .all(db)
259 .await?
260 .into_iter()
261 .filter_map(|node| {
262 if root_node_set.contains(&node.name) {
263 Some(node.name)
264 } else {
265 None
266 }
267 })
268 .collect()
269 };
270
271 let mut result_nodes: HashSet<String> = HashSet::from_iter(pending_nodes.iter().cloned());
272 let mut node_depths: HashMap<String, u64> = HashMap::new();
273 let mut result_edges: HashSet<QueryResultEdge> = HashSet::new();
274
275 let join_col = if !params.reverse_direction {
278 "from_node"
279 } else {
280 "to_node"
281 };
282
283 let mut depth = 0;
284 while params.max_depth.is_none() || depth < params.max_depth.unwrap() {
285 let target_edges = {
287 let target_edge_stmt = sea_query::Query::select()
288 .columns([Alias::new("from_node"), Alias::new("to_node")])
289 .from(Alias::new(edge_table))
290 .inner_join(
291 Alias::new(node_table),
292 Expr::tbl(Alias::new(node_table), Alias::new("name"))
293 .equals(Alias::new(edge_table), Alias::new(join_col)),
294 )
295 .and_where(Expr::col(Alias::new(join_col)).is_in(pending_nodes))
296 .to_owned();
297
298 QueryResultEdge::find_by_statement(builder.build(&target_edge_stmt))
299 .all(db)
300 .await?
301 };
302
303 let mut total_nodes_full = false;
304
305 pending_nodes = target_edges
306 .into_iter()
307 .filter_map(|edge| {
308 let target_node_name = if !params.reverse_direction {
309 edge.to_node.clone()
310 } else {
311 edge.from_node.clone()
312 };
313
314 if result_edges.insert(edge) && !result_nodes.contains(&target_node_name) {
315 if let Some(max_total_size) = params.max_total_size {
316 if result_nodes.len() >= max_total_size {
317 total_nodes_full = true;
318 }
319 }
320 Some(target_node_name)
321 } else {
322 None
323 }
324 })
325 .collect();
326
327 pending_nodes.iter().for_each(|node_name| {
328 if !node_depths.contains_key(node_name) {
329 node_depths.insert(node_name.clone(), depth + 1);
330 }
331 });
332
333 if let Some(order_by_key) = ¶ms.batch_sort_key {
335 pending_nodes = {
336 let pending_nodes_set: HashSet<String> =
337 HashSet::from_iter(pending_nodes.into_iter());
338
339 let stmt = sea_query::Query::select()
340 .column(Alias::new("name"))
341 .from(Alias::new(node_table))
342 .order_by(
343 Alias::new(order_by_key),
344 if params.batch_sort_asc {
345 Order::Asc
346 } else {
347 Order::Desc
348 },
349 )
350 .to_owned();
351
352 NodeName::find_by_statement(builder.build(&stmt))
353 .all(db)
354 .await?
355 .into_iter()
356 .filter_map(|node| {
357 if pending_nodes_set.contains(&node.name) {
358 Some(node.name)
359 } else {
360 None
361 }
362 })
363 .collect()
364 };
365 }
366
367 if let Some(max_batch_size) = params.max_batch_size {
368 if max_batch_size < pending_nodes.len() {
369 pending_nodes = pending_nodes[0..max_batch_size].to_vec();
370 }
371 }
372
373 result_nodes.extend(pending_nodes.iter().cloned());
374
375 if pending_nodes.is_empty() || total_nodes_full {
376 break;
377 }
378
379 depth += 1;
380 }
381
382 let edges: Vec<QueryResultEdge> = {
384 let iter = result_edges.into_iter().filter(|edge| {
385 result_nodes.contains(&edge.from_node) && result_nodes.contains(&edge.to_node)
386 });
387
388 if params.reverse_direction {
389 iter.map(|edge| edge.to_flipped()).collect()
390 } else {
391 iter.collect()
392 }
393 };
394
395 let nodes: Vec<QueryResultNode> = if let Some(weight_key) = params.batch_sort_key {
397 let stmt = sea_query::Query::select()
398 .column(Alias::new("name"))
399 .expr_as(Expr::col(Alias::new(&weight_key)), Alias::new("weight"))
400 .expr_as(Expr::val(Some(0_u64)), Alias::new("depth"))
401 .from(Alias::new(node_table))
402 .and_where(Expr::col(Alias::new("name")).is_in(result_nodes))
403 .to_owned();
404
405 QueryResultNode::find_by_statement(builder.build(&stmt))
406 .all(db)
407 .await?
408 .into_iter()
409 .map(|mut node| {
410 let depth = node_depths.get(&node.name).cloned().unwrap_or_default();
411 node.depth = Some(depth);
412 node
413 })
414 .collect()
415 } else {
416 result_nodes
417 .into_iter()
418 .map(|name| {
419 let depth = node_depths.get(&name).cloned().unwrap_or_default();
420 QueryResultNode {
421 name,
422 weight: None,
423 depth: Some(depth),
424 }
425 })
426 .collect()
427 };
428
429 Ok(QueryResultJson::Graph { nodes, edges })
430 }
431}
432
433#[derive(Debug, Clone, Deserialize, Serialize)]
435pub struct GraphData {
436 nodes: Vec<GraphNodeData>,
438 links: Vec<GraphLinkData>,
440}
441
442#[derive(Debug, Clone, Deserialize, Serialize)]
444pub struct GraphNodeData {
445 id: String,
447 weight: f64,
449}
450
451impl PartialEq for GraphNodeData {
452 fn eq(&self, other: &Self) -> bool {
453 self.id == other.id
454 }
455}
456
457impl Eq for GraphNodeData {}
458
459impl std::hash::Hash for GraphNodeData {
460 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
461 self.id.hash(state);
462 }
463}
464
465#[derive(Debug, Clone, Deserialize, Serialize)]
467pub struct TreeData {
468 nodes: Vec<TreeNodeData>,
470 links: Vec<TreeLinkData>,
472}
473
474#[derive(Debug, Clone, Eq, Deserialize, Serialize)]
476pub struct TreeNodeData {
477 id: String,
479 r#type: TreeNodeType,
481 depth_inv: i32,
484}
485
486impl PartialEq for TreeNodeData {
487 fn eq(&self, other: &Self) -> bool {
488 self.id == other.id && self.r#type == other.r#type
489 }
490}
491
492impl std::hash::Hash for TreeNodeData {
493 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
494 self.id.hash(state);
495 self.r#type.hash(state);
496 }
497}
498
499#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize_repr, Serialize_repr)]
501#[repr(u8)]
502pub enum TreeNodeType {
503 Root = 0,
505 Dependency = 1,
507 Dependent = 2,
509}
510
511#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize_repr, Serialize_repr)]
513#[repr(u8)]
514pub enum NodeWeight {
515 Simple = 0,
517 FastDecay = 1,
519 MediumDecay = 2,
521 SlowDecay = 3,
523 Compound = 4,
525}
526
527#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize, Serialize)]
529pub struct GraphLinkData {
530 source: String,
532 target: String,
534}
535
536#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize, Serialize)]
538pub struct TreeLinkData {
539 source: String,
541 target: String,
543 r#type: TreeNodeType,
545}