Skip to main content

tt_preview/
classifier.rs

1//! Cheap regex-based task classifier. Pattern-matches the last user message
2//! against indicators of: classification, extraction, code, agent, generic chat.
3//! The same patterns inform `output-no-max-tokens` Inspect rule fix hints.
4
5use crate::types::Message;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum TaskClass {
9    Classification,
10    Extraction,
11    Code,
12    Agent,
13    Chat,
14}
15
16pub fn classify(messages: &[Message]) -> TaskClass {
17    let last_user = messages
18        .iter()
19        .rev()
20        .find(|m| m.role == "user")
21        .and_then(|m| m.content.as_str())
22        .map(|s| s.to_lowercase())
23        .unwrap_or_default();
24
25    if last_user.contains("classify")
26        || last_user.contains("category")
27        || last_user.contains("label as")
28        || last_user.contains("is this")
29        || last_user.contains("yes or no")
30    {
31        return TaskClass::Classification;
32    }
33    if last_user.contains("extract")
34        || last_user.contains("parse")
35        || last_user.contains("structured output")
36        || last_user.contains("json schema")
37        || last_user.contains("entity")
38        || last_user.contains("pull out")
39    {
40        return TaskClass::Extraction;
41    }
42    if last_user.contains("function")
43        || last_user.contains("code")
44        || last_user.contains("```")
45        || last_user.contains("refactor")
46        || last_user.contains("implement")
47    {
48        return TaskClass::Code;
49    }
50    if messages.len() > 4 && messages.iter().any(|m| m.role == "tool") {
51        return TaskClass::Agent;
52    }
53    TaskClass::Chat
54}
55
56#[cfg(test)]
57mod tests {
58    use super::*;
59    use serde_json::json;
60
61    fn u(t: &str) -> Message {
62        Message {
63            role: "user".into(),
64            content: json!(t),
65        }
66    }
67
68    #[test]
69    fn classify_classification() {
70        assert_eq!(
71            classify(&[u("Is this email spam? yes or no")]),
72            TaskClass::Classification
73        );
74    }
75
76    #[test]
77    fn classify_extraction() {
78        assert_eq!(
79            classify(&[u("Extract the names from this text")]),
80            TaskClass::Extraction
81        );
82    }
83
84    #[test]
85    fn classify_code() {
86        assert_eq!(
87            classify(&[u("write a function that adds two numbers")]),
88            TaskClass::Code
89        );
90    }
91
92    #[test]
93    fn classify_chat_default() {
94        assert_eq!(classify(&[u("Hi how are you")]), TaskClass::Chat);
95    }
96}