ruvector_dag/sona/
engine.rs1use super::{
4 DagReasoningBank, DagTrajectory, DagTrajectoryBuffer, EwcConfig, EwcPlusPlus, MicroLoRA,
5 MicroLoRAConfig, ReasoningBankConfig,
6};
7use crate::dag::{OperatorType, QueryDag};
8use ndarray::Array1;
9use std::collections::hash_map::DefaultHasher;
10use std::hash::{Hash, Hasher};
11
12pub struct DagSonaEngine {
13 micro_lora: MicroLoRA,
14 trajectory_buffer: DagTrajectoryBuffer,
15 reasoning_bank: DagReasoningBank,
16 #[allow(dead_code)]
17 ewc: EwcPlusPlus,
18 embedding_dim: usize,
19}
20
21impl DagSonaEngine {
22 pub fn new(embedding_dim: usize) -> Self {
23 Self {
24 micro_lora: MicroLoRA::new(MicroLoRAConfig::default(), embedding_dim),
25 trajectory_buffer: DagTrajectoryBuffer::new(1000),
26 reasoning_bank: DagReasoningBank::new(ReasoningBankConfig {
27 pattern_dim: embedding_dim,
28 ..Default::default()
29 }),
30 ewc: EwcPlusPlus::new(EwcConfig::default()),
31 embedding_dim,
32 }
33 }
34
35 pub fn pre_query(&mut self, dag: &QueryDag) -> Vec<f32> {
37 let embedding = self.compute_dag_embedding(dag);
38
39 let similar = self.reasoning_bank.query_similar(&embedding, 3);
41
42 if !similar.is_empty() {
44 let adaptation_signal = self.compute_adaptation_signal(&similar, &embedding);
45 self.micro_lora
46 .adapt(&Array1::from_vec(adaptation_signal), 0.01);
47 }
48
49 self.micro_lora
51 .forward(&Array1::from_vec(embedding))
52 .to_vec()
53 }
54
55 pub fn post_query(
57 &mut self,
58 dag: &QueryDag,
59 execution_time_ms: f64,
60 baseline_time_ms: f64,
61 attention_mechanism: &str,
62 ) {
63 let embedding = self.compute_dag_embedding(dag);
64 let trajectory = DagTrajectory::new(
65 self.hash_dag(dag),
66 embedding,
67 attention_mechanism.to_string(),
68 execution_time_ms,
69 baseline_time_ms,
70 );
71
72 self.trajectory_buffer.push(trajectory);
73 }
74
75 pub fn background_learn(&mut self) {
77 let trajectories = self.trajectory_buffer.drain();
78 if trajectories.is_empty() {
79 return;
80 }
81
82 for t in &trajectories {
84 if t.quality() > 0.6 {
85 self.reasoning_bank
86 .store_pattern(t.dag_embedding.clone(), t.quality());
87 }
88 }
89
90 if self.reasoning_bank.pattern_count() % 100 == 0 {
92 self.reasoning_bank.recompute_clusters();
93 }
94 }
95
96 fn compute_dag_embedding(&self, dag: &QueryDag) -> Vec<f32> {
97 let mut embedding = vec![0.0; self.embedding_dim];
99
100 if dag.node_count() == 0 {
101 return embedding;
102 }
103
104 let mut type_counts = vec![0usize; 20];
106 for node in dag.nodes() {
107 let type_idx = match &node.op_type {
108 OperatorType::SeqScan { .. } => 0,
109 OperatorType::IndexScan { .. } => 1,
110 OperatorType::HnswScan { .. } => 2,
111 OperatorType::IvfFlatScan { .. } => 3,
112 OperatorType::NestedLoopJoin => 4,
113 OperatorType::HashJoin { .. } => 5,
114 OperatorType::MergeJoin { .. } => 6,
115 OperatorType::Aggregate { .. } => 7,
116 OperatorType::GroupBy { .. } => 8,
117 OperatorType::Filter { .. } => 9,
118 OperatorType::Project { .. } => 10,
119 OperatorType::Sort { .. } => 11,
120 OperatorType::Limit { .. } => 12,
121 OperatorType::VectorDistance { .. } => 13,
122 OperatorType::Rerank { .. } => 14,
123 OperatorType::Materialize => 15,
124 OperatorType::Result => 16,
125 #[allow(deprecated)]
126 OperatorType::Scan => 0, #[allow(deprecated)]
128 OperatorType::Join => 4, };
130 if type_idx < type_counts.len() {
131 type_counts[type_idx] += 1;
132 }
133 }
134
135 let total = dag.node_count() as f32;
137 for (i, count) in type_counts.iter().enumerate() {
138 if i < self.embedding_dim / 2 {
139 embedding[i] = *count as f32 / total;
140 }
141 }
142
143 let depth = self.compute_dag_depth(dag);
145 let avg_fanout = dag.node_count() as f32 / (dag.leaves().len().max(1) as f32);
146
147 if self.embedding_dim > 20 {
148 embedding[20] = (depth as f32) / 10.0; embedding[21] = avg_fanout / 5.0; }
151
152 let costs: Vec<f64> = dag.nodes().map(|n| n.estimated_cost).collect();
154 if !costs.is_empty() && self.embedding_dim > 22 {
155 let avg_cost = costs.iter().sum::<f64>() / costs.len() as f64;
156 let max_cost = costs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
157 embedding[22] = (avg_cost / 1000.0) as f32; embedding[23] = (max_cost / 1000.0) as f32;
159 }
160
161 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
163 if norm > 0.0 {
164 embedding.iter_mut().for_each(|x| *x /= norm);
165 }
166
167 embedding
168 }
169
170 fn compute_dag_depth(&self, dag: &QueryDag) -> usize {
171 use std::collections::VecDeque;
173
174 let mut max_depth = 0;
175 let mut queue = VecDeque::new();
176
177 if let Some(root) = dag.root() {
178 queue.push_back((root, 0));
179 }
180
181 while let Some((node_id, depth)) = queue.pop_front() {
182 max_depth = max_depth.max(depth);
183 for &child in dag.children(node_id) {
184 queue.push_back((child, depth + 1));
185 }
186 }
187
188 max_depth
189 }
190
191 fn compute_adaptation_signal(
192 &self,
193 _similar: &[(u64, f32)],
194 _current_embedding: &[f32],
195 ) -> Vec<f32> {
196 vec![0.0; self.embedding_dim]
199 }
200
201 fn hash_dag(&self, dag: &QueryDag) -> u64 {
202 let mut hasher = DefaultHasher::new();
203
204 for node in dag.nodes() {
206 node.id.hash(&mut hasher);
207 match &node.op_type {
209 OperatorType::SeqScan { table } => {
210 0u8.hash(&mut hasher);
211 table.hash(&mut hasher);
212 }
213 OperatorType::IndexScan { index, table } => {
214 1u8.hash(&mut hasher);
215 index.hash(&mut hasher);
216 table.hash(&mut hasher);
217 }
218 OperatorType::HnswScan { index, ef_search } => {
219 2u8.hash(&mut hasher);
220 index.hash(&mut hasher);
221 ef_search.hash(&mut hasher);
222 }
223 OperatorType::IvfFlatScan { index, nprobe } => {
224 3u8.hash(&mut hasher);
225 index.hash(&mut hasher);
226 nprobe.hash(&mut hasher);
227 }
228 OperatorType::NestedLoopJoin => 4u8.hash(&mut hasher),
229 OperatorType::HashJoin { hash_key } => {
230 5u8.hash(&mut hasher);
231 hash_key.hash(&mut hasher);
232 }
233 OperatorType::MergeJoin { merge_key } => {
234 6u8.hash(&mut hasher);
235 merge_key.hash(&mut hasher);
236 }
237 OperatorType::Aggregate { functions } => {
238 7u8.hash(&mut hasher);
239 for func in functions {
240 func.hash(&mut hasher);
241 }
242 }
243 OperatorType::GroupBy { keys } => {
244 8u8.hash(&mut hasher);
245 for key in keys {
246 key.hash(&mut hasher);
247 }
248 }
249 OperatorType::Filter { predicate } => {
250 9u8.hash(&mut hasher);
251 predicate.hash(&mut hasher);
252 }
253 OperatorType::Project { columns } => {
254 10u8.hash(&mut hasher);
255 for col in columns {
256 col.hash(&mut hasher);
257 }
258 }
259 OperatorType::Sort { keys, descending } => {
260 11u8.hash(&mut hasher);
261 for key in keys {
262 key.hash(&mut hasher);
263 }
264 for &desc in descending {
265 desc.hash(&mut hasher);
266 }
267 }
268 OperatorType::Limit { count } => {
269 12u8.hash(&mut hasher);
270 count.hash(&mut hasher);
271 }
272 OperatorType::VectorDistance { metric } => {
273 13u8.hash(&mut hasher);
274 metric.hash(&mut hasher);
275 }
276 OperatorType::Rerank { model } => {
277 14u8.hash(&mut hasher);
278 model.hash(&mut hasher);
279 }
280 OperatorType::Materialize => 15u8.hash(&mut hasher),
281 OperatorType::Result => 16u8.hash(&mut hasher),
282 #[allow(deprecated)]
283 OperatorType::Scan => 0u8.hash(&mut hasher),
284 #[allow(deprecated)]
285 OperatorType::Join => 4u8.hash(&mut hasher),
286 }
287 }
288
289 hasher.finish()
290 }
291
292 pub fn pattern_count(&self) -> usize {
293 self.reasoning_bank.pattern_count()
294 }
295
296 pub fn trajectory_count(&self) -> usize {
297 self.trajectory_buffer.total_count()
298 }
299
300 pub fn cluster_count(&self) -> usize {
301 self.reasoning_bank.cluster_count()
302 }
303}
304
305impl Default for DagSonaEngine {
306 fn default() -> Self {
307 Self::new(256)
308 }
309}