1use zeph_llm::provider::{LlmProvider, Message, MessagePart, Role};
5
6use super::{Agent, CODE_CONTEXT_PREFIX};
7use crate::channel::Channel;
8use crate::metrics::{MetricsSnapshot, SECURITY_EVENT_CAP, SecurityEvent, SecurityEventCategory};
9
10impl<C: Channel> Agent<C> {
11 pub fn sync_community_detection_failures(&self) {
13 if let Some(memory) = self.memory_state.memory.as_ref() {
14 let failures = memory.community_detection_failures();
15 self.update_metrics(|m| {
16 m.graph_community_detection_failures = failures;
17 });
18 }
19 }
20
21 pub fn sync_graph_extraction_metrics(&self) {
23 if let Some(memory) = self.memory_state.memory.as_ref() {
24 let count = memory.graph_extraction_count();
25 let failures = memory.graph_extraction_failures();
26 self.update_metrics(|m| {
27 m.graph_extraction_count = count;
28 m.graph_extraction_failures = failures;
29 });
30 }
31 }
32
33 pub async fn sync_graph_counts(&self) {
35 let Some(memory) = self.memory_state.memory.as_ref() else {
36 return;
37 };
38 let Some(store) = memory.graph_store.as_ref() else {
39 return;
40 };
41 let (entities, edges, communities) = tokio::join!(
42 store.entity_count(),
43 store.active_edge_count(),
44 store.community_count()
45 );
46 self.update_metrics(|m| {
47 m.graph_entities_total = entities.unwrap_or(0).cast_unsigned();
48 m.graph_edges_total = edges.unwrap_or(0).cast_unsigned();
49 m.graph_communities_total = communities.unwrap_or(0).cast_unsigned();
50 });
51 }
52
53 pub async fn check_vector_store_health(&self, backend_name: &str) {
55 let connected = match self.memory_state.memory.as_ref() {
56 Some(m) => m.is_vector_store_connected().await,
57 None => false,
58 };
59 let name = backend_name.to_owned();
60 self.update_metrics(|m| {
61 m.qdrant_available = connected;
62 m.vector_backend = name;
63 });
64 }
65
66 pub(super) fn update_metrics(&self, f: impl FnOnce(&mut MetricsSnapshot)) {
67 if let Some(ref tx) = self.metrics_tx {
68 let elapsed = self.start_time.elapsed().as_secs();
69 tx.send_modify(|m| {
70 m.uptime_seconds = elapsed;
71 f(m);
72 });
73 }
74 }
75
76 pub(super) fn push_security_event(
77 &self,
78 category: SecurityEventCategory,
79 source: &str,
80 detail: impl Into<String>,
81 ) {
82 if let Some(ref tx) = self.metrics_tx {
83 let event = SecurityEvent::new(category, source, detail);
84 let elapsed = self.start_time.elapsed().as_secs();
85 tx.send_modify(|m| {
86 m.uptime_seconds = elapsed;
87 if m.security_events.len() >= SECURITY_EVENT_CAP {
88 m.security_events.pop_front();
89 }
90 m.security_events.push_back(event);
91 });
92 }
93 }
94
95 pub(super) fn recompute_prompt_tokens(&mut self) {
96 self.cached_prompt_tokens = self
97 .messages
98 .iter()
99 .map(|m| self.token_counter.count_message_tokens(m) as u64)
100 .sum();
101 }
102
103 pub(super) fn push_message(&mut self, msg: Message) {
104 self.cached_prompt_tokens += self.token_counter.count_message_tokens(&msg) as u64;
105 self.messages.push(msg);
106 }
107
108 pub(crate) fn record_cost(&self, prompt_tokens: u64, completion_tokens: u64) {
109 if let Some(ref tracker) = self.cost_tracker {
110 tracker.record_usage(&self.runtime.model_name, prompt_tokens, completion_tokens);
111 self.update_metrics(|m| {
112 m.cost_spent_cents = tracker.current_spend();
113 });
114 }
115 }
116
117 pub(crate) fn record_cache_usage(&self) {
118 if let Some((creation, read)) = self.provider.last_cache_usage() {
119 self.update_metrics(|m| {
120 m.cache_creation_tokens += creation;
121 m.cache_read_tokens += read;
122 });
123 }
124 }
125
126 pub fn inject_code_context(&mut self, text: &str) {
129 self.remove_code_context_messages();
130 if text.is_empty() || self.messages.len() <= 1 {
131 return;
132 }
133 let content = format!("{CODE_CONTEXT_PREFIX}{text}");
134 self.messages.insert(
135 1,
136 Message::from_parts(
137 Role::System,
138 vec![MessagePart::CodeContext { text: content }],
139 ),
140 );
141 }
142
143 #[must_use]
144 pub fn context_messages(&self) -> &[Message] {
145 &self.messages
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use super::super::agent_tests::{
152 MockChannel, MockToolExecutor, create_test_registry, mock_provider,
153 };
154 use super::*;
155 use zeph_llm::provider::{MessageMetadata, MessagePart};
156
157 #[test]
158 fn push_message_increments_cached_tokens() {
159 let provider = mock_provider(vec![]);
160 let channel = MockChannel::new(vec![]);
161 let registry = create_test_registry();
162 let executor = MockToolExecutor::no_tools();
163 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
164
165 let before = agent.cached_prompt_tokens;
166 let msg = Message {
167 role: Role::User,
168 content: "hello world!!".to_string(),
169 parts: vec![],
170 metadata: MessageMetadata::default(),
171 };
172 let expected_delta = agent.token_counter.count_message_tokens(&msg) as u64;
173 agent.push_message(msg);
174 assert_eq!(agent.cached_prompt_tokens, before + expected_delta);
175 }
176
177 #[test]
178 fn recompute_prompt_tokens_matches_sum() {
179 let provider = mock_provider(vec![]);
180 let channel = MockChannel::new(vec![]);
181 let registry = create_test_registry();
182 let executor = MockToolExecutor::no_tools();
183 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
184
185 agent.messages.push(Message {
186 role: Role::User,
187 content: "1234".to_string(),
188 parts: vec![],
189 metadata: MessageMetadata::default(),
190 });
191 agent.messages.push(Message {
192 role: Role::Assistant,
193 content: "5678".to_string(),
194 parts: vec![],
195 metadata: MessageMetadata::default(),
196 });
197
198 agent.recompute_prompt_tokens();
199
200 let expected: u64 = agent
201 .messages
202 .iter()
203 .map(|m| agent.token_counter.count_message_tokens(m) as u64)
204 .sum();
205 assert_eq!(agent.cached_prompt_tokens, expected);
206 }
207
208 #[test]
209 fn inject_code_context_into_messages_with_existing_content() {
210 let provider = mock_provider(vec![]);
211 let channel = MockChannel::new(vec![]);
212 let registry = create_test_registry();
213 let executor = MockToolExecutor::no_tools();
214 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
215
216 agent.push_message(Message {
218 role: Role::User,
219 content: "question".to_string(),
220 parts: vec![],
221 metadata: MessageMetadata::default(),
222 });
223
224 agent.inject_code_context("some code here");
225
226 let found = agent.messages.iter().any(|m| {
227 m.parts.iter().any(|p| {
228 matches!(p, MessagePart::CodeContext { text } if text.contains("some code here"))
229 })
230 });
231 assert!(found, "code context should be injected into messages");
232 }
233
234 #[test]
235 fn inject_code_context_empty_text_is_noop() {
236 let provider = mock_provider(vec![]);
237 let channel = MockChannel::new(vec![]);
238 let registry = create_test_registry();
239 let executor = MockToolExecutor::no_tools();
240 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
241
242 agent.push_message(Message {
243 role: Role::User,
244 content: "question".to_string(),
245 parts: vec![],
246 metadata: MessageMetadata::default(),
247 });
248 let count_before = agent.messages.len();
249
250 agent.inject_code_context("");
251
252 assert_eq!(agent.messages.len(), count_before);
254 }
255
256 #[test]
257 fn inject_code_context_with_single_message_is_noop() {
258 let provider = mock_provider(vec![]);
259 let channel = MockChannel::new(vec![]);
260 let registry = create_test_registry();
261 let executor = MockToolExecutor::no_tools();
262 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
263 let count_before = agent.messages.len();
265
266 agent.inject_code_context("some code");
267
268 assert_eq!(agent.messages.len(), count_before);
269 }
270
271 #[test]
272 fn context_messages_returns_all_messages() {
273 let provider = mock_provider(vec![]);
274 let channel = MockChannel::new(vec![]);
275 let registry = create_test_registry();
276 let executor = MockToolExecutor::no_tools();
277 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
278
279 agent.push_message(Message {
280 role: Role::User,
281 content: "test".to_string(),
282 parts: vec![],
283 metadata: MessageMetadata::default(),
284 });
285
286 assert_eq!(agent.context_messages().len(), agent.messages.len());
287 }
288}