tensorlogic_sklears_kernels/learned_composition/trainable.rs
1//! Thin trainable-parameter adapter around [`LearnedMixtureKernel`].
2//!
3//! Exposes the mixture as a parameter container compatible with
4//! `tensorlogic-train`:
5//!
6//! * `parameters()` / `parameters_mut()` access the logits.
7//! * `step(gradient, learning_rate)` applies a vanilla gradient-descent
8//! update to the logits.
9//!
10//! The adapter intentionally does not own an optimizer — choice of
11//! optimizer (SGD, Adam, etc.) stays with the caller. It just bundles the
12//! forward evaluation and the analytical gradient so the caller can write
13//!
14//! ```text
15//! let (k, g) = mixture.evaluate_with_gradient(x, y)?;
16//! let grad = dloss_dk * g; // scale by upstream grad
17//! mixture.step(&grad, learning_rate)?; // vanilla SGD step
18//! ```
19
20use crate::error::Result;
21use crate::learned_composition::mixture::LearnedMixtureKernel;
22
23/// Trainable adapter around a [`LearnedMixtureKernel`].
24#[derive(Clone, Debug)]
25pub struct TrainableKernelMixture {
26 inner: LearnedMixtureKernel,
27}
28
29impl TrainableKernelMixture {
30 /// Wrap an existing mixture kernel.
31 pub fn new(inner: LearnedMixtureKernel) -> Self {
32 Self { inner }
33 }
34
35 /// Number of trainable logits.
36 pub fn num_parameters(&self) -> usize {
37 self.inner.num_kernels()
38 }
39
40 /// Borrow the trainable parameters (logits).
41 pub fn parameters(&self) -> &[f64] {
42 self.inner.logits()
43 }
44
45 /// Mixture weights after softmax.
46 pub fn weights(&self) -> Vec<f64> {
47 self.inner.weights()
48 }
49
50 /// Forward pass — mixture kernel value.
51 pub fn evaluate(&self, x: &[f64], y: &[f64]) -> Result<f64> {
52 self.inner.evaluate(x, y)
53 }
54
55 /// Forward + gradient in a single pass.
56 pub fn evaluate_with_gradient(&self, x: &[f64], y: &[f64]) -> Result<(f64, Vec<f64>)> {
57 self.inner.evaluate_with_gradient(x, y)
58 }
59
60 /// Pure gradient (no forward re-use).
61 pub fn gradient(&self, x: &[f64], y: &[f64]) -> Result<Vec<f64>> {
62 self.inner.gradient_wrt_logits(x, y)
63 }
64
65 /// Apply a vanilla gradient-descent step `w <- w - lr * g`.
66 pub fn step(&mut self, gradient: &[f64], learning_rate: f64) -> Result<()> {
67 self.inner.apply_gradient_step(gradient, learning_rate)
68 }
69
70 /// Replace the logits outright (useful for optimizer-driven updates).
71 pub fn set_parameters(&mut self, new_logits: Vec<f64>) -> Result<()> {
72 self.inner.set_logits(new_logits)
73 }
74
75 /// Borrow the underlying mixture kernel (read-only).
76 pub fn inner(&self) -> &LearnedMixtureKernel {
77 &self.inner
78 }
79
80 /// Consume the adapter and return the underlying mixture kernel.
81 pub fn into_inner(self) -> LearnedMixtureKernel {
82 self.inner
83 }
84}
85
86impl From<LearnedMixtureKernel> for TrainableKernelMixture {
87 fn from(inner: LearnedMixtureKernel) -> Self {
88 Self::new(inner)
89 }
90}