Skip to main content

zeph_core/agent/
utils.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use zeph_llm::provider::{LlmProvider, Message, MessagePart, Role};
5
6use super::{Agent, CODE_CONTEXT_PREFIX};
7use crate::channel::Channel;
8use crate::metrics::{MetricsSnapshot, SECURITY_EVENT_CAP, SecurityEvent, SecurityEventCategory};
9
10impl<C: Channel> Agent<C> {
11    /// Read the community-detection failure counter from `SemanticMemory` and update metrics.
12    pub fn sync_community_detection_failures(&self) {
13        if let Some(memory) = self.memory_state.memory.as_ref() {
14            let failures = memory.community_detection_failures();
15            self.update_metrics(|m| {
16                m.graph_community_detection_failures = failures;
17            });
18        }
19    }
20
21    /// Sync all graph counters (extraction count/failures) from `SemanticMemory` to metrics.
22    pub fn sync_graph_extraction_metrics(&self) {
23        if let Some(memory) = self.memory_state.memory.as_ref() {
24            let count = memory.graph_extraction_count();
25            let failures = memory.graph_extraction_failures();
26            self.update_metrics(|m| {
27                m.graph_extraction_count = count;
28                m.graph_extraction_failures = failures;
29            });
30        }
31    }
32
33    /// Fetch entity/edge/community counts from the graph store and write to metrics.
34    pub async fn sync_graph_counts(&self) {
35        let Some(memory) = self.memory_state.memory.as_ref() else {
36            return;
37        };
38        let Some(store) = memory.graph_store.as_ref() else {
39            return;
40        };
41        let (entities, edges, communities) = tokio::join!(
42            store.entity_count(),
43            store.active_edge_count(),
44            store.community_count()
45        );
46        self.update_metrics(|m| {
47            m.graph_entities_total = entities.unwrap_or(0).cast_unsigned();
48            m.graph_edges_total = edges.unwrap_or(0).cast_unsigned();
49            m.graph_communities_total = communities.unwrap_or(0).cast_unsigned();
50        });
51    }
52
53    /// Perform a real health check on the vector store and update metrics.
54    pub async fn check_vector_store_health(&self, backend_name: &str) {
55        let connected = match self.memory_state.memory.as_ref() {
56            Some(m) => m.is_vector_store_connected().await,
57            None => false,
58        };
59        let name = backend_name.to_owned();
60        self.update_metrics(|m| {
61            m.qdrant_available = connected;
62            m.vector_backend = name;
63        });
64    }
65
66    pub(super) fn update_metrics(&self, f: impl FnOnce(&mut MetricsSnapshot)) {
67        if let Some(ref tx) = self.metrics.metrics_tx {
68            let elapsed = self.lifecycle.start_time.elapsed().as_secs();
69            tx.send_modify(|m| {
70                m.uptime_seconds = elapsed;
71                f(m);
72            });
73        }
74    }
75
76    pub(super) fn push_security_event(
77        &self,
78        category: SecurityEventCategory,
79        source: &str,
80        detail: impl Into<String>,
81    ) {
82        if let Some(ref tx) = self.metrics.metrics_tx {
83            let event = SecurityEvent::new(category, source, detail);
84            let elapsed = self.lifecycle.start_time.elapsed().as_secs();
85            tx.send_modify(|m| {
86                m.uptime_seconds = elapsed;
87                if m.security_events.len() >= SECURITY_EVENT_CAP {
88                    m.security_events.pop_front();
89                }
90                m.security_events.push_back(event);
91            });
92        }
93    }
94
95    pub(super) fn recompute_prompt_tokens(&mut self) {
96        self.providers.cached_prompt_tokens = self
97            .messages
98            .iter()
99            .map(|m| self.metrics.token_counter.count_message_tokens(m) as u64)
100            .sum();
101    }
102
103    pub(super) fn push_message(&mut self, msg: Message) {
104        self.providers.cached_prompt_tokens +=
105            self.metrics.token_counter.count_message_tokens(&msg) as u64;
106        self.messages.push(msg);
107    }
108
109    pub(crate) fn record_cost(&self, prompt_tokens: u64, completion_tokens: u64) {
110        if let Some(ref tracker) = self.metrics.cost_tracker {
111            tracker.record_usage(&self.runtime.model_name, prompt_tokens, completion_tokens);
112            self.update_metrics(|m| {
113                m.cost_spent_cents = tracker.current_spend();
114            });
115        }
116    }
117
118    pub(crate) fn record_cache_usage(&self) {
119        if let Some((creation, read)) = self.provider.last_cache_usage() {
120            self.update_metrics(|m| {
121                m.cache_creation_tokens += creation;
122                m.cache_read_tokens += read;
123            });
124        }
125    }
126
127    /// Inject pre-formatted code context into the message list.
128    /// The caller is responsible for retrieving and formatting the text.
129    pub fn inject_code_context(&mut self, text: &str) {
130        self.remove_code_context_messages();
131        if text.is_empty() || self.messages.len() <= 1 {
132            return;
133        }
134        let content = format!("{CODE_CONTEXT_PREFIX}{text}");
135        self.messages.insert(
136            1,
137            Message::from_parts(
138                Role::System,
139                vec![MessagePart::CodeContext { text: content }],
140            ),
141        );
142    }
143
144    #[must_use]
145    pub fn context_messages(&self) -> &[Message] {
146        &self.messages
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use super::super::agent_tests::{
153        MockChannel, MockToolExecutor, create_test_registry, mock_provider,
154    };
155    use super::*;
156    use zeph_llm::provider::{MessageMetadata, MessagePart};
157
158    #[test]
159    fn push_message_increments_cached_tokens() {
160        let provider = mock_provider(vec![]);
161        let channel = MockChannel::new(vec![]);
162        let registry = create_test_registry();
163        let executor = MockToolExecutor::no_tools();
164        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
165
166        let before = agent.providers.cached_prompt_tokens;
167        let msg = Message {
168            role: Role::User,
169            content: "hello world!!".to_string(),
170            parts: vec![],
171            metadata: MessageMetadata::default(),
172        };
173        let expected_delta = agent.metrics.token_counter.count_message_tokens(&msg) as u64;
174        agent.push_message(msg);
175        assert_eq!(
176            agent.providers.cached_prompt_tokens,
177            before + expected_delta
178        );
179    }
180
181    #[test]
182    fn recompute_prompt_tokens_matches_sum() {
183        let provider = mock_provider(vec![]);
184        let channel = MockChannel::new(vec![]);
185        let registry = create_test_registry();
186        let executor = MockToolExecutor::no_tools();
187        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
188
189        agent.messages.push(Message {
190            role: Role::User,
191            content: "1234".to_string(),
192            parts: vec![],
193            metadata: MessageMetadata::default(),
194        });
195        agent.messages.push(Message {
196            role: Role::Assistant,
197            content: "5678".to_string(),
198            parts: vec![],
199            metadata: MessageMetadata::default(),
200        });
201
202        agent.recompute_prompt_tokens();
203
204        let expected: u64 = agent
205            .messages
206            .iter()
207            .map(|m| agent.metrics.token_counter.count_message_tokens(m) as u64)
208            .sum();
209        assert_eq!(agent.providers.cached_prompt_tokens, expected);
210    }
211
212    #[test]
213    fn inject_code_context_into_messages_with_existing_content() {
214        let provider = mock_provider(vec![]);
215        let channel = MockChannel::new(vec![]);
216        let registry = create_test_registry();
217        let executor = MockToolExecutor::no_tools();
218        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
219
220        // Add a user message so we have more than 1 message
221        agent.push_message(Message {
222            role: Role::User,
223            content: "question".to_string(),
224            parts: vec![],
225            metadata: MessageMetadata::default(),
226        });
227
228        agent.inject_code_context("some code here");
229
230        let found = agent.messages.iter().any(|m| {
231            m.parts.iter().any(|p| {
232                matches!(p, MessagePart::CodeContext { text } if text.contains("some code here"))
233            })
234        });
235        assert!(found, "code context should be injected into messages");
236    }
237
238    #[test]
239    fn inject_code_context_empty_text_is_noop() {
240        let provider = mock_provider(vec![]);
241        let channel = MockChannel::new(vec![]);
242        let registry = create_test_registry();
243        let executor = MockToolExecutor::no_tools();
244        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
245
246        agent.push_message(Message {
247            role: Role::User,
248            content: "question".to_string(),
249            parts: vec![],
250            metadata: MessageMetadata::default(),
251        });
252        let count_before = agent.messages.len();
253
254        agent.inject_code_context("");
255
256        // No code context message inserted for empty text
257        assert_eq!(agent.messages.len(), count_before);
258    }
259
260    #[test]
261    fn inject_code_context_with_single_message_is_noop() {
262        let provider = mock_provider(vec![]);
263        let channel = MockChannel::new(vec![]);
264        let registry = create_test_registry();
265        let executor = MockToolExecutor::no_tools();
266        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
267        // Only system prompt → len == 1 → inject should be noop
268        let count_before = agent.messages.len();
269
270        agent.inject_code_context("some code");
271
272        assert_eq!(agent.messages.len(), count_before);
273    }
274
275    #[test]
276    fn context_messages_returns_all_messages() {
277        let provider = mock_provider(vec![]);
278        let channel = MockChannel::new(vec![]);
279        let registry = create_test_registry();
280        let executor = MockToolExecutor::no_tools();
281        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
282
283        agent.push_message(Message {
284            role: Role::User,
285            content: "test".to_string(),
286            parts: vec![],
287            metadata: MessageMetadata::default(),
288        });
289
290        assert_eq!(agent.context_messages().len(), agent.messages.len());
291    }
292}