1pub mod svc;
27pub mod svr;
28use core::fmt::Debug;
32
33#[cfg(feature = "serde")]
34use serde::{Deserialize, Serialize};
35
36use crate::error::{Failed, FailedError};
37use crate::linalg::basic::arrays::{Array1, ArrayView1};
38
39#[cfg_attr(
42 all(feature = "serde", not(target_arch = "wasm32")),
43 typetag::serde(tag = "type")
44)]
45pub trait Kernel: Debug {
46 #[allow(clippy::ptr_arg)]
47 fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed>;
49}
50
51#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
53#[derive(Debug, Clone)]
54pub struct Kernels;
55
56impl Kernels {
57 pub fn linear() -> LinearKernel {
59 LinearKernel
60 }
61 pub fn rbf() -> RBFKernel {
63 RBFKernel::default()
64 }
65 pub fn polynomial() -> PolynomialKernel {
67 PolynomialKernel::default()
68 }
69 pub fn sigmoid() -> SigmoidKernel {
71 SigmoidKernel::default()
72 }
73}
74
75#[allow(clippy::derive_partial_eq_without_eq)]
77#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
78#[derive(Debug, Clone, PartialEq, Eq, Default)]
79pub struct LinearKernel;
80
81#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
83#[derive(Debug, Default, Clone, PartialEq)]
84pub struct RBFKernel {
85 pub gamma: Option<f64>,
87}
88
89#[allow(dead_code)]
90impl RBFKernel {
91 pub fn with_gamma(mut self, gamma: f64) -> Self {
97 self.gamma = Some(gamma);
98 self
99 }
100}
101
102#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
104#[derive(Debug, Clone, PartialEq)]
105pub struct PolynomialKernel {
106 pub degree: Option<f64>,
108 pub gamma: Option<f64>,
110 pub coef0: Option<f64>,
112}
113
114impl Default for PolynomialKernel {
115 fn default() -> Self {
116 Self {
117 gamma: Option::None,
118 degree: Option::None,
119 coef0: Some(1f64),
120 }
121 }
122}
123
124impl PolynomialKernel {
125 pub fn with_params(mut self, degree: f64, gamma: f64, coef0: f64) -> Self {
131 self.degree = Some(degree);
132 self.gamma = Some(gamma);
133 self.coef0 = Some(coef0);
134 self
135 }
136 pub fn with_gamma(mut self, gamma: f64) -> Self {
142 self.gamma = Some(gamma);
143 self
144 }
145 pub fn with_degree(self, degree: f64, n_features: usize) -> Self {
151 self.with_params(degree, 1f64, 1f64 / n_features as f64)
152 }
153}
154
155#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
157#[derive(Debug, Clone, PartialEq)]
158pub struct SigmoidKernel {
159 pub gamma: Option<f64>,
161 pub coef0: Option<f64>,
163}
164
165impl Default for SigmoidKernel {
166 fn default() -> Self {
167 Self {
168 gamma: Option::None,
169 coef0: Some(1f64),
170 }
171 }
172}
173
174impl SigmoidKernel {
175 pub fn with_params(mut self, gamma: f64, coef0: f64) -> Self {
181 self.gamma = Some(gamma);
182 self.coef0 = Some(coef0);
183 self
184 }
185 pub fn with_gamma(mut self, gamma: f64) -> Self {
191 self.gamma = Some(gamma);
192 self
193 }
194}
195
196#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
197impl Kernel for LinearKernel {
198 fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
199 Ok(x_i.dot(x_j))
200 }
201}
202
203#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
204impl Kernel for RBFKernel {
205 fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
206 if self.gamma.is_none() {
207 return Err(Failed::because(
208 FailedError::ParametersError,
209 "gamma should be set, use {Kernel}::default().with_gamma(..)",
210 ));
211 }
212 let v_diff = x_i.sub(x_j);
213 Ok((-self.gamma.unwrap() * v_diff.mul(&v_diff).sum()).exp())
214 }
215}
216
217#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
218impl Kernel for PolynomialKernel {
219 fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
220 if self.gamma.is_none() || self.coef0.is_none() || self.degree.is_none() {
221 return Err(Failed::because(
222 FailedError::ParametersError, "gamma, coef0, degree should be set,
223 use {Kernel}::default().with_{parameter}(..)")
224 );
225 }
226 let dot = x_i.dot(x_j);
227 Ok((self.gamma.unwrap() * dot + self.coef0.unwrap()).powf(self.degree.unwrap()))
228 }
229}
230
231#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
232impl Kernel for SigmoidKernel {
233 fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
234 if self.gamma.is_none() || self.coef0.is_none() {
235 return Err(Failed::because(
236 FailedError::ParametersError, "gamma, coef0, degree should be set,
237 use {Kernel}::default().with_{parameter}(..)")
238 );
239 }
240 let dot = x_i.dot(x_j);
241 Ok(self.gamma.unwrap() * dot + self.coef0.unwrap().tanh())
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248 use crate::svm::Kernels;
249
250 #[cfg_attr(
251 all(target_arch = "wasm32", not(target_os = "wasi")),
252 wasm_bindgen_test::wasm_bindgen_test
253 )]
254 #[test]
255 fn linear_kernel() {
256 let v1 = vec![1., 2., 3.];
257 let v2 = vec![4., 5., 6.];
258
259 assert_eq!(32f64, Kernels::linear().apply(&v1, &v2).unwrap());
260 }
261
262 #[cfg_attr(
263 all(target_arch = "wasm32", not(target_os = "wasi")),
264 wasm_bindgen_test::wasm_bindgen_test
265 )]
266 #[test]
267 fn rbf_kernel() {
268 let v1 = vec![1., 2., 3.];
269 let v2 = vec![4., 5., 6.];
270
271 let result = Kernels::rbf()
272 .with_gamma(0.055)
273 .apply(&v1, &v2)
274 .unwrap()
275 .abs();
276
277 assert!((0.2265f64 - result) < 1e-4);
278 }
279
280 #[cfg_attr(
281 all(target_arch = "wasm32", not(target_os = "wasi")),
282 wasm_bindgen_test::wasm_bindgen_test
283 )]
284 #[test]
285 fn polynomial_kernel() {
286 let v1 = vec![1., 2., 3.];
287 let v2 = vec![4., 5., 6.];
288
289 let result = Kernels::polynomial()
290 .with_params(3.0, 0.5, 1.0)
291 .apply(&v1, &v2)
292 .unwrap()
293 .abs();
294
295 assert!((4913f64 - result) < std::f64::EPSILON);
296 }
297
298 #[cfg_attr(
299 all(target_arch = "wasm32", not(target_os = "wasi")),
300 wasm_bindgen_test::wasm_bindgen_test
301 )]
302 #[test]
303 fn sigmoid_kernel() {
304 let v1 = vec![1., 2., 3.];
305 let v2 = vec![4., 5., 6.];
306
307 let result = Kernels::sigmoid()
308 .with_params(0.01, 0.1)
309 .apply(&v1, &v2)
310 .unwrap()
311 .abs();
312
313 assert!((0.3969f64 - result) < 1e-4);
314 }
315}