Skip to main content

zeph_core/agent/
utils.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use zeph_llm::provider::{LlmProvider, Message, MessagePart, Role};
5
6use super::{Agent, CODE_CONTEXT_PREFIX};
7use crate::channel::Channel;
8use crate::metrics::{MetricsSnapshot, SECURITY_EVENT_CAP, SecurityEvent, SecurityEventCategory};
9use zeph_tools::FilterStats;
10
11impl<C: Channel> Agent<C> {
12    /// Read the community-detection failure counter from `SemanticMemory` and update metrics.
13    pub fn sync_community_detection_failures(&self) {
14        if let Some(memory) = self.memory_state.memory.as_ref() {
15            let failures = memory.community_detection_failures();
16            self.update_metrics(|m| {
17                m.graph_community_detection_failures = failures;
18            });
19        }
20    }
21
22    /// Sync all graph counters (extraction count/failures) from `SemanticMemory` to metrics.
23    pub fn sync_graph_extraction_metrics(&self) {
24        if let Some(memory) = self.memory_state.memory.as_ref() {
25            let count = memory.graph_extraction_count();
26            let failures = memory.graph_extraction_failures();
27            self.update_metrics(|m| {
28                m.graph_extraction_count = count;
29                m.graph_extraction_failures = failures;
30            });
31        }
32    }
33
34    /// Fetch entity/edge/community counts from the graph store and write to metrics.
35    pub async fn sync_graph_counts(&self) {
36        let Some(memory) = self.memory_state.memory.as_ref() else {
37            return;
38        };
39        let Some(store) = memory.graph_store.as_ref() else {
40            return;
41        };
42        let (entities, edges, communities) = tokio::join!(
43            store.entity_count(),
44            store.active_edge_count(),
45            store.community_count()
46        );
47        self.update_metrics(|m| {
48            m.graph_entities_total = entities.unwrap_or(0).cast_unsigned();
49            m.graph_edges_total = edges.unwrap_or(0).cast_unsigned();
50            m.graph_communities_total = communities.unwrap_or(0).cast_unsigned();
51        });
52    }
53
54    /// Perform a real health check on the vector store and update metrics.
55    pub async fn check_vector_store_health(&self, backend_name: &str) {
56        let connected = match self.memory_state.memory.as_ref() {
57            Some(m) => m.is_vector_store_connected().await,
58            None => false,
59        };
60        let name = backend_name.to_owned();
61        self.update_metrics(|m| {
62            m.qdrant_available = connected;
63            m.vector_backend = name;
64        });
65    }
66
67    /// Fetch compression-guidelines metadata from `SQLite` and write to metrics.
68    ///
69    /// Only fetches version and `created_at`; does not load the full guidelines text.
70    /// Feature-gated: compiled only when `compression-guidelines` is enabled.
71    #[cfg(feature = "compression-guidelines")]
72    pub async fn sync_guidelines_status(&self) {
73        let Some(memory) = self.memory_state.memory.as_ref() else {
74            return;
75        };
76        let cid = self.memory_state.conversation_id;
77        match memory.sqlite().load_compression_guidelines_meta(cid).await {
78            Ok((version, created_at)) => {
79                #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
80                let version_u32 = u32::try_from(version).unwrap_or(0);
81                self.update_metrics(|m| {
82                    m.guidelines_version = version_u32;
83                    m.guidelines_updated_at = created_at;
84                });
85            }
86            Err(e) => {
87                tracing::warn!("failed to sync guidelines status: {e:#}");
88            }
89        }
90    }
91
92    pub(super) fn record_filter_metrics(&mut self, fs: &FilterStats) {
93        let saved = fs.estimated_tokens_saved() as u64;
94        let raw = (fs.raw_chars / 4) as u64;
95        let confidence = fs.confidence;
96        let was_filtered = fs.filtered_chars < fs.raw_chars;
97        self.update_metrics(|m| {
98            m.filter_raw_tokens += raw;
99            m.filter_saved_tokens += saved;
100            m.filter_applications += 1;
101            m.filter_total_commands += 1;
102            if was_filtered {
103                m.filter_filtered_commands += 1;
104            }
105            if let Some(c) = confidence {
106                match c {
107                    zeph_tools::FilterConfidence::Full => {
108                        m.filter_confidence_full += 1;
109                    }
110                    zeph_tools::FilterConfidence::Partial => {
111                        m.filter_confidence_partial += 1;
112                    }
113                    zeph_tools::FilterConfidence::Fallback => {
114                        m.filter_confidence_fallback += 1;
115                    }
116                }
117            }
118        });
119    }
120
121    pub(super) fn update_metrics(&self, f: impl FnOnce(&mut MetricsSnapshot)) {
122        if let Some(ref tx) = self.metrics.metrics_tx {
123            let elapsed = self.lifecycle.start_time.elapsed().as_secs();
124            tx.send_modify(|m| {
125                m.uptime_seconds = elapsed;
126                f(m);
127            });
128        }
129    }
130
131    /// Push the current classifier metrics snapshot into `MetricsSnapshot`.
132    ///
133    /// Call this after any classifier invocation (injection, PII, feedback) so the TUI panel
134    /// reflects the latest p50/p95 values. No-op when classifier metrics are not configured.
135    pub(super) fn push_classifier_metrics(&self) {
136        if let Some(ref m) = self.metrics.classifier_metrics {
137            let snapshot = m.snapshot();
138            self.update_metrics(|ms| ms.classifier = snapshot);
139        }
140    }
141
142    pub(super) fn push_security_event(
143        &self,
144        category: SecurityEventCategory,
145        source: &str,
146        detail: impl Into<String>,
147    ) {
148        if let Some(ref tx) = self.metrics.metrics_tx {
149            let event = SecurityEvent::new(category, source, detail);
150            let elapsed = self.lifecycle.start_time.elapsed().as_secs();
151            tx.send_modify(|m| {
152                m.uptime_seconds = elapsed;
153                if m.security_events.len() >= SECURITY_EVENT_CAP {
154                    m.security_events.pop_front();
155                }
156                m.security_events.push_back(event);
157            });
158        }
159    }
160
161    pub(super) fn recompute_prompt_tokens(&mut self) {
162        self.providers.cached_prompt_tokens = self
163            .msg
164            .messages
165            .iter()
166            .map(|m| self.metrics.token_counter.count_message_tokens(m) as u64)
167            .sum();
168    }
169
170    pub(super) fn push_message(&mut self, msg: Message) {
171        self.providers.cached_prompt_tokens +=
172            self.metrics.token_counter.count_message_tokens(&msg) as u64;
173        self.msg.messages.push(msg);
174    }
175
176    pub(crate) fn record_cost(&self, prompt_tokens: u64, completion_tokens: u64) {
177        if let Some(ref tracker) = self.metrics.cost_tracker {
178            tracker.record_usage(&self.runtime.model_name, prompt_tokens, completion_tokens);
179            self.update_metrics(|m| {
180                m.cost_spent_cents = tracker.current_spend();
181            });
182        }
183    }
184
185    pub(crate) fn record_cache_usage(&self) {
186        if let Some((creation, read)) = self.provider.last_cache_usage() {
187            self.update_metrics(|m| {
188                m.cache_creation_tokens += creation;
189                m.cache_read_tokens += read;
190            });
191        }
192    }
193
194    /// Inject pre-formatted code context into the message list.
195    /// The caller is responsible for retrieving and formatting the text.
196    pub fn inject_code_context(&mut self, text: &str) {
197        self.remove_code_context_messages();
198        if text.is_empty() || self.msg.messages.len() <= 1 {
199            return;
200        }
201        let content = format!("{CODE_CONTEXT_PREFIX}{text}");
202        self.msg.messages.insert(
203            1,
204            Message::from_parts(
205                Role::System,
206                vec![MessagePart::CodeContext { text: content }],
207            ),
208        );
209    }
210
211    #[must_use]
212    pub fn context_messages(&self) -> &[Message] {
213        &self.msg.messages
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use super::super::agent_tests::{
220        MockChannel, MockToolExecutor, create_test_registry, mock_provider,
221    };
222    use super::*;
223    use zeph_llm::provider::{MessageMetadata, MessagePart};
224
225    #[test]
226    fn push_message_increments_cached_tokens() {
227        let provider = mock_provider(vec![]);
228        let channel = MockChannel::new(vec![]);
229        let registry = create_test_registry();
230        let executor = MockToolExecutor::no_tools();
231        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
232
233        let before = agent.providers.cached_prompt_tokens;
234        let msg = Message {
235            role: Role::User,
236            content: "hello world!!".to_string(),
237            parts: vec![],
238            metadata: MessageMetadata::default(),
239        };
240        let expected_delta = agent.metrics.token_counter.count_message_tokens(&msg) as u64;
241        agent.push_message(msg);
242        assert_eq!(
243            agent.providers.cached_prompt_tokens,
244            before + expected_delta
245        );
246    }
247
248    #[test]
249    fn recompute_prompt_tokens_matches_sum() {
250        let provider = mock_provider(vec![]);
251        let channel = MockChannel::new(vec![]);
252        let registry = create_test_registry();
253        let executor = MockToolExecutor::no_tools();
254        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
255
256        agent.msg.messages.push(Message {
257            role: Role::User,
258            content: "1234".to_string(),
259            parts: vec![],
260            metadata: MessageMetadata::default(),
261        });
262        agent.msg.messages.push(Message {
263            role: Role::Assistant,
264            content: "5678".to_string(),
265            parts: vec![],
266            metadata: MessageMetadata::default(),
267        });
268
269        agent.recompute_prompt_tokens();
270
271        let expected: u64 = agent
272            .msg
273            .messages
274            .iter()
275            .map(|m| agent.metrics.token_counter.count_message_tokens(m) as u64)
276            .sum();
277        assert_eq!(agent.providers.cached_prompt_tokens, expected);
278    }
279
280    #[test]
281    fn inject_code_context_into_messages_with_existing_content() {
282        let provider = mock_provider(vec![]);
283        let channel = MockChannel::new(vec![]);
284        let registry = create_test_registry();
285        let executor = MockToolExecutor::no_tools();
286        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
287
288        // Add a user message so we have more than 1 message
289        agent.push_message(Message {
290            role: Role::User,
291            content: "question".to_string(),
292            parts: vec![],
293            metadata: MessageMetadata::default(),
294        });
295
296        agent.inject_code_context("some code here");
297
298        let found = agent.msg.messages.iter().any(|m| {
299            m.parts.iter().any(|p| {
300                matches!(p, MessagePart::CodeContext { text } if text.contains("some code here"))
301            })
302        });
303        assert!(found, "code context should be injected into messages");
304    }
305
306    #[test]
307    fn inject_code_context_empty_text_is_noop() {
308        let provider = mock_provider(vec![]);
309        let channel = MockChannel::new(vec![]);
310        let registry = create_test_registry();
311        let executor = MockToolExecutor::no_tools();
312        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
313
314        agent.push_message(Message {
315            role: Role::User,
316            content: "question".to_string(),
317            parts: vec![],
318            metadata: MessageMetadata::default(),
319        });
320        let count_before = agent.msg.messages.len();
321
322        agent.inject_code_context("");
323
324        // No code context message inserted for empty text
325        assert_eq!(agent.msg.messages.len(), count_before);
326    }
327
328    #[test]
329    fn inject_code_context_with_single_message_is_noop() {
330        let provider = mock_provider(vec![]);
331        let channel = MockChannel::new(vec![]);
332        let registry = create_test_registry();
333        let executor = MockToolExecutor::no_tools();
334        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
335        // Only system prompt → len == 1 → inject should be noop
336        let count_before = agent.msg.messages.len();
337
338        agent.inject_code_context("some code");
339
340        assert_eq!(agent.msg.messages.len(), count_before);
341    }
342
343    #[test]
344    fn context_messages_returns_all_messages() {
345        let provider = mock_provider(vec![]);
346        let channel = MockChannel::new(vec![]);
347        let registry = create_test_registry();
348        let executor = MockToolExecutor::no_tools();
349        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
350
351        agent.push_message(Message {
352            role: Role::User,
353            content: "test".to_string(),
354            parts: vec![],
355            metadata: MessageMetadata::default(),
356        });
357
358        assert_eq!(agent.context_messages().len(), agent.msg.messages.len());
359    }
360}