ruvector_data_framework/
cut_aware_hnsw.rs

1//! Cut-Aware HNSW: Dynamic Min-Cut Integration with Vector Search
2//!
3//! This module bridges dynamic minimum cut tracking with HNSW vector search,
4//! enabling coherence-aware navigation that respects graph boundaries.
5//!
6//! ## Overview
7//!
8//! Traditional HNSW blindly follows similarity edges during search. This module
9//! adds "coherence gates" - weak cuts in the graph that represent semantic boundaries.
10//! When searching, we can optionally halt expansion at these boundaries to stay
11//! within coherent regions.
12//!
13//! ## Key Concepts
14//!
15//! - **Cut Value**: The total weight of edges crossing a partition
16//! - **Coherence Boundary**: A weak cut indicating semantic separation
17//! - **Gated Search**: Search that respects coherence boundaries
18//! - **Coherence Zone**: A region of the graph with strong internal connections
19//!
20//! ## References
21//!
22//! - Stoer-Wagner algorithm for global min-cut
23//! - Euler Tour Trees for dynamic connectivity
24//! - HNSW for approximate nearest neighbor search
25
26use std::collections::{HashMap, HashSet, BinaryHeap, VecDeque};
27use std::sync::{Arc, RwLock};
28use std::sync::atomic::{AtomicU64, Ordering};
29use std::cmp::Reverse;
30
31use chrono::{DateTime, Utc};
32use serde::{Deserialize, Serialize};
33
34use crate::hnsw::{HnswIndex, HnswConfig, HnswSearchResult};
35use crate::ruvector_native::SemanticVector;
36use crate::FrameworkError;
37
38// ============================================================================
39// Configuration and Metrics
40// ============================================================================
41
42/// Configuration for cut-aware HNSW
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct CutAwareConfig {
45    // Standard HNSW parameters
46    pub m: usize,
47    pub ef_construction: usize,
48    pub ef_search: usize,
49
50    // Cut-aware parameters
51    /// Threshold for considering a cut "weak" (gates expansion)
52    pub coherence_gate_threshold: f64,
53
54    /// Maximum number of hops across weak cuts before stopping
55    pub max_cross_cut_hops: usize,
56
57    /// Enable pruning of edges that cross weak cuts
58    pub enable_cut_pruning: bool,
59
60    /// Recompute cuts every N insertions
61    pub cut_recompute_interval: usize,
62
63    /// Minimum zone size (nodes) to track separately
64    pub min_zone_size: usize,
65}
66
67impl Default for CutAwareConfig {
68    fn default() -> Self {
69        Self {
70            m: 16,
71            ef_construction: 200,
72            ef_search: 50,
73            coherence_gate_threshold: 0.3,
74            max_cross_cut_hops: 2,
75            enable_cut_pruning: false,
76            cut_recompute_interval: 100,
77            min_zone_size: 5,
78        }
79    }
80}
81
82/// Performance metrics for cut-aware operations
83#[derive(Debug, Default)]
84pub struct CutAwareMetrics {
85    pub searches_performed: AtomicU64,
86    pub cut_gates_triggered: AtomicU64,
87    pub expansions_pruned: AtomicU64,
88    pub avg_search_depth: AtomicU64,
89    pub cut_recomputations: AtomicU64,
90    pub zone_boundary_crossings: AtomicU64,
91}
92
93impl CutAwareMetrics {
94    pub fn reset(&self) {
95        self.searches_performed.store(0, Ordering::Relaxed);
96        self.cut_gates_triggered.store(0, Ordering::Relaxed);
97        self.expansions_pruned.store(0, Ordering::Relaxed);
98        self.avg_search_depth.store(0, Ordering::Relaxed);
99        self.cut_recomputations.store(0, Ordering::Relaxed);
100        self.zone_boundary_crossings.store(0, Ordering::Relaxed);
101    }
102}
103
104// ============================================================================
105// Dynamic Cut Tracking Structures
106// ============================================================================
107
108/// Edge in the graph with weight
109#[derive(Debug, Clone)]
110struct WeightedEdge {
111    from: u32,
112    to: u32,
113    weight: f64,
114}
115
116/// Dynamic cut watcher using incremental min-cut updates
117///
118/// Tracks the minimum cut value of the graph and identifies weak boundaries.
119/// Uses Stoer-Wagner for global min-cut and incremental updates for efficiency.
120pub struct DynamicCutWatcher {
121    /// Adjacency list representation
122    adjacency: HashMap<u32, HashMap<u32, f64>>,
123
124    /// Cached min-cut value
125    cached_min_cut: Option<f64>,
126
127    /// Cached cut partition
128    cached_partition: Option<(HashSet<u32>, HashSet<u32>)>,
129
130    /// Edges that cross the current min-cut
131    boundary_edges: HashSet<(u32, u32)>,
132
133    /// Version counter for cache invalidation
134    version: u64,
135
136    /// Last version when cut was computed
137    cut_version: u64,
138}
139
140impl DynamicCutWatcher {
141    pub fn new() -> Self {
142        Self {
143            adjacency: HashMap::new(),
144            cached_min_cut: None,
145            cached_partition: None,
146            boundary_edges: HashSet::new(),
147            version: 0,
148            cut_version: 0,
149        }
150    }
151
152    /// Add or update an edge
153    pub fn add_edge(&mut self, u: u32, v: u32, weight: f64) {
154        self.adjacency.entry(u).or_default().insert(v, weight);
155        self.adjacency.entry(v).or_default().insert(u, weight);
156        self.version += 1;
157    }
158
159    /// Remove an edge
160    pub fn remove_edge(&mut self, u: u32, v: u32) {
161        if let Some(neighbors) = self.adjacency.get_mut(&u) {
162            neighbors.remove(&v);
163        }
164        if let Some(neighbors) = self.adjacency.get_mut(&v) {
165            neighbors.remove(&u);
166        }
167        self.version += 1;
168    }
169
170    /// Get current min-cut value (computes if cache invalid)
171    pub fn min_cut_value(&mut self) -> f64 {
172        if self.version != self.cut_version {
173            self.recompute_min_cut();
174        }
175        self.cached_min_cut.unwrap_or(0.0)
176    }
177
178    /// Check if an edge crosses a weak cut
179    pub fn crosses_weak_cut(&mut self, u: u32, v: u32, threshold: f64) -> bool {
180        if self.version != self.cut_version {
181            self.recompute_min_cut();
182        }
183
184        // Check if edge crosses partition
185        if let Some((partition_a, _)) = &self.cached_partition {
186            let u_in_a = partition_a.contains(&u);
187            let v_in_a = partition_a.contains(&v);
188
189            if u_in_a != v_in_a {
190                // Edge crosses partition - check if cut is weak
191                return self.cached_min_cut.unwrap_or(f64::INFINITY) < threshold;
192            }
193        }
194
195        false
196    }
197
198    /// Get nodes in the same partition
199    pub fn same_partition(&mut self, u: u32, v: u32) -> bool {
200        if self.version != self.cut_version {
201            self.recompute_min_cut();
202        }
203
204        if let Some((partition_a, _)) = &self.cached_partition {
205            let u_in_a = partition_a.contains(&u);
206            let v_in_a = partition_a.contains(&v);
207            u_in_a == v_in_a
208        } else {
209            true // If no partition computed, assume same
210        }
211    }
212
213    /// Recompute min-cut using Stoer-Wagner
214    fn recompute_min_cut(&mut self) {
215        let nodes: Vec<u32> = self.adjacency.keys().copied().collect();
216
217        if nodes.len() < 2 {
218            self.cached_min_cut = Some(0.0);
219            self.cached_partition = None;
220            self.cut_version = self.version;
221            return;
222        }
223
224        let (min_cut, partition) = self.stoer_wagner(&nodes);
225
226        // Identify boundary edges
227        self.boundary_edges.clear();
228        for &u in &partition.0 {
229            if let Some(neighbors) = self.adjacency.get(&u) {
230                for (&v, _) in neighbors {
231                    if partition.1.contains(&v) {
232                        let edge = if u < v { (u, v) } else { (v, u) };
233                        self.boundary_edges.insert(edge);
234                    }
235                }
236            }
237        }
238
239        self.cached_min_cut = Some(min_cut);
240        self.cached_partition = Some(partition);
241        self.cut_version = self.version;
242    }
243
244    /// Stoer-Wagner minimum cut algorithm
245    fn stoer_wagner(&self, nodes: &[u32]) -> (f64, (HashSet<u32>, HashSet<u32>)) {
246        let n = nodes.len();
247        if n < 2 {
248            return (0.0, (HashSet::new(), HashSet::new()));
249        }
250
251        let node_to_idx: HashMap<u32, usize> = nodes.iter()
252            .enumerate()
253            .map(|(i, &node)| (node, i))
254            .collect();
255
256        // Build adjacency matrix
257        let mut adj = vec![vec![0.0; n]; n];
258        for (&u, neighbors) in &self.adjacency {
259            if let Some(&i) = node_to_idx.get(&u) {
260                for (&v, &weight) in neighbors {
261                    if let Some(&j) = node_to_idx.get(&v) {
262                        adj[i][j] = weight;
263                    }
264                }
265            }
266        }
267
268        let mut best_cut = f64::INFINITY;
269        let mut best_partition_nodes = HashSet::new();
270
271        let mut active = vec![true; n];
272        let mut merged: Vec<HashSet<usize>> = (0..n).map(|i| {
273            let mut s = HashSet::new();
274            s.insert(i);
275            s
276        }).collect();
277
278        for _phase in 0..(n - 1) {
279            let mut in_a = vec![false; n];
280            let mut key = vec![0.0; n];
281
282            // Find first active node
283            let start = match (0..n).find(|&i| active[i]) {
284                Some(s) => s,
285                None => break,
286            };
287
288            in_a[start] = true;
289            for j in 0..n {
290                if active[j] && !in_a[j] {
291                    key[j] = adj[start][j];
292                }
293            }
294
295            let mut s = start;
296            let mut t = start;
297
298            let active_count = active.iter().filter(|&&a| a).count();
299            for _ in 1..active_count {
300                // Find max key
301                let mut max_key = f64::NEG_INFINITY;
302                let mut max_node = 0;
303
304                for j in 0..n {
305                    if active[j] && !in_a[j] && key[j] > max_key {
306                        max_key = key[j];
307                        max_node = j;
308                    }
309                }
310
311                s = t;
312                t = max_node;
313                in_a[t] = true;
314
315                // Update keys
316                for j in 0..n {
317                    if active[j] && !in_a[j] {
318                        key[j] += adj[t][j];
319                    }
320                }
321            }
322
323            let cut_weight = key[t];
324
325            if cut_weight < best_cut {
326                best_cut = cut_weight;
327                best_partition_nodes = merged[t].clone();
328            }
329
330            // Merge s and t
331            active[t] = false;
332            let to_merge = merged[t].clone();
333            merged[s].extend(to_merge);
334
335            for i in 0..n {
336                if active[i] && i != s {
337                    adj[s][i] += adj[t][i];
338                    adj[i][s] += adj[i][t];
339                }
340            }
341        }
342
343        // Convert indices back to node IDs
344        let partition_a: HashSet<u32> = best_partition_nodes.iter()
345            .map(|&idx| nodes[idx])
346            .collect();
347        let partition_b: HashSet<u32> = nodes.iter()
348            .filter(|&node| !partition_a.contains(node))
349            .copied()
350            .collect();
351
352        (best_cut, (partition_a, partition_b))
353    }
354
355    /// Get all nodes in the graph
356    pub fn nodes(&self) -> Vec<u32> {
357        self.adjacency.keys().copied().collect()
358    }
359
360    /// Get boundary edges
361    pub fn boundary_edges(&mut self) -> &HashSet<(u32, u32)> {
362        if self.version != self.cut_version {
363            self.recompute_min_cut();
364        }
365        &self.boundary_edges
366    }
367}
368
369// ============================================================================
370// Coherence Zones
371// ============================================================================
372
373/// A coherent region in the graph
374#[derive(Debug, Clone, Serialize, Deserialize)]
375pub struct CoherenceZone {
376    pub id: usize,
377    pub nodes: HashSet<u32>,
378    pub internal_cut: f64,
379    pub boundary_cut: f64,
380    pub coherence_ratio: f64,
381}
382
383impl CoherenceZone {
384    /// Calculate coherence ratio (internal / (internal + boundary))
385    pub fn update_ratio(&mut self) {
386        let total = self.internal_cut + self.boundary_cut;
387        self.coherence_ratio = if total > 0.0 {
388            self.internal_cut / total
389        } else {
390            0.0
391        };
392    }
393}
394
395// ============================================================================
396// Cut-Aware HNSW Index
397// ============================================================================
398
399/// Extended HNSW that respects coherence boundaries
400pub struct CutAwareHNSW {
401    /// Base HNSW index
402    hnsw: HnswIndex,
403
404    /// Cut watcher for tracking graph coherence
405    cut_watcher: Arc<RwLock<DynamicCutWatcher>>,
406
407    /// Configuration
408    config: CutAwareConfig,
409
410    /// Metrics
411    metrics: Arc<CutAwareMetrics>,
412
413    /// Node ID to HNSW ID mapping
414    node_to_hnsw: HashMap<u32, usize>,
415    hnsw_to_node: HashMap<usize, u32>,
416
417    /// Next node ID
418    next_node_id: u32,
419
420    /// Insertions since last cut recomputation
421    insertions_since_recompute: usize,
422
423    /// Coherence zones
424    zones: Vec<CoherenceZone>,
425
426    /// Node to zone mapping
427    node_to_zone: HashMap<u32, usize>,
428}
429
430impl CutAwareHNSW {
431    /// Create a new cut-aware HNSW index
432    pub fn new(config: CutAwareConfig) -> Self {
433        let hnsw_config = HnswConfig {
434            m: config.m,
435            ef_construction: config.ef_construction,
436            ef_search: config.ef_search,
437            dimension: 128, // Default, will be set on first insert
438            ..Default::default()
439        };
440
441        Self {
442            hnsw: HnswIndex::with_config(hnsw_config),
443            cut_watcher: Arc::new(RwLock::new(DynamicCutWatcher::new())),
444            config,
445            metrics: Arc::new(CutAwareMetrics::default()),
446            node_to_hnsw: HashMap::new(),
447            hnsw_to_node: HashMap::new(),
448            next_node_id: 0,
449            insertions_since_recompute: 0,
450            zones: Vec::new(),
451            node_to_zone: HashMap::new(),
452        }
453    }
454
455    /// Insert a vector into the index
456    pub fn insert(&mut self, id: u32, vector: &[f32]) -> Result<(), FrameworkError> {
457        // Convert to SemanticVector for HNSW
458        let semantic_vec = SemanticVector {
459            id: id.to_string(),
460            embedding: vector.to_vec(),
461            domain: crate::ruvector_native::Domain::Research,
462            timestamp: Utc::now(),
463            metadata: HashMap::new(),
464        };
465
466        // Insert into HNSW
467        let hnsw_id = self.hnsw.insert(semantic_vec)?;
468
469        // Track mapping
470        self.node_to_hnsw.insert(id, hnsw_id);
471        self.hnsw_to_node.insert(hnsw_id, id);
472
473        // Update cut watcher with edges to similar nodes
474        self.update_cut_watcher_for_node(id, vector)?;
475
476        self.insertions_since_recompute += 1;
477
478        // Recompute cuts periodically
479        if self.insertions_since_recompute >= self.config.cut_recompute_interval {
480            self.recompute_zones();
481            self.insertions_since_recompute = 0;
482            self.metrics.cut_recomputations.fetch_add(1, Ordering::Relaxed);
483        }
484
485        Ok(())
486    }
487
488    /// Update cut watcher with edges for a newly inserted node
489    fn update_cut_watcher_for_node(&mut self, node_id: u32, vector: &[f32]) -> Result<(), FrameworkError> {
490        let hnsw_id = self.node_to_hnsw[&node_id];
491
492        // Find similar nodes using HNSW
493        let neighbors = self.hnsw.search_knn(vector, self.config.m * 2)?;
494
495        // Add edges to cut watcher
496        let mut watcher = self.cut_watcher.write().unwrap();
497        for neighbor in neighbors {
498            if let Some(&neighbor_node_id) = self.hnsw_to_node.get(&neighbor.node_id) {
499                if neighbor_node_id != node_id {
500                    // Use similarity as edge weight (1.0 - distance for cosine)
501                    let weight = if let Some(sim) = neighbor.similarity {
502                        sim.max(0.0) as f64
503                    } else {
504                        (1.0 - neighbor.distance.min(1.0)) as f64
505                    };
506
507                    watcher.add_edge(node_id, neighbor_node_id, weight);
508                }
509            }
510        }
511
512        Ok(())
513    }
514
515    /// Search with coherence gating
516    pub fn search_gated(&self, query: &[f32], k: usize) -> Vec<SearchResult> {
517        self.search_internal(query, k, true)
518    }
519
520    /// Search without coherence gating
521    pub fn search_ungated(&self, query: &[f32], k: usize) -> Vec<SearchResult> {
522        self.search_internal(query, k, false)
523    }
524
525    /// Internal search implementation
526    fn search_internal(&self, query: &[f32], k: usize, use_gates: bool) -> Vec<SearchResult> {
527        self.metrics.searches_performed.fetch_add(1, Ordering::Relaxed);
528
529        // Perform HNSW search
530        let hnsw_results = match self.hnsw.search_knn(query, k * 2) {
531            Ok(r) => r,
532            Err(_) => return Vec::new(),
533        };
534
535        if !use_gates {
536            // No gating - return direct results
537            return hnsw_results.iter()
538                .take(k)
539                .map(|r| SearchResult {
540                    node_id: self.hnsw_to_node.get(&r.node_id).copied().unwrap_or(0),
541                    distance: r.distance,
542                    crossed_cuts: 0,
543                    coherence_score: 1.0,
544                })
545                .collect();
546        }
547
548        // Gated search - filter by coherence
549        let mut results: Vec<SearchResult> = Vec::new();
550        let mut cross_cut_count: HashMap<u32, usize> = HashMap::new();
551
552        let mut watcher = self.cut_watcher.write().unwrap();
553        let threshold = self.config.coherence_gate_threshold;
554
555        for result in hnsw_results.iter().take(k * 2) {
556            if let Some(&node_id) = self.hnsw_to_node.get(&result.node_id) {
557                // Check path quality (simplified - just check direct connection)
558                let crossed = if results.is_empty() {
559                    0
560                } else {
561                    // Check if crossing weak cut from first result
562                    let first_node = results[0].node_id;
563                    if !watcher.same_partition(first_node, node_id) {
564                        1
565                    } else {
566                        0
567                    }
568                };
569
570                cross_cut_count.insert(node_id, crossed);
571
572                // Gate based on cross-cut hops
573                if crossed <= self.config.max_cross_cut_hops {
574                    let coherence_score = 1.0 / (1.0 + crossed as f64 * 0.5);
575
576                    results.push(SearchResult {
577                        node_id,
578                        distance: result.distance,
579                        crossed_cuts: crossed,
580                        coherence_score,
581                    });
582
583                    if results.len() >= k {
584                        break;
585                    }
586                } else {
587                    self.metrics.expansions_pruned.fetch_add(1, Ordering::Relaxed);
588                }
589            }
590        }
591
592        if !cross_cut_count.is_empty() {
593            let total_crossed: usize = cross_cut_count.values().sum();
594            if total_crossed > 0 {
595                self.metrics.cut_gates_triggered.fetch_add(1, Ordering::Relaxed);
596            }
597        }
598
599        results
600    }
601
602    /// Get nodes reachable without crossing weak cuts
603    pub fn coherent_neighborhood(&self, node: u32, radius: usize) -> Vec<u32> {
604        let mut visited = HashSet::new();
605        let mut queue = VecDeque::new();
606        let mut result: Vec<u32> = Vec::new();
607
608        queue.push_back((node, 0));
609        visited.insert(node);
610
611        let mut watcher = self.cut_watcher.write().unwrap();
612        let threshold = self.config.coherence_gate_threshold;
613
614        while let Some((current, depth)) = queue.pop_front() {
615            if depth > radius {
616                continue;
617            }
618
619            result.push(current);
620
621            // Get HNSW neighbors
622            if let Some(&hnsw_id) = self.node_to_hnsw.get(&current) {
623                if let Some(vector) = self.hnsw.get_vector(hnsw_id) {
624                    if let Ok(neighbors) = self.hnsw.search_knn(vector, self.config.m) {
625                        for neighbor in neighbors {
626                            if let Some(&neighbor_node) = self.hnsw_to_node.get(&neighbor.node_id) {
627                                if visited.insert(neighbor_node) {
628                                    // Only add if not crossing weak cut
629                                    if !watcher.crosses_weak_cut(current, neighbor_node, threshold) {
630                                        queue.push_back((neighbor_node, depth + 1));
631                                    }
632                                }
633                            }
634                        }
635                    }
636                }
637            }
638        }
639
640        result
641    }
642
643    /// Check if path crosses a weak cut
644    fn path_crosses_weak_cut(&self, from: u32, to: u32) -> bool {
645        let mut watcher = self.cut_watcher.write().unwrap();
646        watcher.crosses_weak_cut(from, to, self.config.coherence_gate_threshold)
647    }
648
649    /// Add edge and update cut watcher
650    pub fn add_edge(&mut self, u: u32, v: u32, weight: f64) {
651        let mut watcher = self.cut_watcher.write().unwrap();
652        watcher.add_edge(u, v, weight);
653    }
654
655    /// Remove edge and update cut watcher
656    pub fn remove_edge(&mut self, u: u32, v: u32) {
657        let mut watcher = self.cut_watcher.write().unwrap();
658        watcher.remove_edge(u, v);
659    }
660
661    /// Batch update with efficient cut recomputation
662    pub fn batch_update(&mut self, updates: Vec<EdgeUpdate>) -> UpdateStats {
663        let mut stats = UpdateStats::default();
664
665        {
666            let mut watcher = self.cut_watcher.write().unwrap();
667
668            for update in updates {
669                match update.kind {
670                    UpdateKind::Insert => {
671                        if let Some(weight) = update.weight {
672                            watcher.add_edge(update.u, update.v, weight);
673                            stats.edges_added += 1;
674                        }
675                    }
676                    UpdateKind::Delete => {
677                        watcher.remove_edge(update.u, update.v);
678                        stats.edges_removed += 1;
679                    }
680                    UpdateKind::UpdateWeight => {
681                        if let Some(weight) = update.weight {
682                            watcher.add_edge(update.u, update.v, weight);
683                            stats.edges_updated += 1;
684                        }
685                    }
686                }
687            }
688        }
689
690        // Recompute zones after batch
691        self.recompute_zones();
692        self.metrics.cut_recomputations.fetch_add(1, Ordering::Relaxed);
693
694        stats
695    }
696
697    /// Prune weak edges based on cut analysis
698    pub fn prune_weak_edges(&mut self, threshold: f64) -> usize {
699        let mut pruned = 0;
700
701        let mut watcher = self.cut_watcher.write().unwrap();
702        let boundary_edges = watcher.boundary_edges().clone();
703
704        for (u, v) in boundary_edges {
705            // Check if edge is weak
706            if watcher.crosses_weak_cut(u, v, threshold) {
707                watcher.remove_edge(u, v);
708                pruned += 1;
709            }
710        }
711
712        pruned
713    }
714
715    /// Compute coherence zones
716    pub fn compute_zones(&mut self) -> Vec<CoherenceZone> {
717        self.recompute_zones();
718        self.zones.clone()
719    }
720
721    /// Internal zone recomputation
722    fn recompute_zones(&mut self) {
723        self.zones.clear();
724        self.node_to_zone.clear();
725
726        let mut watcher = self.cut_watcher.write().unwrap();
727        let min_cut = watcher.min_cut_value();
728
729        // Use min-cut partition to identify zones
730        let nodes = watcher.nodes();
731        if nodes.len() < self.config.min_zone_size {
732            return;
733        }
734
735        // For now, create two zones based on the min-cut partition
736        // In production, would use hierarchical clustering
737
738        // Trigger computation
739        let _ = watcher.min_cut_value();
740
741        if let Some((part_a, part_b)) = &watcher.cached_partition {
742            if part_a.len() >= self.config.min_zone_size {
743                let zone_a = CoherenceZone {
744                    id: 0,
745                    nodes: part_a.clone(),
746                    internal_cut: min_cut * 0.8, // Approximation
747                    boundary_cut: min_cut * 0.2,
748                    coherence_ratio: 0.8,
749                };
750
751                for &node in part_a {
752                    self.node_to_zone.insert(node, 0);
753                }
754
755                self.zones.push(zone_a);
756            }
757
758            if part_b.len() >= self.config.min_zone_size {
759                let zone_b = CoherenceZone {
760                    id: 1,
761                    nodes: part_b.clone(),
762                    internal_cut: min_cut * 0.8,
763                    boundary_cut: min_cut * 0.2,
764                    coherence_ratio: 0.8,
765                };
766
767                for &node in part_b {
768                    self.node_to_zone.insert(node, 1);
769                }
770
771                self.zones.push(zone_b);
772            }
773        }
774    }
775
776    /// Get zone for a node
777    pub fn node_zone(&self, node: u32) -> Option<usize> {
778        self.node_to_zone.get(&node).copied()
779    }
780
781    /// Cross-zone search (explicitly crosses boundaries)
782    pub fn cross_zone_search(&self, query: &[f32], k: usize, zones: &[usize]) -> Vec<SearchResult> {
783        let mut all_results = Vec::new();
784
785        // Get HNSW results
786        let hnsw_results = match self.hnsw.search_knn(query, k * 3) {
787            Ok(r) => r,
788            Err(_) => return Vec::new(),
789        };
790
791        // Filter by zones
792        for result in hnsw_results {
793            if let Some(&node_id) = self.hnsw_to_node.get(&result.node_id) {
794                if let Some(zone_id) = self.node_to_zone.get(&node_id) {
795                    if zones.contains(zone_id) {
796                        all_results.push(SearchResult {
797                            node_id,
798                            distance: result.distance,
799                            crossed_cuts: if zones.len() > 1 { 1 } else { 0 },
800                            coherence_score: 0.7, // Lower for cross-zone
801                        });
802                    }
803                }
804            }
805        }
806
807        all_results.truncate(k);
808        self.metrics.zone_boundary_crossings.fetch_add(
809            all_results.iter().filter(|r| r.crossed_cuts > 0).count() as u64,
810            Ordering::Relaxed
811        );
812
813        all_results
814    }
815
816    /// Get current metrics
817    pub fn metrics(&self) -> &CutAwareMetrics {
818        &self.metrics
819    }
820
821    /// Reset metrics
822    pub fn reset_metrics(&self) {
823        self.metrics.reset();
824    }
825
826    /// Export metrics as JSON
827    pub fn export_metrics(&self) -> serde_json::Value {
828        serde_json::json!({
829            "searches_performed": self.metrics.searches_performed.load(Ordering::Relaxed),
830            "cut_gates_triggered": self.metrics.cut_gates_triggered.load(Ordering::Relaxed),
831            "expansions_pruned": self.metrics.expansions_pruned.load(Ordering::Relaxed),
832            "avg_search_depth": self.metrics.avg_search_depth.load(Ordering::Relaxed),
833            "cut_recomputations": self.metrics.cut_recomputations.load(Ordering::Relaxed),
834            "zone_boundary_crossings": self.metrics.zone_boundary_crossings.load(Ordering::Relaxed),
835        })
836    }
837
838    /// Get cut distribution across layers
839    pub fn cut_distribution(&self) -> Vec<LayerCutStats> {
840        let watcher = self.cut_watcher.read().unwrap();
841        let nodes = watcher.nodes();
842
843        if nodes.is_empty() {
844            return Vec::new();
845        }
846
847        // For HNSW, we'd need to analyze per-layer
848        // For now, return overall stats
849        vec![LayerCutStats {
850            layer: 0,
851            avg_cut: watcher.cached_min_cut.unwrap_or(0.0),
852            min_cut: watcher.cached_min_cut.unwrap_or(0.0),
853            max_cut: watcher.cached_min_cut.unwrap_or(0.0),
854            weak_edge_count: watcher.boundary_edges.len(),
855        }]
856    }
857}
858
859// ============================================================================
860// Supporting Types
861// ============================================================================
862
863/// Search result with coherence information
864#[derive(Debug, Clone, Serialize, Deserialize)]
865pub struct SearchResult {
866    pub node_id: u32,
867    pub distance: f32,
868    pub crossed_cuts: usize,
869    pub coherence_score: f64,
870}
871
872/// Edge update operation
873#[derive(Debug, Clone)]
874pub struct EdgeUpdate {
875    pub kind: UpdateKind,
876    pub u: u32,
877    pub v: u32,
878    pub weight: Option<f64>,
879}
880
881/// Type of edge update
882#[derive(Debug, Clone, Copy, PartialEq, Eq)]
883pub enum UpdateKind {
884    Insert,
885    Delete,
886    UpdateWeight,
887}
888
889/// Statistics from batch update
890#[derive(Debug, Default, Clone)]
891pub struct UpdateStats {
892    pub edges_added: usize,
893    pub edges_removed: usize,
894    pub edges_updated: usize,
895}
896
897/// Cut statistics per layer
898#[derive(Debug, Clone, Serialize, Deserialize)]
899pub struct LayerCutStats {
900    pub layer: usize,
901    pub avg_cut: f64,
902    pub min_cut: f64,
903    pub max_cut: f64,
904    pub weak_edge_count: usize,
905}
906
907// ============================================================================
908// Tests
909// ============================================================================
910
911#[cfg(test)]
912mod tests {
913    use super::*;
914
915    fn create_test_vector(dim: usize, val: f32) -> Vec<f32> {
916        vec![val; dim]
917    }
918
919    #[test]
920    fn test_cut_watcher_basic() {
921        let mut watcher = DynamicCutWatcher::new();
922
923        // Create a simple graph: 0-1-2
924        watcher.add_edge(0, 1, 1.0);
925        watcher.add_edge(1, 2, 1.0);
926
927        let min_cut = watcher.min_cut_value();
928        assert!(min_cut > 0.0);
929        assert!(min_cut <= 1.0);
930    }
931
932    #[test]
933    fn test_cut_watcher_partition() {
934        let mut watcher = DynamicCutWatcher::new();
935
936        // Two clusters weakly connected
937        watcher.add_edge(0, 1, 1.0);
938        watcher.add_edge(1, 2, 1.0);
939        watcher.add_edge(2, 0, 1.0);
940
941        watcher.add_edge(3, 4, 1.0);
942        watcher.add_edge(4, 5, 1.0);
943        watcher.add_edge(5, 3, 1.0);
944
945        watcher.add_edge(2, 3, 0.1); // Weak bridge
946
947        let min_cut = watcher.min_cut_value();
948        assert!(min_cut < 0.5); // Should find weak bridge
949    }
950
951    #[test]
952    fn test_cut_aware_hnsw_insert() {
953        let config = CutAwareConfig::default();
954        let mut index = CutAwareHNSW::new(config);
955
956        let vec1 = create_test_vector(128, 1.0);
957        let vec2 = create_test_vector(128, 0.9);
958
959        assert!(index.insert(0, &vec1).is_ok());
960        assert!(index.insert(1, &vec2).is_ok());
961    }
962
963    #[test]
964    fn test_gated_vs_ungated_search() {
965        let config = CutAwareConfig {
966            coherence_gate_threshold: 0.5,
967            max_cross_cut_hops: 1,
968            ..Default::default()
969        };
970        let mut index = CutAwareHNSW::new(config);
971
972        // Insert two clusters
973        for i in 0..5 {
974            let vec = create_test_vector(128, 1.0 + i as f32 * 0.1);
975            index.insert(i, &vec).unwrap();
976        }
977
978        for i in 5..10 {
979            let vec = create_test_vector(128, -1.0 + i as f32 * 0.1);
980            index.insert(i, &vec).unwrap();
981        }
982
983        let query = create_test_vector(128, 1.0);
984
985        let gated = index.search_gated(&query, 5);
986        let ungated = index.search_ungated(&query, 5);
987
988        // Both should return results
989        assert!(!gated.is_empty());
990        assert!(!ungated.is_empty());
991    }
992
993    #[test]
994    fn test_coherent_neighborhood() {
995        let config = CutAwareConfig::default();
996        let mut index = CutAwareHNSW::new(config);
997
998        // Create connected nodes
999        for i in 0..5 {
1000            let vec = create_test_vector(128, i as f32);
1001            index.insert(i, &vec).unwrap();
1002        }
1003
1004        let neighborhood = index.coherent_neighborhood(0, 2);
1005        assert!(!neighborhood.is_empty());
1006        assert!(neighborhood.contains(&0));
1007    }
1008
1009    #[test]
1010    fn test_edge_updates() {
1011        let config = CutAwareConfig::default();
1012        let mut index = CutAwareHNSW::new(config);
1013
1014        index.add_edge(0, 1, 1.0);
1015        index.add_edge(1, 2, 1.0);
1016
1017        let updates = vec![
1018            EdgeUpdate {
1019                kind: UpdateKind::Insert,
1020                u: 2,
1021                v: 3,
1022                weight: Some(0.8),
1023            },
1024            EdgeUpdate {
1025                kind: UpdateKind::Delete,
1026                u: 0,
1027                v: 1,
1028                weight: None,
1029            },
1030        ];
1031
1032        let stats = index.batch_update(updates);
1033        assert_eq!(stats.edges_added, 1);
1034        assert_eq!(stats.edges_removed, 1);
1035    }
1036
1037    #[test]
1038    fn test_zone_computation() {
1039        let config = CutAwareConfig {
1040            min_zone_size: 2,
1041            ..Default::default()
1042        };
1043        let mut index = CutAwareHNSW::new(config);
1044
1045        // Insert enough nodes
1046        for i in 0..10 {
1047            let vec = create_test_vector(128, i as f32);
1048            index.insert(i, &vec).unwrap();
1049        }
1050
1051        let zones = index.compute_zones();
1052        // May or may not have zones depending on structure
1053        assert!(zones.len() <= 2);
1054    }
1055
1056    #[test]
1057    fn test_cross_zone_search() {
1058        let config = CutAwareConfig::default();
1059        let mut index = CutAwareHNSW::new(config);
1060
1061        for i in 0..8 {
1062            let vec = create_test_vector(128, i as f32);
1063            index.insert(i, &vec).unwrap();
1064        }
1065
1066        index.compute_zones();
1067
1068        let query = create_test_vector(128, 2.0);
1069        let results = index.cross_zone_search(&query, 3, &[0, 1]);
1070
1071        assert!(!results.is_empty());
1072    }
1073
1074    #[test]
1075    fn test_prune_weak_edges() {
1076        let config = CutAwareConfig::default();
1077        let mut index = CutAwareHNSW::new(config);
1078
1079        index.add_edge(0, 1, 1.0);
1080        index.add_edge(1, 2, 0.1); // Weak edge
1081        index.add_edge(2, 3, 1.0);
1082
1083        let pruned = index.prune_weak_edges(0.5);
1084        assert!(pruned >= 0); // May prune depending on cut
1085    }
1086
1087    #[test]
1088    fn test_metrics_tracking() {
1089        let config = CutAwareConfig::default();
1090        let mut index = CutAwareHNSW::new(config);
1091
1092        for i in 0..5 {
1093            let vec = create_test_vector(128, i as f32);
1094            index.insert(i, &vec).unwrap();
1095        }
1096
1097        let query = create_test_vector(128, 2.0);
1098        index.search_gated(&query, 3);
1099        index.search_ungated(&query, 3);
1100
1101        let metrics = index.metrics();
1102        assert!(metrics.searches_performed.load(Ordering::Relaxed) >= 2);
1103    }
1104
1105    #[test]
1106    fn test_export_metrics() {
1107        let config = CutAwareConfig::default();
1108        let index = CutAwareHNSW::new(config);
1109
1110        let json = index.export_metrics();
1111        assert!(json.is_object());
1112        assert!(json["searches_performed"].is_number());
1113    }
1114
1115    #[test]
1116    fn test_cut_distribution() {
1117        let config = CutAwareConfig::default();
1118        let mut index = CutAwareHNSW::new(config);
1119
1120        for i in 0..5 {
1121            let vec = create_test_vector(128, i as f32);
1122            index.insert(i, &vec).unwrap();
1123        }
1124
1125        let dist = index.cut_distribution();
1126        assert!(!dist.is_empty());
1127    }
1128
1129    #[test]
1130    fn test_path_crosses_weak_cut() {
1131        let config = CutAwareConfig {
1132            coherence_gate_threshold: 0.3,
1133            ..Default::default()
1134        };
1135        let mut index = CutAwareHNSW::new(config);
1136
1137        // Create two clusters
1138        index.add_edge(0, 1, 1.0);
1139        index.add_edge(1, 2, 1.0);
1140        index.add_edge(3, 4, 1.0);
1141        index.add_edge(2, 3, 0.1); // Weak bridge
1142
1143        // Force recomputation
1144        {
1145            let mut watcher = index.cut_watcher.write().unwrap();
1146            watcher.min_cut_value();
1147        }
1148
1149        // Check crossing
1150        let crosses = index.path_crosses_weak_cut(0, 4);
1151        // Result depends on partition computation
1152        assert!(crosses || !crosses); // Just verify it doesn't panic
1153    }
1154
1155    #[test]
1156    fn test_stoer_wagner_triangle() {
1157        let mut watcher = DynamicCutWatcher::new();
1158
1159        // Triangle with equal weights
1160        watcher.add_edge(0, 1, 1.0);
1161        watcher.add_edge(1, 2, 1.0);
1162        watcher.add_edge(2, 0, 1.0);
1163
1164        let min_cut = watcher.min_cut_value();
1165        assert!((min_cut - 2.0).abs() < 0.01); // Cut should be 2.0
1166    }
1167
1168    #[test]
1169    fn test_boundary_edge_tracking() {
1170        let mut watcher = DynamicCutWatcher::new();
1171
1172        // Two components with single bridge
1173        watcher.add_edge(0, 1, 1.0);
1174        watcher.add_edge(1, 0, 1.0);
1175        watcher.add_edge(2, 3, 1.0);
1176        watcher.add_edge(3, 2, 1.0);
1177        watcher.add_edge(1, 2, 0.5); // Bridge
1178
1179        let _ = watcher.min_cut_value();
1180        let boundary = watcher.boundary_edges();
1181
1182        // Should identify bridge edge
1183        assert!(!boundary.is_empty());
1184    }
1185
1186    #[test]
1187    fn test_reset_metrics() {
1188        let config = CutAwareConfig::default();
1189        let index = CutAwareHNSW::new(config);
1190
1191        index.metrics.searches_performed.store(100, Ordering::Relaxed);
1192        index.reset_metrics();
1193
1194        assert_eq!(index.metrics.searches_performed.load(Ordering::Relaxed), 0);
1195    }
1196}