ruvector_dag/attention/
traits.rs1use crate::dag::QueryDag;
4use std::collections::HashMap;
5
6pub type AttentionScores = HashMap<usize, f32>;
8
9#[derive(Debug, Clone)]
11pub struct AttentionConfig {
12 pub normalize: bool,
13 pub temperature: f32,
14 pub dropout: f32,
15}
16
17impl Default for AttentionConfig {
18 fn default() -> Self {
19 Self {
20 normalize: true,
21 temperature: 1.0,
22 dropout: 0.0,
23 }
24 }
25}
26
27#[derive(Debug, thiserror::Error)]
29pub enum AttentionError {
30 #[error("Empty DAG")]
31 EmptyDag,
32 #[error("Cycle detected in DAG")]
33 CycleDetected,
34 #[error("Node {0} not found")]
35 NodeNotFound(usize),
36 #[error("Computation failed: {0}")]
37 ComputationFailed(String),
38}
39
40pub trait DagAttention: Send + Sync {
42 fn forward(&self, dag: &QueryDag) -> Result<AttentionScores, AttentionError>;
44
45 fn update(&mut self, dag: &QueryDag, execution_times: &HashMap<usize, f64>);
47
48 fn name(&self) -> &'static str;
50
51 fn complexity(&self) -> &'static str;
53}