Skip to main content

tuitbot_core/context/
graph_expansion.rs

1//! Graph-aware neighbor expansion, ranking, and classification.
2//!
3//! Expands 1-hop neighbors around a selected note using direct links,
4//! backlinks, and shared tags. Ranks them deterministically and attaches
5//! human-readable reason labels for the frontend.
6
7use std::collections::HashMap;
8
9use crate::error::StorageError;
10use crate::storage::watchtower;
11use crate::storage::DbPool;
12
13/// Maximum neighbors to return by default.
14pub const DEFAULT_MAX_NEIGHBORS: u32 = 8;
15
16/// Maximum fragments from any single graph neighbor in the final prompt.
17pub const MAX_GRAPH_FRAGMENTS_PER_NOTE: u32 = 3;
18
19/// Snippet length for neighbor items.
20const SNIPPET_LEN: usize = 120;
21
22// ============================================================================
23// Scoring weights
24// ============================================================================
25
26const WEIGHT_DIRECT_LINK: f64 = 3.0;
27const WEIGHT_BACKLINK: f64 = 2.0;
28const WEIGHT_SHARED_TAG: f64 = 1.0;
29const WEIGHT_CHUNK_BOOST: f64 = 0.5;
30
31// ============================================================================
32// Types
33// ============================================================================
34
35/// Reason label explaining why a related note was suggested.
36#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
37#[serde(rename_all = "snake_case")]
38pub enum SuggestionReason {
39    LinkedNote,
40    Backlink,
41    SharedTag,
42    MutualLink,
43}
44
45/// Intent hint for the frontend to frame the suggestion.
46#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
47#[serde(rename_all = "snake_case")]
48pub enum SuggestionIntent {
49    ProTip,
50    Counterpoint,
51    Evidence,
52    Related,
53}
54
55/// A related note discovered via graph expansion.
56#[derive(Debug, Clone, serde::Serialize)]
57pub struct GraphNeighbor {
58    pub node_id: i64,
59    pub node_title: Option<String>,
60    pub relative_path: String,
61    pub reason: SuggestionReason,
62    pub reason_label: String,
63    pub intent: SuggestionIntent,
64    pub matched_tags: Vec<String>,
65    pub edge_count: u32,
66    pub shared_tag_count: u32,
67    pub score: f64,
68    pub snippet: Option<String>,
69    pub best_chunk_id: Option<i64>,
70    pub heading_path: Option<String>,
71}
72
73/// Graph state for API responses.
74#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
75#[serde(rename_all = "snake_case")]
76pub enum GraphState {
77    Available,
78    NoRelatedNotes,
79    UnresolvedLinks,
80    FallbackActive,
81    NodeNotIndexed,
82}
83
84// ============================================================================
85// Scoring (pure functions)
86// ============================================================================
87
88/// Compute composite neighbor score from edge counts and chunk boost.
89pub fn compute_neighbor_score(
90    direct_links: u32,
91    backlinks: u32,
92    shared_tags: u32,
93    best_chunk_boost: f64,
94) -> f64 {
95    WEIGHT_DIRECT_LINK * f64::from(direct_links)
96        + WEIGHT_BACKLINK * f64::from(backlinks)
97        + WEIGHT_SHARED_TAG * f64::from(shared_tags)
98        + WEIGHT_CHUNK_BOOST * best_chunk_boost
99}
100
101// ============================================================================
102// Classification (pure functions)
103// ============================================================================
104
105/// Classify the primary reason a neighbor was suggested.
106pub fn classify_suggestion_reason(
107    direct_count: u32,
108    backlink_count: u32,
109    shared_tag_count: u32,
110) -> SuggestionReason {
111    let has_direct = direct_count > 0;
112    let has_backlink = backlink_count > 0;
113    if has_direct && has_backlink {
114        SuggestionReason::MutualLink
115    } else if has_direct {
116        SuggestionReason::LinkedNote
117    } else if has_backlink {
118        SuggestionReason::Backlink
119    } else if shared_tag_count > 0 {
120        SuggestionReason::SharedTag
121    } else {
122        // Shouldn't happen, but safe default.
123        SuggestionReason::LinkedNote
124    }
125}
126
127/// Classify the intent from an edge label using keyword heuristics.
128pub fn classify_suggestion_intent(edge_label: Option<&str>) -> SuggestionIntent {
129    let label = match edge_label {
130        Some(l) => l.to_lowercase(),
131        None => return SuggestionIntent::Related,
132    };
133
134    if label.contains("counterpoint")
135        || label.contains(" vs ")
136        || label.contains("alternative")
137        || label.contains("contrast")
138    {
139        SuggestionIntent::Counterpoint
140    } else if label.contains("tip")
141        || label.contains("how-to")
142        || label.contains("how to")
143        || label.contains("guide")
144    {
145        SuggestionIntent::ProTip
146    } else if label.contains("data")
147        || label.contains("evidence")
148        || label.contains("study")
149        || label.contains("stat")
150    {
151        SuggestionIntent::Evidence
152    } else {
153        SuggestionIntent::Related
154    }
155}
156
157/// Build a human-readable reason label string.
158pub fn build_reason_label(reason: &SuggestionReason, matched_tags: &[String]) -> String {
159    match reason {
160        SuggestionReason::LinkedNote => "linked note".to_string(),
161        SuggestionReason::Backlink => "backlink".to_string(),
162        SuggestionReason::MutualLink => "mutual link".to_string(),
163        SuggestionReason::SharedTag => {
164            if matched_tags.is_empty() {
165                "shared tag".to_string()
166            } else {
167                let tags: Vec<String> = matched_tags.iter().map(|t| format!("#{t}")).collect();
168                format!("shared tag: {}", tags.join(", "))
169            }
170        }
171    }
172}
173
174// ============================================================================
175// Graph expansion (DB-backed)
176// ============================================================================
177
178/// Intermediate accumulator for grouping edges by target node.
179struct NeighborAccum {
180    direct_links: u32,
181    backlinks: u32,
182    shared_tags: Vec<String>,
183    best_edge_label: Option<String>,
184}
185
186/// Expand 1-hop graph neighbors around a selected note.
187///
188/// Queries outgoing edges (forward links) and incoming edges (backlinks),
189/// plus shared-tag neighbors. Groups by target node, scores, enriches with
190/// node metadata and best chunk, and returns top `max_neighbors` results.
191pub async fn expand_graph_neighbors(
192    pool: &DbPool,
193    account_id: &str,
194    node_id: i64,
195    max_neighbors: u32,
196) -> Result<Vec<GraphNeighbor>, StorageError> {
197    let max = if max_neighbors == 0 {
198        DEFAULT_MAX_NEIGHBORS
199    } else {
200        max_neighbors
201    };
202
203    // 1. Fetch outgoing edges (forward links from this node).
204    let outgoing = watchtower::get_edges_for_source(pool, account_id, node_id).await?;
205
206    // 2. Fetch incoming edges (backlinks pointing to this node).
207    let incoming = watchtower::get_edges_for_target(pool, account_id, node_id).await?;
208
209    // 3. Fetch shared-tag neighbors.
210    let tag_neighbors =
211        watchtower::find_shared_tag_neighbors(pool, account_id, node_id, max * 2).await?;
212
213    // 4. Group all neighbors by target node ID.
214    let mut accum: HashMap<i64, NeighborAccum> = HashMap::new();
215
216    for edge in &outgoing {
217        let entry = accum.entry(edge.target_node_id).or_insert(NeighborAccum {
218            direct_links: 0,
219            backlinks: 0,
220            shared_tags: Vec::new(),
221            best_edge_label: None,
222        });
223        match edge.edge_type.as_str() {
224            "backlink" => entry.backlinks += 1,
225            "shared_tag" => {
226                if let Some(label) = &edge.edge_label {
227                    if !entry.shared_tags.contains(label) {
228                        entry.shared_tags.push(label.clone());
229                    }
230                }
231            }
232            _ => entry.direct_links += 1, // wikilink, markdown_link
233        }
234        if entry.best_edge_label.is_none() && edge.edge_type != "shared_tag" {
235            entry.best_edge_label = edge.edge_label.clone();
236        }
237    }
238
239    for edge in &incoming {
240        // Skip self-referential edges.
241        if edge.source_node_id == node_id {
242            continue;
243        }
244        let entry = accum.entry(edge.source_node_id).or_insert(NeighborAccum {
245            direct_links: 0,
246            backlinks: 0,
247            shared_tags: Vec::new(),
248            best_edge_label: None,
249        });
250        match edge.edge_type.as_str() {
251            "wikilink" | "markdown_link" => entry.backlinks += 1,
252            "shared_tag" => {
253                if let Some(label) = &edge.edge_label {
254                    if !entry.shared_tags.contains(label) {
255                        entry.shared_tags.push(label.clone());
256                    }
257                }
258            }
259            _ => entry.backlinks += 1,
260        }
261        if entry.best_edge_label.is_none() && edge.edge_type != "shared_tag" {
262            entry.best_edge_label = edge.edge_label.clone();
263        }
264    }
265
266    for (neighbor_node_id, tag_text) in &tag_neighbors {
267        let entry = accum.entry(*neighbor_node_id).or_insert(NeighborAccum {
268            direct_links: 0,
269            backlinks: 0,
270            shared_tags: Vec::new(),
271            best_edge_label: None,
272        });
273        if !entry.shared_tags.contains(tag_text) {
274            entry.shared_tags.push(tag_text.clone());
275        }
276    }
277
278    if accum.is_empty() {
279        return Ok(Vec::new());
280    }
281
282    // 5. Batch-fetch node metadata.
283    let neighbor_ids: Vec<i64> = accum.keys().copied().collect();
284    let nodes = watchtower::get_nodes_by_ids(pool, account_id, &neighbor_ids).await?;
285    let node_map: HashMap<i64, &watchtower::ContentNode> =
286        nodes.iter().map(|n| (n.id, n)).collect();
287
288    // 6. Batch-fetch best chunk per neighbor.
289    let best_chunks =
290        watchtower::get_best_chunks_for_nodes(pool, account_id, &neighbor_ids).await?;
291    let chunk_map: HashMap<i64, &watchtower::ContentChunk> =
292        best_chunks.iter().map(|c| (c.node_id, c)).collect();
293
294    // 7. Build scored neighbor list.
295    let mut neighbors: Vec<GraphNeighbor> = accum
296        .into_iter()
297        .filter_map(|(nid, acc)| {
298            let node = node_map.get(&nid)?;
299            let shared_tag_count = acc.shared_tags.len() as u32;
300            let edge_count = acc.direct_links + acc.backlinks + shared_tag_count;
301
302            let chunk_boost = chunk_map
303                .get(&nid)
304                .map(|c| c.retrieval_boost)
305                .unwrap_or(0.0);
306
307            let score = compute_neighbor_score(
308                acc.direct_links,
309                acc.backlinks,
310                shared_tag_count,
311                chunk_boost,
312            );
313
314            let reason =
315                classify_suggestion_reason(acc.direct_links, acc.backlinks, shared_tag_count);
316            let intent = classify_suggestion_intent(acc.best_edge_label.as_deref());
317            let reason_label = build_reason_label(&reason, &acc.shared_tags);
318
319            let (snippet, best_chunk_id, heading_path) = match chunk_map.get(&nid) {
320                Some(c) => (
321                    Some(truncate(c.chunk_text.as_str(), SNIPPET_LEN)),
322                    Some(c.id),
323                    if c.heading_path.is_empty() {
324                        None
325                    } else {
326                        Some(c.heading_path.clone())
327                    },
328                ),
329                None => (None, None, None),
330            };
331
332            Some(GraphNeighbor {
333                node_id: nid,
334                node_title: node.title.clone(),
335                relative_path: node.relative_path.clone(),
336                reason,
337                reason_label,
338                intent,
339                matched_tags: acc.shared_tags,
340                edge_count,
341                shared_tag_count,
342                score,
343                snippet,
344                best_chunk_id,
345                heading_path,
346            })
347        })
348        .collect();
349
350    // 8. Sort: score DESC, edge_count DESC, node_id ASC.
351    neighbors.sort_by(|a, b| {
352        b.score
353            .partial_cmp(&a.score)
354            .unwrap_or(std::cmp::Ordering::Equal)
355            .then(b.edge_count.cmp(&a.edge_count))
356            .then(a.node_id.cmp(&b.node_id))
357    });
358
359    // 9. Cap at max neighbors.
360    neighbors.truncate(max as usize);
361
362    Ok(neighbors)
363}
364
365// ============================================================================
366// Helpers
367// ============================================================================
368
369fn truncate(text: &str, max_len: usize) -> String {
370    if text.len() <= max_len {
371        text.to_string()
372    } else {
373        let mut end = max_len.saturating_sub(3);
374        while end > 0 && !text.is_char_boundary(end) {
375            end -= 1;
376        }
377        format!("{}...", &text[..end])
378    }
379}
380
381// ============================================================================
382// Tests
383// ============================================================================
384
385#[cfg(test)]
386mod tests {
387    use super::*;
388
389    // -- compute_neighbor_score --
390
391    #[test]
392    fn score_weights_verified() {
393        let score = compute_neighbor_score(1, 1, 1, 1.0);
394        // 3.0 + 2.0 + 1.0 + 0.5 = 6.5
395        assert!((score - 6.5).abs() < f64::EPSILON);
396    }
397
398    #[test]
399    fn score_zero_inputs() {
400        let score = compute_neighbor_score(0, 0, 0, 0.0);
401        assert!((score - 0.0).abs() < f64::EPSILON);
402    }
403
404    #[test]
405    fn score_direct_only() {
406        let score = compute_neighbor_score(2, 0, 0, 0.0);
407        assert!((score - 6.0).abs() < f64::EPSILON);
408    }
409
410    #[test]
411    fn score_backlink_only() {
412        let score = compute_neighbor_score(0, 3, 0, 0.0);
413        assert!((score - 6.0).abs() < f64::EPSILON);
414    }
415
416    #[test]
417    fn score_shared_tag_only() {
418        let score = compute_neighbor_score(0, 0, 4, 0.0);
419        assert!((score - 4.0).abs() < f64::EPSILON);
420    }
421
422    #[test]
423    fn score_chunk_boost_contribution() {
424        let score = compute_neighbor_score(0, 0, 0, 2.5);
425        assert!((score - 1.25).abs() < f64::EPSILON);
426    }
427
428    // -- classify_suggestion_reason --
429
430    #[test]
431    fn reason_mutual_link() {
432        assert_eq!(
433            classify_suggestion_reason(1, 1, 0),
434            SuggestionReason::MutualLink
435        );
436    }
437
438    #[test]
439    fn reason_linked_note() {
440        assert_eq!(
441            classify_suggestion_reason(1, 0, 0),
442            SuggestionReason::LinkedNote
443        );
444    }
445
446    #[test]
447    fn reason_backlink() {
448        assert_eq!(
449            classify_suggestion_reason(0, 1, 0),
450            SuggestionReason::Backlink
451        );
452    }
453
454    #[test]
455    fn reason_shared_tag() {
456        assert_eq!(
457            classify_suggestion_reason(0, 0, 2),
458            SuggestionReason::SharedTag
459        );
460    }
461
462    #[test]
463    fn reason_mutual_takes_precedence_over_tags() {
464        assert_eq!(
465            classify_suggestion_reason(1, 1, 3),
466            SuggestionReason::MutualLink
467        );
468    }
469
470    // -- classify_suggestion_intent --
471
472    #[test]
473    fn intent_none_label() {
474        assert_eq!(classify_suggestion_intent(None), SuggestionIntent::Related);
475    }
476
477    #[test]
478    fn intent_counterpoint() {
479        assert_eq!(
480            classify_suggestion_intent(Some("see counterpoint")),
481            SuggestionIntent::Counterpoint
482        );
483    }
484
485    #[test]
486    fn intent_vs() {
487        assert_eq!(
488            classify_suggestion_intent(Some("React vs Vue")),
489            SuggestionIntent::Counterpoint
490        );
491    }
492
493    #[test]
494    fn intent_pro_tip() {
495        assert_eq!(
496            classify_suggestion_intent(Some("quick tip")),
497            SuggestionIntent::ProTip
498        );
499    }
500
501    #[test]
502    fn intent_guide() {
503        assert_eq!(
504            classify_suggestion_intent(Some("setup guide")),
505            SuggestionIntent::ProTip
506        );
507    }
508
509    #[test]
510    fn intent_evidence() {
511        assert_eq!(
512            classify_suggestion_intent(Some("research data")),
513            SuggestionIntent::Evidence
514        );
515    }
516
517    #[test]
518    fn intent_study() {
519        assert_eq!(
520            classify_suggestion_intent(Some("case study")),
521            SuggestionIntent::Evidence
522        );
523    }
524
525    #[test]
526    fn intent_default_related() {
527        assert_eq!(
528            classify_suggestion_intent(Some("just a note")),
529            SuggestionIntent::Related
530        );
531    }
532
533    // -- build_reason_label --
534
535    #[test]
536    fn label_linked_note() {
537        assert_eq!(
538            build_reason_label(&SuggestionReason::LinkedNote, &[]),
539            "linked note"
540        );
541    }
542
543    #[test]
544    fn label_backlink() {
545        assert_eq!(
546            build_reason_label(&SuggestionReason::Backlink, &[]),
547            "backlink"
548        );
549    }
550
551    #[test]
552    fn label_mutual_link() {
553        assert_eq!(
554            build_reason_label(&SuggestionReason::MutualLink, &[]),
555            "mutual link"
556        );
557    }
558
559    #[test]
560    fn label_shared_tag_no_tags() {
561        assert_eq!(
562            build_reason_label(&SuggestionReason::SharedTag, &[]),
563            "shared tag"
564        );
565    }
566
567    #[test]
568    fn label_shared_tag_single() {
569        assert_eq!(
570            build_reason_label(&SuggestionReason::SharedTag, &["rust".to_string()]),
571            "shared tag: #rust"
572        );
573    }
574
575    #[test]
576    fn label_shared_tag_multiple() {
577        let tags = vec!["rust".to_string(), "async".to_string()];
578        assert_eq!(
579            build_reason_label(&SuggestionReason::SharedTag, &tags),
580            "shared tag: #rust, #async"
581        );
582    }
583
584    // -- truncate --
585
586    #[test]
587    fn truncate_short() {
588        assert_eq!(truncate("hello", 10), "hello");
589    }
590
591    #[test]
592    fn truncate_long() {
593        let result = truncate("hello world this is long text", 10);
594        assert!(result.ends_with("..."));
595        assert!(result.len() <= 13);
596    }
597
598    // -- SuggestionReason serde --
599
600    #[test]
601    fn reason_serializes_snake_case() {
602        assert_eq!(
603            serde_json::to_string(&SuggestionReason::LinkedNote).unwrap(),
604            "\"linked_note\""
605        );
606        assert_eq!(
607            serde_json::to_string(&SuggestionReason::MutualLink).unwrap(),
608            "\"mutual_link\""
609        );
610        assert_eq!(
611            serde_json::to_string(&SuggestionReason::SharedTag).unwrap(),
612            "\"shared_tag\""
613        );
614    }
615
616    // -- GraphState serde --
617
618    #[test]
619    fn graph_state_serializes_snake_case() {
620        assert_eq!(
621            serde_json::to_string(&GraphState::NoRelatedNotes).unwrap(),
622            "\"no_related_notes\""
623        );
624        assert_eq!(
625            serde_json::to_string(&GraphState::FallbackActive).unwrap(),
626            "\"fallback_active\""
627        );
628    }
629
630    #[test]
631    fn graph_state_all_variants_serialize() {
632        assert_eq!(
633            serde_json::to_string(&GraphState::Available).unwrap(),
634            "\"available\""
635        );
636        assert_eq!(
637            serde_json::to_string(&GraphState::UnresolvedLinks).unwrap(),
638            "\"unresolved_links\""
639        );
640        assert_eq!(
641            serde_json::to_string(&GraphState::NodeNotIndexed).unwrap(),
642            "\"node_not_indexed\""
643        );
644    }
645
646    #[test]
647    fn score_tag_only_neighbor() {
648        // 0 direct, 0 backlinks, 2 shared_tags, no chunk boost = 2.0
649        let score = compute_neighbor_score(0, 0, 2, 0.0);
650        assert!((score - 2.0).abs() < f64::EPSILON);
651    }
652
653    #[test]
654    fn classify_reason_zero_direct_zero_backlink_with_tags() {
655        assert_eq!(
656            classify_suggestion_reason(0, 0, 5),
657            SuggestionReason::SharedTag
658        );
659    }
660
661    #[test]
662    fn classify_reason_zero_everything_defaults_linked() {
663        // Edge case: no links, no backlinks, no tags → defaults to LinkedNote
664        assert_eq!(
665            classify_suggestion_reason(0, 0, 0),
666            SuggestionReason::LinkedNote
667        );
668    }
669}