web_page_classifier/
lib.rs1mod model;
28pub mod url_heuristics;
29
30use std::sync::OnceLock;
31
32pub use url_heuristics::classify_url;
33
34pub const N_NUMERIC_FEATURES: usize = 89;
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
39pub enum PageType {
40 Article,
41 Collection,
42 Documentation,
43 Forum,
44 Listing,
45 Product,
46 Service,
47}
48
49impl PageType {
50 #[must_use]
52 pub fn parse(s: &str) -> Option<Self> {
53 match s.to_ascii_lowercase().as_str() {
54 "article" => Some(Self::Article),
55 "collection" | "category" => Some(Self::Collection),
56 "documentation" | "docs" => Some(Self::Documentation),
57 "forum" => Some(Self::Forum),
58 "listing" => Some(Self::Listing),
59 "product" => Some(Self::Product),
60 "service" => Some(Self::Service),
61 _ => None,
62 }
63 }
64
65 #[must_use]
67 pub fn as_str(&self) -> &'static str {
68 match self {
69 Self::Article => "article",
70 Self::Collection => "collection",
71 Self::Documentation => "documentation",
72 Self::Forum => "forum",
73 Self::Listing => "listing",
74 Self::Product => "product",
75 Self::Service => "service",
76 }
77 }
78}
79
80impl std::fmt::Display for PageType {
81 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82 f.write_str(self.as_str())
83 }
84}
85
86impl std::str::FromStr for PageType {
87 type Err = String;
88
89 fn from_str(s: &str) -> Result<Self, Self::Err> {
90 Self::parse(s).ok_or_else(|| format!("unknown page type: {s}"))
91 }
92}
93
94pub const N_QUALITY_FEATURES: usize = 27;
96
97static MODEL_BYTES: &[u8] = include_bytes!("xgboost_v2.bin");
99
100static QUALITY_MODEL_BYTES: &[u8] = include_bytes!("quality_model_v1.bin");
102
103static MODEL: OnceLock<model::Model> = OnceLock::new();
105
106static QUALITY_MODEL: OnceLock<model::QualityModel> = OnceLock::new();
108
109fn get_model() -> &'static model::Model {
110 MODEL.get_or_init(|| {
111 model::Model::from_bytes(MODEL_BYTES)
112 .expect("embedded classifier model is valid")
113 })
114}
115
116fn get_quality_model() -> &'static model::QualityModel {
117 QUALITY_MODEL.get_or_init(|| {
118 model::QualityModel::from_bytes(QUALITY_MODEL_BYTES)
119 .expect("embedded quality model is valid")
120 })
121}
122
123#[must_use]
136pub fn classify_ml(numeric_features: &[f64], title_meta: &str) -> (PageType, f64) {
137 assert_eq!(
138 numeric_features.len(),
139 N_NUMERIC_FEATURES,
140 "Expected {} numeric features, got {}",
141 N_NUMERIC_FEATURES,
142 numeric_features.len()
143 );
144
145 let m = get_model();
146
147 let scaled = m.scale_features(numeric_features);
149
150 let tfidf = m.compute_tfidf(title_meta);
152
153 let mut all_features = Vec::with_capacity(scaled.len() + tfidf.len());
155 all_features.extend_from_slice(&scaled);
156 all_features.extend_from_slice(&tfidf);
157
158 let (class_idx, confidence) = m.predict(&all_features);
160
161 let page_type = m.class_labels.get(class_idx)
162 .and_then(|s| PageType::parse(s))
163 .unwrap_or(PageType::Article);
164
165 (page_type, confidence)
166}
167
168#[must_use]
191pub fn predict_quality(features: &[f64]) -> f64 {
192 assert_eq!(
193 features.len(),
194 N_QUALITY_FEATURES,
195 "Expected {} quality features, got {}",
196 N_QUALITY_FEATURES,
197 features.len()
198 );
199
200 let m = get_quality_model();
201 let scaled = m.scale_features(features);
202 let predicted = m.predict(&scaled);
203 predicted.clamp(0.0, 1.0)
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209
210 #[test]
211 fn test_model_loads() {
212 let m = get_model();
213 assert_eq!(m.n_classes, 7);
214 assert!(!m.trees.is_empty());
215 }
216
217 #[test]
218 fn test_classify_ml_returns_valid_type() {
219 let features = vec![0.0f64; N_NUMERIC_FEATURES];
220 let (page_type, confidence) = classify_ml(&features, "Example blog post about technology");
221 assert!(confidence >= 0.0 && confidence <= 1.0);
222 assert_eq!(page_type.as_str().is_empty(), false);
224 }
225
226 #[test]
227 fn test_classify_url_basic() {
228 assert_eq!(classify_url("https://forum.example.com/thread/123"), PageType::Forum);
229 assert_eq!(classify_url("https://docs.example.com/api"), PageType::Documentation);
230 assert_eq!(classify_url("https://example.com/products/widget"), PageType::Product);
231 }
232
233 #[test]
234 fn test_page_type_display() {
235 assert_eq!(PageType::Article.to_string(), "article");
236 assert_eq!(PageType::Forum.as_str(), "forum");
237 }
238
239 #[test]
240 fn test_page_type_parse() {
241 assert_eq!(PageType::parse("article"), Some(PageType::Article));
242 assert_eq!(PageType::parse("FORUM"), Some(PageType::Forum));
243 assert_eq!(PageType::parse("category"), Some(PageType::Collection));
244 assert_eq!(PageType::parse("unknown"), None);
245 }
246
247 #[test]
248 fn test_quality_model_loads() {
249 let _ = get_quality_model();
251 }
252
253 #[test]
254 fn test_predict_quality() {
255 let features = vec![0.0f64; N_QUALITY_FEATURES];
256 let quality = predict_quality(&features);
257 assert!(quality >= 0.0 && quality <= 1.0);
258 }
259
260 #[test]
261 fn test_predict_quality_good_extraction() {
262 let mut features = vec![0.0f64; N_QUALITY_FEATURES];
264 features[0] = 0.95; features[1] = 8000.0; features[2] = 1200.0; features[3] = 0.55; features[5] = 40.0; features[8] = 15.0; features[13] = 1.0; features[20] = 0.8; let quality = predict_quality(&features);
273 assert!(quality >= 0.0 && quality <= 1.0, "Score should be in [0,1], got {quality}");
275 }
276
277 #[test]
278 fn test_page_type_from_str_trait() {
279 assert_eq!("article".parse::<PageType>(), Ok(PageType::Article));
280 assert!("unknown".parse::<PageType>().is_err());
281 }
282}