traitclaw_core/traits/context_manager.rs
1//! Async context window management.
2//!
3//! [`ContextManager`] provides pluggable, async context window management.
4//! It supports LLM-powered compression and accurate token counting.
5//!
6//! # Example
7//!
8//! ```rust
9//! use traitclaw_core::traits::context_manager::ContextManager;
10//! use traitclaw_core::types::message::Message;
11//! use traitclaw_core::types::agent_state::AgentState;
12//! use async_trait::async_trait;
13//!
14//! struct MyCompressor;
15//!
16//! #[async_trait]
17//! impl ContextManager for MyCompressor {
18//! async fn prepare(
19//! &self,
20//! messages: &mut Vec<Message>,
21//! context_window: usize,
22//! state: &mut AgentState,
23//! ) {
24//! // Custom async compression logic
25//! let tokens = self.estimate_tokens(messages);
26//! if tokens > context_window {
27//! // Compress...
28//! }
29//! }
30//! }
31//! ```
32
33use async_trait::async_trait;
34
35use crate::types::agent_state::AgentState;
36use crate::types::message::Message;
37
38/// Async trait for pluggable context window management.
39///
40/// Called before each LLM request to ensure the message list fits within
41/// the model's context window. Supports async operations such as
42/// LLM-powered summarization and external token-counting APIs.
43///
44/// Implementations MUST NOT remove system messages.
45#[async_trait]
46pub trait ContextManager: Send + Sync {
47 /// Prepare the message list by pruning or compressing if necessary.
48 ///
49 /// `context_window` is the model's maximum token capacity.
50 /// This method is async to support LLM-powered compression strategies.
51 async fn prepare(
52 &self,
53 messages: &mut Vec<Message>,
54 context_window: usize,
55 state: &mut AgentState,
56 );
57
58 /// Estimate the total token count for a message list.
59 ///
60 /// Default implementation uses the 4-characters ≈ 1-token approximation.
61 /// Override with `TikTokenCounter` for model-accurate counting.
62 fn estimate_tokens(&self, messages: &[Message]) -> usize {
63 messages.iter().map(|m| m.content.len() / 4 + 1).sum()
64 }
65}
66
67#[cfg(test)]
68mod tests {
69 use super::*;
70 use std::sync::Arc;
71
72 // ── Object safety: confirm Arc<dyn ContextManager> compiles ──────────
73 #[test]
74 fn test_context_manager_is_object_safe() {
75 struct Dummy;
76
77 #[async_trait]
78 impl ContextManager for Dummy {
79 async fn prepare(
80 &self,
81 _messages: &mut Vec<Message>,
82 _context_window: usize,
83 _state: &mut AgentState,
84 ) {
85 }
86 }
87
88 let _: Arc<dyn ContextManager> = Arc::new(Dummy);
89 }
90
91 // ── Default estimate_tokens() ───────────────────────────────────────
92 #[test]
93 fn test_default_estimate_tokens() {
94 struct Dummy;
95
96 #[async_trait]
97 impl ContextManager for Dummy {
98 async fn prepare(
99 &self,
100 _messages: &mut Vec<Message>,
101 _context_window: usize,
102 _state: &mut AgentState,
103 ) {
104 }
105 }
106
107 let cm = Dummy;
108 let messages = vec![
109 Message {
110 role: crate::types::message::MessageRole::User,
111 content: "a".repeat(400), // 400 chars → 400/4 + 1 = 101 tokens
112 tool_call_id: None,
113 },
114 Message {
115 role: crate::types::message::MessageRole::Assistant,
116 content: "b".repeat(800), // 800 chars → 800/4 + 1 = 201 tokens
117 tool_call_id: None,
118 },
119 ];
120
121 let tokens = cm.estimate_tokens(&messages);
122 assert_eq!(
123 tokens, 302,
124 "4-chars ≈ 1-token: (400/4+1) + (800/4+1) = 302"
125 );
126 }
127}