Skip to main content

ruvector_dag/sona/
engine.rs

1//! DagSonaEngine: Main orchestration for SONA learning
2
3use super::{
4    DagReasoningBank, DagTrajectory, DagTrajectoryBuffer, EwcConfig, EwcPlusPlus, MicroLoRA,
5    MicroLoRAConfig, ReasoningBankConfig,
6};
7use crate::dag::{OperatorType, QueryDag};
8use ndarray::Array1;
9use std::collections::hash_map::DefaultHasher;
10use std::hash::{Hash, Hasher};
11
12pub struct DagSonaEngine {
13    micro_lora: MicroLoRA,
14    trajectory_buffer: DagTrajectoryBuffer,
15    reasoning_bank: DagReasoningBank,
16    #[allow(dead_code)]
17    ewc: EwcPlusPlus,
18    embedding_dim: usize,
19}
20
21impl DagSonaEngine {
22    pub fn new(embedding_dim: usize) -> Self {
23        Self {
24            micro_lora: MicroLoRA::new(MicroLoRAConfig::default(), embedding_dim),
25            trajectory_buffer: DagTrajectoryBuffer::new(1000),
26            reasoning_bank: DagReasoningBank::new(ReasoningBankConfig {
27                pattern_dim: embedding_dim,
28                ..Default::default()
29            }),
30            ewc: EwcPlusPlus::new(EwcConfig::default()),
31            embedding_dim,
32        }
33    }
34
35    /// Pre-query instant adaptation (<100μs)
36    pub fn pre_query(&mut self, dag: &QueryDag) -> Vec<f32> {
37        let embedding = self.compute_dag_embedding(dag);
38
39        // Query similar patterns
40        let similar = self.reasoning_bank.query_similar(&embedding, 3);
41
42        // If we have similar patterns, adapt MicroLoRA
43        if !similar.is_empty() {
44            let adaptation_signal = self.compute_adaptation_signal(&similar, &embedding);
45            self.micro_lora
46                .adapt(&Array1::from_vec(adaptation_signal), 0.01);
47        }
48
49        // Return enhanced embedding
50        self.micro_lora
51            .forward(&Array1::from_vec(embedding))
52            .to_vec()
53    }
54
55    /// Post-query trajectory recording
56    pub fn post_query(
57        &mut self,
58        dag: &QueryDag,
59        execution_time_ms: f64,
60        baseline_time_ms: f64,
61        attention_mechanism: &str,
62    ) {
63        let embedding = self.compute_dag_embedding(dag);
64        let trajectory = DagTrajectory::new(
65            self.hash_dag(dag),
66            embedding,
67            attention_mechanism.to_string(),
68            execution_time_ms,
69            baseline_time_ms,
70        );
71
72        self.trajectory_buffer.push(trajectory);
73    }
74
75    /// Background learning cycle (called periodically)
76    pub fn background_learn(&mut self) {
77        let trajectories = self.trajectory_buffer.drain();
78        if trajectories.is_empty() {
79            return;
80        }
81
82        // Store high-quality patterns
83        for t in &trajectories {
84            if t.quality() > 0.6 {
85                self.reasoning_bank
86                    .store_pattern(t.dag_embedding.clone(), t.quality());
87            }
88        }
89
90        // Recompute clusters periodically (every 100 patterns)
91        if self.reasoning_bank.pattern_count() % 100 == 0 {
92            self.reasoning_bank.recompute_clusters();
93        }
94    }
95
96    fn compute_dag_embedding(&self, dag: &QueryDag) -> Vec<f32> {
97        // Compute embedding from DAG structure
98        let mut embedding = vec![0.0; self.embedding_dim];
99
100        if dag.node_count() == 0 {
101            return embedding;
102        }
103
104        // Encode operator type distribution (20 different types)
105        let mut type_counts = vec![0usize; 20];
106        for node in dag.nodes() {
107            let type_idx = match &node.op_type {
108                OperatorType::SeqScan { .. } => 0,
109                OperatorType::IndexScan { .. } => 1,
110                OperatorType::HnswScan { .. } => 2,
111                OperatorType::IvfFlatScan { .. } => 3,
112                OperatorType::NestedLoopJoin => 4,
113                OperatorType::HashJoin { .. } => 5,
114                OperatorType::MergeJoin { .. } => 6,
115                OperatorType::Aggregate { .. } => 7,
116                OperatorType::GroupBy { .. } => 8,
117                OperatorType::Filter { .. } => 9,
118                OperatorType::Project { .. } => 10,
119                OperatorType::Sort { .. } => 11,
120                OperatorType::Limit { .. } => 12,
121                OperatorType::VectorDistance { .. } => 13,
122                OperatorType::Rerank { .. } => 14,
123                OperatorType::Materialize => 15,
124                OperatorType::Result => 16,
125                #[allow(deprecated)]
126                OperatorType::Scan => 0, // Treat as SeqScan
127                #[allow(deprecated)]
128                OperatorType::Join => 4, // Treat as NestedLoopJoin
129            };
130            if type_idx < type_counts.len() {
131                type_counts[type_idx] += 1;
132            }
133        }
134
135        // Normalize and place in embedding
136        let total = dag.node_count() as f32;
137        for (i, count) in type_counts.iter().enumerate() {
138            if i < self.embedding_dim / 2 {
139                embedding[i] = *count as f32 / total;
140            }
141        }
142
143        // Encode structural features (depth, breadth, connectivity)
144        let depth = self.compute_dag_depth(dag);
145        let avg_fanout = dag.node_count() as f32 / (dag.leaves().len().max(1) as f32);
146
147        if self.embedding_dim > 20 {
148            embedding[20] = (depth as f32) / 10.0; // Normalize depth
149            embedding[21] = avg_fanout / 5.0; // Normalize fanout
150        }
151
152        // Encode cost statistics
153        let costs: Vec<f64> = dag.nodes().map(|n| n.estimated_cost).collect();
154        if !costs.is_empty() && self.embedding_dim > 22 {
155            let avg_cost = costs.iter().sum::<f64>() / costs.len() as f64;
156            let max_cost = costs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
157            embedding[22] = (avg_cost / 1000.0) as f32; // Normalize
158            embedding[23] = (max_cost / 1000.0) as f32;
159        }
160
161        // Normalize entire embedding
162        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
163        if norm > 0.0 {
164            embedding.iter_mut().for_each(|x| *x /= norm);
165        }
166
167        embedding
168    }
169
170    fn compute_dag_depth(&self, dag: &QueryDag) -> usize {
171        // BFS to find maximum depth
172        use std::collections::VecDeque;
173
174        let mut max_depth = 0;
175        let mut queue = VecDeque::new();
176
177        if let Some(root) = dag.root() {
178            queue.push_back((root, 0));
179        }
180
181        while let Some((node_id, depth)) = queue.pop_front() {
182            max_depth = max_depth.max(depth);
183            for &child in dag.children(node_id) {
184                queue.push_back((child, depth + 1));
185            }
186        }
187
188        max_depth
189    }
190
191    fn compute_adaptation_signal(
192        &self,
193        _similar: &[(u64, f32)],
194        _current_embedding: &[f32],
195    ) -> Vec<f32> {
196        // Weighted average of similar pattern embeddings
197        // For now, just return zeros as we'd need to store pattern vectors
198        vec![0.0; self.embedding_dim]
199    }
200
201    fn hash_dag(&self, dag: &QueryDag) -> u64 {
202        let mut hasher = DefaultHasher::new();
203
204        // Hash node types and edges
205        for node in dag.nodes() {
206            node.id.hash(&mut hasher);
207            // Hash operator type discriminant
208            match &node.op_type {
209                OperatorType::SeqScan { table } => {
210                    0u8.hash(&mut hasher);
211                    table.hash(&mut hasher);
212                }
213                OperatorType::IndexScan { index, table } => {
214                    1u8.hash(&mut hasher);
215                    index.hash(&mut hasher);
216                    table.hash(&mut hasher);
217                }
218                OperatorType::HnswScan { index, ef_search } => {
219                    2u8.hash(&mut hasher);
220                    index.hash(&mut hasher);
221                    ef_search.hash(&mut hasher);
222                }
223                OperatorType::IvfFlatScan { index, nprobe } => {
224                    3u8.hash(&mut hasher);
225                    index.hash(&mut hasher);
226                    nprobe.hash(&mut hasher);
227                }
228                OperatorType::NestedLoopJoin => 4u8.hash(&mut hasher),
229                OperatorType::HashJoin { hash_key } => {
230                    5u8.hash(&mut hasher);
231                    hash_key.hash(&mut hasher);
232                }
233                OperatorType::MergeJoin { merge_key } => {
234                    6u8.hash(&mut hasher);
235                    merge_key.hash(&mut hasher);
236                }
237                OperatorType::Aggregate { functions } => {
238                    7u8.hash(&mut hasher);
239                    for func in functions {
240                        func.hash(&mut hasher);
241                    }
242                }
243                OperatorType::GroupBy { keys } => {
244                    8u8.hash(&mut hasher);
245                    for key in keys {
246                        key.hash(&mut hasher);
247                    }
248                }
249                OperatorType::Filter { predicate } => {
250                    9u8.hash(&mut hasher);
251                    predicate.hash(&mut hasher);
252                }
253                OperatorType::Project { columns } => {
254                    10u8.hash(&mut hasher);
255                    for col in columns {
256                        col.hash(&mut hasher);
257                    }
258                }
259                OperatorType::Sort { keys, descending } => {
260                    11u8.hash(&mut hasher);
261                    for key in keys {
262                        key.hash(&mut hasher);
263                    }
264                    for &desc in descending {
265                        desc.hash(&mut hasher);
266                    }
267                }
268                OperatorType::Limit { count } => {
269                    12u8.hash(&mut hasher);
270                    count.hash(&mut hasher);
271                }
272                OperatorType::VectorDistance { metric } => {
273                    13u8.hash(&mut hasher);
274                    metric.hash(&mut hasher);
275                }
276                OperatorType::Rerank { model } => {
277                    14u8.hash(&mut hasher);
278                    model.hash(&mut hasher);
279                }
280                OperatorType::Materialize => 15u8.hash(&mut hasher),
281                OperatorType::Result => 16u8.hash(&mut hasher),
282                #[allow(deprecated)]
283                OperatorType::Scan => 0u8.hash(&mut hasher),
284                #[allow(deprecated)]
285                OperatorType::Join => 4u8.hash(&mut hasher),
286            }
287        }
288
289        hasher.finish()
290    }
291
292    pub fn pattern_count(&self) -> usize {
293        self.reasoning_bank.pattern_count()
294    }
295
296    pub fn trajectory_count(&self) -> usize {
297        self.trajectory_buffer.total_count()
298    }
299
300    pub fn cluster_count(&self) -> usize {
301        self.reasoning_bank.cluster_count()
302    }
303}
304
305impl Default for DagSonaEngine {
306    fn default() -> Self {
307        Self::new(256)
308    }
309}