Skip to main content

ruvector_dag/attention/
traits.rs

1//! Core traits and types for DAG attention mechanisms
2
3use crate::dag::QueryDag;
4use std::collections::HashMap;
5
6/// Attention scores for DAG nodes
7pub type AttentionScores = HashMap<usize, f32>;
8
9/// Configuration for attention mechanisms
10#[derive(Debug, Clone)]
11pub struct AttentionConfig {
12    pub normalize: bool,
13    pub temperature: f32,
14    pub dropout: f32,
15}
16
17impl Default for AttentionConfig {
18    fn default() -> Self {
19        Self {
20            normalize: true,
21            temperature: 1.0,
22            dropout: 0.0,
23        }
24    }
25}
26
27/// Errors from attention computation
28#[derive(Debug, thiserror::Error)]
29pub enum AttentionError {
30    #[error("Empty DAG")]
31    EmptyDag,
32    #[error("Cycle detected in DAG")]
33    CycleDetected,
34    #[error("Node {0} not found")]
35    NodeNotFound(usize),
36    #[error("Computation failed: {0}")]
37    ComputationFailed(String),
38}
39
40/// Trait for DAG attention mechanisms
41pub trait DagAttention: Send + Sync {
42    /// Compute attention scores for all nodes
43    fn forward(&self, dag: &QueryDag) -> Result<AttentionScores, AttentionError>;
44
45    /// Update internal state after execution feedback
46    fn update(&mut self, dag: &QueryDag, execution_times: &HashMap<usize, f64>);
47
48    /// Get mechanism name
49    fn name(&self) -> &'static str;
50
51    /// Get computational complexity description
52    fn complexity(&self) -> &'static str;
53}