Skip to main content

rustant_core/channels/
style_tracker.rs

1//! Communication style tracking for channel senders.
2//!
3//! Analyzes message patterns per-sender to learn communication preferences:
4//! message length, formality, emoji usage, common greetings, and topics.
5//! Every N messages, generates `Fact` entries for long-term memory.
6
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// A tracked style profile for a single sender.
11#[derive(Debug, Clone, Default, Serialize, Deserialize)]
12pub struct SenderStyleProfile {
13    /// Sender identifier (e.g., Slack user ID).
14    pub sender_id: String,
15    /// Channel type (e.g., "slack", "email").
16    pub channel_type: String,
17    /// Total messages analyzed.
18    pub message_count: usize,
19    /// Average message length in characters.
20    pub avg_message_length: f64,
21    /// Formality score: 0.0 (very casual) to 1.0 (very formal).
22    pub formality_score: f64,
23    /// Whether the sender frequently uses emoji.
24    pub uses_emoji: bool,
25    /// Common greeting patterns observed.
26    pub common_greetings: Vec<String>,
27    /// Frequently discussed topics/keywords.
28    pub frequent_topics: Vec<String>,
29    /// Average response time in seconds (if measurable).
30    pub avg_response_time_secs: Option<f64>,
31}
32
33/// Tracks communication styles across multiple senders.
34#[derive(Debug, Default, Serialize, Deserialize)]
35pub struct CommunicationStyleTracker {
36    /// Per-sender style profiles.
37    pub profiles: HashMap<String, SenderStyleProfile>,
38    /// Total messages processed.
39    pub total_messages: usize,
40    /// Threshold: generate facts every N messages per sender.
41    pub fact_threshold: usize,
42}
43
44impl CommunicationStyleTracker {
45    /// Create a new tracker with the given fact generation threshold.
46    pub fn new(fact_threshold: usize) -> Self {
47        Self {
48            profiles: HashMap::new(),
49            total_messages: 0,
50            fact_threshold,
51        }
52    }
53
54    /// Analyze a message and update the sender's style profile.
55    ///
56    /// Returns a list of fact strings if the threshold is reached.
57    pub fn track_message(
58        &mut self,
59        sender_id: &str,
60        channel_type: &str,
61        message: &str,
62    ) -> Vec<String> {
63        self.total_messages += 1;
64
65        let profile = self
66            .profiles
67            .entry(sender_id.to_string())
68            .or_insert_with(|| SenderStyleProfile {
69                sender_id: sender_id.to_string(),
70                channel_type: channel_type.to_string(),
71                ..Default::default()
72            });
73
74        let msg_len = message.len() as f64;
75        let old_count = profile.message_count as f64;
76        profile.message_count += 1;
77        let new_count = profile.message_count as f64;
78
79        // Running average of message length
80        profile.avg_message_length = (profile.avg_message_length * old_count + msg_len) / new_count;
81
82        // Formality heuristic
83        let formality = compute_formality(message);
84        profile.formality_score = (profile.formality_score * old_count + formality) / new_count;
85
86        // Emoji detection
87        if contains_emoji(message) {
88            profile.uses_emoji = true;
89        }
90
91        // Greeting detection
92        let greeting = detect_greeting(message);
93        if let Some(g) = greeting {
94            if !profile.common_greetings.contains(&g) && profile.common_greetings.len() < 5 {
95                profile.common_greetings.push(g);
96            }
97        }
98
99        // Generate facts at threshold
100        let mut facts = Vec::new();
101        if profile.message_count > 0
102            && profile.message_count.is_multiple_of(self.fact_threshold)
103            && self.fact_threshold > 0
104        {
105            facts.push(format!(
106                "Sender '{}' on {} typically writes {} messages (avg {} chars). \
107                 Formality: {:.1}/1.0. Uses emoji: {}.",
108                profile.sender_id,
109                profile.channel_type,
110                if profile.avg_message_length > 200.0 {
111                    "long"
112                } else if profile.avg_message_length > 50.0 {
113                    "medium"
114                } else {
115                    "short"
116                },
117                profile.avg_message_length as usize,
118                profile.formality_score,
119                profile.uses_emoji,
120            ));
121
122            if !profile.common_greetings.is_empty() {
123                facts.push(format!(
124                    "Sender '{}' commonly greets with: {}",
125                    profile.sender_id,
126                    profile.common_greetings.join(", ")
127                ));
128            }
129        }
130
131        facts
132    }
133
134    /// Get a sender's style profile.
135    pub fn get_profile(&self, sender_id: &str) -> Option<&SenderStyleProfile> {
136        self.profiles.get(sender_id)
137    }
138
139    /// Get all tracked profiles.
140    pub fn all_profiles(&self) -> &HashMap<String, SenderStyleProfile> {
141        &self.profiles
142    }
143}
144
145/// Compute a formality score from 0.0 (casual) to 1.0 (formal).
146fn compute_formality(message: &str) -> f64 {
147    let mut score = 0.5_f64; // neutral baseline
148
149    // Formal indicators
150    if message.contains("Dear ") || message.contains("Regards") || message.contains("Sincerely") {
151        score += 0.2_f64;
152    }
153    if message.ends_with('.') || message.ends_with('!') {
154        score += 0.05_f64;
155    }
156    // Starts with capital letter
157    if message
158        .chars()
159        .next()
160        .map(|c| c.is_uppercase())
161        .unwrap_or(false)
162    {
163        score += 0.05_f64;
164    }
165
166    // Casual indicators
167    if message.contains("lol") || message.contains("haha") || message.contains("lmao") {
168        score -= 0.2_f64;
169    }
170    if message == message.to_lowercase() && message.len() > 10 {
171        score -= 0.1_f64;
172    }
173    if contains_emoji(message) {
174        score -= 0.05_f64;
175    }
176
177    score.clamp(0.0_f64, 1.0_f64)
178}
179
180/// Check if a string contains emoji characters.
181fn contains_emoji(s: &str) -> bool {
182    s.chars().any(|c| {
183        let cp = c as u32;
184        (0x1F600..=0x1F64F).contains(&cp) // Emoticons
185            || (0x1F300..=0x1F5FF).contains(&cp) // Misc symbols
186            || (0x1F680..=0x1F6FF).contains(&cp) // Transport
187            || (0x1F900..=0x1F9FF).contains(&cp) // Supplemental
188            || (0x2600..=0x26FF).contains(&cp) // Misc symbols
189            || (0x2700..=0x27BF).contains(&cp) // Dingbats
190    })
191}
192
193/// Detect greeting patterns at the start of a message.
194fn detect_greeting(message: &str) -> Option<String> {
195    let lower = message.to_lowercase();
196    let _first_word = lower.split_whitespace().next().unwrap_or("");
197
198    let greetings = [
199        "hi",
200        "hello",
201        "hey",
202        "good morning",
203        "good afternoon",
204        "good evening",
205        "greetings",
206        "howdy",
207        "sup",
208        "yo",
209    ];
210
211    for g in &greetings {
212        if lower.starts_with(g) {
213            return Some(g.to_string());
214        }
215    }
216
217    if lower.starts_with("dear") {
218        return Some("dear".to_string());
219    }
220
221    None
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227
228    #[test]
229    fn test_track_single_message() {
230        let mut tracker = CommunicationStyleTracker::new(50);
231        let facts = tracker.track_message("user1", "slack", "Hello, how are you today?");
232        assert!(facts.is_empty()); // Under threshold
233        assert_eq!(tracker.total_messages, 1);
234
235        let profile = tracker.get_profile("user1").unwrap();
236        assert_eq!(profile.message_count, 1);
237        assert!(profile.avg_message_length > 0.0);
238    }
239
240    #[test]
241    fn test_formality_formal() {
242        let score = compute_formality("Dear John, I hope this message finds you well. Regards.");
243        assert!(score > 0.6);
244    }
245
246    #[test]
247    fn test_formality_casual() {
248        let score = compute_formality("hey lol whats up haha");
249        assert!(score < 0.4);
250    }
251
252    #[test]
253    fn test_contains_emoji() {
254        assert!(!contains_emoji("Hello world"));
255        // Unicode emoji tests - using escape sequences
256        assert!(contains_emoji("Hello \u{1F600}"));
257    }
258
259    #[test]
260    fn test_detect_greeting() {
261        assert_eq!(detect_greeting("Hello there!"), Some("hello".to_string()));
262        assert_eq!(detect_greeting("hey what's up"), Some("hey".to_string()));
263        assert_eq!(detect_greeting("Thanks for the update"), None);
264    }
265
266    #[test]
267    fn test_fact_generation_at_threshold() {
268        let mut tracker = CommunicationStyleTracker::new(3);
269        tracker.track_message("user1", "slack", "Message 1");
270        tracker.track_message("user1", "slack", "Message 2");
271        let facts = tracker.track_message("user1", "slack", "Message 3");
272        assert!(!facts.is_empty()); // Should generate facts at count 3
273    }
274
275    #[test]
276    fn test_greeting_tracking() {
277        let mut tracker = CommunicationStyleTracker::new(50);
278        tracker.track_message("user1", "slack", "Hello everyone!");
279        tracker.track_message("user1", "slack", "Hey, quick question");
280        let profile = tracker.get_profile("user1").unwrap();
281        assert!(profile.common_greetings.contains(&"hello".to_string()));
282        assert!(profile.common_greetings.contains(&"hey".to_string()));
283    }
284}