Skip to main content

tensorlogic_sklears_kernels/deep_kernel/
builder.rs

1//! Fluent builder for common Deep Kernel topologies.
2//!
3//! Mirrors the style of
4//! [`crate::learned_composition::LearnedMixtureBuilder`]: chained
5//! setters plus a terminating [`DeepKernelBuilder::build`] that returns
6//! `Result<DeepKernel<MLPFeatureExtractor, K>>`.
7//!
8//! Typical usage:
9//!
10//! ```rust
11//! use tensorlogic_sklears_kernels::{
12//!     deep_kernel::{Activation, DeepKernelBuilder},
13//!     RbfKernel, RbfKernelConfig,
14//! };
15//!
16//! let rbf = RbfKernel::new(RbfKernelConfig::new(0.5)).expect("valid");
17//! let dkl = DeepKernelBuilder::new()
18//!     .input_dim(4)
19//!     .hidden_layer(8, Activation::ReLU)
20//!     .hidden_layer(4, Activation::Tanh)
21//!     .output_dim(2, Activation::Identity)
22//!     .seed(42)
23//!     .build(rbf)
24//!     .expect("valid topology");
25//! let _ = dkl;
26//! ```
27
28use crate::deep_kernel::feature_extractor::MLPFeatureExtractor;
29use crate::deep_kernel::kernel::DeepKernel;
30use crate::deep_kernel::layer::Activation;
31use crate::error::{KernelError, Result};
32use crate::types::Kernel;
33
34/// Fluent builder for Deep Kernel networks.
35///
36/// The builder records layer widths and activations in order. The first
37/// width is the input dimension (set via [`Self::input_dim`]). Every
38/// subsequent hidden layer appends a new width and its activation. The
39/// terminating [`Self::output_dim`] records the final width and the
40/// output-layer activation.
41#[derive(Clone, Debug, Default)]
42pub struct DeepKernelBuilder {
43    widths: Vec<usize>,
44    activations: Vec<Activation>,
45    seed: Option<u64>,
46    has_output: bool,
47}
48
49impl DeepKernelBuilder {
50    /// Empty builder — no layers configured yet.
51    pub fn new() -> Self {
52        Self::default()
53    }
54
55    /// Set the input dimension. Must be called exactly once before the
56    /// first hidden or output layer.
57    pub fn input_dim(mut self, dim: usize) -> Self {
58        if self.widths.is_empty() {
59            self.widths.push(dim);
60        } else {
61            self.widths[0] = dim;
62        }
63        self
64    }
65
66    /// Append a hidden layer with the given width and activation.
67    /// Ignored if the builder has already been closed with
68    /// [`Self::output_dim`] — the builder reports that as an error at
69    /// build time via
70    /// [`KernelError::InvalidParameter`].
71    pub fn hidden_layer(mut self, width: usize, activation: Activation) -> Self {
72        if !self.has_output {
73            self.widths.push(width);
74            self.activations.push(activation);
75        }
76        self
77    }
78
79    /// Finalise the topology by appending the output layer. Must be
80    /// called exactly once.
81    pub fn output_dim(mut self, width: usize, activation: Activation) -> Self {
82        if !self.has_output {
83            self.widths.push(width);
84            self.activations.push(activation);
85            self.has_output = true;
86        }
87        self
88    }
89
90    /// Set the RNG seed used for Xavier initialisation.
91    pub fn seed(mut self, seed: u64) -> Self {
92        self.seed = Some(seed);
93        self
94    }
95
96    /// Produce an owned [`MLPFeatureExtractor`] for the configured
97    /// topology, without a base kernel. Useful when the caller wants to
98    /// combine the MLP with a kernel that needs additional plumbing.
99    pub fn build_extractor(&self) -> Result<MLPFeatureExtractor> {
100        if self.widths.len() < 2 {
101            return Err(KernelError::InvalidParameter {
102                parameter: "widths".to_string(),
103                value: format!("{:?}", self.widths),
104                reason: "builder needs at least input_dim + output_dim".to_string(),
105            });
106        }
107        if !self.has_output {
108            return Err(KernelError::InvalidParameter {
109                parameter: "output_dim".to_string(),
110                value: "unset".to_string(),
111                reason: "call output_dim before build".to_string(),
112            });
113        }
114        let seed = self.seed.unwrap_or(0);
115        MLPFeatureExtractor::xavier_init(&self.widths, &self.activations, seed)
116    }
117
118    /// Finalise the builder against a base kernel and produce a fully
119    /// wired [`DeepKernel`].
120    pub fn build<K: Kernel>(self, base: K) -> Result<DeepKernel<MLPFeatureExtractor, K>> {
121        let extractor = self.build_extractor()?;
122        Ok(DeepKernel::new(extractor, base))
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129    use crate::types::RbfKernelConfig;
130    use crate::RbfKernel;
131
132    #[test]
133    fn builder_assembles_three_layer_mlp() {
134        let rbf = RbfKernel::new(RbfKernelConfig::new(0.5)).expect("valid");
135        let dkl = DeepKernelBuilder::new()
136            .input_dim(3)
137            .hidden_layer(5, Activation::ReLU)
138            .output_dim(2, Activation::Identity)
139            .seed(123)
140            .build(rbf)
141            .expect("valid build");
142        assert_eq!(dkl.feature_extractor().num_layers(), 2);
143    }
144
145    #[test]
146    fn builder_fails_without_output_dim() {
147        let rbf = RbfKernel::new(RbfKernelConfig::new(0.5)).expect("valid");
148        let result = DeepKernelBuilder::new()
149            .input_dim(3)
150            .hidden_layer(5, Activation::ReLU)
151            .build(rbf);
152        match result {
153            Ok(_) => panic!("missing output_dim must fail"),
154            Err(KernelError::InvalidParameter { .. }) => {}
155            Err(other) => panic!("unexpected error variant: {}", other),
156        }
157    }
158
159    #[test]
160    fn builder_fails_when_only_input_set() {
161        let rbf = RbfKernel::new(RbfKernelConfig::new(0.5)).expect("valid");
162        let result = DeepKernelBuilder::new().input_dim(3).build(rbf);
163        match result {
164            Ok(_) => panic!("only input_dim set must fail"),
165            Err(KernelError::InvalidParameter { .. }) => {}
166            Err(other) => panic!("unexpected error variant: {}", other),
167        }
168    }
169}