rust_ml/builders/
logistic_regression.rs

1/// Builder implementation for LogisticRegression models.
2///
3/// This module provides a builder pattern implementation for creating LogisticRegression
4/// models with customizable configurations, such as feature count and activation function.
5/// The builder allows for fluent API-style configuration of model parameters before construction.
6use crate::builders::builder::Builder;
7use crate::core::activations::activation_functions::ActivationFn;
8use crate::core::error::ModelError;
9use crate::core::types::{Matrix, Vector};
10use crate::model::logistic_regression::LogisticRegression;
11
12// use crate::model::logistic_regression::LogisticRegression;
13
14/// Builder for creating LogisticRegression models with customizable configurations.
15///
16/// The LogisticRegressionBuilder provides methods to configure the properties of a
17/// LogisticRegression model before it is instantiated, following the Builder design pattern.
18///
19/// # Fields
20///
21/// * `n_features` - The number of input features for the logistic regression model
22/// * `activation_fn` - The activation function to use (default: Sigmoid)
23///
24/// # Examples
25///
26/// ```
27/// use rust_ml::model::logistic_regression::LogisticRegression;
28/// use rust_ml::core::activations::activation_functions::ActivationFn;
29/// use rust_ml::builders::builder::Builder;
30///
31/// // Create a logistic regression model with 4 features and sigmoid activation
32/// let model = LogisticRegression::builder()
33///     .n_features(4)
34///     .activation_function(ActivationFn::Sigmoid)
35///     .build()
36///     .unwrap();
37/// ```
38pub struct LogisticRegressionBuilder {
39    /// Number of input features for the model
40    n_features: usize,
41    /// Activation function to be used in the model
42    activation_fn: ActivationFn,
43    /// Classification threshold
44    threshold: f64,
45}
46
47impl LogisticRegressionBuilder {
48    /// Creates a new LogisticRegressionBuilder with default parameter values.
49    ///
50    /// The default configuration uses 1 feature and the Sigmoid activation function.
51    ///
52    /// # Returns
53    ///
54    /// * `Self` - A new LogisticRegressionBuilder instance with default settings
55    pub fn new() -> Self {
56        Self {
57            n_features: 1,
58            activation_fn: ActivationFn::Sigmoid,
59            threshold: 0.5,
60        }
61    }
62
63    /// Sets the number of input features for the logistic regression model.
64    ///
65    /// # Arguments
66    ///
67    /// * `n_features` - The number of independent variables (features) in the input data
68    ///
69    /// # Returns
70    ///
71    /// * `Self` - Builder instance with updated feature count for method chaining
72    pub fn n_features(mut self, n_features: usize) -> Self {
73        self.n_features = n_features;
74        self
75    }
76
77    /// Sets the activation function to use in the logistic regression model.
78    ///
79    /// While sigmoid is the traditional activation function for logistic regression,
80    /// other functions like ReLU or Tanh could be used for specific use cases.
81    ///
82    /// # Arguments
83    ///
84    /// * `activation_function` - The activation function to use
85    ///
86    /// # Returns
87    ///
88    /// * `Self` - Builder instance with updated activation function for method chaining
89    pub fn activation_function(mut self, activation_function: ActivationFn) -> Self {
90        self.activation_fn = activation_function;
91        self
92    }
93    /// Sets the classification threshold for the logistic regression model.
94    ///
95    /// # Arguments
96    ///
97    /// * `threshold` - The threshold value for classifying predictions (between 0 and 1)
98    ///
99    /// # Returns
100    ///
101    /// * `Self` - Builder instance with updated threshold for method chaining
102    pub fn threshold(mut self, threshold: f64) -> Self {
103        if !(0.0..=1.0).contains(&threshold) {
104            panic!("Threshold must be between 0 and 1");
105        }
106        self.threshold = threshold;
107        self
108    }
109}
110
111impl Builder<LogisticRegression, Matrix, Vector> for LogisticRegressionBuilder {
112    /// Builds and returns a new LogisticRegression model with the configured parameters.
113    ///
114    /// # Returns
115    ///
116    /// * `Result<LogisticRegression, ModelError>` - A new LogisticRegression instance with the
117    ///   specified configuration, or an error if construction fails
118    fn build(&self) -> Result<LogisticRegression, ModelError> {
119        Ok(LogisticRegression::new(
120            self.n_features,
121            self.activation_fn,
122            self.threshold,
123        ))
124    }
125}
126
127impl Default for LogisticRegressionBuilder {
128    /// Creates a new LogisticRegressionBuilder with default parameter values.
129    ///
130    /// The default configuration uses 1 feature and the Sigmoid activation function.
131    ///
132    /// # Returns
133    ///
134    /// * `Self` - A new LogisticRegressionBuilder instance with default settings
135    fn default() -> Self {
136        Self::new()
137    }
138}