Skip to main content

tensorlogic_sklears_kernels/learned_composition/
trainable.rs

1//! Thin trainable-parameter adapter around [`LearnedMixtureKernel`].
2//!
3//! Exposes the mixture as a parameter container compatible with
4//! `tensorlogic-train`:
5//!
6//! * `parameters()` / `parameters_mut()` access the logits.
7//! * `step(gradient, learning_rate)` applies a vanilla gradient-descent
8//!   update to the logits.
9//!
10//! The adapter intentionally does not own an optimizer — choice of
11//! optimizer (SGD, Adam, etc.) stays with the caller. It just bundles the
12//! forward evaluation and the analytical gradient so the caller can write
13//!
14//! ```text
15//! let (k, g) = mixture.evaluate_with_gradient(x, y)?;
16//! let grad = dloss_dk * g;                       // scale by upstream grad
17//! mixture.step(&grad, learning_rate)?;           // vanilla SGD step
18//! ```
19
20use crate::error::Result;
21use crate::learned_composition::mixture::LearnedMixtureKernel;
22
23/// Trainable adapter around a [`LearnedMixtureKernel`].
24#[derive(Clone, Debug)]
25pub struct TrainableKernelMixture {
26    inner: LearnedMixtureKernel,
27}
28
29impl TrainableKernelMixture {
30    /// Wrap an existing mixture kernel.
31    pub fn new(inner: LearnedMixtureKernel) -> Self {
32        Self { inner }
33    }
34
35    /// Number of trainable logits.
36    pub fn num_parameters(&self) -> usize {
37        self.inner.num_kernels()
38    }
39
40    /// Borrow the trainable parameters (logits).
41    pub fn parameters(&self) -> &[f64] {
42        self.inner.logits()
43    }
44
45    /// Mixture weights after softmax.
46    pub fn weights(&self) -> Vec<f64> {
47        self.inner.weights()
48    }
49
50    /// Forward pass — mixture kernel value.
51    pub fn evaluate(&self, x: &[f64], y: &[f64]) -> Result<f64> {
52        self.inner.evaluate(x, y)
53    }
54
55    /// Forward + gradient in a single pass.
56    pub fn evaluate_with_gradient(&self, x: &[f64], y: &[f64]) -> Result<(f64, Vec<f64>)> {
57        self.inner.evaluate_with_gradient(x, y)
58    }
59
60    /// Pure gradient (no forward re-use).
61    pub fn gradient(&self, x: &[f64], y: &[f64]) -> Result<Vec<f64>> {
62        self.inner.gradient_wrt_logits(x, y)
63    }
64
65    /// Apply a vanilla gradient-descent step `w <- w - lr * g`.
66    pub fn step(&mut self, gradient: &[f64], learning_rate: f64) -> Result<()> {
67        self.inner.apply_gradient_step(gradient, learning_rate)
68    }
69
70    /// Replace the logits outright (useful for optimizer-driven updates).
71    pub fn set_parameters(&mut self, new_logits: Vec<f64>) -> Result<()> {
72        self.inner.set_logits(new_logits)
73    }
74
75    /// Borrow the underlying mixture kernel (read-only).
76    pub fn inner(&self) -> &LearnedMixtureKernel {
77        &self.inner
78    }
79
80    /// Consume the adapter and return the underlying mixture kernel.
81    pub fn into_inner(self) -> LearnedMixtureKernel {
82        self.inner
83    }
84}
85
86impl From<LearnedMixtureKernel> for TrainableKernelMixture {
87    fn from(inner: LearnedMixtureKernel) -> Self {
88        Self::new(inner)
89    }
90}