1use zeph_llm::provider::{LlmProvider, Message, MessagePart, Role};
5
6use super::{Agent, CODE_CONTEXT_PREFIX};
7use crate::channel::Channel;
8use crate::metrics::{MetricsSnapshot, SECURITY_EVENT_CAP, SecurityEvent, SecurityEventCategory};
9use zeph_tools::FilterStats;
10
11impl<C: Channel> Agent<C> {
12 pub fn sync_community_detection_failures(&self) {
14 if let Some(memory) = self.memory_state.memory.as_ref() {
15 let failures = memory.community_detection_failures();
16 self.update_metrics(|m| {
17 m.graph_community_detection_failures = failures;
18 });
19 }
20 }
21
22 pub fn sync_graph_extraction_metrics(&self) {
24 if let Some(memory) = self.memory_state.memory.as_ref() {
25 let count = memory.graph_extraction_count();
26 let failures = memory.graph_extraction_failures();
27 self.update_metrics(|m| {
28 m.graph_extraction_count = count;
29 m.graph_extraction_failures = failures;
30 });
31 }
32 }
33
34 pub async fn sync_graph_counts(&self) {
36 let Some(memory) = self.memory_state.memory.as_ref() else {
37 return;
38 };
39 let Some(store) = memory.graph_store.as_ref() else {
40 return;
41 };
42 let (entities, edges, communities) = tokio::join!(
43 store.entity_count(),
44 store.active_edge_count(),
45 store.community_count()
46 );
47 self.update_metrics(|m| {
48 m.graph_entities_total = entities.unwrap_or(0).cast_unsigned();
49 m.graph_edges_total = edges.unwrap_or(0).cast_unsigned();
50 m.graph_communities_total = communities.unwrap_or(0).cast_unsigned();
51 });
52 }
53
54 pub async fn check_vector_store_health(&self, backend_name: &str) {
56 let connected = match self.memory_state.memory.as_ref() {
57 Some(m) => m.is_vector_store_connected().await,
58 None => false,
59 };
60 let name = backend_name.to_owned();
61 self.update_metrics(|m| {
62 m.qdrant_available = connected;
63 m.vector_backend = name;
64 });
65 }
66
67 pub async fn sync_guidelines_status(&self) {
72 let Some(memory) = self.memory_state.memory.as_ref() else {
73 return;
74 };
75 let cid = self.memory_state.conversation_id;
76 match memory.sqlite().load_compression_guidelines_meta(cid).await {
77 Ok((version, created_at)) => {
78 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
79 let version_u32 = u32::try_from(version).unwrap_or(0);
80 self.update_metrics(|m| {
81 m.guidelines_version = version_u32;
82 m.guidelines_updated_at = created_at;
83 });
84 }
85 Err(e) => {
86 tracing::warn!("failed to sync guidelines status: {e:#}");
87 }
88 }
89 }
90
91 pub(super) fn record_filter_metrics(&mut self, fs: &FilterStats) {
92 let saved = fs.estimated_tokens_saved() as u64;
93 let raw = (fs.raw_chars / 4) as u64;
94 let confidence = fs.confidence;
95 let was_filtered = fs.filtered_chars < fs.raw_chars;
96 self.update_metrics(|m| {
97 m.filter_raw_tokens += raw;
98 m.filter_saved_tokens += saved;
99 m.filter_applications += 1;
100 m.filter_total_commands += 1;
101 if was_filtered {
102 m.filter_filtered_commands += 1;
103 }
104 if let Some(c) = confidence {
105 match c {
106 zeph_tools::FilterConfidence::Full => {
107 m.filter_confidence_full += 1;
108 }
109 zeph_tools::FilterConfidence::Partial => {
110 m.filter_confidence_partial += 1;
111 }
112 zeph_tools::FilterConfidence::Fallback => {
113 m.filter_confidence_fallback += 1;
114 }
115 }
116 }
117 });
118 }
119
120 pub(super) fn update_metrics(&self, f: impl FnOnce(&mut MetricsSnapshot)) {
121 if let Some(ref tx) = self.metrics.metrics_tx {
122 let elapsed = self.lifecycle.start_time.elapsed().as_secs();
123 tx.send_modify(|m| {
124 m.uptime_seconds = elapsed;
125 f(m);
126 });
127 }
128 }
129
130 pub(super) fn push_classifier_metrics(&self) {
135 if let Some(ref m) = self.metrics.classifier_metrics {
136 let snapshot = m.snapshot();
137 self.update_metrics(|ms| ms.classifier = snapshot);
138 }
139 }
140
141 pub(super) fn push_security_event(
142 &self,
143 category: SecurityEventCategory,
144 source: &str,
145 detail: impl Into<String>,
146 ) {
147 if let Some(ref tx) = self.metrics.metrics_tx {
148 let event = SecurityEvent::new(category, source, detail);
149 let elapsed = self.lifecycle.start_time.elapsed().as_secs();
150 tx.send_modify(|m| {
151 m.uptime_seconds = elapsed;
152 if m.security_events.len() >= SECURITY_EVENT_CAP {
153 m.security_events.pop_front();
154 }
155 m.security_events.push_back(event);
156 });
157 }
158 }
159
160 pub(super) fn recompute_prompt_tokens(&mut self) {
161 self.providers.cached_prompt_tokens = self
162 .msg
163 .messages
164 .iter()
165 .map(|m| self.metrics.token_counter.count_message_tokens(m) as u64)
166 .sum();
167 }
168
169 pub(super) fn push_message(&mut self, msg: Message) {
170 self.providers.cached_prompt_tokens +=
171 self.metrics.token_counter.count_message_tokens(&msg) as u64;
172 self.msg.messages.push(msg);
173 }
174
175 pub(crate) fn record_cost(&self, prompt_tokens: u64, completion_tokens: u64) {
176 if let Some(ref tracker) = self.metrics.cost_tracker {
177 tracker.record_usage(&self.runtime.model_name, prompt_tokens, completion_tokens);
178 self.update_metrics(|m| {
179 m.cost_spent_cents = tracker.current_spend();
180 });
181 }
182 }
183
184 pub(crate) fn record_cache_usage(&self) {
185 if let Some((creation, read)) = self.provider.last_cache_usage() {
186 self.update_metrics(|m| {
187 m.cache_creation_tokens += creation;
188 m.cache_read_tokens += read;
189 });
190 }
191 }
192
193 pub fn inject_code_context(&mut self, text: &str) {
196 self.remove_code_context_messages();
197 if text.is_empty() || self.msg.messages.len() <= 1 {
198 return;
199 }
200 let content = format!("{CODE_CONTEXT_PREFIX}{text}");
201 self.msg.messages.insert(
202 1,
203 Message::from_parts(
204 Role::System,
205 vec![MessagePart::CodeContext { text: content }],
206 ),
207 );
208 }
209
210 #[must_use]
211 pub fn context_messages(&self) -> &[Message] {
212 &self.msg.messages
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::super::agent_tests::{
219 MockChannel, MockToolExecutor, create_test_registry, mock_provider,
220 };
221 use super::*;
222 use zeph_llm::provider::{MessageMetadata, MessagePart};
223
224 #[test]
225 fn push_message_increments_cached_tokens() {
226 let provider = mock_provider(vec![]);
227 let channel = MockChannel::new(vec![]);
228 let registry = create_test_registry();
229 let executor = MockToolExecutor::no_tools();
230 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
231
232 let before = agent.providers.cached_prompt_tokens;
233 let msg = Message {
234 role: Role::User,
235 content: "hello world!!".to_string(),
236 parts: vec![],
237 metadata: MessageMetadata::default(),
238 };
239 let expected_delta = agent.metrics.token_counter.count_message_tokens(&msg) as u64;
240 agent.push_message(msg);
241 assert_eq!(
242 agent.providers.cached_prompt_tokens,
243 before + expected_delta
244 );
245 }
246
247 #[test]
248 fn recompute_prompt_tokens_matches_sum() {
249 let provider = mock_provider(vec![]);
250 let channel = MockChannel::new(vec![]);
251 let registry = create_test_registry();
252 let executor = MockToolExecutor::no_tools();
253 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
254
255 agent.msg.messages.push(Message {
256 role: Role::User,
257 content: "1234".to_string(),
258 parts: vec![],
259 metadata: MessageMetadata::default(),
260 });
261 agent.msg.messages.push(Message {
262 role: Role::Assistant,
263 content: "5678".to_string(),
264 parts: vec![],
265 metadata: MessageMetadata::default(),
266 });
267
268 agent.recompute_prompt_tokens();
269
270 let expected: u64 = agent
271 .msg
272 .messages
273 .iter()
274 .map(|m| agent.metrics.token_counter.count_message_tokens(m) as u64)
275 .sum();
276 assert_eq!(agent.providers.cached_prompt_tokens, expected);
277 }
278
279 #[test]
280 fn inject_code_context_into_messages_with_existing_content() {
281 let provider = mock_provider(vec![]);
282 let channel = MockChannel::new(vec![]);
283 let registry = create_test_registry();
284 let executor = MockToolExecutor::no_tools();
285 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
286
287 agent.push_message(Message {
289 role: Role::User,
290 content: "question".to_string(),
291 parts: vec![],
292 metadata: MessageMetadata::default(),
293 });
294
295 agent.inject_code_context("some code here");
296
297 let found = agent.msg.messages.iter().any(|m| {
298 m.parts.iter().any(|p| {
299 matches!(p, MessagePart::CodeContext { text } if text.contains("some code here"))
300 })
301 });
302 assert!(found, "code context should be injected into messages");
303 }
304
305 #[test]
306 fn inject_code_context_empty_text_is_noop() {
307 let provider = mock_provider(vec![]);
308 let channel = MockChannel::new(vec![]);
309 let registry = create_test_registry();
310 let executor = MockToolExecutor::no_tools();
311 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
312
313 agent.push_message(Message {
314 role: Role::User,
315 content: "question".to_string(),
316 parts: vec![],
317 metadata: MessageMetadata::default(),
318 });
319 let count_before = agent.msg.messages.len();
320
321 agent.inject_code_context("");
322
323 assert_eq!(agent.msg.messages.len(), count_before);
325 }
326
327 #[test]
328 fn inject_code_context_with_single_message_is_noop() {
329 let provider = mock_provider(vec![]);
330 let channel = MockChannel::new(vec![]);
331 let registry = create_test_registry();
332 let executor = MockToolExecutor::no_tools();
333 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
334 let count_before = agent.msg.messages.len();
336
337 agent.inject_code_context("some code");
338
339 assert_eq!(agent.msg.messages.len(), count_before);
340 }
341
342 #[test]
343 fn context_messages_returns_all_messages() {
344 let provider = mock_provider(vec![]);
345 let channel = MockChannel::new(vec![]);
346 let registry = create_test_registry();
347 let executor = MockToolExecutor::no_tools();
348 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
349
350 agent.push_message(Message {
351 role: Role::User,
352 content: "test".to_string(),
353 parts: vec![],
354 metadata: MessageMetadata::default(),
355 });
356
357 assert_eq!(agent.context_messages().len(), agent.msg.messages.len());
358 }
359}