ubiquity_quiver/
lib.rs

1//! ubiquity-quiver: quiver representation with arrows and activations
2
3use ndarray::{Array2, Array1, Axis};
4use serde::{Serialize, Deserialize};
5use ubiquity_kernel::{NodeId, ArrowId, Result, UbiqError};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub enum Activation {
9    Linear,
10    Relu,
11    Tanh,
12}
13
14impl Activation {
15    fn apply(&self, v: &Array1<f32>) -> Array1<f32> {
16        match self {
17            Activation::Linear => v.clone(),
18            Activation::Relu => v.mapv(|x| x.max(0.0)),
19            Activation::Tanh => v.mapv(|x| x.tanh()),
20        }
21    }
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct Node {
26    pub id: NodeId,
27    pub dim: usize,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct Arrow {
32    pub id: ArrowId,
33    pub src: NodeId,
34    pub dst: NodeId,
35    pub w: Array2<f32>,
36    pub b: Option<Array1<f32>>,
37    pub act: Activation,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct Quiver {
42    pub nodes: Vec<Node>,
43    pub arrows: Vec<Arrow>,
44}
45
46impl Quiver {
47    pub fn forward(&self, inputs: &[(NodeId, Array1<f32>)]) -> Result<Vec<(NodeId, Array1<f32>)>> {
48        // naive forward: single pass by arrow order; for DAGs, topological order works
49        let mut state = std::collections::HashMap::<NodeId, Array1<f32>>::new();
50        for (id, v) in inputs.iter() { state.insert(*id, v.clone()); }
51        for a in &self.arrows {
52            let x = state.get(&a.src).ok_or_else(|| UbiqError::Graph(format!("missing src {:?}", a.src)))?;
53            if x.len() != a.w.shape()[1] {
54                return Err(UbiqError::Dim(format!("arrow {:?} expected {}, got {}", a.id, a.w.shape()[1], x.len())));
55            }
56            let mut y = a.w.dot(x);
57            if let Some(b) = &a.b { y = &y + b; }
58            let y = a.act.apply(&y);
59            state.insert(a.dst, y);
60        }
61        Ok(self.nodes.iter().filter_map(|n| state.get(&n.id).map(|v| (n.id, v.clone()))).collect())
62    }
63}