ruvector_dag/attention/
trait_def.rs1use crate::dag::QueryDag;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use thiserror::Error;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct AttentionScores {
11 pub scores: Vec<f32>,
13 pub edge_weights: Option<Vec<Vec<f32>>>,
15 pub metadata: HashMap<String, String>,
17}
18
19impl AttentionScores {
20 pub fn new(scores: Vec<f32>) -> Self {
21 Self {
22 scores,
23 edge_weights: None,
24 metadata: HashMap::new(),
25 }
26 }
27
28 pub fn with_edge_weights(mut self, weights: Vec<Vec<f32>>) -> Self {
29 self.edge_weights = Some(weights);
30 self
31 }
32
33 pub fn with_metadata(mut self, key: String, value: String) -> Self {
34 self.metadata.insert(key, value);
35 self
36 }
37}
38
39#[derive(Debug, Error)]
41pub enum AttentionError {
42 #[error("Invalid DAG structure: {0}")]
43 InvalidDag(String),
44
45 #[error("Dimension mismatch: expected {expected}, got {actual}")]
46 DimensionMismatch { expected: usize, actual: usize },
47
48 #[error("Computation failed: {0}")]
49 ComputationFailed(String),
50
51 #[error("Configuration error: {0}")]
52 ConfigError(String),
53}
54
55pub trait DagAttentionMechanism: Send + Sync {
57 fn forward(&self, dag: &QueryDag) -> Result<AttentionScores, AttentionError>;
59
60 fn name(&self) -> &'static str;
62
63 fn complexity(&self) -> &'static str;
65
66 fn update(&mut self, _dag: &QueryDag, _execution_times: &HashMap<usize, f64>) {
68 }
70
71 fn reset(&mut self) {
73 }
75}