tensorlogic_sklears_kernels/deep_kernel/
builder.rs1use crate::deep_kernel::feature_extractor::MLPFeatureExtractor;
29use crate::deep_kernel::kernel::DeepKernel;
30use crate::deep_kernel::layer::Activation;
31use crate::error::{KernelError, Result};
32use crate::types::Kernel;
33
34#[derive(Clone, Debug, Default)]
42pub struct DeepKernelBuilder {
43 widths: Vec<usize>,
44 activations: Vec<Activation>,
45 seed: Option<u64>,
46 has_output: bool,
47}
48
49impl DeepKernelBuilder {
50 pub fn new() -> Self {
52 Self::default()
53 }
54
55 pub fn input_dim(mut self, dim: usize) -> Self {
58 if self.widths.is_empty() {
59 self.widths.push(dim);
60 } else {
61 self.widths[0] = dim;
62 }
63 self
64 }
65
66 pub fn hidden_layer(mut self, width: usize, activation: Activation) -> Self {
72 if !self.has_output {
73 self.widths.push(width);
74 self.activations.push(activation);
75 }
76 self
77 }
78
79 pub fn output_dim(mut self, width: usize, activation: Activation) -> Self {
82 if !self.has_output {
83 self.widths.push(width);
84 self.activations.push(activation);
85 self.has_output = true;
86 }
87 self
88 }
89
90 pub fn seed(mut self, seed: u64) -> Self {
92 self.seed = Some(seed);
93 self
94 }
95
96 pub fn build_extractor(&self) -> Result<MLPFeatureExtractor> {
100 if self.widths.len() < 2 {
101 return Err(KernelError::InvalidParameter {
102 parameter: "widths".to_string(),
103 value: format!("{:?}", self.widths),
104 reason: "builder needs at least input_dim + output_dim".to_string(),
105 });
106 }
107 if !self.has_output {
108 return Err(KernelError::InvalidParameter {
109 parameter: "output_dim".to_string(),
110 value: "unset".to_string(),
111 reason: "call output_dim before build".to_string(),
112 });
113 }
114 let seed = self.seed.unwrap_or(0);
115 MLPFeatureExtractor::xavier_init(&self.widths, &self.activations, seed)
116 }
117
118 pub fn build<K: Kernel>(self, base: K) -> Result<DeepKernel<MLPFeatureExtractor, K>> {
121 let extractor = self.build_extractor()?;
122 Ok(DeepKernel::new(extractor, base))
123 }
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129 use crate::types::RbfKernelConfig;
130 use crate::RbfKernel;
131
132 #[test]
133 fn builder_assembles_three_layer_mlp() {
134 let rbf = RbfKernel::new(RbfKernelConfig::new(0.5)).expect("valid");
135 let dkl = DeepKernelBuilder::new()
136 .input_dim(3)
137 .hidden_layer(5, Activation::ReLU)
138 .output_dim(2, Activation::Identity)
139 .seed(123)
140 .build(rbf)
141 .expect("valid build");
142 assert_eq!(dkl.feature_extractor().num_layers(), 2);
143 }
144
145 #[test]
146 fn builder_fails_without_output_dim() {
147 let rbf = RbfKernel::new(RbfKernelConfig::new(0.5)).expect("valid");
148 let result = DeepKernelBuilder::new()
149 .input_dim(3)
150 .hidden_layer(5, Activation::ReLU)
151 .build(rbf);
152 match result {
153 Ok(_) => panic!("missing output_dim must fail"),
154 Err(KernelError::InvalidParameter { .. }) => {}
155 Err(other) => panic!("unexpected error variant: {}", other),
156 }
157 }
158
159 #[test]
160 fn builder_fails_when_only_input_set() {
161 let rbf = RbfKernel::new(RbfKernelConfig::new(0.5)).expect("valid");
162 let result = DeepKernelBuilder::new().input_dim(3).build(rbf);
163 match result {
164 Ok(_) => panic!("only input_dim set must fail"),
165 Err(KernelError::InvalidParameter { .. }) => {}
166 Err(other) => panic!("unexpected error variant: {}", other),
167 }
168 }
169}