1use crate::types::*;
4use crate::Result;
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
9#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
10pub enum ComplexityLevel {
11 Trivial,
13 Simple,
15 Moderate,
17 Complex,
19 VeryComplex,
21}
22
23impl std::fmt::Display for ComplexityLevel {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 match self {
26 ComplexityLevel::Trivial => write!(f, "TRIVIAL"),
27 ComplexityLevel::Simple => write!(f, "SIMPLE"),
28 ComplexityLevel::Moderate => write!(f, "MODERATE"),
29 ComplexityLevel::Complex => write!(f, "COMPLEX"),
30 ComplexityLevel::VeryComplex => write!(f, "VERY_COMPLEX"),
31 }
32 }
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37#[serde(rename_all = "camelCase")]
38pub struct TokenEstimate {
39 pub input_tokens: usize,
41 pub output_tokens: usize,
43 pub total_tokens: usize,
45 pub confidence: f32,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
51#[serde(rename_all = "camelCase")]
52pub struct ComplexityAssessment {
53 pub level: ComplexityLevel,
55 pub reasoning: String,
57 pub estimated_steps: usize,
59 pub estimated_tokens: TokenEstimate,
61 pub confidence: f32,
63 pub factors: ComplexityFactors,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
69#[serde(rename_all = "camelCase")]
70pub struct ComplexityFactors {
71 pub length_score: f32,
73 pub question_score: f32,
75 pub domain_score: f32,
77 pub context_score: f32,
79 pub reasoning_score: f32,
81}
82
83impl ComplexityFactors {
84 pub fn average(&self) -> f32 {
86 (self.length_score
87 + self.question_score
88 + self.domain_score
89 + self.context_score
90 + self.reasoning_score)
91 / 5.0
92 }
93}
94
95pub struct ComplexityAnalyzer;
97
98impl ComplexityAnalyzer {
99 pub fn new() -> Self {
101 Self
102 }
103
104 pub async fn assess(&self, message: &Memory, state: &State) -> Result<ComplexityAssessment> {
106 let text = &message.content.text;
107
108 let factors = ComplexityFactors {
110 length_score: self.analyze_length(text),
111 question_score: self.analyze_questions(text),
112 domain_score: self.analyze_domain(text),
113 context_score: self.analyze_context_needed(text, state),
114 reasoning_score: self.analyze_reasoning_depth(text),
115 };
116
117 let level = self.determine_level(&factors);
119
120 let estimated_steps = self.estimate_steps(&level, &factors);
122 let estimated_tokens = self.estimate_tokens(&level, &factors, text);
123
124 let confidence = self.calculate_confidence(&factors);
126
127 let reasoning = self.build_reasoning(&factors, &level);
129
130 Ok(ComplexityAssessment {
131 level,
132 reasoning,
133 estimated_steps,
134 estimated_tokens,
135 confidence,
136 factors,
137 })
138 }
139
140 fn analyze_length(&self, text: &str) -> f32 {
142 let words = text.split_whitespace().count();
143
144 match words {
146 0..=5 => 0.1, 6..=15 => 0.2, 16..=50 => 0.4, 51..=150 => 0.6, 151..=300 => 0.8, _ => 1.0, }
153 }
154
155 fn analyze_questions(&self, text: &str) -> f32 {
157 let lower = text.to_lowercase();
158 let mut score = 0.0;
159
160 let question_marks = text.matches('?').count();
162 score += (question_marks as f32 * 0.2).min(0.4);
163
164 let complex_patterns = [
166 "how do i",
167 "how can i",
168 "how would",
169 "why does",
170 "why is",
171 "why would",
172 "what's the difference",
173 "what is the best way",
174 "can you explain",
175 "could you help me understand",
176 "multiple",
177 "several",
178 "various",
179 ];
180
181 for pattern in &complex_patterns {
182 if lower.contains(pattern) {
183 score += 0.15;
184 }
185 }
186
187 if lower.contains(" and ") || lower.contains(" or ") {
189 score += 0.2;
190 }
191
192 score.min(1.0)
193 }
194
195 fn analyze_domain(&self, text: &str) -> f32 {
197 let lower = text.to_lowercase();
198 let mut score: f32 = 0.0;
199
200 let technical_keywords = [
202 "algorithm",
203 "implement",
204 "code",
205 "function",
206 "system",
207 "architecture",
208 "database",
209 "optimization",
210 "performance",
211 "security",
212 "encryption",
213 "network",
214 "protocol",
215 "api",
216 "machine learning",
217 "neural network",
218 "blockchain",
219 "distributed",
220 "concurrent",
221 "async",
222 "runtime",
223 ];
224
225 let mut hits = 0;
226 for keyword in &technical_keywords {
227 if lower.contains(keyword) {
228 hits += 1;
229 }
230 }
231
232 let academic_keywords = ["research", "study", "analysis", "theory", "hypothesis"];
234 score = 0.2 + 0.15 * hits as f32;
236 if lower.contains("consensus") {
237 score += 0.15;
238 }
239 if lower.contains("raft") {
240 score += 0.15;
241 }
242 if lower.contains("leader election") {
243 score += 0.1;
244 }
245 if lower.contains("log replication") {
246 score += 0.1;
247 }
248
249 score.min(1.0)
250 }
251
252 fn analyze_context_needed(&self, text: &str, state: &State) -> f32 {
254 let lower = text.to_lowercase();
255 let mut score: f32 = 0.0;
256
257 let context_patterns = [
259 "previous",
260 "earlier",
261 "before",
262 "last time",
263 "you said",
264 "you mentioned",
265 "as discussed",
266 "continue",
267 "following up",
268 "regarding",
269 ];
270
271 for pattern in &context_patterns {
272 if lower.contains(pattern) {
273 score += 0.2;
274 }
275 }
276
277 let pronouns = ["it", "this", "that", "these", "those", "they"];
279 for pronoun in &pronouns {
280 if lower.contains(&format!(" {} ", pronoun)) {
281 score += 0.1;
282 }
283 }
284
285 if let Some(recent_messages) = state.data.get("recentMessages") {
287 if let Some(arr) = recent_messages.as_array() {
288 if arr.len() > 3 {
289 score += 0.2;
290 }
291 }
292 }
293
294 score.min(1.0)
295 }
296
297 fn analyze_reasoning_depth(&self, text: &str) -> f32 {
299 let lower = text.to_lowercase();
300 let mut score: f32 = 0.2; let reasoning_patterns = [
304 "step by step",
305 "first",
306 "then",
307 "finally",
308 "process",
309 "explain how",
310 "explain why",
311 "reasoning",
312 "logic",
313 "compare",
314 "contrast",
315 "analyze",
316 "evaluate",
317 "pros and cons",
318 "advantages",
319 "disadvantages",
320 "consider",
321 "think about",
322 "take into account",
323 ];
324
325 for pattern in &reasoning_patterns {
326 if lower.contains(pattern) {
327 score += 0.15;
328 }
329 }
330
331 if lower.contains("because") || lower.contains("therefore") || lower.contains("thus") {
333 score += 0.1;
334 }
335
336 if lower.matches(" if ").count() > 1 {
338 score += 0.2;
339 }
340
341 score.min(1.0)
342 }
343
344 fn determine_level(&self, factors: &ComplexityFactors) -> ComplexityLevel {
346 let avg = factors.average();
347
348 let max_score = factors
350 .length_score
351 .max(factors.question_score)
352 .max(factors.domain_score)
353 .max(factors.context_score)
354 .max(factors.reasoning_score);
355
356 let weighted = avg * 0.7 + max_score * 0.3;
358
359 match weighted {
360 x if x < 0.2 => ComplexityLevel::Trivial,
361 x if x < 0.4 => ComplexityLevel::Simple,
362 x if x < 0.6 => ComplexityLevel::Moderate,
363 x if x < 0.8 => ComplexityLevel::Complex,
364 _ => ComplexityLevel::VeryComplex,
365 }
366 }
367
368 fn estimate_steps(&self, level: &ComplexityLevel, factors: &ComplexityFactors) -> usize {
370 let base_steps = match level {
371 ComplexityLevel::Trivial => 1,
372 ComplexityLevel::Simple => 2,
373 ComplexityLevel::Moderate => 4,
374 ComplexityLevel::Complex => 7,
375 ComplexityLevel::VeryComplex => 12,
376 };
377
378 let adjustment = (factors.reasoning_score * 3.0) as usize;
380
381 base_steps + adjustment
382 }
383
384 fn estimate_tokens(
386 &self,
387 level: &ComplexityLevel,
388 factors: &ComplexityFactors,
389 text: &str,
390 ) -> TokenEstimate {
391 let message_tokens = (text.len() / 4).max(1);
393
394 let context_tokens = (factors.context_score * 300.0) as usize;
396
397 let system_tokens = match level {
399 ComplexityLevel::Trivial => 50,
400 ComplexityLevel::Simple => 100,
401 ComplexityLevel::Moderate => 150,
402 ComplexityLevel::Complex => 200,
403 ComplexityLevel::VeryComplex => 300,
404 };
405
406 let input_tokens = message_tokens + context_tokens + system_tokens;
407
408 let base_output = match level {
410 ComplexityLevel::Trivial => 50,
411 ComplexityLevel::Simple => 150,
412 ComplexityLevel::Moderate => 300,
413 ComplexityLevel::Complex => 500,
414 ComplexityLevel::VeryComplex => 1000,
415 };
416
417 let domain_adjustment = (factors.domain_score * 200.0) as usize;
419 let output_tokens = base_output + domain_adjustment;
420
421 let buffered_output = (output_tokens as f32 * 1.2) as usize;
423
424 TokenEstimate {
425 input_tokens,
426 output_tokens: buffered_output,
427 total_tokens: input_tokens + buffered_output,
428 confidence: self.calculate_confidence(factors),
429 }
430 }
431
432 fn calculate_confidence(&self, factors: &ComplexityFactors) -> f32 {
434 let avg = factors.average();
436 let variance = [
437 (factors.length_score - avg).powi(2),
438 (factors.question_score - avg).powi(2),
439 (factors.domain_score - avg).powi(2),
440 (factors.context_score - avg).powi(2),
441 (factors.reasoning_score - avg).powi(2),
442 ]
443 .iter()
444 .sum::<f32>()
445 / 5.0;
446
447 let confidence = 1.0 - variance.sqrt().min(0.5);
449
450 confidence.max(0.5).min(0.95)
451 }
452
453 fn build_reasoning(&self, factors: &ComplexityFactors, level: &ComplexityLevel) -> String {
455 format!(
456 "Complexity: {} | Factors: length={:.2}, questions={:.2}, domain={:.2}, context={:.2}, reasoning={:.2} | Average: {:.2}",
457 level,
458 factors.length_score,
459 factors.question_score,
460 factors.domain_score,
461 factors.context_score,
462 factors.reasoning_score,
463 factors.average()
464 )
465 }
466}
467
468impl Default for ComplexityAnalyzer {
469 fn default() -> Self {
470 Self::new()
471 }
472}
473
474#[cfg(test)]
475mod tests {
476 use super::*;
477 use uuid::Uuid;
478
479 fn create_test_message(text: &str) -> Memory {
480 Memory {
481 id: Uuid::new_v4(),
482 entity_id: Uuid::new_v4(),
483 agent_id: Uuid::new_v4(),
484 room_id: Uuid::new_v4(),
485 content: Content {
486 text: text.to_string(),
487 ..Default::default()
488 },
489 embedding: None,
490 metadata: None,
491 created_at: chrono::Utc::now().timestamp(),
492 unique: None,
493 similarity: None,
494 }
495 }
496
497 #[tokio::test]
498 async fn test_trivial_complexity() {
499 let analyzer = ComplexityAnalyzer::new();
500 let message = create_test_message("Hi");
501 let state = State::new();
502
503 let assessment = analyzer.assess(&message, &state).await.unwrap();
504 assert!(matches!(
505 assessment.level,
506 ComplexityLevel::Trivial | ComplexityLevel::Simple
507 ));
508 assert!(assessment.estimated_tokens.total_tokens <= 300);
509 }
510
511 #[tokio::test]
512 async fn test_simple_complexity() {
513 let analyzer = ComplexityAnalyzer::new();
514 let message = create_test_message("What's the weather like today?");
515 let state = State::new();
516
517 let assessment = analyzer.assess(&message, &state).await.unwrap();
518 assert!(matches!(
519 assessment.level,
520 ComplexityLevel::Simple | ComplexityLevel::Trivial
521 ));
522 }
523
524 #[tokio::test]
525 async fn test_complex_technical() {
526 let analyzer = ComplexityAnalyzer::new();
527 let message = create_test_message(
528 "Can you explain how to implement a distributed consensus algorithm \
529 using Raft protocol, including the leader election process and log replication?",
530 );
531 let state = State::new();
532
533 let assessment = analyzer.assess(&message, &state).await.unwrap();
534 assert!(matches!(
535 assessment.level,
536 ComplexityLevel::Moderate | ComplexityLevel::Complex | ComplexityLevel::VeryComplex
537 ));
538 assert!(assessment.factors.domain_score > 0.5);
539 assert!(assessment.factors.question_score > 0.3);
540 }
541
542 #[tokio::test]
543 async fn test_token_estimation() {
544 let analyzer = ComplexityAnalyzer::new();
545
546 let short_msg = create_test_message("Hello");
548 let state = State::new();
549 let short_assessment = analyzer.assess(&short_msg, &state).await.unwrap();
550
551 let long_msg = create_test_message(
553 "This is a much longer message that contains many words and will require \
554 more tokens to process and respond to appropriately.",
555 );
556 let long_assessment = analyzer.assess(&long_msg, &state).await.unwrap();
557
558 assert!(
559 long_assessment.estimated_tokens.total_tokens
560 > short_assessment.estimated_tokens.total_tokens
561 );
562 }
563
564 #[test]
565 fn test_complexity_level_display() {
566 assert_eq!(ComplexityLevel::Trivial.to_string(), "TRIVIAL");
567 assert_eq!(ComplexityLevel::Complex.to_string(), "COMPLEX");
568 }
569}