Skip to main content

zeph_core/agent/
utils.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use zeph_llm::provider::{LlmProvider, Message, MessagePart, Role};
5
6use super::{Agent, CODE_CONTEXT_PREFIX};
7use crate::channel::Channel;
8use crate::metrics::{MetricsSnapshot, SECURITY_EVENT_CAP, SecurityEvent, SecurityEventCategory};
9use zeph_tools::FilterStats;
10
11impl<C: Channel> Agent<C> {
12    /// Read the community-detection failure counter from `SemanticMemory` and update metrics.
13    pub fn sync_community_detection_failures(&self) {
14        if let Some(memory) = self.memory_state.memory.as_ref() {
15            let failures = memory.community_detection_failures();
16            self.update_metrics(|m| {
17                m.graph_community_detection_failures = failures;
18            });
19        }
20    }
21
22    /// Sync all graph counters (extraction count/failures) from `SemanticMemory` to metrics.
23    pub fn sync_graph_extraction_metrics(&self) {
24        if let Some(memory) = self.memory_state.memory.as_ref() {
25            let count = memory.graph_extraction_count();
26            let failures = memory.graph_extraction_failures();
27            self.update_metrics(|m| {
28                m.graph_extraction_count = count;
29                m.graph_extraction_failures = failures;
30            });
31        }
32    }
33
34    /// Fetch entity/edge/community counts from the graph store and write to metrics.
35    pub async fn sync_graph_counts(&self) {
36        let Some(memory) = self.memory_state.memory.as_ref() else {
37            return;
38        };
39        let Some(store) = memory.graph_store.as_ref() else {
40            return;
41        };
42        let (entities, edges, communities) = tokio::join!(
43            store.entity_count(),
44            store.active_edge_count(),
45            store.community_count()
46        );
47        self.update_metrics(|m| {
48            m.graph_entities_total = entities.unwrap_or(0).cast_unsigned();
49            m.graph_edges_total = edges.unwrap_or(0).cast_unsigned();
50            m.graph_communities_total = communities.unwrap_or(0).cast_unsigned();
51        });
52    }
53
54    /// Perform a real health check on the vector store and update metrics.
55    pub async fn check_vector_store_health(&self, backend_name: &str) {
56        let connected = match self.memory_state.memory.as_ref() {
57            Some(m) => m.is_vector_store_connected().await,
58            None => false,
59        };
60        let name = backend_name.to_owned();
61        self.update_metrics(|m| {
62            m.qdrant_available = connected;
63            m.vector_backend = name;
64        });
65    }
66
67    /// Fetch compression-guidelines metadata from `SQLite` and write to metrics.
68    ///
69    /// Only fetches version and `created_at`; does not load the full guidelines text.
70    /// Feature-gated: compiled only when `compression-guidelines` is enabled.
71    #[cfg(feature = "compression-guidelines")]
72    pub async fn sync_guidelines_status(&self) {
73        let Some(memory) = self.memory_state.memory.as_ref() else {
74            return;
75        };
76        let cid = self.memory_state.conversation_id;
77        match memory.sqlite().load_compression_guidelines_meta(cid).await {
78            Ok((version, created_at)) => {
79                #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
80                let version_u32 = u32::try_from(version).unwrap_or(0);
81                self.update_metrics(|m| {
82                    m.guidelines_version = version_u32;
83                    m.guidelines_updated_at = created_at;
84                });
85            }
86            Err(e) => {
87                tracing::warn!("failed to sync guidelines status: {e:#}");
88            }
89        }
90    }
91
92    pub(super) fn record_filter_metrics(&mut self, fs: &FilterStats) {
93        let saved = fs.estimated_tokens_saved() as u64;
94        let raw = (fs.raw_chars / 4) as u64;
95        let confidence = fs.confidence;
96        let was_filtered = fs.filtered_chars < fs.raw_chars;
97        self.update_metrics(|m| {
98            m.filter_raw_tokens += raw;
99            m.filter_saved_tokens += saved;
100            m.filter_applications += 1;
101            m.filter_total_commands += 1;
102            if was_filtered {
103                m.filter_filtered_commands += 1;
104            }
105            if let Some(c) = confidence {
106                match c {
107                    zeph_tools::FilterConfidence::Full => {
108                        m.filter_confidence_full += 1;
109                    }
110                    zeph_tools::FilterConfidence::Partial => {
111                        m.filter_confidence_partial += 1;
112                    }
113                    zeph_tools::FilterConfidence::Fallback => {
114                        m.filter_confidence_fallback += 1;
115                    }
116                }
117            }
118        });
119    }
120
121    pub(super) fn update_metrics(&self, f: impl FnOnce(&mut MetricsSnapshot)) {
122        if let Some(ref tx) = self.metrics.metrics_tx {
123            let elapsed = self.lifecycle.start_time.elapsed().as_secs();
124            tx.send_modify(|m| {
125                m.uptime_seconds = elapsed;
126                f(m);
127            });
128        }
129    }
130
131    pub(super) fn push_security_event(
132        &self,
133        category: SecurityEventCategory,
134        source: &str,
135        detail: impl Into<String>,
136    ) {
137        if let Some(ref tx) = self.metrics.metrics_tx {
138            let event = SecurityEvent::new(category, source, detail);
139            let elapsed = self.lifecycle.start_time.elapsed().as_secs();
140            tx.send_modify(|m| {
141                m.uptime_seconds = elapsed;
142                if m.security_events.len() >= SECURITY_EVENT_CAP {
143                    m.security_events.pop_front();
144                }
145                m.security_events.push_back(event);
146            });
147        }
148    }
149
150    pub(super) fn recompute_prompt_tokens(&mut self) {
151        self.providers.cached_prompt_tokens = self
152            .msg
153            .messages
154            .iter()
155            .map(|m| self.metrics.token_counter.count_message_tokens(m) as u64)
156            .sum();
157    }
158
159    pub(super) fn push_message(&mut self, msg: Message) {
160        self.providers.cached_prompt_tokens +=
161            self.metrics.token_counter.count_message_tokens(&msg) as u64;
162        self.msg.messages.push(msg);
163    }
164
165    pub(crate) fn record_cost(&self, prompt_tokens: u64, completion_tokens: u64) {
166        if let Some(ref tracker) = self.metrics.cost_tracker {
167            tracker.record_usage(&self.runtime.model_name, prompt_tokens, completion_tokens);
168            self.update_metrics(|m| {
169                m.cost_spent_cents = tracker.current_spend();
170            });
171        }
172    }
173
174    pub(crate) fn record_cache_usage(&self) {
175        if let Some((creation, read)) = self.provider.last_cache_usage() {
176            self.update_metrics(|m| {
177                m.cache_creation_tokens += creation;
178                m.cache_read_tokens += read;
179            });
180        }
181    }
182
183    /// Inject pre-formatted code context into the message list.
184    /// The caller is responsible for retrieving and formatting the text.
185    pub fn inject_code_context(&mut self, text: &str) {
186        self.remove_code_context_messages();
187        if text.is_empty() || self.msg.messages.len() <= 1 {
188            return;
189        }
190        let content = format!("{CODE_CONTEXT_PREFIX}{text}");
191        self.msg.messages.insert(
192            1,
193            Message::from_parts(
194                Role::System,
195                vec![MessagePart::CodeContext { text: content }],
196            ),
197        );
198    }
199
200    #[must_use]
201    pub fn context_messages(&self) -> &[Message] {
202        &self.msg.messages
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::super::agent_tests::{
209        MockChannel, MockToolExecutor, create_test_registry, mock_provider,
210    };
211    use super::*;
212    use zeph_llm::provider::{MessageMetadata, MessagePart};
213
214    #[test]
215    fn push_message_increments_cached_tokens() {
216        let provider = mock_provider(vec![]);
217        let channel = MockChannel::new(vec![]);
218        let registry = create_test_registry();
219        let executor = MockToolExecutor::no_tools();
220        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
221
222        let before = agent.providers.cached_prompt_tokens;
223        let msg = Message {
224            role: Role::User,
225            content: "hello world!!".to_string(),
226            parts: vec![],
227            metadata: MessageMetadata::default(),
228        };
229        let expected_delta = agent.metrics.token_counter.count_message_tokens(&msg) as u64;
230        agent.push_message(msg);
231        assert_eq!(
232            agent.providers.cached_prompt_tokens,
233            before + expected_delta
234        );
235    }
236
237    #[test]
238    fn recompute_prompt_tokens_matches_sum() {
239        let provider = mock_provider(vec![]);
240        let channel = MockChannel::new(vec![]);
241        let registry = create_test_registry();
242        let executor = MockToolExecutor::no_tools();
243        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
244
245        agent.msg.messages.push(Message {
246            role: Role::User,
247            content: "1234".to_string(),
248            parts: vec![],
249            metadata: MessageMetadata::default(),
250        });
251        agent.msg.messages.push(Message {
252            role: Role::Assistant,
253            content: "5678".to_string(),
254            parts: vec![],
255            metadata: MessageMetadata::default(),
256        });
257
258        agent.recompute_prompt_tokens();
259
260        let expected: u64 = agent
261            .msg
262            .messages
263            .iter()
264            .map(|m| agent.metrics.token_counter.count_message_tokens(m) as u64)
265            .sum();
266        assert_eq!(agent.providers.cached_prompt_tokens, expected);
267    }
268
269    #[test]
270    fn inject_code_context_into_messages_with_existing_content() {
271        let provider = mock_provider(vec![]);
272        let channel = MockChannel::new(vec![]);
273        let registry = create_test_registry();
274        let executor = MockToolExecutor::no_tools();
275        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
276
277        // Add a user message so we have more than 1 message
278        agent.push_message(Message {
279            role: Role::User,
280            content: "question".to_string(),
281            parts: vec![],
282            metadata: MessageMetadata::default(),
283        });
284
285        agent.inject_code_context("some code here");
286
287        let found = agent.msg.messages.iter().any(|m| {
288            m.parts.iter().any(|p| {
289                matches!(p, MessagePart::CodeContext { text } if text.contains("some code here"))
290            })
291        });
292        assert!(found, "code context should be injected into messages");
293    }
294
295    #[test]
296    fn inject_code_context_empty_text_is_noop() {
297        let provider = mock_provider(vec![]);
298        let channel = MockChannel::new(vec![]);
299        let registry = create_test_registry();
300        let executor = MockToolExecutor::no_tools();
301        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
302
303        agent.push_message(Message {
304            role: Role::User,
305            content: "question".to_string(),
306            parts: vec![],
307            metadata: MessageMetadata::default(),
308        });
309        let count_before = agent.msg.messages.len();
310
311        agent.inject_code_context("");
312
313        // No code context message inserted for empty text
314        assert_eq!(agent.msg.messages.len(), count_before);
315    }
316
317    #[test]
318    fn inject_code_context_with_single_message_is_noop() {
319        let provider = mock_provider(vec![]);
320        let channel = MockChannel::new(vec![]);
321        let registry = create_test_registry();
322        let executor = MockToolExecutor::no_tools();
323        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
324        // Only system prompt → len == 1 → inject should be noop
325        let count_before = agent.msg.messages.len();
326
327        agent.inject_code_context("some code");
328
329        assert_eq!(agent.msg.messages.len(), count_before);
330    }
331
332    #[test]
333    fn context_messages_returns_all_messages() {
334        let provider = mock_provider(vec![]);
335        let channel = MockChannel::new(vec![]);
336        let registry = create_test_registry();
337        let executor = MockToolExecutor::no_tools();
338        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
339
340        agent.push_message(Message {
341            role: Role::User,
342            content: "test".to_string(),
343            parts: vec![],
344            metadata: MessageMetadata::default(),
345        });
346
347        assert_eq!(agent.context_messages().len(), agent.msg.messages.len());
348    }
349}