Skip to main content

zeph_core/agent/
utils.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use zeph_llm::provider::{LlmProvider, Message, MessagePart, Role};
5
6use super::{Agent, CODE_CONTEXT_PREFIX};
7use crate::channel::Channel;
8use crate::metrics::{MetricsSnapshot, SECURITY_EVENT_CAP, SecurityEvent, SecurityEventCategory};
9
10impl<C: Channel> Agent<C> {
11    /// Read the community-detection failure counter from `SemanticMemory` and update metrics.
12    pub fn sync_community_detection_failures(&self) {
13        if let Some(memory) = self.memory_state.memory.as_ref() {
14            let failures = memory.community_detection_failures();
15            self.update_metrics(|m| {
16                m.graph_community_detection_failures = failures;
17            });
18        }
19    }
20
21    /// Sync all graph counters (extraction count/failures) from `SemanticMemory` to metrics.
22    pub fn sync_graph_extraction_metrics(&self) {
23        if let Some(memory) = self.memory_state.memory.as_ref() {
24            let count = memory.graph_extraction_count();
25            let failures = memory.graph_extraction_failures();
26            self.update_metrics(|m| {
27                m.graph_extraction_count = count;
28                m.graph_extraction_failures = failures;
29            });
30        }
31    }
32
33    /// Fetch entity/edge/community counts from the graph store and write to metrics.
34    pub async fn sync_graph_counts(&self) {
35        let Some(memory) = self.memory_state.memory.as_ref() else {
36            return;
37        };
38        let Some(store) = memory.graph_store.as_ref() else {
39            return;
40        };
41        let (entities, edges, communities) = tokio::join!(
42            store.entity_count(),
43            store.active_edge_count(),
44            store.community_count()
45        );
46        self.update_metrics(|m| {
47            m.graph_entities_total = entities.unwrap_or(0).cast_unsigned();
48            m.graph_edges_total = edges.unwrap_or(0).cast_unsigned();
49            m.graph_communities_total = communities.unwrap_or(0).cast_unsigned();
50        });
51    }
52
53    /// Perform a real health check on the vector store and update metrics.
54    pub async fn check_vector_store_health(&self, backend_name: &str) {
55        let connected = match self.memory_state.memory.as_ref() {
56            Some(m) => m.is_vector_store_connected().await,
57            None => false,
58        };
59        let name = backend_name.to_owned();
60        self.update_metrics(|m| {
61            m.qdrant_available = connected;
62            m.vector_backend = name;
63        });
64    }
65
66    pub(super) fn update_metrics(&self, f: impl FnOnce(&mut MetricsSnapshot)) {
67        if let Some(ref tx) = self.metrics_tx {
68            let elapsed = self.start_time.elapsed().as_secs();
69            tx.send_modify(|m| {
70                m.uptime_seconds = elapsed;
71                f(m);
72            });
73        }
74    }
75
76    pub(super) fn push_security_event(
77        &self,
78        category: SecurityEventCategory,
79        source: &str,
80        detail: impl Into<String>,
81    ) {
82        if let Some(ref tx) = self.metrics_tx {
83            let event = SecurityEvent::new(category, source, detail);
84            let elapsed = self.start_time.elapsed().as_secs();
85            tx.send_modify(|m| {
86                m.uptime_seconds = elapsed;
87                if m.security_events.len() >= SECURITY_EVENT_CAP {
88                    m.security_events.pop_front();
89                }
90                m.security_events.push_back(event);
91            });
92        }
93    }
94
95    pub(super) fn recompute_prompt_tokens(&mut self) {
96        self.cached_prompt_tokens = self
97            .messages
98            .iter()
99            .map(|m| self.token_counter.count_message_tokens(m) as u64)
100            .sum();
101    }
102
103    pub(super) fn push_message(&mut self, msg: Message) {
104        self.cached_prompt_tokens += self.token_counter.count_message_tokens(&msg) as u64;
105        self.messages.push(msg);
106    }
107
108    pub(crate) fn record_cost(&self, prompt_tokens: u64, completion_tokens: u64) {
109        if let Some(ref tracker) = self.cost_tracker {
110            tracker.record_usage(&self.runtime.model_name, prompt_tokens, completion_tokens);
111            self.update_metrics(|m| {
112                m.cost_spent_cents = tracker.current_spend();
113            });
114        }
115    }
116
117    pub(crate) fn record_cache_usage(&self) {
118        if let Some((creation, read)) = self.provider.last_cache_usage() {
119            self.update_metrics(|m| {
120                m.cache_creation_tokens += creation;
121                m.cache_read_tokens += read;
122            });
123        }
124    }
125
126    /// Inject pre-formatted code context into the message list.
127    /// The caller is responsible for retrieving and formatting the text.
128    pub fn inject_code_context(&mut self, text: &str) {
129        self.remove_code_context_messages();
130        if text.is_empty() || self.messages.len() <= 1 {
131            return;
132        }
133        let content = format!("{CODE_CONTEXT_PREFIX}{text}");
134        self.messages.insert(
135            1,
136            Message::from_parts(
137                Role::System,
138                vec![MessagePart::CodeContext { text: content }],
139            ),
140        );
141    }
142
143    #[must_use]
144    pub fn context_messages(&self) -> &[Message] {
145        &self.messages
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use super::super::agent_tests::{
152        MockChannel, MockToolExecutor, create_test_registry, mock_provider,
153    };
154    use super::*;
155    use zeph_llm::provider::{MessageMetadata, MessagePart};
156
157    #[test]
158    fn push_message_increments_cached_tokens() {
159        let provider = mock_provider(vec![]);
160        let channel = MockChannel::new(vec![]);
161        let registry = create_test_registry();
162        let executor = MockToolExecutor::no_tools();
163        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
164
165        let before = agent.cached_prompt_tokens;
166        let msg = Message {
167            role: Role::User,
168            content: "hello world!!".to_string(),
169            parts: vec![],
170            metadata: MessageMetadata::default(),
171        };
172        let expected_delta = agent.token_counter.count_message_tokens(&msg) as u64;
173        agent.push_message(msg);
174        assert_eq!(agent.cached_prompt_tokens, before + expected_delta);
175    }
176
177    #[test]
178    fn recompute_prompt_tokens_matches_sum() {
179        let provider = mock_provider(vec![]);
180        let channel = MockChannel::new(vec![]);
181        let registry = create_test_registry();
182        let executor = MockToolExecutor::no_tools();
183        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
184
185        agent.messages.push(Message {
186            role: Role::User,
187            content: "1234".to_string(),
188            parts: vec![],
189            metadata: MessageMetadata::default(),
190        });
191        agent.messages.push(Message {
192            role: Role::Assistant,
193            content: "5678".to_string(),
194            parts: vec![],
195            metadata: MessageMetadata::default(),
196        });
197
198        agent.recompute_prompt_tokens();
199
200        let expected: u64 = agent
201            .messages
202            .iter()
203            .map(|m| agent.token_counter.count_message_tokens(m) as u64)
204            .sum();
205        assert_eq!(agent.cached_prompt_tokens, expected);
206    }
207
208    #[test]
209    fn inject_code_context_into_messages_with_existing_content() {
210        let provider = mock_provider(vec![]);
211        let channel = MockChannel::new(vec![]);
212        let registry = create_test_registry();
213        let executor = MockToolExecutor::no_tools();
214        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
215
216        // Add a user message so we have more than 1 message
217        agent.push_message(Message {
218            role: Role::User,
219            content: "question".to_string(),
220            parts: vec![],
221            metadata: MessageMetadata::default(),
222        });
223
224        agent.inject_code_context("some code here");
225
226        let found = agent.messages.iter().any(|m| {
227            m.parts.iter().any(|p| {
228                matches!(p, MessagePart::CodeContext { text } if text.contains("some code here"))
229            })
230        });
231        assert!(found, "code context should be injected into messages");
232    }
233
234    #[test]
235    fn inject_code_context_empty_text_is_noop() {
236        let provider = mock_provider(vec![]);
237        let channel = MockChannel::new(vec![]);
238        let registry = create_test_registry();
239        let executor = MockToolExecutor::no_tools();
240        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
241
242        agent.push_message(Message {
243            role: Role::User,
244            content: "question".to_string(),
245            parts: vec![],
246            metadata: MessageMetadata::default(),
247        });
248        let count_before = agent.messages.len();
249
250        agent.inject_code_context("");
251
252        // No code context message inserted for empty text
253        assert_eq!(agent.messages.len(), count_before);
254    }
255
256    #[test]
257    fn inject_code_context_with_single_message_is_noop() {
258        let provider = mock_provider(vec![]);
259        let channel = MockChannel::new(vec![]);
260        let registry = create_test_registry();
261        let executor = MockToolExecutor::no_tools();
262        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
263        // Only system prompt → len == 1 → inject should be noop
264        let count_before = agent.messages.len();
265
266        agent.inject_code_context("some code");
267
268        assert_eq!(agent.messages.len(), count_before);
269    }
270
271    #[test]
272    fn context_messages_returns_all_messages() {
273        let provider = mock_provider(vec![]);
274        let channel = MockChannel::new(vec![]);
275        let registry = create_test_registry();
276        let executor = MockToolExecutor::no_tools();
277        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
278
279        agent.push_message(Message {
280            role: Role::User,
281            content: "test".to_string(),
282            parts: vec![],
283            metadata: MessageMetadata::default(),
284        });
285
286        assert_eq!(agent.context_messages().len(), agent.messages.len());
287    }
288}