starfish_core/query/
mod.rs

1//! Graph query engine
2
3use crate::{
4    lang::query::{
5        QueryCommonConstraint, QueryConstraintSortByKeyJson, QueryGraphConstraint,
6        QueryGraphConstraintJson, QueryGraphConstraintLimitJson, QueryGraphJson, QueryJson,
7        QueryResultJson, QueryVectorConstraint, QueryVectorConstraintJson, QueryVectorJson,
8    },
9    schema::{format_edge_table_name, format_node_table_name},
10};
11use sea_orm::{ConnectionTrait, DbConn, DbErr, FromQueryResult, Order};
12use sea_query::{Alias, Expr, SelectStatement};
13use serde::{Deserialize, Serialize};
14use serde_repr::{Deserialize_repr, Serialize_repr};
15use std::collections::{HashMap, HashSet};
16
17#[derive(Debug, Clone, Serialize, Deserialize, FromQueryResult)]
18/// A queried node
19pub struct QueryResultNode {
20    /// Name of the node
21    pub name: String,
22    /// Associated weight (specified in query)
23    pub weight: Option<f64>,
24    /// Depth when this node is first found in the graph.
25    /// Some(0) for root nodes.
26    /// None if querying a vector.
27    pub depth: Option<u64>,
28}
29
30#[derive(Debug, Clone, FromQueryResult)]
31/// A helper struct to temporarily store unique nodes
32struct NodeName {
33    name: String,
34}
35
36#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize, FromQueryResult)]
37#[serde(rename_all = "camelCase")]
38/// A queried edge
39pub struct QueryResultEdge {
40    /// Name of the node in the from side
41    pub from_node: String,
42    /// Name of the node in the to side
43    pub to_node: String,
44}
45
46impl QueryResultEdge {
47    /// Convert self to an edge with flipped directions
48    pub fn to_flipped(self) -> Self {
49        Self {
50            from_node: self.to_node,
51            to_node: self.from_node,
52        }
53    }
54}
55
56#[derive(Debug)]
57/// A helper struct to specify how to perform a graph query
58pub struct QueryGraphParams {
59    /// Which entity to consider for constructing the graph (unformatted)
60    pub entity_name: Result<String, DbErr>,
61    /// Which relation to consider for constructing the graph (unformatted)
62    pub relation_name: Result<String, DbErr>,
63    /// Whether to reverse the direction when constructing the graph
64    pub reverse_direction: bool,
65    /// Specify the root nodes to be the nodes with the supplied names
66    /// The keys in the HashMaps must be Formatted.
67    pub root_node_names: Vec<String>,
68    /// Recursion goes up to this level, 0 means no recursion at all.
69    /// Recursion does not terminate early if this value is None.
70    pub max_depth: Option<u64>,
71    /// Sort each batch on this key (this value is a Formatted column name).
72    /// This key is also used as for filling the `weight` field of queried nodes, if supplied.
73    /// The order is random if this value is None.
74    pub batch_sort_key: Option<String>,
75    /// Sort each batch in an ascending order if this value is true.
76    pub batch_sort_asc: bool,
77    /// Include up to this number of nodes in each batch.
78    /// All nodes are included in all batches if this value is None.
79    pub max_batch_size: Option<usize>,
80    /// Include up to this number of nodes across the whole recursion.
81    /// All nodes are included if this value is None.
82    pub max_total_size: Option<usize>,
83}
84
85impl Default for QueryGraphParams {
86    fn default() -> Self {
87        Self {
88            entity_name: Err(DbErr::Custom("Entity name is unspecified.".to_owned())),
89            relation_name: Err(DbErr::Custom("Relation name is unspecified.".to_owned())),
90            reverse_direction: false,
91            root_node_names: vec![],
92            max_depth: Some(6),
93            batch_sort_key: None,
94            batch_sort_asc: false,
95            max_batch_size: Some(6),
96            max_total_size: Some(10000),
97        }
98    }
99}
100
101impl QueryGraphParams {
102    /// Construct params from metadata
103    pub fn from_query_graph_metadata(metadata: QueryGraphJson) -> Self {
104        let mut params = Self {
105            entity_name: Ok(metadata.of),
106            ..Default::default()
107        };
108
109        metadata
110            .constraints
111            .into_iter()
112            .for_each(|constraint| match constraint {
113                QueryGraphConstraintJson::Common(constraint) => {
114                    params.handle_common_constraint(constraint)
115                }
116                QueryGraphConstraintJson::Exclusive(constraint) => {
117                    params.handle_graph_constraint(constraint)
118                }
119            });
120
121        params
122    }
123
124    fn handle_common_constraint(&mut self, constraint: QueryCommonConstraint) {
125        match constraint {
126            QueryCommonConstraint::SortBy(sort_by) => {
127                self.batch_sort_key = match sort_by.key {
128                    QueryConstraintSortByKeyJson::Connectivity { of, r#type } => {
129                        Some(r#type.to_column_name(of))
130                    }
131                };
132                self.batch_sort_asc = !sort_by.desc;
133            }
134            QueryCommonConstraint::Limit(limit) => self.max_total_size = Some(limit as usize),
135        }
136    }
137
138    fn handle_graph_constraint(&mut self, constraint: QueryGraphConstraint) {
139        match constraint {
140            QueryGraphConstraint::Edge { of, traversal } => {
141                self.relation_name = Ok(of);
142                self.reverse_direction = traversal.reverse_direction;
143            }
144            QueryGraphConstraint::RootNodes(root_node_names) => {
145                self.root_node_names = root_node_names;
146            }
147            QueryGraphConstraint::Limit(limit) => match limit {
148                QueryGraphConstraintLimitJson::Depth(depth) => self.max_depth = depth,
149                QueryGraphConstraintLimitJson::BatchSize(batch_size) => {
150                    self.max_batch_size = batch_size
151                }
152            },
153        }
154    }
155}
156
157/// Query graph data
158#[derive(Debug)]
159pub struct Query;
160
161impl Query {
162    /// Query data from db
163    pub async fn query(db: &DbConn, query_json: QueryJson) -> Result<QueryResultJson, DbErr> {
164        match query_json {
165            QueryJson::Vector(metadata) => Self::query_vector(db, metadata).await,
166            QueryJson::Graph(metadata) => Self::query_graph(db, metadata).await,
167        }
168    }
169
170    async fn query_vector(
171        db: &DbConn,
172        metadata: QueryVectorJson,
173    ) -> Result<QueryResultJson, DbErr> {
174        let mut stmt = sea_query::Query::select();
175
176        stmt.column(Alias::new("name"))
177            .expr_as(Expr::value(Option::<f64>::None), Alias::new("weight"))
178            .expr_as(Expr::val(Option::<u64>::None), Alias::new("depth"))
179            .from(Alias::new(&format_node_table_name(metadata.of)));
180
181        for constraint in metadata.constraints {
182            match constraint {
183                QueryVectorConstraintJson::Common(constraint) => {
184                    Self::handle_common_constraint(&mut stmt, constraint)
185                }
186                QueryVectorConstraintJson::Exclusive(constraint) => {
187                    Self::handle_vector_constraint(&mut stmt, constraint)
188                }
189            }
190        }
191
192        let builder = db.get_database_backend();
193
194        Ok(QueryResultJson::Vector(
195            QueryResultNode::find_by_statement(builder.build(&stmt))
196                .all(db)
197                .await?,
198        ))
199    }
200
201    fn handle_common_constraint(stmt: &mut SelectStatement, constraint: QueryCommonConstraint) {
202        match constraint {
203            QueryCommonConstraint::SortBy(sort_by) => {
204                let col_name = match sort_by.key {
205                    QueryConstraintSortByKeyJson::Connectivity { of, r#type } => {
206                        r#type.to_column_name(of)
207                    }
208                };
209                stmt.expr_as(Expr::col(Alias::new(&col_name)), Alias::new("weight"))
210                    .order_by(
211                        Alias::new(&col_name),
212                        if sort_by.desc {
213                            Order::Desc
214                        } else {
215                            Order::Asc
216                        },
217                    );
218            }
219            QueryCommonConstraint::Limit(limit) => {
220                stmt.limit(limit);
221            }
222        }
223    }
224
225    fn handle_vector_constraint(_: &mut SelectStatement, constraint: QueryVectorConstraint) {
226        match constraint {
227            // Empty
228        }
229    }
230
231    async fn query_graph(db: &DbConn, metadata: QueryGraphJson) -> Result<QueryResultJson, DbErr> {
232        let params = QueryGraphParams::from_query_graph_metadata(metadata);
233
234        println!("Querying a graph with params:\n{:?}", params);
235
236        Self::traverse_with_params(db, params).await
237    }
238
239    async fn traverse_with_params(
240        db: &DbConn,
241        params: QueryGraphParams,
242    ) -> Result<QueryResultJson, DbErr> {
243        let builder = db.get_database_backend();
244        let edge_table = &format_edge_table_name(params.relation_name?);
245        let node_table = &format_node_table_name(params.entity_name?);
246
247        // Start with root nodes
248        let mut pending_nodes: Vec<String> = {
249            let root_node_set: HashSet<String> =
250                HashSet::from_iter(params.root_node_names.into_iter());
251
252            let root_node_stmt = sea_query::Query::select()
253                .column(Alias::new("name"))
254                .from(Alias::new(node_table))
255                .to_owned();
256
257            NodeName::find_by_statement(builder.build(&root_node_stmt))
258                .all(db)
259                .await?
260                .into_iter()
261                .filter_map(|node| {
262                    if root_node_set.contains(&node.name) {
263                        Some(node.name)
264                    } else {
265                        None
266                    }
267                })
268                .collect()
269        };
270
271        let mut result_nodes: HashSet<String> = HashSet::from_iter(pending_nodes.iter().cloned());
272        let mut node_depths: HashMap<String, u64> = HashMap::new();
273        let mut result_edges: HashSet<QueryResultEdge> = HashSet::new();
274
275        // Normal direction: Join on "from" -> finding "to"'s
276        // Reverse: Join on "to" -> finding "from"'s
277        let join_col = if !params.reverse_direction {
278            "from_node"
279        } else {
280            "to_node"
281        };
282
283        let mut depth = 0;
284        while params.max_depth.is_none() || depth < params.max_depth.unwrap() {
285            // Fetch target edges from pending_nodes
286            let target_edges = {
287                let target_edge_stmt = sea_query::Query::select()
288                    .columns([Alias::new("from_node"), Alias::new("to_node")])
289                    .from(Alias::new(edge_table))
290                    .inner_join(
291                        Alias::new(node_table),
292                        Expr::tbl(Alias::new(node_table), Alias::new("name"))
293                            .equals(Alias::new(edge_table), Alias::new(join_col)),
294                    )
295                    .and_where(Expr::col(Alias::new(join_col)).is_in(pending_nodes))
296                    .to_owned();
297
298                QueryResultEdge::find_by_statement(builder.build(&target_edge_stmt))
299                    .all(db)
300                    .await?
301            };
302
303            let mut total_nodes_full = false;
304
305            pending_nodes = target_edges
306                .into_iter()
307                .filter_map(|edge| {
308                    let target_node_name = if !params.reverse_direction {
309                        edge.to_node.clone()
310                    } else {
311                        edge.from_node.clone()
312                    };
313
314                    if result_edges.insert(edge) && !result_nodes.contains(&target_node_name) {
315                        if let Some(max_total_size) = params.max_total_size {
316                            if result_nodes.len() >= max_total_size {
317                                total_nodes_full = true;
318                            }
319                        }
320                        Some(target_node_name)
321                    } else {
322                        None
323                    }
324                })
325                .collect();
326
327            pending_nodes.iter().for_each(|node_name| {
328                if !node_depths.contains_key(node_name) {
329                    node_depths.insert(node_name.clone(), depth + 1);
330                }
331            });
332
333            // Sort by specified key if appropriate
334            if let Some(order_by_key) = &params.batch_sort_key {
335                pending_nodes = {
336                    let pending_nodes_set: HashSet<String> =
337                        HashSet::from_iter(pending_nodes.into_iter());
338
339                    let stmt = sea_query::Query::select()
340                        .column(Alias::new("name"))
341                        .from(Alias::new(node_table))
342                        .order_by(
343                            Alias::new(order_by_key),
344                            if params.batch_sort_asc {
345                                Order::Asc
346                            } else {
347                                Order::Desc
348                            },
349                        )
350                        .to_owned();
351
352                    NodeName::find_by_statement(builder.build(&stmt))
353                        .all(db)
354                        .await?
355                        .into_iter()
356                        .filter_map(|node| {
357                            if pending_nodes_set.contains(&node.name) {
358                                Some(node.name)
359                            } else {
360                                None
361                            }
362                        })
363                        .collect()
364                };
365            }
366
367            if let Some(max_batch_size) = params.max_batch_size {
368                if max_batch_size < pending_nodes.len() {
369                    pending_nodes = pending_nodes[0..max_batch_size].to_vec();
370                }
371            }
372
373            result_nodes.extend(pending_nodes.iter().cloned());
374
375            if pending_nodes.is_empty() || total_nodes_full {
376                break;
377            }
378
379            depth += 1;
380        }
381
382        // Make sure all edges in result_edges use only nodes in result_nodes
383        let edges: Vec<QueryResultEdge> = {
384            let iter = result_edges.into_iter().filter(|edge| {
385                result_nodes.contains(&edge.from_node) && result_nodes.contains(&edge.to_node)
386            });
387
388            if params.reverse_direction {
389                iter.map(|edge| edge.to_flipped()).collect()
390            } else {
391                iter.collect()
392            }
393        };
394
395        // Fetch the weights if needed
396        let nodes: Vec<QueryResultNode> = if let Some(weight_key) = params.batch_sort_key {
397            let stmt = sea_query::Query::select()
398                .column(Alias::new("name"))
399                .expr_as(Expr::col(Alias::new(&weight_key)), Alias::new("weight"))
400                .expr_as(Expr::val(Some(0_u64)), Alias::new("depth"))
401                .from(Alias::new(node_table))
402                .and_where(Expr::col(Alias::new("name")).is_in(result_nodes))
403                .to_owned();
404
405            QueryResultNode::find_by_statement(builder.build(&stmt))
406                .all(db)
407                .await?
408                .into_iter()
409                .map(|mut node| {
410                    let depth = node_depths.get(&node.name).cloned().unwrap_or_default();
411                    node.depth = Some(depth);
412                    node
413                })
414                .collect()
415        } else {
416            result_nodes
417                .into_iter()
418                .map(|name| {
419                    let depth = node_depths.get(&name).cloned().unwrap_or_default();
420                    QueryResultNode {
421                        name,
422                        weight: None,
423                        depth: Some(depth),
424                    }
425                })
426                .collect()
427        };
428
429        Ok(QueryResultJson::Graph { nodes, edges })
430    }
431}
432
433/// Graph data
434#[derive(Debug, Clone, Deserialize, Serialize)]
435pub struct GraphData {
436    /// Graph node data
437    nodes: Vec<GraphNodeData>,
438    /// Link data
439    links: Vec<GraphLinkData>,
440}
441
442/// Graph node data
443#[derive(Debug, Clone, Deserialize, Serialize)]
444pub struct GraphNodeData {
445    /// Name of node
446    id: String,
447    /// Weight
448    weight: f64,
449}
450
451impl PartialEq for GraphNodeData {
452    fn eq(&self, other: &Self) -> bool {
453        self.id == other.id
454    }
455}
456
457impl Eq for GraphNodeData {}
458
459impl std::hash::Hash for GraphNodeData {
460    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
461        self.id.hash(state);
462    }
463}
464
465/// Tree data
466#[derive(Debug, Clone, Deserialize, Serialize)]
467pub struct TreeData {
468    /// Tree node data
469    nodes: Vec<TreeNodeData>,
470    /// Link data
471    links: Vec<TreeLinkData>,
472}
473
474/// Tree node data
475#[derive(Debug, Clone, Eq, Deserialize, Serialize)]
476pub struct TreeNodeData {
477    /// Name of node
478    id: String,
479    /// Node type
480    r#type: TreeNodeType,
481    /// Node depth inverse (the higher, the deeper in recursion this node was found)
482    /// This field is not used to identify a tree node.
483    depth_inv: i32,
484}
485
486impl PartialEq for TreeNodeData {
487    fn eq(&self, other: &Self) -> bool {
488        self.id == other.id && self.r#type == other.r#type
489    }
490}
491
492impl std::hash::Hash for TreeNodeData {
493    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
494        self.id.hash(state);
495        self.r#type.hash(state);
496    }
497}
498
499/// Denotes which side a node belongs to, relative to the **root** node
500#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize_repr, Serialize_repr)]
501#[repr(u8)]
502pub enum TreeNodeType {
503    /// Centered
504    Root = 0,
505    /// To the Left
506    Dependency = 1,
507    /// To the Right
508    Dependent = 2,
509}
510
511/// Node weight option
512#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize_repr, Serialize_repr)]
513#[repr(u8)]
514pub enum NodeWeight {
515    /// Simple (Immediately decay to 0)
516    Simple = 0,
517    /// Complex with weight decay factor 0.3
518    FastDecay = 1,
519    /// Complex with weight decay factor 0.5
520    MediumDecay = 2,
521    /// Complex with weight decay factor 0.7
522    SlowDecay = 3,
523    /// Compound (No decay)
524    Compound = 4,
525}
526
527/// Graph link data
528#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize, Serialize)]
529pub struct GraphLinkData {
530    /// Source node
531    source: String,
532    /// Target node
533    target: String,
534}
535
536/// Tree link data
537#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize, Serialize)]
538pub struct TreeLinkData {
539    /// Source node
540    source: String,
541    /// Target node
542    target: String,
543    /// Edge type
544    r#type: TreeNodeType,
545}