rusty_machine/learning/
dbscan.rs1use learning::{LearningResult, UnSupModel};
40use learning::error::{Error, ErrorKind};
41
42use linalg::{Matrix, Vector, BaseMatrix};
43use rulinalg::utils;
44
45#[derive(Debug)]
50pub struct DBSCAN {
51 eps: f64,
52 min_points: usize,
53 clusters: Option<Vector<Option<usize>>>,
54 predictive: bool,
55 _visited: Vec<bool>,
56 _cluster_data: Option<Matrix<f64>>,
57}
58
59impl Default for DBSCAN {
65 fn default() -> DBSCAN {
66 DBSCAN {
67 eps: 0.5,
68 min_points: 5,
69 clusters: None,
70 predictive: false,
71 _visited: Vec::new(),
72 _cluster_data: None,
73 }
74 }
75}
76
77impl UnSupModel<Matrix<f64>, Vector<Option<usize>>> for DBSCAN {
78 fn train(&mut self, inputs: &Matrix<f64>) -> LearningResult<()> {
80 self.init_params(inputs.rows());
81 let mut cluster = 0;
82
83 for (idx, point) in inputs.iter_rows().enumerate() {
84 let visited = self._visited[idx];
85
86 if !visited {
87 self._visited[idx] = true;
88
89 let neighbours = self.region_query(point, inputs);
90
91 if neighbours.len() >= self.min_points {
92 self.expand_cluster(inputs, idx, neighbours, cluster);
93 cluster += 1;
94 }
95 }
96 }
97
98 if self.predictive {
99 self._cluster_data = Some(inputs.clone());
100 }
101
102 Ok(())
103 }
104
105 fn predict(&self, inputs: &Matrix<f64>) -> LearningResult<Vector<Option<usize>>> {
106 if self.predictive {
107 if let (&Some(ref cluster_data), &Some(ref clusters)) = (&self._cluster_data,
108 &self.clusters) {
109 let mut classes = Vec::with_capacity(inputs.rows());
110
111 for input_point in inputs.iter_rows() {
112 let mut distances = Vec::with_capacity(cluster_data.rows());
113
114 for cluster_point in cluster_data.iter_rows() {
115 let point_distance =
116 utils::vec_bin_op(input_point, cluster_point, |x, y| x - y);
117 distances.push(utils::dot(&point_distance, &point_distance).sqrt());
118 }
119
120 let (closest_idx, closest_dist) = utils::argmin(&distances);
121 if closest_dist < self.eps {
122 classes.push(clusters[closest_idx]);
123 } else {
124 classes.push(None);
125 }
126 }
127
128 Ok(Vector::new(classes))
129 } else {
130 Err(Error::new_untrained())
131 }
132 } else {
133 Err(Error::new(ErrorKind::InvalidState,
134 "Model must be set to predictive. Use `self.set_predictive(true)`."))
135 }
136 }
137}
138
139impl DBSCAN {
140 pub fn new(eps: f64, min_points: usize) -> DBSCAN {
143 assert!(eps > 0f64, "The model epsilon must be positive.");
144
145 DBSCAN {
146 eps: eps,
147 min_points: min_points,
148 clusters: None,
149 predictive: false,
150 _visited: Vec::new(),
151 _cluster_data: None,
152 }
153 }
154
155 pub fn set_predictive(&mut self, predictive: bool) {
161 self.predictive = predictive;
162 }
163
164 pub fn clusters(&self) -> Option<&Vector<Option<usize>>> {
166 self.clusters.as_ref()
167 }
168
169 fn expand_cluster(&mut self,
170 inputs: &Matrix<f64>,
171 point_idx: usize,
172 neighbour_pts: Vec<usize>,
173 cluster: usize) {
174 debug_assert!(point_idx < inputs.rows(),
175 "Point index too large for inputs");
176 debug_assert!(neighbour_pts.iter().all(|x| *x < inputs.rows()),
177 "Neighbour indices too large for inputs");
178
179 self.clusters.as_mut().map(|x| x.mut_data()[point_idx] = Some(cluster));
180
181 for data_point_idx in &neighbour_pts {
182 let visited = self._visited[*data_point_idx];
183 if !visited {
184 self._visited[*data_point_idx] = true;
185 let data_point_row = unsafe { inputs.get_row_unchecked(*data_point_idx) };
186 let sub_neighbours = self.region_query(data_point_row, inputs);
187
188 if sub_neighbours.len() >= self.min_points {
189 self.expand_cluster(inputs, *data_point_idx, sub_neighbours, cluster);
190 }
191 }
192 }
193 }
194
195
196 fn region_query(&self, point: &[f64], inputs: &Matrix<f64>) -> Vec<usize> {
197 debug_assert!(point.len() == inputs.cols(),
198 "point must be of same dimension as inputs");
199
200 let mut in_neighbourhood = Vec::new();
201 for (idx, data_point) in inputs.iter_rows().enumerate() {
202 let point_distance = utils::vec_bin_op(data_point, point, |x, y| x - y);
203 let dist = utils::dot(&point_distance, &point_distance).sqrt();
204
205 if dist < self.eps {
206 in_neighbourhood.push(idx);
207 }
208 }
209
210 in_neighbourhood
211 }
212
213 fn init_params(&mut self, total_points: usize) {
214 unsafe {
215 self._visited.reserve(total_points);
216 self._visited.set_len(total_points);
217 }
218
219 for i in 0..total_points {
220 self._visited[i] = false;
221 }
222
223 self.clusters = Some(Vector::new(vec![None; total_points]));
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use super::DBSCAN;
230 use linalg::Matrix;
231
232 #[test]
233 fn test_region_query() {
234 let model = DBSCAN::new(1.0, 3);
235
236 let inputs = Matrix::new(3, 2, vec![1.0, 1.0, 1.1, 1.9, 3.0, 3.0]);
237
238 let neighbours = model.region_query(&[1.0, 1.0], &inputs);
239
240 assert!(neighbours.len() == 2);
241 }
242
243 #[test]
244 fn test_region_query_small_eps() {
245 let model = DBSCAN::new(0.01, 3);
246
247 let inputs = Matrix::new(3, 2, vec![1.0, 1.0, 1.1, 1.9, 1.1, 1.1]);
248
249 let neighbours = model.region_query(&[1.0, 1.0], &inputs);
250
251 assert!(neighbours.len() == 1);
252 }
253}