Skip to main content

web_page_classifier/
lib.rs

1//! Fast web page type classification.
2//!
3//! Classifies web pages into 7 types using a compact XGBoost model:
4//! Article, Forum, Product, Collection, Listing, Documentation, Service.
5//!
6//! # Quick Start
7//!
8//! ```
9//! use web_page_classifier::{PageType, classify_url};
10//!
11//! let page_type = classify_url("https://docs.example.com/api/reference");
12//! assert_eq!(page_type, PageType::Documentation);
13//! ```
14//!
15//! # ML Classification
16//!
17//! For higher accuracy, extract numeric features from the HTML DOM and pass
18//! them along with title/description text:
19//!
20//! ```
21//! use web_page_classifier::{classify_ml, N_NUMERIC_FEATURES};
22//!
23//! let features = vec![0.0f64; N_NUMERIC_FEATURES]; // your extracted features
24//! let (page_type, confidence) = classify_ml(&features, "Example Article Title");
25//! ```
26
27mod model;
28pub mod url_heuristics;
29
30use std::sync::OnceLock;
31
32pub use url_heuristics::classify_url;
33
34/// Number of numeric features expected by the ML model.
35pub const N_NUMERIC_FEATURES: usize = 89;
36
37/// Web page type classification.
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
39pub enum PageType {
40    Article,
41    Collection,
42    Documentation,
43    Forum,
44    Listing,
45    Product,
46    Service,
47}
48
49impl PageType {
50    /// Parse from string (case-insensitive). Returns None for unknown types.
51    #[must_use]
52    pub fn parse(s: &str) -> Option<Self> {
53        match s.to_ascii_lowercase().as_str() {
54            "article" => Some(Self::Article),
55            "collection" | "category" => Some(Self::Collection),
56            "documentation" | "docs" => Some(Self::Documentation),
57            "forum" => Some(Self::Forum),
58            "listing" => Some(Self::Listing),
59            "product" => Some(Self::Product),
60            "service" => Some(Self::Service),
61            _ => None,
62        }
63    }
64
65    /// Return the type name as a lowercase string.
66    #[must_use]
67    pub fn as_str(&self) -> &'static str {
68        match self {
69            Self::Article => "article",
70            Self::Collection => "collection",
71            Self::Documentation => "documentation",
72            Self::Forum => "forum",
73            Self::Listing => "listing",
74            Self::Product => "product",
75            Self::Service => "service",
76        }
77    }
78}
79
80impl std::fmt::Display for PageType {
81    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82        f.write_str(self.as_str())
83    }
84}
85
86impl std::str::FromStr for PageType {
87    type Err = String;
88
89    fn from_str(s: &str) -> Result<Self, Self::Err> {
90        Self::parse(s).ok_or_else(|| format!("unknown page type: {s}"))
91    }
92}
93
94/// Number of features expected by the quality predictor.
95pub const N_QUALITY_FEATURES: usize = 27;
96
97/// Embedded classifier model binary.
98static MODEL_BYTES: &[u8] = include_bytes!("xgboost_v2.bin");
99
100/// Embedded quality predictor model binary.
101static QUALITY_MODEL_BYTES: &[u8] = include_bytes!("quality_model_v1.bin");
102
103/// Lazily-initialized classifier model.
104static MODEL: OnceLock<model::Model> = OnceLock::new();
105
106/// Lazily-initialized quality model.
107static QUALITY_MODEL: OnceLock<model::QualityModel> = OnceLock::new();
108
109fn get_model() -> &'static model::Model {
110    MODEL.get_or_init(|| {
111        model::Model::from_bytes(MODEL_BYTES)
112            .expect("embedded classifier model is valid")
113    })
114}
115
116fn get_quality_model() -> &'static model::QualityModel {
117    QUALITY_MODEL.get_or_init(|| {
118        model::QualityModel::from_bytes(QUALITY_MODEL_BYTES)
119            .expect("embedded quality model is valid")
120    })
121}
122
123/// Classify a web page using the ML model.
124///
125/// # Arguments
126/// * `numeric_features` - Raw (unscaled) numeric features. Must have length
127///   [`N_NUMERIC_FEATURES`]. The model handles scaling internally.
128/// * `title_meta` - Concatenated title + description text for TF-IDF features.
129///
130/// # Returns
131/// `(PageType, confidence)` where confidence is in `[0.0, 1.0]`.
132///
133/// # Panics
134/// Panics if `numeric_features.len() != N_NUMERIC_FEATURES`.
135#[must_use]
136pub fn classify_ml(numeric_features: &[f64], title_meta: &str) -> (PageType, f64) {
137    assert_eq!(
138        numeric_features.len(),
139        N_NUMERIC_FEATURES,
140        "Expected {} numeric features, got {}",
141        N_NUMERIC_FEATURES,
142        numeric_features.len()
143    );
144
145    let m = get_model();
146
147    // Scale numeric features
148    let scaled = m.scale_features(numeric_features);
149
150    // Compute TF-IDF features
151    let tfidf = m.compute_tfidf(title_meta);
152
153    // Combine into full feature vector
154    let mut all_features = Vec::with_capacity(scaled.len() + tfidf.len());
155    all_features.extend_from_slice(&scaled);
156    all_features.extend_from_slice(&tfidf);
157
158    // Run forest prediction
159    let (class_idx, confidence) = m.predict(&all_features);
160
161    let page_type = m.class_labels.get(class_idx)
162        .and_then(|s| PageType::parse(s))
163        .unwrap_or(PageType::Article);
164
165    (page_type, confidence)
166}
167
168/// Predict extraction quality (estimated F1 score) from post-extraction features.
169///
170/// Returns a value in `[0.0, 1.0]` estimating how well the extraction captured
171/// the page's main content. Low scores (< 0.80) indicate the extraction may be
172/// poor and should be routed to an LLM fallback.
173///
174/// # Arguments
175/// * `features` - Raw (unscaled) quality features. Must have length
176///   [`N_QUALITY_FEATURES`]. Features include content statistics, page type
177///   indicators, and HTML-level signals.
178///
179/// # Feature order (27 features)
180/// 0: heuristic_conf, 1: content_len, 2: word_count, 3: vocab_ratio,
181/// 4: avg_word_len, 5: sentence_count, 6: avg_sentence_len,
182/// 7: sentence_uniqueness, 8: paragraph_count, 9: avg_paragraph_len,
183/// 10: link_count_in_content, 11: link_density, 12: boilerplate_keywords,
184/// 13-19: is_article..is_service (one-hot page type),
185/// 20: length_ratio, 21: html_size, 22: extraction_ratio,
186/// 23: og_overlap, 24: script_count, 25: has_jsonld, 26: top_bigram_freq
187///
188/// # Panics
189/// Panics if `features.len() != N_QUALITY_FEATURES`.
190#[must_use]
191pub fn predict_quality(features: &[f64]) -> f64 {
192    assert_eq!(
193        features.len(),
194        N_QUALITY_FEATURES,
195        "Expected {} quality features, got {}",
196        N_QUALITY_FEATURES,
197        features.len()
198    );
199
200    let m = get_quality_model();
201    let scaled = m.scale_features(features);
202    let predicted = m.predict(&scaled);
203    predicted.clamp(0.0, 1.0)
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209
210    #[test]
211    fn test_model_loads() {
212        let m = get_model();
213        assert_eq!(m.n_classes, 7);
214        assert!(!m.trees.is_empty());
215    }
216
217    #[test]
218    fn test_classify_ml_returns_valid_type() {
219        let features = vec![0.0f64; N_NUMERIC_FEATURES];
220        let (page_type, confidence) = classify_ml(&features, "Example blog post about technology");
221        assert!(confidence >= 0.0 && confidence <= 1.0);
222        // With all-zero features and "blog" in text, should lean toward article
223        assert_eq!(page_type.as_str().is_empty(), false);
224    }
225
226    #[test]
227    fn test_classify_url_basic() {
228        assert_eq!(classify_url("https://forum.example.com/thread/123"), PageType::Forum);
229        assert_eq!(classify_url("https://docs.example.com/api"), PageType::Documentation);
230        assert_eq!(classify_url("https://example.com/products/widget"), PageType::Product);
231    }
232
233    #[test]
234    fn test_page_type_display() {
235        assert_eq!(PageType::Article.to_string(), "article");
236        assert_eq!(PageType::Forum.as_str(), "forum");
237    }
238
239    #[test]
240    fn test_page_type_parse() {
241        assert_eq!(PageType::parse("article"), Some(PageType::Article));
242        assert_eq!(PageType::parse("FORUM"), Some(PageType::Forum));
243        assert_eq!(PageType::parse("category"), Some(PageType::Collection));
244        assert_eq!(PageType::parse("unknown"), None);
245    }
246
247    #[test]
248    fn test_quality_model_loads() {
249        // Just verify it parses without panicking
250        let _ = get_quality_model();
251    }
252
253    #[test]
254    fn test_predict_quality() {
255        let features = vec![0.0f64; N_QUALITY_FEATURES];
256        let quality = predict_quality(&features);
257        assert!(quality >= 0.0 && quality <= 1.0);
258    }
259
260    #[test]
261    fn test_predict_quality_good_extraction() {
262        // Simulate a good article extraction
263        let mut features = vec![0.0f64; N_QUALITY_FEATURES];
264        features[0] = 0.95;  // heuristic_conf
265        features[1] = 8000.0; // content_len
266        features[2] = 1200.0; // word_count
267        features[3] = 0.55;  // vocab_ratio
268        features[5] = 40.0;  // sentence_count
269        features[8] = 15.0;  // paragraph_count
270        features[13] = 1.0;  // is_article
271        features[20] = 0.8;  // length_ratio
272        let quality = predict_quality(&features);
273        // With partial features, score may not be high, but should be valid
274        assert!(quality >= 0.0 && quality <= 1.0, "Score should be in [0,1], got {quality}");
275    }
276
277    #[test]
278    fn test_page_type_from_str_trait() {
279        assert_eq!("article".parse::<PageType>(), Ok(PageType::Article));
280        assert!("unknown".parse::<PageType>().is_err());
281    }
282}