1use ndarray::{Array2, Array1, Axis};
4use serde::{Serialize, Deserialize};
5use ubiquity_kernel::{NodeId, ArrowId, Result, UbiqError};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub enum Activation {
9 Linear,
10 Relu,
11 Tanh,
12}
13
14impl Activation {
15 fn apply(&self, v: &Array1<f32>) -> Array1<f32> {
16 match self {
17 Activation::Linear => v.clone(),
18 Activation::Relu => v.mapv(|x| x.max(0.0)),
19 Activation::Tanh => v.mapv(|x| x.tanh()),
20 }
21 }
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct Node {
26 pub id: NodeId,
27 pub dim: usize,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct Arrow {
32 pub id: ArrowId,
33 pub src: NodeId,
34 pub dst: NodeId,
35 pub w: Array2<f32>,
36 pub b: Option<Array1<f32>>,
37 pub act: Activation,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct Quiver {
42 pub nodes: Vec<Node>,
43 pub arrows: Vec<Arrow>,
44}
45
46impl Quiver {
47 pub fn forward(&self, inputs: &[(NodeId, Array1<f32>)]) -> Result<Vec<(NodeId, Array1<f32>)>> {
48 let mut state = std::collections::HashMap::<NodeId, Array1<f32>>::new();
50 for (id, v) in inputs.iter() { state.insert(*id, v.clone()); }
51 for a in &self.arrows {
52 let x = state.get(&a.src).ok_or_else(|| UbiqError::Graph(format!("missing src {:?}", a.src)))?;
53 if x.len() != a.w.shape()[1] {
54 return Err(UbiqError::Dim(format!("arrow {:?} expected {}, got {}", a.id, a.w.shape()[1], x.len())));
55 }
56 let mut y = a.w.dot(x);
57 if let Some(b) = &a.b { y = &y + b; }
58 let y = a.act.apply(&y);
59 state.insert(a.dst, y);
60 }
61 Ok(self.nodes.iter().filter_map(|n| state.get(&n.id).map(|v| (n.id, v.clone()))).collect())
62 }
63}