1use std::collections::HashMap;
24use std::sync::Arc;
25use serde::{Serialize, Deserialize};
26use anyhow::Result;
27use thiserror::Error;
28
29#[derive(Error, Debug)]
30pub enum SalienceError {
31 #[error("Invalid token: {0}")]
32 InvalidToken(u32),
33 #[error("Computation failed: {0}")]
34 ComputationError(String),
35 #[error("Configuration error: {0}")]
36 ConfigError(String),
37 #[error("Memory allocation failed")]
38 MemoryError,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct SalienceConfig {
43 pub learning_rate: f64,
44 pub discount_factor: f64,
45 pub threshold: f64,
46 pub outer_loop_iterations: usize,
47 pub inner_loop_iterations: usize,
48 pub phoneme_preservation: bool,
49 pub enable_foraging: bool,
50 pub adaptive_threshold: bool,
51}
52
53impl Default for SalienceConfig {
54 fn default() -> Self {
55 Self {
56 learning_rate: 0.01,
57 discount_factor: 0.95,
58 threshold: 0.5,
59 outer_loop_iterations: 100,
60 inner_loop_iterations: 10,
61 phoneme_preservation: true,
62 enable_foraging: true,
63 adaptive_threshold: true,
64 }
65 }
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct SalienceResult {
70 pub token_id: u32,
71 pub salience_score: f32,
72 pub confidence: f32,
73 pub phoneme_preserved: bool,
74 pub foraging_probability: f32,
75 pub role_inference: Option<String>,
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct MesolimbicState {
80 pub dopamine_level: f64,
81 pub attention_focus: Vec<u32>,
82 pub reward_prediction: f64,
83 pub exploration_factor: f64,
84}
85
86impl Default for MesolimbicState {
87 fn default() -> Self {
88 Self {
89 dopamine_level: 0.5,
90 attention_focus: Vec::new(),
91 reward_prediction: 0.0,
92 exploration_factor: 0.1,
93 }
94 }
95}
96
97pub struct UnifiedSalienceSystem {
99 config: SalienceConfig,
100 state: MesolimbicState,
101 token_history: HashMap<u32, Vec<f32>>,
102 phoneme_patterns: HashMap<u32, Vec<u32>>,
103 role_mappings: HashMap<u32, String>,
104}
105
106impl UnifiedSalienceSystem {
107 pub fn new(config: SalienceConfig) -> Self {
108 Self {
109 config,
110 state: MesolimbicState::default(),
111 token_history: HashMap::new(),
112 phoneme_patterns: HashMap::new(),
113 role_mappings: HashMap::new(),
114 }
115 }
116
117 pub fn compute_salience(&mut self, tokens: &[u32]) -> Result<Vec<SalienceResult>, SalienceError> {
119 let mut results = Vec::with_capacity(tokens.len());
120
121 for &token_id in tokens {
122 let result = self.compute_token_salience(token_id)?;
123 results.push(result);
124 }
125
126 self.update_mesolimbic_state(&results);
128
129 Ok(results)
130 }
131
132 fn compute_token_salience(&mut self, token_id: u32) -> Result<SalienceResult, SalienceError> {
133 let base_salience = self.compute_base_salience(token_id);
135
136 let phoneme_preserved = if self.config.phoneme_preservation {
138 self.analyze_phoneme_preservation(token_id)
139 } else {
140 true
141 };
142
143 let foraging_probability = if self.config.enable_foraging {
145 self.compute_foraging_probability(token_id)
146 } else {
147 0.5
148 };
149
150 let role_inference = self.infer_token_role(token_id);
152
153 let confidence = self.compute_confidence(token_id, base_salience);
155
156 let final_salience = if self.config.adaptive_threshold {
158 self.apply_adaptive_threshold(base_salience, token_id)
159 } else {
160 base_salience
161 };
162
163 self.update_token_history(token_id, final_salience);
165
166 Ok(SalienceResult {
167 token_id,
168 salience_score: final_salience,
169 confidence,
170 phoneme_preserved,
171 foraging_probability,
172 role_inference,
173 })
174 }
175
176 fn compute_base_salience(&self, token_id: u32) -> f32 {
177 let frequency_factor = self.compute_frequency_factor(token_id);
179 let context_factor = self.compute_context_factor(token_id);
180 let novelty_factor = self.compute_novelty_factor(token_id);
181 let attention_factor = self.compute_attention_factor(token_id);
182
183 let salience = frequency_factor * 0.3
185 + context_factor * 0.3
186 + novelty_factor * 0.2
187 + attention_factor * 0.2;
188
189 salience.clamp(0.0, 1.0)
190 }
191
192 fn compute_frequency_factor(&self, token_id: u32) -> f32 {
193 let history = self.token_history.get(&token_id);
195 match history {
196 Some(hist) if !hist.is_empty() => {
197 let avg_occurrence = hist.len() as f32 / 1000.0; (1.0 - avg_occurrence).max(0.1)
199 }
200 _ => 0.8 }
202 }
203
204 fn compute_context_factor(&self, token_id: u32) -> f32 {
205 if self.state.attention_focus.contains(&token_id) {
207 0.9
208 } else {
209 let related_score = self.state.attention_focus.iter()
211 .map(|&focus_token| self.compute_token_similarity(token_id, focus_token))
212 .fold(0.0f32, |acc, sim| acc.max(sim));
213 related_score * 0.7
214 }
215 }
216
217 fn compute_novelty_factor(&self, token_id: u32) -> f32 {
218 match self.token_history.get(&token_id) {
220 Some(history) if !history.is_empty() => {
221 let recent_occurrences = history.iter().rev().take(10).count();
222 (10 - recent_occurrences) as f32 / 10.0
223 }
224 _ => 1.0 }
226 }
227
228 fn compute_attention_factor(&self, token_id: u32) -> f32 {
229 let dopamine_influence = (self.state.dopamine_level as f32).clamp(0.0, 1.0);
231 let exploration_influence = (self.state.exploration_factor as f32).clamp(0.0, 1.0);
232
233 let base_attention = 0.5;
235 base_attention + (dopamine_influence * 0.3) + (exploration_influence * 0.2)
236 }
237
238 fn compute_token_similarity(&self, token1: u32, token2: u32) -> f32 {
239 if token1 == token2 {
241 return 1.0;
242 }
243
244 if let (Some(pattern1), Some(pattern2)) = (
246 self.phoneme_patterns.get(&token1),
247 self.phoneme_patterns.get(&token2)
248 ) {
249 let common_phonemes = pattern1.iter()
250 .filter(|&p| pattern2.contains(p))
251 .count();
252 let total_phonemes = (pattern1.len() + pattern2.len()).max(1);
253 (common_phonemes * 2) as f32 / total_phonemes as f32
254 } else {
255 let diff = (token1 as i64 - token2 as i64).abs() as f32;
257 (1.0 / (1.0 + diff / 1000.0)).clamp(0.0, 1.0)
258 }
259 }
260
261 fn analyze_phoneme_preservation(&mut self, token_id: u32) -> bool {
262 if let Some(pattern) = self.phoneme_patterns.get(&token_id) {
264 let critical_phonemes = [1, 2, 3, 5, 8, 13]; pattern.iter().any(|&p| critical_phonemes.contains(&p))
267 } else {
268 let pattern = self.generate_phoneme_pattern(token_id);
270 let preserved = pattern.len() > 2; self.phoneme_patterns.insert(token_id, pattern);
272 preserved
273 }
274 }
275
276 fn generate_phoneme_pattern(&self, token_id: u32) -> Vec<u32> {
277 let mut pattern = Vec::new();
279 let mut id = token_id;
280
281 while id > 0 && pattern.len() < 5 {
283 pattern.push(id % 20); id /= 20;
285 }
286
287 if pattern.is_empty() {
288 pattern.push(0); }
290
291 pattern
292 }
293
294 fn compute_foraging_probability(&self, token_id: u32) -> f32 {
295 let mut total_probability = 0.0;
297
298 for _ in 0..self.config.outer_loop_iterations {
300 let mut inner_probability = 0.0;
301
302 for _ in 0..self.config.inner_loop_iterations {
304 let exploration_reward = self.compute_exploration_reward(token_id);
305 let exploitation_reward = self.compute_exploitation_reward(token_id);
306
307 let probability = exploration_reward * self.state.exploration_factor as f32
309 + exploitation_reward * (1.0 - self.state.exploration_factor as f32);
310
311 inner_probability += probability;
312 }
313
314 total_probability += inner_probability / self.config.inner_loop_iterations as f32;
315 }
316
317 (total_probability / self.config.outer_loop_iterations as f32).clamp(0.0, 1.0)
318 }
319
320 fn compute_exploration_reward(&self, token_id: u32) -> f32 {
321 let novelty = self.compute_novelty_factor(token_id);
323 let uncertainty = 1.0 - self.compute_confidence(token_id, 0.5);
324 (novelty + uncertainty) / 2.0
325 }
326
327 fn compute_exploitation_reward(&self, token_id: u32) -> f32 {
328 let history_value = self.token_history.get(&token_id)
330 .map(|hist| hist.iter().sum::<f32>() / hist.len() as f32)
331 .unwrap_or(0.5);
332
333 let attention_value = if self.state.attention_focus.contains(&token_id) { 0.8 } else { 0.2 };
334
335 (history_value + attention_value) / 2.0
336 }
337
338 fn infer_token_role(&mut self, token_id: u32) -> Option<String> {
339 if let Some(existing_role) = self.role_mappings.get(&token_id) {
341 return Some(existing_role.clone());
342 }
343
344 let role = if token_id < 100 {
346 "function_word"
347 } else if token_id < 1000 {
348 "content_word"
349 } else if token_id < 10000 {
350 "domain_specific"
351 } else {
352 "rare_token"
353 };
354
355 let role_string = role.to_string();
356 self.role_mappings.insert(token_id, role_string.clone());
357 Some(role_string)
358 }
359
360 fn compute_confidence(&self, token_id: u32, salience: f32) -> f32 {
361 let history_consistency = self.token_history.get(&token_id)
363 .map(|hist| {
364 if hist.len() < 2 {
365 0.5
366 } else {
367 let variance = self.compute_variance(hist);
368 (1.0 - variance).clamp(0.0, 1.0)
369 }
370 })
371 .unwrap_or(0.3);
372
373 let state_confidence = (self.state.dopamine_level as f32).clamp(0.0, 1.0);
374 let salience_confidence = salience;
375
376 (history_consistency + state_confidence + salience_confidence) / 3.0
377 }
378
379 fn compute_variance(&self, values: &[f32]) -> f32 {
380 if values.len() < 2 {
381 return 0.0;
382 }
383
384 let mean = values.iter().sum::<f32>() / values.len() as f32;
385 let variance = values.iter()
386 .map(|&x| (x - mean).powi(2))
387 .sum::<f32>() / values.len() as f32;
388 variance
389 }
390
391 fn apply_adaptive_threshold(&mut self, salience: f32, token_id: u32) -> f32 {
392 let base_threshold = self.config.threshold as f32;
394
395 let recent_avg = self.compute_recent_average_salience();
397 let adaptive_threshold = if recent_avg > base_threshold {
398 base_threshold * 1.1 } else {
400 base_threshold * 0.9 };
402
403 if salience > adaptive_threshold {
405 salience
406 } else {
407 salience * 0.8 }
409 }
410
411 fn compute_recent_average_salience(&self) -> f32 {
412 let recent_values: Vec<f32> = self.token_history.values()
413 .filter_map(|hist| hist.last().copied())
414 .collect();
415
416 if recent_values.is_empty() {
417 0.5
418 } else {
419 recent_values.iter().sum::<f32>() / recent_values.len() as f32
420 }
421 }
422
423 fn update_token_history(&mut self, token_id: u32, salience: f32) {
424 let history = self.token_history.entry(token_id).or_insert_with(Vec::new);
425 history.push(salience);
426
427 if history.len() > 100 {
429 history.remove(0);
430 }
431 }
432
433 fn update_mesolimbic_state(&mut self, results: &[SalienceResult]) {
434 let avg_salience = results.iter().map(|r| r.salience_score).sum::<f32>() / results.len() as f32;
436 let dopamine_delta = (avg_salience - 0.5) as f64 * self.config.learning_rate;
437 self.state.dopamine_level = (self.state.dopamine_level + dopamine_delta).clamp(0.0, 1.0);
438
439 self.state.attention_focus.clear();
441 for result in results {
442 if result.salience_score > self.config.threshold as f32 {
443 self.state.attention_focus.push(result.token_id);
444 }
445 }
446
447 let current_reward = avg_salience as f64;
449 let prediction_error = current_reward - self.state.reward_prediction;
450 self.state.reward_prediction += self.config.learning_rate * prediction_error;
451
452 if prediction_error.abs() > 0.1 {
454 self.state.exploration_factor = (self.state.exploration_factor + 0.01).min(0.3);
455 } else {
456 self.state.exploration_factor = (self.state.exploration_factor - 0.01).max(0.05);
457 }
458 }
459
460 pub fn get_state(&self) -> &MesolimbicState {
462 &self.state
463 }
464
465 pub fn reset(&mut self) {
467 self.state = MesolimbicState::default();
468 self.token_history.clear();
469 self.phoneme_patterns.clear();
470 self.role_mappings.clear();
471 }
472
473 pub fn update_config(&mut self, config: SalienceConfig) {
475 self.config = config;
476 }
477}
478
479pub fn create_salience_system(config: SalienceConfig) -> UnifiedSalienceSystem {
481 UnifiedSalienceSystem::new(config)
482}
483
484pub fn compute_token_salience(tokens: &[u32]) -> Result<Vec<SalienceResult>, SalienceError> {
486 let mut system = UnifiedSalienceSystem::new(SalienceConfig::default());
487 system.compute_salience(tokens)
488}