1use std::collections::VecDeque;
13use std::time::Instant;
14
15use zeph_core::channel::{ChannelError, ChannelMessage, ToolOutputEvent};
16
17#[derive(Debug, Clone)]
39pub struct CapturedResponse {
40 pub prompt_index: usize,
42 pub text: String,
44 pub elapsed: std::time::Duration,
47 pub input_tokens: u64,
49 pub output_tokens: u64,
51 pub context_window: u64,
53}
54
55pub struct BenchmarkChannel {
85 prompts: VecDeque<String>,
86 responses: Vec<CapturedResponse>,
87 current_index: usize,
88 total: usize,
89 chunk_buffer: String,
91 chunk_start: Option<Instant>,
92 pending_input_tokens: u64,
94 pending_output_tokens: u64,
95 pending_context_window: u64,
96}
97
98impl BenchmarkChannel {
99 #[must_use]
114 pub fn new(prompts: Vec<String>) -> Self {
115 let total = prompts.len();
116 Self {
117 prompts: VecDeque::from(prompts),
118 responses: Vec::new(),
119 current_index: 0,
120 total,
121 chunk_buffer: String::new(),
122 chunk_start: None,
123 pending_input_tokens: 0,
124 pending_output_tokens: 0,
125 pending_context_window: 0,
126 }
127 }
128
129 #[must_use]
140 pub fn total(&self) -> usize {
141 self.total
142 }
143
144 #[must_use]
158 pub fn into_responses(self) -> Vec<CapturedResponse> {
159 self.responses
160 }
161
162 #[must_use]
173 pub fn responses(&self) -> &[CapturedResponse] {
174 &self.responses
175 }
176
177 fn flush_chunk_buffer(&mut self) {
178 if self.chunk_buffer.is_empty() {
179 return;
180 }
181 let elapsed = self
182 .chunk_start
183 .map_or(std::time::Duration::ZERO, |s| s.elapsed());
184 self.responses.push(CapturedResponse {
185 prompt_index: self.current_index.saturating_sub(1),
186 text: std::mem::take(&mut self.chunk_buffer),
187 elapsed,
188 input_tokens: self.pending_input_tokens,
189 output_tokens: self.pending_output_tokens,
190 context_window: self.pending_context_window,
191 });
192 self.chunk_start = None;
193 self.pending_input_tokens = 0;
194 self.pending_output_tokens = 0;
195 self.pending_context_window = 0;
196 }
197}
198
199impl zeph_core::channel::Channel for BenchmarkChannel {
200 async fn recv(&mut self) -> Result<Option<ChannelMessage>, ChannelError> {
201 match self.prompts.pop_front() {
202 Some(text) => {
203 self.current_index += 1;
204 Ok(Some(ChannelMessage {
205 text,
206 attachments: vec![],
207 }))
208 }
209 None => Ok(None),
210 }
211 }
212
213 fn supports_exit(&self) -> bool {
214 false
215 }
216
217 async fn send(&mut self, text: &str) -> Result<(), ChannelError> {
218 self.responses.push(CapturedResponse {
219 prompt_index: self.current_index.saturating_sub(1),
220 text: text.to_owned(),
221 elapsed: std::time::Duration::ZERO,
222 input_tokens: self.pending_input_tokens,
223 output_tokens: self.pending_output_tokens,
224 context_window: self.pending_context_window,
225 });
226 self.pending_input_tokens = 0;
227 self.pending_output_tokens = 0;
228 self.pending_context_window = 0;
229 Ok(())
230 }
231
232 async fn send_chunk(&mut self, chunk: &str) -> Result<(), ChannelError> {
233 if self.chunk_start.is_none() {
234 self.chunk_start = Some(Instant::now());
235 }
236 self.chunk_buffer.push_str(chunk);
237 Ok(())
238 }
239
240 async fn flush_chunks(&mut self) -> Result<(), ChannelError> {
241 self.flush_chunk_buffer();
242 Ok(())
243 }
244
245 async fn send_usage(
246 &mut self,
247 input_tokens: u64,
248 output_tokens: u64,
249 context_window: u64,
250 ) -> Result<(), ChannelError> {
251 self.pending_input_tokens = input_tokens;
252 self.pending_output_tokens = output_tokens;
253 self.pending_context_window = context_window;
254 Ok(())
255 }
256
257 async fn send_tool_output(&mut self, _event: ToolOutputEvent) -> Result<(), ChannelError> {
262 Ok(())
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use zeph_core::channel::{
269 Channel, ElicitationField, ElicitationFieldType, ElicitationRequest, ElicitationResponse,
270 ToolOutputEvent,
271 };
272
273 use super::*;
274
275 #[tokio::test]
276 async fn recv_drains_queue_and_returns_none_when_empty() {
277 let mut ch = BenchmarkChannel::new(vec!["hello".into(), "world".into()]);
278 let msg1 = ch.recv().await.unwrap().unwrap();
279 assert_eq!(msg1.text, "hello");
280 let msg2 = ch.recv().await.unwrap().unwrap();
281 assert_eq!(msg2.text, "world");
282 let msg3 = ch.recv().await.unwrap();
283 assert!(msg3.is_none());
284 }
285
286 #[tokio::test]
287 async fn send_accumulates_response() {
288 let mut ch = BenchmarkChannel::new(vec!["prompt".into()]);
289 let _ = ch.recv().await.unwrap();
290 ch.send("response text").await.unwrap();
291 assert_eq!(ch.responses().len(), 1);
292 assert_eq!(ch.responses()[0].text, "response text");
293 }
294
295 #[tokio::test]
296 async fn confirm_returns_true() {
297 let mut ch = BenchmarkChannel::new(vec![]);
298 let result = ch.confirm("delete?").await.unwrap();
299 assert!(result);
300 }
301
302 #[tokio::test]
303 async fn elicit_returns_declined() {
304 let mut ch = BenchmarkChannel::new(vec![]);
305 let req = ElicitationRequest {
306 server_name: "test-server".into(),
307 message: "provide input".into(),
308 fields: vec![ElicitationField {
309 name: "field".into(),
310 description: None,
311 field_type: ElicitationFieldType::String,
312 required: true,
313 }],
314 };
315 let result = ch.elicit(req).await.unwrap();
316 assert!(matches!(result, ElicitationResponse::Declined));
317 }
318
319 #[tokio::test]
320 async fn send_chunk_and_flush_captures_response() {
321 let mut ch = BenchmarkChannel::new(vec!["p".into()]);
322 let _ = ch.recv().await.unwrap();
323 ch.send_chunk("part1").await.unwrap();
324 ch.send_chunk(" part2").await.unwrap();
325 ch.flush_chunks().await.unwrap();
326 assert_eq!(ch.responses().len(), 1);
327 assert_eq!(ch.responses()[0].text, "part1 part2");
328 }
329
330 #[tokio::test]
331 async fn supports_exit_returns_false() {
332 let ch = BenchmarkChannel::new(vec![]);
333 assert!(!ch.supports_exit());
334 }
335
336 #[tokio::test]
337 async fn send_usage_captured_on_send() {
338 let mut ch = BenchmarkChannel::new(vec!["p".into()]);
339 let _ = ch.recv().await.unwrap();
340 ch.send_usage(10, 20, 128_000).await.unwrap();
341 ch.send("answer").await.unwrap();
342 let r = &ch.responses()[0];
343 assert_eq!(r.input_tokens, 10);
344 assert_eq!(r.output_tokens, 20);
345 assert_eq!(r.context_window, 128_000);
346 }
347
348 #[tokio::test]
349 async fn send_tool_output_does_not_add_to_responses() {
350 let mut ch = BenchmarkChannel::new(vec!["p".into()]);
351 let _ = ch.recv().await.unwrap();
352 ch.send_tool_output(ToolOutputEvent {
353 tool_name: "bash".into(),
354 display: "some tool output".into(),
355 diff: None,
356 filter_stats: None,
357 kept_lines: None,
358 locations: None,
359 tool_call_id: "tc-1".into(),
360
361 terminal_id: None,
362 is_error: false,
363 parent_tool_use_id: None,
364 raw_response: None,
365 started_at: None,
366 })
367 .await
368 .unwrap();
369 assert_eq!(ch.responses().len(), 0);
371 }
372}