rustlearn/
lib.rs

1//! A machine learning crate for Rust.
2//!
3//!
4//! # Introduction
5//!
6//! This crate contains reasonably effective implementations
7//! of a number of common machine learing algorithms.
8//!
9//! At the moment, `rustlearn` uses its own basic dense and sparse array types, but I will be happy
10//! to use something more robust once a clear winner in that space emerges.
11//!
12//! # Features
13//!
14//! ## Matrix primitives
15//!
16//! - [dense matrices](array/dense/index.html)
17//! - [sparse matrices](array/sparse/index.html)
18//!
19//! ## Models
20//!
21//! - [logistic regression](linear_models/sgdclassifier/index.html) using stochastic gradient descent,
22//! - [support vector machines](svm/libsvm/svc/index.html) using the `libsvm` library,
23//! - [decision trees](trees/decision_tree/index.html) using the CART algorithm,
24//! - [random forests](ensemble/random_forest/index.html) using CART decision trees, and
25//! - [factorization machines](factorization/factorization_machines/index.html).
26//!
27//! All the models support fitting and prediction on both dense and sparse data, and the implementations
28//! should be roughly competitive with Python `sklearn` implementations, both in accuracy and performance.
29//!
30//! ## Cross-validation
31//!
32//! - [k-fold cross-validation](cross_validation/cross_validation/index.html)
33//! - [shuffle split](cross_validation/shuffle_split/index.html)
34//!
35//! ## Metrics
36//!
37//! - [accuracy](metrics/fn.accuracy_score.html)
38//! - [mean_absolute_error](metrics/fn.mean_absolute_error.html)
39//! - [mean_squared_error](metrics/fn.mean_squared_error.html)
40//! - [ROC AUC score](metrics/ranking/fn.roc_auc_score.html)
41//! - [dcg_score](metrics/ranking/fn.dcg_score.html)
42//! - [ndcg_score](metrics/ranking/fn.ndcg_score.html)
43//!
44//! ## Parallelization
45//!
46//! A number of models support both parallel model fitting and prediction.
47//!
48//! ## Model serialization
49//!
50//! Model serialization is supported via `serde`.
51//!
52//! # Using `rustlearn`
53//! Usage should be straightforward.
54//!
55//! - import the prelude for alll the linear algebra primitives and common traits:
56//!
57//! ```
58//! use rustlearn::prelude::*;
59//! ```
60//!
61//! - import individual models and utilities from submodules:
62//!
63//! ```
64//! use rustlearn::prelude::*;
65//!
66//! use rustlearn::linear_models::sgdclassifier::Hyperparameters;
67//! // more imports
68//! ```
69//!
70//! # Examples
71//!
72//! ## Logistic regression
73//!
74//! ```
75//! use rustlearn::prelude::*;
76//! use rustlearn::datasets::iris;
77//! use rustlearn::cross_validation::CrossValidation;
78//! use rustlearn::linear_models::sgdclassifier::Hyperparameters;
79//! use rustlearn::metrics::accuracy_score;
80//!
81//!
82//! let (X, y) = iris::load_data();
83//!
84//! let num_splits = 10;
85//! let num_epochs = 5;
86//!
87//! let mut accuracy = 0.0;
88//!
89//! for (train_idx, test_idx) in CrossValidation::new(X.rows(), num_splits) {
90//!
91//!     let X_train = X.get_rows(&train_idx);
92//!     let y_train = y.get_rows(&train_idx);
93//!     let X_test = X.get_rows(&test_idx);
94//!     let y_test = y.get_rows(&test_idx);
95//!
96//!     let mut model = Hyperparameters::new(X.cols())
97//!                                     .learning_rate(0.5)
98//!                                     .l2_penalty(0.0)
99//!                                     .l1_penalty(0.0)
100//!                                     .one_vs_rest();
101//!
102//!     for _ in 0..num_epochs {
103//!         model.fit(&X_train, &y_train).unwrap();
104//!     }
105//!
106//!     let prediction = model.predict(&X_test).unwrap();
107//!     accuracy += accuracy_score(&y_test, &prediction);
108//! }
109//!
110//! accuracy /= num_splits as f32;
111//!
112//! ```
113//!
114//! ## Random forest
115//!
116//! ```
117//! use rustlearn::prelude::*;
118//!
119//! use rustlearn::ensemble::random_forest::Hyperparameters;
120//! use rustlearn::datasets::iris;
121//! use rustlearn::trees::decision_tree;
122//!
123//! let (data, target) = iris::load_data();
124//!
125//! let mut tree_params = decision_tree::Hyperparameters::new(data.cols());
126//! tree_params.min_samples_split(10)
127//!     .max_features(4);
128//!
129//! let mut model = Hyperparameters::new(tree_params, 10)
130//!     .one_vs_rest();
131//!
132//! model.fit(&data, &target).unwrap();
133//!
134//! // Optionally serialize and deserialize the model
135//!
136//! // let encoded = bincode::serialize(&model).unwrap();
137//! // let decoded: OneVsRestWrapper<RandomForest> = bincode::deserialize(&encoded).unwrap();
138//!
139//! let prediction = model.predict(&data).unwrap();
140//! ```
141
142// Only use unstable features when we are benchmarking
143#![cfg_attr(feature = "bench", feature(test))]
144// Allow conventional capital X for feature arrays.
145#![allow(non_snake_case)]
146
147#[cfg(feature = "bench")]
148extern crate test;
149
150#[cfg(test)]
151extern crate bincode;
152
153#[cfg(test)]
154extern crate csv;
155
156#[cfg(test)]
157extern crate serde_json;
158
159extern crate crossbeam;
160extern crate rand;
161extern crate serde;
162#[macro_use]
163extern crate serde_derive;
164
165pub mod array;
166pub mod cross_validation;
167pub mod datasets;
168pub mod ensemble;
169pub mod factorization;
170pub mod feature_extraction;
171pub mod linear_models;
172pub mod metrics;
173pub mod multiclass;
174pub mod svm;
175pub mod traits;
176pub mod trees;
177pub mod utils;
178
179#[allow(unused_imports)]
180pub mod prelude {
181    //! Basic data structures and traits used throughout `rustlearn`.
182    pub use array::prelude::*;
183    pub use traits::*;
184}