Skip to main content

ruvector_dag/attention/
trait_def.rs

1//! DagAttention trait definition for pluggable attention mechanisms
2
3use crate::dag::QueryDag;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use thiserror::Error;
7
8/// Attention scores for each node in the DAG
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct AttentionScores {
11    /// Attention score for each node (0.0 to 1.0)
12    pub scores: Vec<f32>,
13    /// Optional attention weights between nodes (adjacency-like)
14    pub edge_weights: Option<Vec<Vec<f32>>>,
15    /// Metadata for debugging
16    pub metadata: HashMap<String, String>,
17}
18
19impl AttentionScores {
20    pub fn new(scores: Vec<f32>) -> Self {
21        Self {
22            scores,
23            edge_weights: None,
24            metadata: HashMap::new(),
25        }
26    }
27
28    pub fn with_edge_weights(mut self, weights: Vec<Vec<f32>>) -> Self {
29        self.edge_weights = Some(weights);
30        self
31    }
32
33    pub fn with_metadata(mut self, key: String, value: String) -> Self {
34        self.metadata.insert(key, value);
35        self
36    }
37}
38
39/// Errors that can occur during attention computation
40#[derive(Debug, Error)]
41pub enum AttentionError {
42    #[error("Invalid DAG structure: {0}")]
43    InvalidDag(String),
44
45    #[error("Dimension mismatch: expected {expected}, got {actual}")]
46    DimensionMismatch { expected: usize, actual: usize },
47
48    #[error("Computation failed: {0}")]
49    ComputationFailed(String),
50
51    #[error("Configuration error: {0}")]
52    ConfigError(String),
53}
54
55/// Trait for DAG attention mechanisms
56pub trait DagAttentionMechanism: Send + Sync {
57    /// Compute attention scores for the given DAG
58    fn forward(&self, dag: &QueryDag) -> Result<AttentionScores, AttentionError>;
59
60    /// Get the mechanism name
61    fn name(&self) -> &'static str;
62
63    /// Get computational complexity as a string
64    fn complexity(&self) -> &'static str;
65
66    /// Optional: Update internal state based on execution feedback
67    fn update(&mut self, _dag: &QueryDag, _execution_times: &HashMap<usize, f64>) {
68        // Default: no-op
69    }
70
71    /// Optional: Reset internal state
72    fn reset(&mut self) {
73        // Default: no-op
74    }
75}