Skip to main content

traitclaw_core/traits/
context_manager.rs

1//! Async context window management.
2//!
3//! [`ContextManager`] provides pluggable, async context window management.
4//! It supports LLM-powered compression and accurate token counting.
5//!
6//! # Example
7//!
8//! ```rust
9//! use traitclaw_core::traits::context_manager::ContextManager;
10//! use traitclaw_core::types::message::Message;
11//! use traitclaw_core::types::agent_state::AgentState;
12//! use async_trait::async_trait;
13//!
14//! struct MyCompressor;
15//!
16//! #[async_trait]
17//! impl ContextManager for MyCompressor {
18//!     async fn prepare(
19//!         &self,
20//!         messages: &mut Vec<Message>,
21//!         context_window: usize,
22//!         state: &mut AgentState,
23//!     ) {
24//!         // Custom async compression logic
25//!         let tokens = self.estimate_tokens(messages);
26//!         if tokens > context_window {
27//!             // Compress...
28//!         }
29//!     }
30//! }
31//! ```
32
33use async_trait::async_trait;
34
35use crate::types::agent_state::AgentState;
36use crate::types::message::Message;
37
38/// Async trait for pluggable context window management.
39///
40/// Called before each LLM request to ensure the message list fits within
41/// the model's context window. Supports async operations such as
42/// LLM-powered summarization and external token-counting APIs.
43///
44/// Implementations MUST NOT remove system messages.
45#[async_trait]
46pub trait ContextManager: Send + Sync {
47    /// Prepare the message list by pruning or compressing if necessary.
48    ///
49    /// `context_window` is the model's maximum token capacity.
50    /// This method is async to support LLM-powered compression strategies.
51    async fn prepare(
52        &self,
53        messages: &mut Vec<Message>,
54        context_window: usize,
55        state: &mut AgentState,
56    );
57
58    /// Estimate the total token count for a message list.
59    ///
60    /// Default implementation uses the 4-characters ≈ 1-token approximation.
61    /// Override with `TikTokenCounter` for model-accurate counting.
62    fn estimate_tokens(&self, messages: &[Message]) -> usize {
63        messages.iter().map(|m| m.content.len() / 4 + 1).sum()
64    }
65}
66
67#[cfg(test)]
68mod tests {
69    use super::*;
70    use std::sync::Arc;
71
72    // ── Object safety: confirm Arc<dyn ContextManager> compiles ──────────
73    #[test]
74    fn test_context_manager_is_object_safe() {
75        struct Dummy;
76
77        #[async_trait]
78        impl ContextManager for Dummy {
79            async fn prepare(
80                &self,
81                _messages: &mut Vec<Message>,
82                _context_window: usize,
83                _state: &mut AgentState,
84            ) {
85            }
86        }
87
88        let _: Arc<dyn ContextManager> = Arc::new(Dummy);
89    }
90
91    // ── Default estimate_tokens() ───────────────────────────────────────
92    #[test]
93    fn test_default_estimate_tokens() {
94        struct Dummy;
95
96        #[async_trait]
97        impl ContextManager for Dummy {
98            async fn prepare(
99                &self,
100                _messages: &mut Vec<Message>,
101                _context_window: usize,
102                _state: &mut AgentState,
103            ) {
104            }
105        }
106
107        let cm = Dummy;
108        let messages = vec![
109            Message {
110                role: crate::types::message::MessageRole::User,
111                content: "a".repeat(400), // 400 chars → 400/4 + 1 = 101 tokens
112                tool_call_id: None,
113            },
114            Message {
115                role: crate::types::message::MessageRole::Assistant,
116                content: "b".repeat(800), // 800 chars → 800/4 + 1 = 201 tokens
117                tool_call_id: None,
118            },
119        ];
120
121        let tokens = cm.estimate_tokens(&messages);
122        assert_eq!(
123            tokens, 302,
124            "4-chars ≈ 1-token: (400/4+1) + (800/4+1) = 302"
125        );
126    }
127}