1use std::{
7 io,
8 sync::Arc,
9 time::{SystemTime, UNIX_EPOCH},
10};
11
12use axum::{
13 Json, Router,
14 extract::State,
15 http::{HeaderMap, HeaderName, HeaderValue, Method, StatusCode, Uri},
16 response::{
17 IntoResponse, Response,
18 sse::{Event, Sse},
19 },
20 routing::{get, post},
21};
22use serde_json::{Value, json};
23use thiserror::Error;
24use tokio::net::TcpListener;
25use tracing::{debug, error, info, warn};
26
27use crate::{
28 attestation::{AttestationError, AttestationVerifier},
29 config::ProxyConfig,
30 e2ee::{E2eeCodec, E2eeCodecError},
31 keys::ProxyInstanceKey,
32 openai::{
33 ErrorResponse,
34 chat::{
35 ChatCompletionRequest, ChatConstructionError, ChatRequestError, NormalizedChatMessage,
36 },
37 },
38 sessions::{AttestedModelState, SessionContext, SessionError, SessionManager, SessionRequest},
39 tools::{ToolEmulationContext, ToolOutputClassification, ValidatedToolCall},
40 venice::{VeniceClient, VeniceClientError},
41};
42
43pub const HEADER_PROXY_E2EE: &str = "X-Venice-Proxy-E2EE";
44pub const HEADER_PROXY_ATTESTATION_MODE: &str = "X-Venice-Proxy-Attestation-Mode";
45pub const HEADER_PROXY_ATTESTED_MODEL: &str = "X-Venice-Proxy-Attested-Model";
46pub const HEADER_PROXY_TEE_PROVIDER: &str = "X-Venice-Proxy-TEE-Provider";
47pub const HEADER_PROXY_TDX_VERIFIED: &str = "X-Venice-Proxy-TDX-Verified";
48pub const HEADER_PROXY_TDX_DEBUG: &str = "X-Venice-Proxy-TDX-Debug";
49pub const HEADER_PROXY_NVIDIA_VERIFIED: &str = "X-Venice-Proxy-NVIDIA-Verified";
50pub const HEADER_PROXY_KEY_BINDING: &str = "X-Venice-Proxy-Key-Binding";
51pub const HEADER_PROXY_SESSION_ID: &str = "X-Venice-Proxy-Session-Id";
52pub const HEADER_PROXY_SESSION_SCOPE: &str = "X-Venice-Proxy-Session-Scope";
53pub const HEADER_PROXY_TOOL_MODE: &str = "X-Venice-Proxy-Tool-Mode";
54pub const HEADER_PROXY_TOOL_RETRIES: &str = "X-Venice-Proxy-Tool-Retries";
55pub const HEADER_PROXY_ERROR_CODE: &str = "X-Venice-Proxy-Error-Code";
56
57#[derive(Debug, Clone)]
59pub struct AppState {
60 config: Arc<ProxyConfig>,
61 venice_client: VeniceClient,
62 proxy_instance_key: Option<ProxyInstanceKey>,
63 session_manager: SessionManager,
64 attestation_verifier: AttestationVerifier,
65}
66
67impl AppState {
68 pub fn new(config: ProxyConfig) -> Result<Self, VeniceClientError> {
70 let venice_client = VeniceClient::from_config(&config)?;
71 Ok(Self::from_parts(config, venice_client))
72 }
73
74 pub fn from_parts(config: ProxyConfig, venice_client: VeniceClient) -> Self {
76 let proxy_instance_key = ProxyInstanceKey::generate_from_config(&config.keys);
77 let session_manager = SessionManager::new(config.session.clone());
78 let attestation_verifier = AttestationVerifier::from_config(&config, venice_client.clone());
79
80 Self {
81 config: Arc::new(config),
82 venice_client,
83 proxy_instance_key,
84 session_manager,
85 attestation_verifier,
86 }
87 }
88
89 pub fn config(&self) -> &ProxyConfig {
91 &self.config
92 }
93
94 pub fn venice_client(&self) -> &VeniceClient {
96 &self.venice_client
97 }
98
99 pub fn proxy_instance_key(&self) -> Option<&ProxyInstanceKey> {
101 self.proxy_instance_key.as_ref()
102 }
103
104 pub fn session_manager(&self) -> &SessionManager {
106 &self.session_manager
107 }
108
109 pub fn attestation_verifier(&self) -> &AttestationVerifier {
111 &self.attestation_verifier
112 }
113}
114
115pub fn router(config: ProxyConfig) -> Result<Router, VeniceClientError> {
118 Ok(router_from_state(AppState::new(config)?))
119}
120
121pub fn router_with_venice_client(config: ProxyConfig, venice_client: VeniceClient) -> Router {
126 router_from_state(AppState::from_parts(config, venice_client))
127}
128
129fn router_from_state(state: AppState) -> Router {
131 Router::new()
132 .route("/v1/models", get(list_models).fallback(method_not_allowed))
133 .route(
134 "/v1/chat/completions",
135 post(create_chat_completion).fallback(method_not_allowed),
136 )
137 .fallback(not_found)
138 .with_state(state)
139}
140
141pub async fn serve(listener: TcpListener, router: Router) -> io::Result<()> {
143 axum::serve(listener, router).await
144}
145
146async fn list_models(State(state): State<AppState>) -> Result<Response, ProxyError> {
148 info!(route = "/v1/models", "listing Venice models");
149 let models = state.venice_client().list_models().await?;
150 let mut response = Json(models).into_response();
151 ProxyMetadataHeaders::from_config(state.config()).apply(response.headers_mut());
152 info!(route = "/v1/models", "Venice models response proxied");
153 Ok(response)
154}
155
156async fn create_chat_completion(
158 State(state): State<AppState>,
159 headers: HeaderMap,
160 Json(body): Json<Value>,
161) -> Result<Response, ProxyError> {
162 let request = ChatCompletionRequest::parse(&body)?;
163 let proxy_instance_key = state
164 .proxy_instance_key()
165 .ok_or(ProxyError::ProxyInstanceKeyUnavailable)?;
166
167 let session_resolution = state
168 .session_manager()
169 .get_or_create(SessionRequest::new(&request.model, &headers).with_body(&body))?;
170 let session_created = session_resolution.created;
171 let session_replaced_expired = session_resolution.replaced_expired;
172 let session_scope = session_resolution.session.scope;
173 let session = ensure_attested_session(&state, session_resolution.session).await?;
174 let model_public_key = session
175 .attested_model_public_key
176 .as_deref()
177 .ok_or(ProxyError::MissingAttestedModelKey)?;
178
179 let codec =
180 E2eeCodec::from_config(&state.config().e2ee).map_err(ChatConstructionError::E2ee)?;
181 let tool_context = ToolEmulationContext::from_request(&state.config().tools, &request)?;
182 let metadata = ProxyMetadataHeaders::for_verified_chat(state.config(), &session);
183
184 info!(
185 route = "/v1/chat/completions",
186 model = %request.model,
187 stream = request.stream,
188 message_count = request.messages.len(),
189 tool_count = request.tools.len(),
190 tool_mode = tool_context.is_some(),
191 session_created,
192 session_replaced_expired = ?session_replaced_expired,
193 session_scope = %session_scope,
194 "chat completion request accepted"
195 );
196
197 if let Some(tool_context) = tool_context {
198 info!(model = %request.model, "using tool-emulated chat completion");
199 return openai_tool_emulated_chat_response(
200 &state,
201 &request,
202 &tool_context,
203 codec,
204 proxy_instance_key.clone(),
205 model_public_key,
206 metadata,
207 )
208 .await;
209 }
210
211 let prepared = request.to_venice_e2ee_request(&codec, model_public_key)?;
212 info!(
213 model = %request.model,
214 client_stream = prepared.client_stream,
215 "forwarding encrypted chat completion to Venice"
216 );
217
218 let upstream = state
219 .venice_client()
220 .create_chat_completion_stream(
221 &prepared.upstream,
222 proxy_instance_key.public_key_hex(),
223 model_public_key,
224 )
225 .await?;
226
227 if prepared.client_stream {
228 info!(model = %request.model, "streaming chat completion response to client");
229 let include_usage_requested = request.stream_options.include_usage.unwrap_or(false);
230 let transformer = OpenAiChatStreamTransformer::new(
231 codec,
232 proxy_instance_key.clone(),
233 request.model.clone(),
234 include_usage_requested,
235 );
236 Ok(chat_sse_response(
237 upstream,
238 transformer,
239 request.model,
240 include_usage_requested,
241 &CHAT_SSE_LOG,
242 metadata,
243 ))
244 } else {
245 info!(model = %request.model, "buffering chat completion response for client");
246 openai_chat_buffered_response(
247 upstream,
248 codec,
249 proxy_instance_key.clone(),
250 request.model,
251 metadata,
252 )
253 .await
254 }
255}
256
257async fn ensure_attested_session(
259 state: &AppState,
260 session: SessionContext,
261) -> Result<SessionContext, ProxyError> {
262 if session.attested_model_public_key.is_some() {
263 info!(model = %session.model_id, session_scope = %session.scope, "using cached model attestation");
264 return Ok(session);
265 }
266
267 info!(model = %session.model_id, session_scope = %session.scope, "fetching model attestation");
268 let attestation = state
269 .attestation_verifier()
270 .verify_model_attestation(&session.model_id)
271 .await?;
272
273 info!(
274 model = %attestation.model_id,
275 tee_provider = attestation.tee_provider.as_deref().unwrap_or("unknown"),
276 tdx_verified = attestation.tdx.verified,
277 nvidia_verified = attestation.nvidia.verified.as_header_value(),
278 "model attestation verified"
279 );
280
281 let state_update = AttestedModelState {
282 model_public_key: attestation.model_public_key,
283 tee_provider: attestation.tee_provider,
284 tdx_debug: attestation.tdx.debug.or(attestation.debug),
285 nvidia_verified: attestation.nvidia.verified.as_header_value().to_owned(),
286 verified_at: attestation.verified_at,
287 };
288
289 Ok(state
290 .session_manager()
291 .set_attested_model_state(&session.session_key, state_update)?)
292}
293
294async fn openai_chat_buffered_response(
296 upstream: reqwest::Response,
297 codec: E2eeCodec,
298 proxy_instance_key: ProxyInstanceKey,
299 fallback_model: String,
300 metadata: ProxyMetadataHeaders,
301) -> Result<Response, ProxyError> {
302 let completion =
303 buffer_openai_chat_completion(upstream, codec, proxy_instance_key, fallback_model).await?;
304 let mut response = Json(completion).into_response();
305 metadata.apply(response.headers_mut());
306 Ok(response)
307}
308
309async fn openai_tool_emulated_chat_response(
311 state: &AppState,
312 request: &ChatCompletionRequest,
313 tool_context: &ToolEmulationContext,
314 codec: E2eeCodec,
315 proxy_instance_key: ProxyInstanceKey,
316 model_public_key: &str,
317 metadata: ProxyMetadataHeaders,
318) -> Result<Response, ProxyError> {
319 info!(
320 model = %request.model,
321 max_retries = tool_context.max_retries(),
322 "starting tool-emulated chat completion"
323 );
324 if request.stream {
325 let upstream = tool_emulated_upstream_stream(
326 state,
327 request,
328 tool_context,
329 &codec,
330 &proxy_instance_key,
331 model_public_key,
332 None,
333 )
334 .await?;
335
336 let include_usage_requested = request.stream_options.include_usage.unwrap_or(false);
337 let transformer = OpenAiToolEmulatedChatStreamTransformer::new(
338 tool_context,
339 codec,
340 proxy_instance_key,
341 request.model.clone(),
342 include_usage_requested,
343 )
344 .map_err(ProxyError::ChatStream)?;
345 return Ok(chat_sse_response(
346 upstream,
347 transformer,
348 request.model.clone(),
349 include_usage_requested,
350 &TOOL_EMULATED_CHAT_SSE_LOG,
351 metadata,
352 ));
353 }
354
355 let mut retries = 0;
356 let mut correction: Option<(String, String)> = None;
357
358 loop {
359 let upstream = tool_emulated_upstream_stream(
360 state,
361 request,
362 tool_context,
363 &codec,
364 &proxy_instance_key,
365 model_public_key,
366 correction.as_ref(),
367 )
368 .await?;
369
370 let completion = match tokio::time::timeout(
371 tool_context.marker_timeout(),
372 buffer_openai_chat_completion(
373 upstream,
374 codec.clone(),
375 proxy_instance_key.clone(),
376 request.model.clone(),
377 ),
378 )
379 .await
380 {
381 Ok(completion) => completion?,
382 Err(_) => {
383 let validation_error = format!(
384 "tool-emulated completion did not finish within {}",
385 humantime::format_duration(tool_context.config().tool_call_marker_timeout)
386 );
387 if retries >= tool_context.max_retries() {
388 return Err(ProxyError::ToolCallRetryExhausted {
389 max_retries: tool_context.max_retries(),
390 last_validation_error: validation_error,
391 });
392 }
393 warn!(
394 model = %request.model,
395 retry = retries + 1,
396 max_retries = tool_context.max_retries(),
397 "tool call marker timed out; retrying with correction"
398 );
399 retries += 1;
400 correction = Some((validation_error, String::new()));
401 continue;
402 }
403 };
404 let assistant_content = completion
405 .get("choices")
406 .and_then(Value::as_array)
407 .and_then(|choices| choices.first())
408 .and_then(|choice| choice.get("message"))
409 .and_then(|message| message.get("content"))
410 .and_then(Value::as_str)
411 .unwrap_or_default();
412
413 let mut metadata = metadata.clone();
414 if retries > 0 {
415 metadata.tool_retries = Some(retries);
416 }
417
418 match tool_context.classify_assistant_output(assistant_content) {
419 ToolOutputClassification::NormalText => {
420 info!(model = %request.model, retries, "tool emulation produced normal text");
421 let mut response = Json(completion).into_response();
422 metadata.apply(response.headers_mut());
423 return Ok(response);
424 }
425 ToolOutputClassification::ToolCalls(tool_calls) => {
426 info!(
427 model = %request.model,
428 tool_calls = tool_calls.len(),
429 retries,
430 "tool emulation produced tool calls"
431 );
432 let body = openai_tool_call_completion(completion, tool_calls);
433 let mut response = Json(body).into_response();
434 metadata.apply(response.headers_mut());
435 return Ok(response);
436 }
437 ToolOutputClassification::InvalidToolCall {
438 error,
439 invalid_output,
440 } => {
441 if retries >= tool_context.max_retries() {
442 warn!(
443 model = %request.model,
444 max_retries = tool_context.max_retries(),
445 validation_error = %error,
446 "tool call validation failed and retries were exhausted"
447 );
448 return Err(ProxyError::ToolCallRetryExhausted {
449 max_retries: tool_context.max_retries(),
450 last_validation_error: error.to_string(),
451 });
452 }
453 warn!(
454 model = %request.model,
455 retry = retries + 1,
456 max_retries = tool_context.max_retries(),
457 validation_error = %error,
458 "tool call validation failed; retrying with correction"
459 );
460 retries += 1;
461 correction = Some((error.to_string(), invalid_output));
462 }
463 }
464 }
465}
466
467async fn tool_emulated_upstream_stream(
470 state: &AppState,
471 request: &ChatCompletionRequest,
472 tool_context: &ToolEmulationContext,
473 codec: &E2eeCodec,
474 proxy_instance_key: &ProxyInstanceKey,
475 model_public_key: &str,
476 correction: Option<&(String, String)>,
477) -> Result<reqwest::Response, ProxyError> {
478 let messages = tool_emulated_messages(request, tool_context, correction);
479 let mut tool_request = request.clone();
480 tool_request.messages = messages;
481
482 let prepared = tool_request.to_venice_e2ee_request(codec, model_public_key)?;
483
484 Ok(state
485 .venice_client()
486 .create_chat_completion_stream(
487 &prepared.upstream,
488 proxy_instance_key.public_key_hex(),
489 model_public_key,
490 )
491 .await?)
492}
493
494fn tool_emulated_messages(
496 request: &ChatCompletionRequest,
497 tool_context: &ToolEmulationContext,
498 correction: Option<&(String, String)>,
499) -> Vec<NormalizedChatMessage> {
500 let mut messages = request.messages.clone();
501 let mut tool_system_content = tool_context.controller_message().content;
502
503 if let Some((validation_error, invalid_output)) = correction {
504 tool_system_content.push_str("\n\n");
505 tool_system_content.push_str(
506 &tool_context
507 .correction_message(validation_error, invalid_output)
508 .content,
509 );
510 }
511
512 append_to_system_message(&mut messages, tool_system_content);
513 messages
514}
515
516fn append_to_system_message(messages: &mut Vec<NormalizedChatMessage>, content: String) {
518 if let Some(system_message) = messages.iter_mut().find(|message| message.role == "system") {
519 system_message.content.push_str("\n\n");
520 system_message.content.push_str(&content);
521 } else {
522 messages.insert(0, NormalizedChatMessage::new("system", content));
523 }
524}
525
526fn openai_tool_call_completion(completion: Value, tool_calls: Vec<ValidatedToolCall>) -> Value {
528 let choice = completion
529 .get("choices")
530 .and_then(Value::as_array)
531 .and_then(|choices| choices.first())
532 .cloned()
533 .unwrap_or(Value::Null);
534 let index = choice.get("index").and_then(Value::as_u64).unwrap_or(0);
535 let tool_call_values: Vec<Value> = tool_calls
536 .iter()
537 .map(ValidatedToolCall::to_openai_value)
538 .collect();
539 let reasoning_content = choice
540 .get("message")
541 .and_then(|message| message.get("reasoning_content"))
542 .and_then(Value::as_str);
543 let mut message = serde_json::Map::new();
544 message.insert("role".to_owned(), Value::String("assistant".to_owned()));
545 message.insert("content".to_owned(), Value::Null);
546 if let Some(reasoning_content) = reasoning_content {
547 message.insert(
548 "reasoning_content".to_owned(),
549 Value::String(reasoning_content.to_owned()),
550 );
551 }
552 message.insert("tool_calls".to_owned(), Value::Array(tool_call_values));
553
554 json!({
555 "id": string_field(&completion, "id").unwrap_or("chatcmpl-local"),
556 "object": "chat.completion",
557 "created": integer_field(&completion, "created").unwrap_or_else(unix_timestamp_now),
558 "model": string_field(&completion, "model").unwrap_or("unknown"),
559 "choices": [{
560 "index": index,
561 "message": Value::Object(message),
562 "finish_reason": "tool_calls",
563 }],
564 "usage": completion.get("usage").cloned().unwrap_or(Value::Null),
565 })
566}
567
568async fn buffer_openai_chat_completion(
570 mut upstream: reqwest::Response,
571 codec: E2eeCodec,
572 proxy_instance_key: ProxyInstanceKey,
573 fallback_model: String,
574) -> Result<Value, ChatStreamError> {
575 info!(model = %fallback_model, "buffering upstream chat stream");
576 let mut parser = SseEventParser::default();
577 let mut transformer =
578 OpenAiChatCompletionBuffer::new(codec, proxy_instance_key, fallback_model.clone());
579 let mut upstream_done = false;
580 let mut chunk_count = 0_u64;
581 let mut event_count = 0_u64;
582
583 while let Some(chunk) = upstream
584 .chunk()
585 .await
586 .map_err(ChatStreamError::upstream_stream)?
587 {
588 chunk_count += 1;
589 let chunk = std::str::from_utf8(&chunk).map_err(ChatStreamError::invalid_utf8)?;
590 let events = parser.push(chunk)?;
591 event_count += events.len() as u64;
592 debug!(
593 model = %fallback_model,
594 chunk_count,
595 parsed_events = events.len(),
596 total_events = event_count,
597 "parsed buffered upstream SSE chunk"
598 );
599
600 for event in events {
601 if transformer.handle_event(event)? {
602 upstream_done = true;
603 break;
604 }
605 }
606
607 if upstream_done {
608 break;
609 }
610 }
611
612 if !upstream_done {
613 warn!(
614 model = %fallback_model,
615 chunk_count,
616 event_count,
617 "buffered upstream stream ended before DONE"
618 );
619 parser.finish()?;
620 return Err(ChatStreamError::malformed_event(
621 "upstream stream ended before data: [DONE]",
622 ));
623 }
624
625 let completion = transformer.into_response();
626 info!(
627 model = %fallback_model,
628 chunk_count,
629 event_count,
630 "buffered upstream chat stream transformed"
631 );
632 Ok(completion)
633}
634
635struct ChatSseLogMessages {
638 start: &'static str,
639 parsed_chunk: &'static str,
640 transformed_event: &'static str,
641 completed: &'static str,
642 ended_early: &'static str,
643}
644
645const CHAT_SSE_LOG: ChatSseLogMessages = ChatSseLogMessages {
646 start: "starting upstream chat SSE transformation",
647 parsed_chunk: "parsed streaming upstream SSE chunk",
648 transformed_event: "transformed streaming upstream SSE event",
649 completed: "completed upstream chat SSE transformation",
650 ended_early: "streaming upstream stream ended before DONE",
651};
652
653const TOOL_EMULATED_CHAT_SSE_LOG: ChatSseLogMessages = ChatSseLogMessages {
654 start: "starting tool-emulated upstream chat SSE transformation",
655 parsed_chunk: "parsed tool-emulated upstream SSE chunk",
656 transformed_event: "transformed tool-emulated upstream SSE event",
657 completed: "completed tool-emulated upstream chat SSE transformation",
658 ended_early: "tool-emulated upstream stream ended before DONE",
659};
660
661trait ChatSseTransformer {
663 fn handle_event(&mut self, event: RawSseEvent) -> Result<Vec<StreamOutput>, ChatStreamError>;
665}
666
667fn chat_sse_response<T>(
669 upstream: reqwest::Response,
670 transformer: T,
671 fallback_model: String,
672 include_usage_requested: bool,
673 log: &'static ChatSseLogMessages,
674 metadata: ProxyMetadataHeaders,
675) -> Response
676where
677 T: ChatSseTransformer + Send + 'static,
678{
679 let stream = chat_sse_event_stream(
680 upstream,
681 transformer,
682 fallback_model,
683 include_usage_requested,
684 log,
685 );
686 let mut response = Sse::new(stream).into_response();
687 metadata.apply(response.headers_mut());
688 response
689}
690
691fn chat_sse_event_stream<T>(
693 mut upstream: reqwest::Response,
694 mut transformer: T,
695 fallback_model: String,
696 include_usage_requested: bool,
697 log: &'static ChatSseLogMessages,
698) -> impl futures_core::Stream<Item = Result<Event, axum::BoxError>>
699where
700 T: ChatSseTransformer + Send + 'static,
701{
702 async_stream::try_stream! {
703 info!(
704 model = %fallback_model,
705 include_usage_requested,
706 "{}", log.start
707 );
708 let mut parser = SseEventParser::default();
709 let mut upstream_done = false;
710 let mut chunk_count = 0_u64;
711 let mut event_count = 0_u64;
712 let mut output_count = 0_u64;
713
714 while let Some(chunk) = upstream
715 .chunk()
716 .await
717 .map_err(ChatStreamError::upstream_stream)
718 .map_err(box_chat_stream_error)?
719 {
720 chunk_count += 1;
721 let chunk = std::str::from_utf8(&chunk)
722 .map_err(ChatStreamError::invalid_utf8)
723 .map_err(box_chat_stream_error)?;
724 let events = parser.push(chunk).map_err(box_chat_stream_error)?;
725 event_count += events.len() as u64;
726 debug!(
727 model = %fallback_model,
728 chunk_count,
729 parsed_events = events.len(),
730 total_events = event_count,
731 "{}", log.parsed_chunk
732 );
733
734 for event in events {
735 let outputs = transformer.handle_event(event).map_err(box_chat_stream_error)?;
736 output_count += outputs.len() as u64;
737 debug!(
738 model = %fallback_model,
739 emitted_outputs = outputs.len(),
740 total_outputs = output_count,
741 "{}", log.transformed_event
742 );
743
744 for output in outputs {
745 match output {
746 StreamOutput::Json(value) => yield Event::default().data(value.to_string()),
747 StreamOutput::Done => {
748 upstream_done = true;
749 info!(
750 model = %fallback_model,
751 chunk_count,
752 event_count,
753 output_count,
754 "{}", log.completed
755 );
756 yield Event::default().data("[DONE]");
757 break;
758 }
759 }
760 }
761
762 if upstream_done {
763 break;
764 }
765 }
766
767 if upstream_done {
768 break;
769 }
770 }
771
772 if !upstream_done {
773 warn!(
774 model = %fallback_model,
775 chunk_count,
776 event_count,
777 output_count,
778 "{}", log.ended_early
779 );
780 parser.finish().map_err(box_chat_stream_error)?;
781 Err::<(), axum::BoxError>(box_chat_stream_error(ChatStreamError::malformed_event(
782 "upstream stream ended before data: [DONE]",
783 )))?;
784 }
785 }
786}
787
788fn box_chat_stream_error(error: ChatStreamError) -> axum::BoxError {
790 error!(error = %error, "chat stream transformation failed");
791 Box::new(error)
792}
793
794#[derive(Debug, Default)]
796struct SseEventParser {
797 buffer: String,
798}
799
800impl SseEventParser {
801 fn push(&mut self, chunk: &str) -> Result<Vec<RawSseEvent>, ChatStreamError> {
803 self.buffer.push_str(chunk);
804 let mut events = Vec::new();
805
806 while let Some((boundary_start, boundary_len)) = sse_event_boundary(&self.buffer) {
807 let raw = self.buffer[..boundary_start].to_owned();
808 self.buffer.drain(..boundary_start + boundary_len);
809 if let Some(event) = parse_sse_event(&raw)? {
810 events.push(event);
811 }
812 }
813
814 debug!(
815 chunk_bytes = chunk.len(),
816 buffered_bytes = self.buffer.len(),
817 parsed_events = events.len(),
818 "SSE parser processed upstream chunk"
819 );
820 Ok(events)
821 }
822
823 fn finish(&self) -> Result<(), ChatStreamError> {
825 if self.buffer.trim().is_empty() {
826 Ok(())
827 } else {
828 warn!(
829 buffered_bytes = self.buffer.len(),
830 "upstream SSE stream ended with incomplete event"
831 );
832 Err(ChatStreamError::malformed_event(
833 "upstream stream ended with an incomplete SSE event",
834 ))
835 }
836 }
837}
838
839#[derive(Debug, Clone, PartialEq, Eq)]
841struct RawSseEvent {
842 event: Option<String>,
843 data: String,
844}
845
846struct UpstreamEventLogMessages {
850 event: &'static str,
851 sse_error: &'static str,
852 done: &'static str,
853 parsing: Option<&'static str>,
854 json_error: &'static str,
855 missing_choices: &'static str,
856 parsed: Option<&'static str>,
857 unexpected_choice_count: &'static str,
858}
859
860const BUFFERED_UPSTREAM_EVENT_LOG: UpstreamEventLogMessages = UpstreamEventLogMessages {
861 event: "buffering upstream SSE event",
862 sse_error: "upstream SSE error event while buffering response",
863 done: "received upstream DONE while buffering response",
864 parsing: Some("parsing buffered upstream chat JSON chunk"),
865 json_error: "upstream JSON error chunk while buffering response",
866 missing_choices: "buffered upstream chat chunk is missing choices array",
867 parsed: Some("parsed buffered upstream chat chunk"),
868 unexpected_choice_count: "unexpected buffered upstream choice count",
869};
870
871const STREAMING_UPSTREAM_EVENT_LOG: UpstreamEventLogMessages = UpstreamEventLogMessages {
872 event: "transforming streaming upstream SSE event",
873 sse_error: "upstream SSE error event while streaming response",
874 done: "received upstream DONE while streaming response",
875 parsing: Some("parsing streaming upstream chat JSON chunk"),
876 json_error: "upstream JSON error chunk while streaming response",
877 missing_choices: "streaming upstream chat chunk is missing choices array",
878 parsed: Some("parsed streaming upstream chat chunk"),
879 unexpected_choice_count: "unexpected streaming upstream choice count",
880};
881
882const TOOL_EMULATED_UPSTREAM_EVENT_LOG: UpstreamEventLogMessages = UpstreamEventLogMessages {
883 event: "transforming tool-emulated streaming upstream SSE event",
884 sse_error: "upstream SSE error event while streaming tool-emulated response",
885 done: "received upstream DONE while streaming tool-emulated response",
886 parsing: None,
887 json_error: "upstream JSON error chunk while streaming tool-emulated response",
888 missing_choices: "tool-emulated upstream chat chunk is missing choices array",
889 parsed: None,
890 unexpected_choice_count: "unexpected tool-emulated upstream choice count",
891};
892
893enum UpstreamEventKind {
895 Done,
897 Usage(Value),
899 Choice { value: Value, choice: Value },
901}
902
903fn classify_upstream_event(
907 event: RawSseEvent,
908 log: &UpstreamEventLogMessages,
909) -> Result<UpstreamEventKind, ChatStreamError> {
910 let event_type = event.event.as_deref().unwrap_or("message");
911 let is_done = event.data.trim() == "[DONE]";
912 debug!(event_type, is_done, "{}", log.event);
913
914 if event.event.as_deref() == Some("error") {
915 warn!("{}", log.sse_error);
916 return Err(ChatStreamError::upstream_event(event.data));
917 }
918
919 if is_done {
920 info!("{}", log.done);
921 return Ok(UpstreamEventKind::Done);
922 }
923
924 if let Some(parsing) = log.parsing {
925 debug!("{}", parsing);
926 }
927 let value: Value = serde_json::from_str(&event.data).map_err(ChatStreamError::json_event)?;
928 if let Some(error) = value.get("error") {
929 warn!("{}", log.json_error);
930 return Err(ChatStreamError::upstream_event(error.to_string()));
931 }
932
933 let Some(choices) = value.get("choices").and_then(Value::as_array) else {
934 warn!("{}", log.missing_choices);
935 return Err(ChatStreamError::malformed_event(
936 "upstream chat chunk is missing choices array",
937 ));
938 };
939 if let Some(parsed) = log.parsed {
940 debug!(choice_count = choices.len(), "{}", parsed);
941 }
942
943 if choices.is_empty() {
944 return Ok(UpstreamEventKind::Usage(value));
945 }
946 if choices.len() != 1 {
947 warn!(
948 choice_count = choices.len(),
949 "{}", log.unexpected_choice_count
950 );
951 return Err(ChatStreamError::malformed_event(format!(
952 "expected exactly one upstream choice, got {}",
953 choices.len(),
954 )));
955 }
956
957 let choice = choices[0].clone();
958 Ok(UpstreamEventKind::Choice { value, choice })
959}
960
961struct ChunkContext {
964 codec: E2eeCodec,
965 proxy_instance_key: ProxyInstanceKey,
966 fallback_id: String,
967 fallback_created: i64,
968 fallback_model: String,
969}
970
971impl ChunkContext {
972 fn new(codec: E2eeCodec, proxy_instance_key: ProxyInstanceKey, fallback_model: String) -> Self {
974 Self {
975 codec,
976 proxy_instance_key,
977 fallback_id: format!("chatcmpl-local-{}", uuid::Uuid::new_v4()),
978 fallback_created: unix_timestamp_now(),
979 fallback_model,
980 }
981 }
982
983 fn decrypt(&self, content: Option<&str>) -> Result<Option<String>, ChatStreamError> {
985 self.codec
986 .decrypt_response_content(content, self.proxy_instance_key.private_key())
987 .map_err(ChatStreamError::decryption)
988 }
989
990 fn chunk_with_choice(
992 &self,
993 upstream: &Value,
994 index: u64,
995 delta: Value,
996 finish_reason: Value,
997 ) -> Value {
998 json!({
999 "id": string_field(upstream, "id").unwrap_or(&self.fallback_id),
1000 "object": string_field(upstream, "object").unwrap_or("chat.completion.chunk"),
1001 "created": integer_field(upstream, "created").unwrap_or(self.fallback_created),
1002 "model": string_field(upstream, "model").unwrap_or(&self.fallback_model),
1003 "choices": [{
1004 "index": index,
1005 "delta": delta,
1006 "finish_reason": finish_reason,
1007 }],
1008 })
1009 }
1010
1011 fn usage_chunk(&self, upstream: &Value, usage: &Value) -> Value {
1013 json!({
1014 "id": string_field(upstream, "id").unwrap_or(&self.fallback_id),
1015 "object": string_field(upstream, "object").unwrap_or("chat.completion.chunk"),
1016 "created": integer_field(upstream, "created").unwrap_or(self.fallback_created),
1017 "model": string_field(upstream, "model").unwrap_or(&self.fallback_model),
1018 "choices": [],
1019 "usage": usage,
1020 })
1021 }
1022}
1023
1024struct OpenAiChatCompletionBuffer {
1026 ctx: ChunkContext,
1027 id: Option<String>,
1028 created: Option<i64>,
1029 model: Option<String>,
1030 choice_index: Option<u64>,
1031 saw_encrypted_response_field: bool,
1032 content: String,
1033 reasoning_content: String,
1034 finish_reason: Option<Value>,
1035 usage: Option<Value>,
1036}
1037
1038impl OpenAiChatCompletionBuffer {
1039 fn new(codec: E2eeCodec, proxy_instance_key: ProxyInstanceKey, fallback_model: String) -> Self {
1041 Self {
1042 ctx: ChunkContext::new(codec, proxy_instance_key, fallback_model),
1043 id: None,
1044 created: None,
1045 model: None,
1046 choice_index: None,
1047 saw_encrypted_response_field: false,
1048 content: String::new(),
1049 reasoning_content: String::new(),
1050 finish_reason: None,
1051 usage: None,
1052 }
1053 }
1054
1055 fn handle_event(&mut self, event: RawSseEvent) -> Result<bool, ChatStreamError> {
1057 match classify_upstream_event(event, &BUFFERED_UPSTREAM_EVENT_LOG)? {
1058 UpstreamEventKind::Done => {
1059 if !self.saw_encrypted_response_field {
1060 self.ctx.decrypt(None)?;
1061 }
1062 if self.finish_reason.is_none() {
1063 self.finish_reason = Some(Value::String("stop".to_owned()));
1064 }
1065 Ok(true)
1066 }
1067 UpstreamEventKind::Usage(value) => {
1068 self.record_metadata(&value);
1069 self.handle_usage_chunk(&value).map(|()| false)
1070 }
1071 UpstreamEventKind::Choice { value, choice } => {
1072 self.record_metadata(&value);
1073 self.handle_choice_chunk(&choice)?;
1074 Ok(false)
1075 }
1076 }
1077 }
1078
1079 fn handle_usage_chunk(&mut self, value: &Value) -> Result<(), ChatStreamError> {
1081 let Some(usage) = value.get("usage") else {
1082 warn!("buffered upstream chunk has no choices and no usage");
1083 return Err(ChatStreamError::malformed_event(
1084 "upstream chunk has no choices and no usage",
1085 ));
1086 };
1087
1088 info!("buffered upstream usage chunk");
1089 self.usage = Some(usage.clone());
1090 Ok(())
1091 }
1092
1093 fn handle_choice_chunk(&mut self, choice: &Value) -> Result<(), ChatStreamError> {
1095 let choice = choice.as_object().ok_or_else(|| {
1096 ChatStreamError::malformed_event("upstream choice must be a JSON object")
1097 })?;
1098 let index = normalized_choice_index(choice.get("index"))?;
1099 match self.choice_index {
1100 Some(existing) if existing != index => {
1101 return Err(ChatStreamError::malformed_event(
1102 "upstream choice index changed while buffering a completion",
1103 ));
1104 }
1105 None => self.choice_index = Some(index),
1106 Some(_) => {}
1107 }
1108
1109 let finish_reason = normalized_finish_reason(choice.get("finish_reason"))?;
1110 let delta = choice.get("delta").unwrap_or(&Value::Null);
1111 let content = encrypted_delta_content(delta)?;
1112 let reasoning_content = encrypted_delta_reasoning_content(delta)?;
1113 debug!(
1114 choice_index = index,
1115 has_encrypted_content = content.is_some(),
1116 has_encrypted_reasoning_content = reasoning_content.is_some(),
1117 has_finish_reason = !finish_reason.is_null(),
1118 "transforming buffered upstream choice chunk"
1119 );
1120
1121 if let Some(content) = content {
1122 let decrypted = self.ctx.decrypt(Some(content))?;
1123 self.saw_encrypted_response_field = true;
1124 debug!(
1125 choice_index = index,
1126 has_decrypted_content = decrypted.is_some(),
1127 "decrypted buffered upstream content chunk"
1128 );
1129 if let Some(content) = decrypted {
1130 self.content.push_str(&content);
1131 }
1132 }
1133
1134 if let Some(reasoning_content) = reasoning_content {
1135 let decrypted = self.ctx.decrypt(Some(reasoning_content))?;
1136 self.saw_encrypted_response_field = true;
1137 debug!(
1138 choice_index = index,
1139 has_decrypted_reasoning_content = decrypted.is_some(),
1140 "decrypted buffered upstream reasoning content chunk"
1141 );
1142 if let Some(reasoning_content) = decrypted {
1143 self.reasoning_content.push_str(&reasoning_content);
1144 }
1145 }
1146
1147 if !finish_reason.is_null() {
1148 self.finish_reason = Some(finish_reason);
1149 }
1150
1151 Ok(())
1152 }
1153
1154 fn record_metadata(&mut self, value: &Value) {
1156 if self.id.is_none()
1157 && let Some(id) = string_field(value, "id")
1158 {
1159 self.id = Some(id.to_owned());
1160 }
1161 if self.created.is_none()
1162 && let Some(created) = integer_field(value, "created")
1163 {
1164 self.created = Some(created);
1165 }
1166 if self.model.is_none()
1167 && let Some(model) = string_field(value, "model")
1168 {
1169 self.model = Some(model.to_owned());
1170 }
1171 }
1172
1173 fn into_response(self) -> Value {
1175 let mut message = serde_json::Map::new();
1176 message.insert("role".to_owned(), Value::String("assistant".to_owned()));
1177 if !self.reasoning_content.is_empty() {
1178 message.insert(
1179 "reasoning_content".to_owned(),
1180 Value::String(self.reasoning_content),
1181 );
1182 }
1183 message.insert("content".to_owned(), Value::String(self.content));
1184
1185 json!({
1186 "id": self.id.unwrap_or(self.ctx.fallback_id),
1187 "object": "chat.completion",
1188 "created": self.created.unwrap_or(self.ctx.fallback_created),
1189 "model": self.model.unwrap_or(self.ctx.fallback_model),
1190 "choices": [{
1191 "index": self.choice_index.unwrap_or(0),
1192 "message": Value::Object(message),
1193 "finish_reason": self.finish_reason.unwrap_or_else(|| Value::String("stop".to_owned())),
1194 }],
1195 "usage": self.usage.unwrap_or(Value::Null),
1196 })
1197 }
1198}
1199
1200fn sse_event_boundary(buffer: &str) -> Option<(usize, usize)> {
1202 ["\r\n\r\n", "\n\n", "\r\r"]
1203 .into_iter()
1204 .filter_map(|delimiter| buffer.find(delimiter).map(|index| (index, delimiter.len())))
1205 .min_by_key(|(index, _)| *index)
1206}
1207
1208fn parse_sse_event(raw: &str) -> Result<Option<RawSseEvent>, ChatStreamError> {
1210 let mut event = None;
1211 let mut data_lines = Vec::new();
1212 let mut saw_non_comment_field = false;
1213
1214 for line in raw.lines() {
1215 let line = line.strip_suffix('\r').unwrap_or(line);
1216 if line.is_empty() || line.starts_with(':') {
1217 continue;
1218 }
1219
1220 saw_non_comment_field = true;
1221 let (field, value) = line.split_once(':').unwrap_or((line, ""));
1222 let value = value.strip_prefix(' ').unwrap_or(value);
1223 match field {
1224 "event" => event = Some(value.to_owned()),
1225 "data" => data_lines.push(value.to_owned()),
1226 "id" | "retry" => {}
1227 other => {
1228 warn!(field = other, "unsupported upstream SSE field");
1229 return Err(ChatStreamError::malformed_event(format!(
1230 "unsupported upstream SSE field {other:?}",
1231 )));
1232 }
1233 }
1234 }
1235
1236 if data_lines.is_empty() {
1237 return if saw_non_comment_field {
1238 warn!("upstream SSE event did not contain a data field");
1239 Err(ChatStreamError::malformed_event(
1240 "upstream SSE event did not contain a data field",
1241 ))
1242 } else {
1243 debug!("ignored upstream SSE comment or heartbeat event");
1244 Ok(None)
1245 };
1246 }
1247
1248 debug!(
1249 event_type = event.as_deref().unwrap_or("message"),
1250 data_line_count = data_lines.len(),
1251 "parsed upstream SSE event"
1252 );
1253
1254 Ok(Some(RawSseEvent {
1255 event,
1256 data: data_lines.join("\n"),
1257 }))
1258}
1259
1260struct OpenAiChatStreamTransformer {
1262 ctx: ChunkContext,
1263 include_usage_requested: bool,
1264 sent_role: bool,
1265 sent_final_finish: bool,
1266}
1267
1268impl OpenAiChatStreamTransformer {
1269 fn new(
1271 codec: E2eeCodec,
1272 proxy_instance_key: ProxyInstanceKey,
1273 fallback_model: String,
1274 include_usage_requested: bool,
1275 ) -> Self {
1276 Self {
1277 ctx: ChunkContext::new(codec, proxy_instance_key, fallback_model),
1278 include_usage_requested,
1279 sent_role: false,
1280 sent_final_finish: false,
1281 }
1282 }
1283
1284 fn handle_choice_chunk(
1286 &mut self,
1287 value: &Value,
1288 choice: &Value,
1289 ) -> Result<Vec<StreamOutput>, ChatStreamError> {
1290 let choice = choice.as_object().ok_or_else(|| {
1291 ChatStreamError::malformed_event("upstream choice must be a JSON object")
1292 })?;
1293 let finish_reason = normalized_finish_reason(choice.get("finish_reason"))?;
1294 let delta = choice.get("delta").unwrap_or(&Value::Null);
1295 let content = encrypted_delta_content(delta)?;
1296 let reasoning_content = encrypted_delta_reasoning_content(delta)?;
1297 debug!(
1298 has_encrypted_content = content.is_some(),
1299 has_encrypted_reasoning_content = reasoning_content.is_some(),
1300 has_finish_reason = !finish_reason.is_null(),
1301 "transforming streaming upstream choice chunk"
1302 );
1303
1304 let mut output = Vec::new();
1305
1306 if content.is_none() && reasoning_content.is_none() {
1307 if !finish_reason.is_null() {
1308 output.push(StreamOutput::Json(self.chunk_with_choice(
1309 value,
1310 choice.get("index"),
1311 json!({}),
1312 finish_reason,
1313 )?));
1314 self.sent_final_finish = true;
1315 }
1316 return Ok(output);
1317 }
1318
1319 let decrypted_content = match content {
1320 Some(content) => self.ctx.decrypt(Some(content))?,
1321 None => None,
1322 };
1323 let decrypted_reasoning_content = match reasoning_content {
1324 Some(reasoning_content) => self.ctx.decrypt(Some(reasoning_content))?,
1325 None => None,
1326 };
1327 debug!(
1328 has_decrypted_content = decrypted_content.is_some(),
1329 has_decrypted_reasoning_content = decrypted_reasoning_content.is_some(),
1330 "decrypted streaming upstream content chunk"
1331 );
1332
1333 if decrypted_content.is_some() || decrypted_reasoning_content.is_some() {
1334 let mut delta = serde_json::Map::new();
1335
1336 if !self.sent_role {
1337 delta.insert("role".to_owned(), Value::String("assistant".to_owned()));
1338 self.sent_role = true;
1339 }
1340
1341 if let Some(reasoning_content) = decrypted_reasoning_content {
1342 delta.insert(
1343 "reasoning_content".to_owned(),
1344 Value::String(reasoning_content),
1345 );
1346 }
1347
1348 if let Some(content) = decrypted_content {
1349 delta.insert("content".to_owned(), Value::String(content));
1350 }
1351
1352 let final_finish = !finish_reason.is_null();
1353 let content_finish_reason = if final_finish {
1354 Value::Null
1355 } else {
1356 finish_reason.clone()
1357 };
1358 output.push(StreamOutput::Json(self.chunk_with_choice(
1359 value,
1360 choice.get("index"),
1361 Value::Object(delta),
1362 content_finish_reason,
1363 )?));
1364 if final_finish {
1365 output.push(StreamOutput::Json(self.chunk_with_choice(
1366 value,
1367 choice.get("index"),
1368 json!({}),
1369 finish_reason,
1370 )?));
1371 self.sent_final_finish = true;
1372 }
1373 return Ok(output);
1374 }
1375
1376 Ok(output)
1377 }
1378
1379 fn handle_usage_chunk(&self, value: &Value) -> Result<Vec<StreamOutput>, ChatStreamError> {
1381 let Some(usage) = value.get("usage") else {
1382 warn!("streaming upstream chunk has no choices and no usage");
1383 return Err(ChatStreamError::malformed_event(
1384 "upstream chunk has no choices and no usage",
1385 ));
1386 };
1387
1388 if !self.include_usage_requested {
1392 debug!("streaming upstream usage chunk ignored because client did not request usage");
1393 return Ok(Vec::new());
1394 }
1395
1396 info!("streaming upstream usage chunk forwarded");
1397 Ok(vec![StreamOutput::Json(self.ctx.usage_chunk(value, usage))])
1398 }
1399
1400 fn finish_chunk(&self) -> Value {
1402 self.ctx
1403 .chunk_with_choice(&Value::Null, 0, json!({}), Value::String("stop".to_owned()))
1404 }
1405
1406 fn chunk_with_choice(
1408 &self,
1409 upstream: &Value,
1410 index: Option<&Value>,
1411 delta: Value,
1412 finish_reason: Value,
1413 ) -> Result<Value, ChatStreamError> {
1414 let index = normalized_choice_index(index)?;
1415 Ok(self
1416 .ctx
1417 .chunk_with_choice(upstream, index, delta, finish_reason))
1418 }
1419}
1420
1421impl ChatSseTransformer for OpenAiChatStreamTransformer {
1422 fn handle_event(&mut self, event: RawSseEvent) -> Result<Vec<StreamOutput>, ChatStreamError> {
1424 match classify_upstream_event(event, &STREAMING_UPSTREAM_EVENT_LOG)? {
1425 UpstreamEventKind::Done => {
1426 let mut output = Vec::new();
1427 if !self.sent_final_finish {
1428 debug!("synthesizing final streaming finish chunk before DONE");
1429 output.push(StreamOutput::Json(self.finish_chunk()));
1430 self.sent_final_finish = true;
1431 }
1432 output.push(StreamOutput::Done);
1433 Ok(output)
1434 }
1435 UpstreamEventKind::Usage(value) => self.handle_usage_chunk(&value),
1436 UpstreamEventKind::Choice { value, choice } => {
1437 self.handle_choice_chunk(&value, &choice)
1438 }
1439 }
1440 }
1441}
1442
1443const TOOL_CALL_START_MARKER: &str = "<tool_call>";
1444
1445struct OpenAiToolEmulatedChatStreamTransformer {
1451 ctx: ChunkContext,
1452 tool_context: ToolEmulationContext,
1453 include_usage_requested: bool,
1454 sent_role: bool,
1455 sent_final_finish: bool,
1456 pending_text: String,
1457 tool_buffer: String,
1458 buffering_tool_call: bool,
1459 emitted_tool_calls: bool,
1460}
1461
1462impl OpenAiToolEmulatedChatStreamTransformer {
1463 fn new(
1465 tool_context: &ToolEmulationContext,
1466 codec: E2eeCodec,
1467 proxy_instance_key: ProxyInstanceKey,
1468 fallback_model: String,
1469 include_usage_requested: bool,
1470 ) -> Result<Self, ChatStreamError> {
1471 Ok(Self {
1472 ctx: ChunkContext::new(codec, proxy_instance_key, fallback_model),
1473 tool_context: tool_context.clone(),
1474 include_usage_requested,
1475 sent_role: false,
1476 sent_final_finish: false,
1477 pending_text: String::new(),
1478 tool_buffer: String::new(),
1479 buffering_tool_call: false,
1480 emitted_tool_calls: false,
1481 })
1482 }
1483
1484 fn handle_choice_chunk(
1486 &mut self,
1487 value: &Value,
1488 choice: &Value,
1489 ) -> Result<Vec<StreamOutput>, ChatStreamError> {
1490 let choice = choice.as_object().ok_or_else(|| {
1491 ChatStreamError::malformed_event("upstream choice must be a JSON object")
1492 })?;
1493 let index = normalized_choice_index(choice.get("index"))?;
1494 let finish_reason = normalized_finish_reason(choice.get("finish_reason"))?;
1495 let delta = choice.get("delta").unwrap_or(&Value::Null);
1496 let content = encrypted_delta_content(delta)?;
1497 let reasoning_content = encrypted_delta_reasoning_content(delta)?;
1498
1499 let mut output = Vec::new();
1500
1501 if let Some(reasoning_content) = reasoning_content
1502 && let Some(reasoning_content) = self.ctx.decrypt(Some(reasoning_content))?
1503 && !self.sent_final_finish
1504 {
1505 output.push(self.reasoning_chunk(value, index, reasoning_content));
1506 }
1507
1508 if let Some(content) = content
1509 && let Some(content) = self.ctx.decrypt(Some(content))?
1510 && !self.sent_final_finish
1511 {
1512 output.extend(self.push_decrypted_content(value, index, &content)?);
1513 }
1514
1515 if !finish_reason.is_null() && !self.sent_final_finish {
1516 output.extend(self.finish_buffered_content(value, index, finish_reason)?);
1517 }
1518
1519 Ok(output)
1520 }
1521
1522 fn push_decrypted_content(
1524 &mut self,
1525 upstream: &Value,
1526 index: u64,
1527 content: &str,
1528 ) -> Result<Vec<StreamOutput>, ChatStreamError> {
1529 if self.buffering_tool_call {
1530 self.tool_buffer.push_str(content);
1531 self.ensure_tool_buffer_within_limit()?;
1532 return Ok(Vec::new());
1533 }
1534
1535 self.pending_text.push_str(content);
1536 if let Some(marker_index) = self.pending_text.find(TOOL_CALL_START_MARKER) {
1537 let text = self.pending_text[..marker_index].to_owned();
1538 self.tool_buffer = self.pending_text[marker_index..].to_owned();
1539 self.pending_text.clear();
1540 self.buffering_tool_call = true;
1541 self.ensure_tool_buffer_within_limit()?;
1542 return Ok(self.text_chunk_if_not_empty(upstream, index, text));
1543 }
1544
1545 let streamable_len = streamable_pending_text_len(&self.pending_text);
1546 if streamable_len == 0 {
1547 return Ok(Vec::new());
1548 }
1549
1550 let text = self.pending_text[..streamable_len].to_owned();
1551 self.pending_text.drain(..streamable_len);
1552 Ok(vec![
1553 self.text_field_chunk(upstream, index, "content", text),
1554 ])
1555 }
1556
1557 fn finish_buffered_content(
1559 &mut self,
1560 upstream: &Value,
1561 index: u64,
1562 finish_reason: Value,
1563 ) -> Result<Vec<StreamOutput>, ChatStreamError> {
1564 let mut output = Vec::new();
1565
1566 if self.buffering_tool_call {
1567 output.extend(self.buffered_tool_call_chunks(upstream, index)?);
1568 } else if !self.pending_text.is_empty() {
1569 let text = std::mem::take(&mut self.pending_text);
1570 output.push(self.text_field_chunk(upstream, index, "content", text));
1571 }
1572
1573 let finish_reason = if self.emitted_tool_calls {
1574 Value::String("tool_calls".to_owned())
1575 } else {
1576 finish_reason
1577 };
1578 output.push(StreamOutput::Json(self.ctx.chunk_with_choice(
1579 upstream,
1580 index,
1581 json!({}),
1582 finish_reason,
1583 )));
1584 self.sent_final_finish = true;
1585 Ok(output)
1586 }
1587
1588 fn buffered_tool_call_chunks(
1590 &mut self,
1591 upstream: &Value,
1592 index: u64,
1593 ) -> Result<Vec<StreamOutput>, ChatStreamError> {
1594 self.ensure_tool_buffer_within_limit()?;
1595 match self
1596 .tool_context
1597 .classify_assistant_output(&self.tool_buffer)
1598 {
1599 ToolOutputClassification::ToolCalls(tool_calls) => {
1600 self.emitted_tool_calls = true;
1601 Ok(tool_calls
1602 .iter()
1603 .enumerate()
1604 .map(|(tool_index, tool_call)| {
1605 self.full_tool_call_chunk(upstream, index, tool_index, tool_call)
1606 })
1607 .collect())
1608 }
1609 ToolOutputClassification::NormalText => {
1610 let text = std::mem::take(&mut self.tool_buffer);
1611 self.buffering_tool_call = false;
1612 Ok(self.text_chunk_if_not_empty(upstream, index, text))
1613 }
1614 ToolOutputClassification::InvalidToolCall { error, .. } => {
1615 error!(
1616 validation_error = %error,
1617 payload_bytes = self.tool_buffer.len(),
1618 payload = %self.tool_buffer,
1619 "buffered streamed tool-call payload failed validation"
1620 );
1621 Err(ChatStreamError::malformed_event(format!(
1622 "tool call parsing failed: {error}"
1623 )))
1624 }
1625 }
1626 }
1627
1628 fn ensure_tool_buffer_within_limit(&self) -> Result<(), ChatStreamError> {
1630 if self.tool_buffer.len() > self.tool_context.config().tool_call_max_bytes {
1631 return Err(ChatStreamError::malformed_event(format!(
1632 "tool call output exceeded max size of {} bytes",
1633 self.tool_context.config().tool_call_max_bytes
1634 )));
1635 }
1636 Ok(())
1637 }
1638
1639 fn text_chunk_if_not_empty(
1641 &mut self,
1642 upstream: &Value,
1643 index: u64,
1644 text: String,
1645 ) -> Vec<StreamOutput> {
1646 if text.is_empty() {
1647 Vec::new()
1648 } else {
1649 vec![self.text_field_chunk(upstream, index, "content", text)]
1650 }
1651 }
1652
1653 fn reasoning_chunk(
1655 &mut self,
1656 upstream: &Value,
1657 index: u64,
1658 reasoning_content: String,
1659 ) -> StreamOutput {
1660 self.text_field_chunk(upstream, index, "reasoning_content", reasoning_content)
1661 }
1662
1663 fn text_field_chunk(
1665 &mut self,
1666 upstream: &Value,
1667 index: u64,
1668 field: &'static str,
1669 text: String,
1670 ) -> StreamOutput {
1671 let mut delta = serde_json::Map::new();
1672 self.insert_role_if_needed(&mut delta);
1673 delta.insert(field.to_owned(), Value::String(text));
1674
1675 StreamOutput::Json(self.ctx.chunk_with_choice(
1676 upstream,
1677 index,
1678 Value::Object(delta),
1679 Value::Null,
1680 ))
1681 }
1682
1683 fn insert_role_if_needed(&mut self, delta: &mut serde_json::Map<String, Value>) {
1685 if !self.sent_role {
1686 delta.insert("role".to_owned(), Value::String("assistant".to_owned()));
1687 self.sent_role = true;
1688 }
1689 }
1690
1691 fn full_tool_call_chunk(
1693 &mut self,
1694 upstream: &Value,
1695 index: u64,
1696 tool_index: usize,
1697 tool_call: &ValidatedToolCall,
1698 ) -> StreamOutput {
1699 let mut delta = serde_json::Map::new();
1700 self.insert_role_if_needed(&mut delta);
1701
1702 let mut tool_call_value = tool_call.to_openai_value();
1703
1704 if let Some(tool_call_object) = tool_call_value.as_object_mut() {
1705 tool_call_object.insert("index".to_owned(), json!(tool_index));
1706 }
1707 delta.insert("tool_calls".to_owned(), Value::Array(vec![tool_call_value]));
1708
1709 StreamOutput::Json(self.ctx.chunk_with_choice(
1710 upstream,
1711 index,
1712 Value::Object(delta),
1713 Value::Null,
1714 ))
1715 }
1716
1717 fn handle_usage_chunk(&self, value: &Value) -> Result<Vec<StreamOutput>, ChatStreamError> {
1719 let Some(usage) = value.get("usage") else {
1720 warn!("tool-emulated upstream chunk has no choices and no usage");
1721 return Err(ChatStreamError::malformed_event(
1722 "upstream chunk has no choices and no usage",
1723 ));
1724 };
1725
1726 if !self.include_usage_requested {
1728 return Ok(Vec::new());
1729 }
1730
1731 Ok(vec![StreamOutput::Json(self.ctx.usage_chunk(value, usage))])
1732 }
1733
1734 fn finish_stream(&mut self) -> Result<Vec<StreamOutput>, ChatStreamError> {
1736 let upstream = &Value::Null;
1737 let mut output = Vec::new();
1738
1739 if !self.sent_final_finish {
1740 output.extend(self.finish_buffered_content(
1741 upstream,
1742 0,
1743 Value::String("stop".to_owned()),
1744 )?);
1745 }
1746
1747 output.push(StreamOutput::Done);
1748 Ok(output)
1749 }
1750}
1751
1752fn streamable_pending_text_len(pending_text: &str) -> usize {
1754 let protected_suffix_len = TOOL_CALL_START_MARKER.len().saturating_sub(1);
1755 if pending_text.len() <= protected_suffix_len {
1756 return 0;
1757 }
1758
1759 let mut split_at = pending_text.len() - protected_suffix_len;
1760 while !pending_text.is_char_boundary(split_at) {
1761 split_at -= 1;
1762 }
1763 split_at
1764}
1765
1766impl ChatSseTransformer for OpenAiToolEmulatedChatStreamTransformer {
1767 fn handle_event(&mut self, event: RawSseEvent) -> Result<Vec<StreamOutput>, ChatStreamError> {
1769 match classify_upstream_event(event, &TOOL_EMULATED_UPSTREAM_EVENT_LOG)? {
1770 UpstreamEventKind::Done => self.finish_stream(),
1771 UpstreamEventKind::Usage(value) => self.handle_usage_chunk(&value),
1772 UpstreamEventKind::Choice { value, choice } => {
1773 self.handle_choice_chunk(&value, &choice)
1774 }
1775 }
1776 }
1777}
1778
1779#[derive(Debug, Clone, PartialEq, Eq)]
1781enum StreamOutput {
1782 Json(Value),
1783 Done,
1784}
1785
1786fn normalized_choice_index(index: Option<&Value>) -> Result<u64, ChatStreamError> {
1788 match index {
1789 Some(Value::Number(number)) => number.as_u64().ok_or_else(|| {
1790 ChatStreamError::malformed_event("upstream choice index must be a non-negative integer")
1791 }),
1792 Some(_) => Err(ChatStreamError::malformed_event(
1793 "upstream choice index must be a non-negative integer",
1794 )),
1795 None => Ok(0),
1796 }
1797}
1798
1799fn normalized_finish_reason(value: Option<&Value>) -> Result<Value, ChatStreamError> {
1801 match value {
1802 Some(Value::Null) | None => Ok(Value::Null),
1803 Some(Value::String(reason)) => Ok(Value::String(reason.clone())),
1804 Some(_) => Err(ChatStreamError::malformed_event(
1805 "upstream finish_reason must be a string or null",
1806 )),
1807 }
1808}
1809
1810fn encrypted_delta_content(delta: &Value) -> Result<Option<&str>, ChatStreamError> {
1812 encrypted_delta_text_field(delta, "content")
1813}
1814
1815fn encrypted_delta_reasoning_content(delta: &Value) -> Result<Option<&str>, ChatStreamError> {
1817 encrypted_delta_text_field(delta, "reasoning_content")
1818}
1819
1820fn encrypted_delta_text_field<'a>(
1822 delta: &'a Value,
1823 field: &'static str,
1824) -> Result<Option<&'a str>, ChatStreamError> {
1825 match delta.get(field) {
1826 Some(Value::Null) => {
1827 debug!(field, "ignoring null upstream delta text field");
1828 Ok(None)
1829 }
1830 Some(Value::String(content)) if content.is_empty() => {
1831 debug!(field, "ignoring empty upstream delta text field");
1832 Ok(None)
1833 }
1834 Some(Value::String(content)) => Ok(Some(content.as_str())),
1835 Some(_) => Err(ChatStreamError::malformed_event(format!(
1836 "upstream delta.{field} must be a string or null"
1837 ))),
1838 None => Ok(None),
1839 }
1840}
1841
1842fn string_field<'a>(value: &'a Value, field: &str) -> Option<&'a str> {
1844 value.get(field).and_then(Value::as_str)
1845}
1846
1847fn integer_field(value: &Value, field: &str) -> Option<i64> {
1849 value.get(field).and_then(Value::as_i64)
1850}
1851
1852fn unix_timestamp_now() -> i64 {
1854 SystemTime::now()
1855 .duration_since(UNIX_EPOCH)
1856 .map(|duration| duration.as_secs() as i64)
1857 .unwrap_or(0)
1858}
1859
1860async fn method_not_allowed(method: Method, uri: Uri) -> ProxyError {
1862 ProxyError::MethodNotAllowed { method, uri }
1863}
1864
1865async fn not_found(uri: Uri) -> ProxyError {
1867 ProxyError::NotFound { uri }
1868}
1869
1870#[derive(Debug, Error)]
1872pub enum ChatStreamError {
1873 #[error("Venice upstream stream failed: {message}")]
1874 UpstreamStream { message: String },
1875 #[error("Venice upstream stream emitted an error event: {message}")]
1876 UpstreamEvent { message: String },
1877 #[error("Venice upstream stream event is malformed: {message}")]
1878 MalformedEvent { message: String },
1879 #[error("failed to decrypt Venice E2EE response chunk: {source}")]
1880 Decryption { source: E2eeCodecError },
1881}
1882
1883impl ChatStreamError {
1884 fn upstream_stream(source: reqwest::Error) -> Self {
1886 Self::UpstreamStream {
1887 message: source.to_string(),
1888 }
1889 }
1890
1891 fn upstream_event(message: impl Into<String>) -> Self {
1893 Self::UpstreamEvent {
1894 message: message.into(),
1895 }
1896 }
1897
1898 fn malformed_event(message: impl Into<String>) -> Self {
1900 Self::MalformedEvent {
1901 message: message.into(),
1902 }
1903 }
1904
1905 fn invalid_utf8(source: std::str::Utf8Error) -> Self {
1907 Self::MalformedEvent {
1908 message: format!("upstream SSE bytes are not valid UTF-8: {source}"),
1909 }
1910 }
1911
1912 fn json_event(source: serde_json::Error) -> Self {
1914 Self::MalformedEvent {
1915 message: format!("upstream SSE data is not valid JSON: {source}"),
1916 }
1917 }
1918
1919 fn decryption(source: E2eeCodecError) -> Self {
1921 Self::Decryption { source }
1922 }
1923
1924 fn api_error_type(&self) -> &'static str {
1926 match self {
1927 Self::UpstreamStream { .. }
1928 | Self::UpstreamEvent { .. }
1929 | Self::MalformedEvent { .. } => "proxy_upstream_error",
1930 Self::Decryption { .. } => "proxy_e2ee_error",
1931 }
1932 }
1933
1934 fn api_error_code(&self) -> &'static str {
1936 match self {
1937 Self::UpstreamStream { .. } => "upstream_stream_error",
1938 Self::UpstreamEvent { .. } => "upstream_stream_error",
1939 Self::MalformedEvent { .. } => "upstream_malformed_response",
1940 Self::Decryption { .. } => "e2ee_response_decryption_failed",
1941 }
1942 }
1943}
1944
1945#[derive(Debug, Error)]
1947pub enum ProxyError {
1948 #[error(transparent)]
1949 Venice(#[from] VeniceClientError),
1950 #[error(transparent)]
1951 Attestation(#[from] AttestationError),
1952 #[error(transparent)]
1953 Session(#[from] SessionError),
1954 #[error(transparent)]
1955 ChatRequest(#[from] ChatRequestError),
1956 #[error(transparent)]
1957 ChatConstruction(#[from] ChatConstructionError),
1958 #[error(transparent)]
1959 ChatStream(#[from] ChatStreamError),
1960 #[error("The model failed to produce a valid tool call after correction attempts.")]
1961 ToolCallRetryExhausted {
1962 max_retries: u32,
1963 last_validation_error: String,
1964 },
1965 #[error(
1966 "proxy instance key is unavailable; keys.generate_proxy_instance_key_on_startup must be enabled for E2EE chat requests"
1967 )]
1968 ProxyInstanceKeyUnavailable,
1969 #[error("session does not contain an attested model public key after attestation verification")]
1970 MissingAttestedModelKey,
1971 #[error("method {method} is not supported for {uri}")]
1972 MethodNotAllowed { method: Method, uri: Uri },
1973 #[error("route {uri} was not found")]
1974 NotFound { uri: Uri },
1975}
1976
1977impl ProxyError {
1978 fn status(&self) -> StatusCode {
1980 match self {
1981 Self::Venice(_) => StatusCode::BAD_GATEWAY,
1982 Self::Attestation(error) if error.verifier_unavailable() => {
1983 StatusCode::SERVICE_UNAVAILABLE
1984 }
1985 Self::Attestation(_) => StatusCode::BAD_GATEWAY,
1986 Self::Session(
1987 SessionError::MissingSessionIdentifier | SessionError::InvalidHeaderValue { .. },
1988 ) => StatusCode::BAD_REQUEST,
1989 Self::Session(_) => StatusCode::INTERNAL_SERVER_ERROR,
1990 Self::ChatRequest(_) => StatusCode::BAD_REQUEST,
1991 Self::ChatConstruction(_)
1992 | Self::ChatStream(_)
1993 | Self::ToolCallRetryExhausted { .. } => StatusCode::BAD_GATEWAY,
1994 Self::ProxyInstanceKeyUnavailable | Self::MissingAttestedModelKey => {
1995 StatusCode::INTERNAL_SERVER_ERROR
1996 }
1997 Self::MethodNotAllowed { .. } => StatusCode::METHOD_NOT_ALLOWED,
1998 Self::NotFound { .. } => StatusCode::NOT_FOUND,
1999 }
2000 }
2001
2002 fn error_type(&self) -> &'static str {
2004 match self {
2005 Self::Venice(error) => error.api_error_type(),
2006 Self::Attestation(error) => error.api_error_type(),
2007 Self::Session(
2008 SessionError::MissingSessionIdentifier | SessionError::InvalidHeaderValue { .. },
2009 ) => "invalid_request_error",
2010 Self::Session(_) => "proxy_session_error",
2011 Self::ChatRequest(_) => "invalid_request_error",
2012 Self::ChatConstruction(_) => "proxy_e2ee_error",
2013 Self::ChatStream(error) => error.api_error_type(),
2014 Self::ToolCallRetryExhausted { .. } => "proxy_tool_call_error",
2015 Self::ProxyInstanceKeyUnavailable => "proxy_configuration_error",
2016 Self::MissingAttestedModelKey => "proxy_attestation_error",
2017 Self::MethodNotAllowed { .. } | Self::NotFound { .. } => "invalid_request_error",
2018 }
2019 }
2020
2021 fn code(&self) -> &'static str {
2023 match self {
2024 Self::Venice(error) => error.api_error_code(),
2025 Self::Attestation(error) => error.api_error_code(),
2026 Self::Session(SessionError::MissingSessionIdentifier) => "session_identifier_missing",
2027 Self::Session(SessionError::InvalidHeaderValue { .. }) => "invalid_session_header",
2028 Self::Session(_) => "session_error",
2029 Self::ChatRequest(error) => error.api_error_code(),
2030 Self::ChatConstruction(error) => error.api_error_code(),
2031 Self::ChatStream(error) => error.api_error_code(),
2032 Self::ToolCallRetryExhausted { .. } => "invalid_tool_call",
2033 Self::ProxyInstanceKeyUnavailable => "proxy_instance_key_unavailable",
2034 Self::MissingAttestedModelKey => "attestation_failed",
2035 Self::MethodNotAllowed { .. } => "method_not_allowed",
2036 Self::NotFound { .. } => "not_found",
2037 }
2038 }
2039}
2040
2041impl IntoResponse for ProxyError {
2042 fn into_response(self) -> Response {
2044 let status = self.status();
2045 let error_code = self.code();
2046 let error_type = self.error_type();
2047
2048 if status.is_server_error() {
2049 error!(
2050 status = status.as_u16(),
2051 error_code,
2052 error_type,
2053 error = %self,
2054 "proxy request failed"
2055 );
2056 } else {
2057 warn!(
2058 status = status.as_u16(),
2059 error_code,
2060 error_type,
2061 error = %self,
2062 "proxy request rejected"
2063 );
2064 }
2065
2066 let mut response = if let Self::ToolCallRetryExhausted {
2067 max_retries,
2068 last_validation_error,
2069 } = &self
2070 {
2071 let body = json!({
2072 "error": {
2073 "message": self.to_string(),
2074 "type": error_type,
2075 "code": error_code,
2076 "details": {
2077 "max_retries": max_retries,
2078 "last_validation_error": last_validation_error,
2079 },
2080 }
2081 });
2082 (status, Json(body)).into_response()
2083 } else {
2084 let body = ErrorResponse::new(self.to_string(), error_type, error_code);
2085 (status, Json(body)).into_response()
2086 };
2087
2088 apply_error_headers(response.headers_mut(), error_code);
2089 response
2090 }
2091}
2092
2093#[derive(Debug, Clone, Default, PartialEq, Eq)]
2098pub struct ProxyMetadataHeaders {
2099 pub e2ee: Option<String>,
2100 pub attestation_mode: Option<String>,
2101 pub attested_model: Option<String>,
2102 pub tee_provider: Option<String>,
2103 pub tdx_verified: Option<bool>,
2104 pub tdx_debug: Option<bool>,
2105 pub nvidia_verified: Option<String>,
2106 pub key_binding: Option<bool>,
2107 pub session_id: Option<String>,
2108 pub session_scope: Option<String>,
2109 pub tool_mode: Option<String>,
2110 pub tool_retries: Option<u32>,
2111}
2112
2113impl ProxyMetadataHeaders {
2114 pub fn from_config(config: &ProxyConfig) -> Self {
2117 Self {
2118 attestation_mode: Some(config.attestation.mode.as_str().to_owned()),
2119 tool_mode: Some(config.tools.mode.as_str().to_owned()),
2120 ..Self::default()
2121 }
2122 }
2123
2124 pub fn for_verified_chat(config: &ProxyConfig, session: &SessionContext) -> Self {
2126 let tee_provider = session
2127 .attestation_tee_provider
2128 .clone()
2129 .unwrap_or_else(|| "unknown".to_owned());
2130 let tdx_debug = session.attestation_tdx_debug;
2131 let nvidia_verified = session
2132 .attestation_nvidia_verified
2133 .clone()
2134 .unwrap_or_else(|| "not-present".to_owned());
2135
2136 Self {
2137 e2ee: Some("verified".to_owned()),
2138 attestation_mode: Some(config.attestation.mode.as_str().to_owned()),
2139 attested_model: Some(session.model_id.clone()),
2140 tee_provider: Some(tee_provider),
2141 tdx_verified: config.attestation.require_tdx.then_some(true),
2142 tdx_debug,
2143 nvidia_verified: Some(nvidia_verified),
2144 key_binding: Some(true),
2145 session_id: Some(session.agent_session_id.clone()),
2146 session_scope: Some(session.scope.as_str().to_owned()),
2147 tool_mode: Some(config.tools.mode.as_str().to_owned()),
2148 tool_retries: None,
2149 }
2150 }
2151
2152 pub fn apply(&self, headers: &mut HeaderMap) {
2154 insert_optional_header(headers, HEADER_PROXY_E2EE, self.e2ee.as_deref());
2155 insert_optional_header(
2156 headers,
2157 HEADER_PROXY_ATTESTATION_MODE,
2158 self.attestation_mode.as_deref(),
2159 );
2160 insert_optional_header(
2161 headers,
2162 HEADER_PROXY_ATTESTED_MODEL,
2163 self.attested_model.as_deref(),
2164 );
2165 insert_optional_header(
2166 headers,
2167 HEADER_PROXY_TEE_PROVIDER,
2168 self.tee_provider.as_deref(),
2169 );
2170 insert_optional_bool_header(headers, HEADER_PROXY_TDX_VERIFIED, self.tdx_verified);
2171 insert_optional_bool_header(headers, HEADER_PROXY_TDX_DEBUG, self.tdx_debug);
2172 insert_optional_header(
2173 headers,
2174 HEADER_PROXY_NVIDIA_VERIFIED,
2175 self.nvidia_verified.as_deref(),
2176 );
2177 insert_optional_bool_header(headers, HEADER_PROXY_KEY_BINDING, self.key_binding);
2178 insert_optional_header(headers, HEADER_PROXY_SESSION_ID, self.session_id.as_deref());
2179 insert_optional_header(
2180 headers,
2181 HEADER_PROXY_SESSION_SCOPE,
2182 self.session_scope.as_deref(),
2183 );
2184 insert_optional_header(headers, HEADER_PROXY_TOOL_MODE, self.tool_mode.as_deref());
2185 if let Some(tool_retries) = self.tool_retries {
2186 insert_header(
2187 headers,
2188 HEADER_PROXY_TOOL_RETRIES,
2189 &tool_retries.to_string(),
2190 );
2191 }
2192 }
2193}
2194
2195pub fn apply_error_headers(headers: &mut HeaderMap, error_code: &str) {
2197 insert_header(headers, HEADER_PROXY_ERROR_CODE, error_code);
2198}
2199
2200fn insert_optional_header(headers: &mut HeaderMap, name: &'static str, value: Option<&str>) {
2202 if let Some(value) = value {
2203 insert_header(headers, name, value);
2204 }
2205}
2206
2207fn insert_optional_bool_header(headers: &mut HeaderMap, name: &'static str, value: Option<bool>) {
2209 if let Some(value) = value {
2210 insert_header(headers, name, if value { "true" } else { "false" });
2211 }
2212}
2213
2214fn insert_header(headers: &mut HeaderMap, name: &'static str, value: &str) {
2216 let Ok(name) = HeaderName::from_bytes(name.as_bytes()) else {
2217 return;
2218 };
2219 let Ok(value) = HeaderValue::from_str(value) else {
2220 return;
2221 };
2222 headers.insert(name, value);
2223}
2224
2225#[cfg(test)]
2226mod tests {
2227 use super::*;
2228 use std::{
2229 collections::{HashMap, VecDeque},
2230 sync::{Arc, Mutex},
2231 time::Duration,
2232 };
2233
2234 use axum::{
2235 body::Body,
2236 extract::Query,
2237 http::Request,
2238 routing::{get, post},
2239 };
2240 use serde_json::json;
2241
2242 use crate::config::NvidiaRequirement;
2243 use tower::ServiceExt;
2244
2245 fn test_app() -> Router {
2246 router_with_venice_client(ProxyConfig::default(), test_venice_client())
2247 }
2248
2249 fn test_venice_client() -> VeniceClient {
2250 test_venice_client_for_base_url("http://127.0.0.1:1/api/v1")
2251 }
2252
2253 fn test_venice_client_for_base_url(base_url: impl AsRef<str>) -> VeniceClient {
2254 VeniceClient::new(base_url.as_ref(), "test-api-key", Duration::from_secs(1))
2255 .expect("test Venice client should build")
2256 }
2257
2258 fn chat_config_with_basic_test_attestation() -> ProxyConfig {
2259 let mut config = ProxyConfig::default();
2260 config.attestation.require_tdx = false;
2261 config.attestation.require_nvidia = NvidiaRequirement::Never;
2262 config
2263 }
2264
2265 #[test]
2266 fn app_state_initializes_key_and_session_managers_from_config() {
2267 let state = AppState::from_parts(ProxyConfig::default(), test_venice_client());
2268
2269 let key = state
2270 .proxy_instance_key()
2271 .expect("default config should generate startup key");
2272 assert_eq!(key.public_key_hex().len(), 130);
2273 assert!(state.session_manager().is_empty());
2274 assert_eq!(
2275 state.attestation_verifier().policy(),
2276 &ProxyConfig::default().attestation
2277 );
2278
2279 let mut config = ProxyConfig::default();
2280 config.keys.generate_proxy_instance_key_on_startup = false;
2281 let state = AppState::from_parts(config, test_venice_client());
2282 assert!(state.proxy_instance_key().is_none());
2283 }
2284
2285 async fn error_body(response: Response) -> ErrorResponse {
2286 let bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
2287 .await
2288 .expect("response body should buffer");
2289 serde_json::from_slice(&bytes).expect("response should be OpenAI-style error JSON")
2290 }
2291
2292 #[tokio::test]
2293 async fn chat_route_ignores_upstream_role_only_chunk_before_encrypted_content() {
2294 let response = streaming_chat_response(
2295 "chat-route-role-only",
2296 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":true}"#,
2297 vec![
2298 MockStreamFrame::Role,
2299 MockStreamFrame::Text("Hello"),
2300 MockStreamFrame::Finish("stop"),
2301 MockStreamFrame::Done,
2302 ],
2303 )
2304 .await;
2305
2306 assert_eq!(response.status(), StatusCode::OK);
2307 let body = response_body(response).await;
2308 let data = sse_data(&body);
2309 assert_eq!(data.len(), 3);
2310 let first: Value = serde_json::from_str(data[0]).expect("first chunk should be JSON");
2311 assert_eq!(first["choices"][0]["delta"]["role"], "assistant");
2312 assert_eq!(first["choices"][0]["delta"]["content"], "Hello");
2313 assert_eq!(data[2], "[DONE]");
2314 }
2315
2316 #[tokio::test]
2317 async fn chat_route_streams_decrypted_normal_assistant_text() {
2318 let response = streaming_chat_response(
2319 "chat-route-test",
2320 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":true}"#,
2321 vec![
2322 MockStreamFrame::NullContent,
2323 MockStreamFrame::EmptyContent,
2324 MockStreamFrame::Text("Hello"),
2325 MockStreamFrame::Finish("stop"),
2326 MockStreamFrame::Done,
2327 ],
2328 )
2329 .await;
2330
2331 assert_eq!(response.status(), StatusCode::OK);
2332 assert_eq!(
2333 response.headers().get(HEADER_PROXY_E2EE).unwrap(),
2334 "verified"
2335 );
2336 assert_eq!(
2337 response.headers().get(HEADER_PROXY_ATTESTED_MODEL).unwrap(),
2338 "e2ee-test"
2339 );
2340
2341 let body = response_body(response).await;
2342 let data = sse_data(&body);
2343 assert_eq!(data.len(), 3);
2344
2345 let first: Value = serde_json::from_str(data[0]).expect("first chunk should be JSON");
2346 assert_eq!(first["object"], "chat.completion.chunk");
2347 assert_eq!(first["model"], "e2ee-test");
2348 assert_eq!(first["choices"][0]["delta"]["role"], "assistant");
2349 assert_eq!(first["choices"][0]["delta"]["content"], "Hello");
2350 assert!(first["choices"][0]["finish_reason"].is_null());
2351
2352 let final_chunk: Value = serde_json::from_str(data[1]).expect("final chunk should be JSON");
2353 assert_eq!(final_chunk["choices"][0]["delta"], json!({}));
2354 assert_eq!(final_chunk["choices"][0]["finish_reason"], "stop");
2355 assert_eq!(data[2], "[DONE]");
2356 }
2357
2358 #[tokio::test]
2359 async fn chat_route_streams_decrypted_reasoning_content() {
2360 let response = streaming_chat_response(
2361 "chat-route-reasoning-stream",
2362 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":true,"reasoning":{"effort":"high"}}"#,
2363 vec![
2364 MockStreamFrame::Reasoning("Thinking"),
2365 MockStreamFrame::Text("Answer"),
2366 MockStreamFrame::Finish("stop"),
2367 MockStreamFrame::Done,
2368 ],
2369 )
2370 .await;
2371
2372 assert_eq!(response.status(), StatusCode::OK);
2373 let body = response_body(response).await;
2374 let data = sse_data(&body);
2375 assert_eq!(data.len(), 4);
2376 let reasoning: Value =
2377 serde_json::from_str(data[0]).expect("reasoning chunk should be JSON");
2378 let answer: Value = serde_json::from_str(data[1]).expect("answer chunk should be JSON");
2379
2380 assert_eq!(reasoning["choices"][0]["delta"]["role"], "assistant");
2381 assert_eq!(
2382 reasoning["choices"][0]["delta"]["reasoning_content"],
2383 "Thinking"
2384 );
2385 assert!(answer["choices"][0]["delta"].get("role").is_none());
2386 assert_eq!(answer["choices"][0]["delta"]["content"], "Answer");
2387 assert_eq!(data.last().copied(), Some("[DONE]"));
2388 }
2389
2390 #[tokio::test]
2391 async fn chat_route_streams_multiple_decrypted_content_chunks() {
2392 let response = streaming_chat_response(
2393 "chat-route-multiple-chunks",
2394 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":true}"#,
2395 vec![
2396 MockStreamFrame::Text("Hello"),
2397 MockStreamFrame::Text(" world"),
2398 MockStreamFrame::Finish("stop"),
2399 MockStreamFrame::Done,
2400 ],
2401 )
2402 .await;
2403
2404 assert_eq!(response.status(), StatusCode::OK);
2405 let body = response_body(response).await;
2406 let data = sse_data(&body);
2407 let first: Value = serde_json::from_str(data[0]).expect("first chunk should be JSON");
2408 let second: Value = serde_json::from_str(data[1]).expect("second chunk should be JSON");
2409
2410 assert_eq!(first["choices"][0]["delta"]["role"], "assistant");
2411 assert_eq!(first["choices"][0]["delta"]["content"], "Hello");
2412 assert!(second["choices"][0]["delta"].get("role").is_none());
2413 assert_eq!(second["choices"][0]["delta"]["content"], " world");
2414 assert_eq!(data.last().copied(), Some("[DONE]"));
2415 }
2416
2417 #[tokio::test]
2418 async fn chat_route_passes_through_usage_chunk_when_requested_and_upstream_provides_it() {
2419 let response = streaming_chat_response(
2420 "chat-route-usage",
2421 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":true,"stream_options":{"include_usage":true}}"#,
2422 vec![
2423 MockStreamFrame::Text("Hello"),
2424 MockStreamFrame::Finish("stop"),
2425 MockStreamFrame::Usage,
2426 MockStreamFrame::Done,
2427 ],
2428 )
2429 .await;
2430
2431 assert_eq!(response.status(), StatusCode::OK);
2432 let body = response_body(response).await;
2433 let data = sse_data(&body);
2434 assert_eq!(data.len(), 4);
2435 let usage_chunk: Value = serde_json::from_str(data[2]).expect("usage chunk should be JSON");
2436 assert_eq!(usage_chunk["choices"], json!([]));
2437 assert_eq!(usage_chunk["usage"]["total_tokens"], 3);
2438 assert_eq!(data[3], "[DONE]");
2439 }
2440
2441 #[tokio::test]
2442 async fn chat_route_returns_buffered_non_streaming_completion() {
2443 let response = chat_response(
2444 "chat-route-non-streaming-success",
2445 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":false}"#,
2446 vec![
2447 MockStreamFrame::NullContent,
2448 MockStreamFrame::EmptyContent,
2449 MockStreamFrame::Text("Hello"),
2450 MockStreamFrame::Text(" world"),
2451 MockStreamFrame::Finish("stop"),
2452 MockStreamFrame::Done,
2453 ],
2454 )
2455 .await;
2456
2457 assert_eq!(response.status(), StatusCode::OK);
2458 assert_eq!(
2459 response.headers().get(HEADER_PROXY_E2EE).unwrap(),
2460 "verified"
2461 );
2462 let body = json_body(response).await;
2463 assert_eq!(body["object"], "chat.completion");
2464 assert_eq!(body["id"], "chatcmpl-upstream-test");
2465 assert_eq!(body["created"], 1_717_171_717);
2466 assert_eq!(body["model"], "e2ee-test");
2467 assert_eq!(body["choices"][0]["index"], 0);
2468 assert_eq!(body["choices"][0]["message"]["role"], "assistant");
2469 assert_eq!(body["choices"][0]["message"]["content"], "Hello world");
2470 assert_eq!(body["choices"][0]["finish_reason"], "stop");
2471 assert!(body["usage"].is_null());
2472 }
2473
2474 #[tokio::test]
2475 async fn chat_route_returns_buffered_reasoning_content() {
2476 let response = chat_response(
2477 "chat-route-reasoning-non-streaming",
2478 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":false,"reasoning_effort":"medium"}"#,
2479 vec![
2480 MockStreamFrame::Reasoning("Think "),
2481 MockStreamFrame::Reasoning("first."),
2482 MockStreamFrame::Text("Answer"),
2483 MockStreamFrame::Finish("stop"),
2484 MockStreamFrame::Done,
2485 ],
2486 )
2487 .await;
2488
2489 assert_eq!(response.status(), StatusCode::OK);
2490 let body = json_body(response).await;
2491 assert_eq!(
2492 body["choices"][0]["message"]["reasoning_content"],
2493 "Think first."
2494 );
2495 assert_eq!(body["choices"][0]["message"]["content"], "Answer");
2496 }
2497
2498 #[tokio::test]
2499 async fn chat_route_treats_omitted_stream_as_buffered_non_streaming() {
2500 let response = chat_response(
2501 "chat-route-omitted-stream",
2502 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}]}"#,
2503 vec![MockStreamFrame::Text("Hello"), MockStreamFrame::Done],
2504 )
2505 .await;
2506
2507 assert_eq!(response.status(), StatusCode::OK);
2508 let body = json_body(response).await;
2509 assert_eq!(body["object"], "chat.completion");
2510 assert_eq!(body["choices"][0]["message"]["content"], "Hello");
2511 assert_eq!(body["choices"][0]["finish_reason"], "stop");
2512 }
2513
2514 #[tokio::test]
2515 async fn chat_route_streams_incremental_tool_call_chunks() {
2516 let response = streaming_chat_response(
2517 "chat-route-tool-stream",
2518 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"search"}],"stream":true,"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object","properties":{"query":{"type":"string"}},"required":["query"],"additionalProperties":false}}}]}"#,
2519 vec![
2520 MockStreamFrame::Text("<tool_call>\n{\"name\":\"search_web\",\"arguments\":{\"query\":\"example\"}}\n</tool_call>"),
2521 MockStreamFrame::Finish("stop"),
2522 MockStreamFrame::Done,
2523 ],
2524 )
2525 .await;
2526
2527 assert_eq!(response.status(), StatusCode::OK);
2528 let body = response_body(response).await;
2529 let chunks = sse_json_chunks(&body);
2530
2531 assert_eq!(chunks[0]["choices"][0]["delta"]["role"], "assistant");
2532
2533 let tool_calls = streamed_tool_call_deltas(&chunks);
2534 assert!(!tool_calls.is_empty());
2535 let first = tool_calls[0];
2536 assert_eq!(first["index"], 0);
2537 assert!(first["id"].as_str().unwrap().starts_with("call_"));
2538 assert_eq!(first["type"], "function");
2539 assert_eq!(first["function"]["name"], "search_web");
2540 for later in &tool_calls[1..] {
2541 assert!(later.get("id").is_none());
2542 assert!(later.get("type").is_none());
2543 assert!(later["function"].get("name").is_none());
2544 }
2545 assert_eq!(
2546 streamed_tool_call_arguments(&chunks, 0),
2547 r#"{"query":"example"}"#
2548 );
2549
2550 let final_chunk = chunks.last().expect("stream should have chunks");
2551 assert_eq!(final_chunk["choices"][0]["delta"], json!({}));
2552 assert_eq!(final_chunk["choices"][0]["finish_reason"], "tool_calls");
2553 }
2554
2555 #[tokio::test]
2556 async fn chat_route_streams_text_then_incremental_tool_call() {
2557 let response = streaming_chat_response(
2558 "chat-route-tool-stream-mixed-text",
2559 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"search"}],"stream":true,"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object","properties":{"query":{"type":"string"}},"required":["query"],"additionalProperties":false}}}]}"#,
2560 vec![
2561 MockStreamFrame::NullContent,
2562 MockStreamFrame::EmptyContent,
2563 MockStreamFrame::Text("I'll check that. "),
2564 MockStreamFrame::Text("<tool_call>{\"name\":\"search_web\",\"arguments\":{\"query\":\"example\"}}"),
2565 MockStreamFrame::Text("</tool_call>"),
2566 MockStreamFrame::Finish("stop"),
2567 MockStreamFrame::Done,
2568 ],
2569 )
2570 .await;
2571
2572 assert_eq!(response.status(), StatusCode::OK);
2573 let body = response_body(response).await;
2574 let chunks = sse_json_chunks(&body);
2575
2576 assert_eq!(chunks[0]["choices"][0]["delta"]["role"], "assistant");
2577 assert_eq!(streamed_content(&chunks), "I'll check that. ");
2578
2579 let tool_calls = streamed_tool_call_deltas(&chunks);
2580 assert!(!tool_calls.is_empty());
2581 assert_eq!(tool_calls[0]["function"]["name"], "search_web");
2582 assert_eq!(
2583 streamed_tool_call_arguments(&chunks, 0),
2584 r#"{"query":"example"}"#
2585 );
2586
2587 let final_chunk = chunks.last().expect("stream should have chunks");
2588 assert_eq!(final_chunk["choices"][0]["finish_reason"], "tool_calls");
2589 }
2590
2591 #[tokio::test]
2592 async fn chat_route_fails_closed_on_unterminated_streamed_tool_call() {
2593 let response = streaming_chat_response(
2596 "chat-route-tool-stream-missing-close",
2597 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"search"}],"stream":true,"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object","properties":{"query":{"type":"string"}},"required":["query"],"additionalProperties":false}}}]}"#,
2598 vec![
2599 MockStreamFrame::Text("I'll check that. "),
2600 MockStreamFrame::Text("<tool_call>{\"name\":"),
2601 MockStreamFrame::Finish("stop"),
2602 MockStreamFrame::Done,
2603 ],
2604 )
2605 .await;
2606
2607 assert_stream_body_fails(response).await;
2608 }
2609
2610 #[tokio::test]
2611 async fn chat_route_streams_hermes_format_tool_call_from_glm_model() {
2612 let response = streaming_chat_response(
2615 "chat-route-tool-stream-glm-hermes",
2616 r#"{"model":"e2ee-glm-5-1","messages":[{"role":"user","content":"search"}],"stream":true,"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object","properties":{"query":{"type":"string"}},"required":["query"],"additionalProperties":false}}}]}"#,
2617 vec![
2618 MockStreamFrame::Text("<tool_call>\n{\"name\":\"search_web\",\"arguments\":"),
2619 MockStreamFrame::Text("{\"query\":\"example\"}}\n</tool_call>"),
2620 MockStreamFrame::Finish("stop"),
2621 MockStreamFrame::Done,
2622 ],
2623 )
2624 .await;
2625
2626 assert_eq!(response.status(), StatusCode::OK);
2627 let body = response_body(response).await;
2628 let chunks = sse_json_chunks(&body);
2629
2630 let tool_calls = streamed_tool_call_deltas(&chunks);
2631 assert!(!tool_calls.is_empty());
2632 assert_eq!(tool_calls[0]["function"]["name"], "search_web");
2633 assert_eq!(
2634 streamed_tool_call_arguments(&chunks, 0),
2635 r#"{"query":"example"}"#
2636 );
2637
2638 let final_chunk = chunks.last().expect("stream should have chunks");
2639 assert_eq!(final_chunk["choices"][0]["finish_reason"], "tool_calls");
2640 }
2641
2642 #[tokio::test]
2643 async fn chat_route_recovers_streamed_tool_call_with_truncated_closing_marker() {
2644 let response = streaming_chat_response(
2647 "chat-route-tool-stream-truncated-close",
2648 r#"{"model":"e2ee-glm-4-7-flash-p","messages":[{"role":"user","content":"search"}],"stream":true,"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object","properties":{"query":{"type":"string"}},"required":["query"],"additionalProperties":false}}}]}"#,
2649 vec![
2650 MockStreamFrame::Text("<tool_call>\n{\"name\":\"search_web\",\"arguments\":{\"query\":\"example\"}}\n"),
2651 MockStreamFrame::Finish("stop"),
2652 MockStreamFrame::Done,
2653 ],
2654 )
2655 .await;
2656
2657 assert_eq!(response.status(), StatusCode::OK);
2658 let body = response_body(response).await;
2659 let chunks = sse_json_chunks(&body);
2660
2661 let tool_calls = streamed_tool_call_deltas(&chunks);
2662 assert!(!tool_calls.is_empty());
2663 assert_eq!(tool_calls[0]["function"]["name"], "search_web");
2664 assert_eq!(
2665 streamed_tool_call_arguments(&chunks, 0),
2666 r#"{"query":"example"}"#
2667 );
2668
2669 let final_chunk = chunks.last().expect("stream should have chunks");
2670 assert_eq!(final_chunk["choices"][0]["finish_reason"], "tool_calls");
2671 }
2672
2673 #[tokio::test]
2674 async fn chat_route_streams_multiple_tool_calls_split_across_chunks() {
2675 let response = streaming_chat_response(
2676 "chat-route-tool-stream-multiple-calls",
2677 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"search"}],"stream":true,"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object","properties":{"query":{"type":"string"}},"required":["query"],"additionalProperties":false}}}]}"#,
2678 vec![
2679 MockStreamFrame::Text("<tool_call>{\"name\":\"search_web\",\"arguments\":{\"query\":\"first\"}}"),
2680 MockStreamFrame::Text("</tool_call><tool_call>{\"name\":\"search_web\",\"arguments\":"),
2681 MockStreamFrame::Text("{\"query\":\"second\"}}</tool_call>"),
2682 MockStreamFrame::Finish("stop"),
2683 MockStreamFrame::Done,
2684 ],
2685 )
2686 .await;
2687
2688 assert_eq!(response.status(), StatusCode::OK);
2689 let body = response_body(response).await;
2690 let chunks = sse_json_chunks(&body);
2691
2692 assert_eq!(chunks[0]["choices"][0]["delta"]["role"], "assistant");
2693 let tool_calls = streamed_tool_call_deltas(&chunks);
2694 let first = tool_calls
2695 .iter()
2696 .find(|tool_call| tool_call["index"] == 0 && tool_call.get("id").is_some())
2697 .expect("first call should have an id-bearing fragment");
2698 let second = tool_calls
2699 .iter()
2700 .find(|tool_call| tool_call["index"] == 1 && tool_call.get("id").is_some())
2701 .expect("second call should have an id-bearing fragment");
2702 assert_eq!(first["function"]["name"], "search_web");
2703 assert_eq!(second["function"]["name"], "search_web");
2704 assert_ne!(first["id"], second["id"]);
2705 assert_eq!(
2706 streamed_tool_call_arguments(&chunks, 0),
2707 r#"{"query":"first"}"#
2708 );
2709 assert_eq!(
2710 streamed_tool_call_arguments(&chunks, 1),
2711 r#"{"query":"second"}"#
2712 );
2713
2714 let final_chunk = chunks.last().expect("stream should have chunks");
2715 assert_eq!(final_chunk["choices"][0]["delta"], json!({}));
2716 assert_eq!(final_chunk["choices"][0]["finish_reason"], "tool_calls");
2717 }
2718
2719 #[tokio::test]
2720 async fn chat_route_tool_stream_passes_through_usage_chunk_when_requested() {
2721 let response = streaming_chat_response(
2722 "chat-route-tool-stream-usage",
2723 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"search"}],"stream":true,"stream_options":{"include_usage":true},"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object","properties":{"query":{"type":"string"}},"required":["query"],"additionalProperties":false}}}]}"#,
2724 vec![
2725 MockStreamFrame::Text("<tool_call>{\"name\":\"search_web\",\"arguments\":{\"query\":\"example\"}}</tool_call>"),
2726 MockStreamFrame::Finish("stop"),
2727 MockStreamFrame::Usage,
2728 MockStreamFrame::Done,
2729 ],
2730 )
2731 .await;
2732
2733 assert_eq!(response.status(), StatusCode::OK);
2734 let body = response_body(response).await;
2735 let chunks = sse_json_chunks(&body);
2736
2737 let usage_chunk = chunks.last().expect("stream should have chunks");
2739 assert_eq!(usage_chunk["choices"], json!([]));
2740 assert_eq!(usage_chunk["usage"]["total_tokens"], 3);
2741 let finish_chunk = &chunks[chunks.len() - 2];
2742 assert_eq!(finish_chunk["choices"][0]["finish_reason"], "tool_calls");
2743 }
2744
2745 #[tokio::test]
2746 async fn chat_route_fails_closed_when_streamed_tool_call_exceeds_max_bytes() {
2747 let model_public_key = ProxyInstanceKey::generate().public_key_hex().to_owned();
2748 let base_url = spawn_streaming_venice_server(
2749 model_public_key,
2750 true,
2751 vec![
2752 MockStreamFrame::Text("<tool_call>{\"name\":\"search_web\",\"arguments\":{\"query\":\"this argument body is much longer than the configured cap\"}}</tool_call>"),
2753 MockStreamFrame::Finish("stop"),
2754 MockStreamFrame::Done,
2755 ],
2756 )
2757 .await;
2758 let mut config = chat_config_with_basic_test_attestation();
2759 config.tools.tool_call_max_bytes = 16;
2760
2761 let response = request_chat_with_config(
2762 config,
2763 "chat-route-tool-stream-max-bytes",
2764 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"search"}],"stream":true,"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object","properties":{"query":{"type":"string"}},"required":["query"],"additionalProperties":false}}}]}"#,
2765 base_url,
2766 )
2767 .await;
2768
2769 assert_stream_body_fails(response).await;
2770 }
2771
2772 #[tokio::test]
2773 async fn chat_route_streams_all_tool_calls_when_parallel_tool_calls_false() {
2774 let response = streaming_chat_response(
2777 "chat-route-tool-stream-parallel-false",
2778 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"search"}],"stream":true,"parallel_tool_calls":false,"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object","properties":{"query":{"type":"string"}},"required":["query"],"additionalProperties":false}}}]}"#,
2779 vec![
2780 MockStreamFrame::Text("<tool_call>{\"name\":\"search_web\",\"arguments\":{\"query\":\"first\"}}</tool_call>"),
2781 MockStreamFrame::Text("<tool_call>{\"name\":\"search_web\",\"arguments\":{\"query\":\"second\"}}</tool_call>"),
2782 MockStreamFrame::Finish("stop"),
2783 MockStreamFrame::Done,
2784 ],
2785 )
2786 .await;
2787
2788 assert_eq!(response.status(), StatusCode::OK);
2789 let body = response_body(response).await;
2790 let chunks = sse_json_chunks(&body);
2791
2792 assert_eq!(
2793 streamed_tool_call_arguments(&chunks, 0),
2794 r#"{"query":"first"}"#
2795 );
2796 assert_eq!(
2797 streamed_tool_call_arguments(&chunks, 1),
2798 r#"{"query":"second"}"#
2799 );
2800
2801 let final_chunk = chunks.last().expect("stream should have chunks");
2802 assert_eq!(final_chunk["choices"][0]["finish_reason"], "tool_calls");
2803 }
2804
2805 #[tokio::test]
2806 async fn chat_route_returns_non_streaming_tool_call_body_from_mixed_text() {
2807 let response = chat_response(
2808 "chat-route-tool-non-stream-mixed-text",
2809 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"search"}],"stream":false,"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object","properties":{"query":{"type":"string"}},"required":["query"]}}}]}"#,
2810 vec![
2811 MockStreamFrame::Text("I'll check that. <tool_call>{\"name\":\"search_web\",\"arguments\":{\"query\":\"example\"}}</tool_call>"),
2812 MockStreamFrame::Done,
2813 ],
2814 )
2815 .await;
2816
2817 assert_eq!(response.status(), StatusCode::OK);
2818 let body = json_body(response).await;
2819 assert_eq!(body["choices"][0]["finish_reason"], "tool_calls");
2820 let tool_call = &body["choices"][0]["message"]["tool_calls"][0];
2821 assert_eq!(tool_call["function"]["name"], "search_web");
2822 assert_eq!(tool_call["function"]["arguments"], r#"{"query":"example"}"#);
2823 }
2824
2825 #[tokio::test]
2826 async fn chat_route_returns_non_streaming_tool_call_body() {
2827 let response = chat_response(
2828 "chat-route-tool-non-stream",
2829 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"search"}],"stream":false,"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object","properties":{"query":{"type":"string"}},"required":["query"]}}}]}"#,
2830 vec![
2831 MockStreamFrame::Text("<tool_call>{\"name\":\"search_web\",\"arguments\":{\"query\":\"example\"}}</tool_call>"),
2832 MockStreamFrame::Done,
2833 ],
2834 )
2835 .await;
2836
2837 assert_eq!(response.status(), StatusCode::OK);
2838 let body = json_body(response).await;
2839 assert_eq!(body["object"], "chat.completion");
2840 assert!(body["choices"][0]["message"]["content"].is_null());
2841 assert_eq!(body["choices"][0]["finish_reason"], "tool_calls");
2842 let tool_call = &body["choices"][0]["message"]["tool_calls"][0];
2843 assert!(tool_call["id"].as_str().unwrap().starts_with("call_"));
2844 assert_eq!(tool_call["type"], "function");
2845 assert_eq!(tool_call["function"]["name"], "search_web");
2846 assert_eq!(tool_call["function"]["arguments"], r#"{"query":"example"}"#);
2847 }
2848
2849 #[tokio::test]
2850 async fn chat_route_returns_non_streaming_multiple_tool_calls() {
2851 let response = chat_response(
2852 "chat-route-tool-non-stream-multiple-calls",
2853 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"search"}],"stream":false,"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object","properties":{"query":{"type":"string"}},"required":["query"]}}}]}"#,
2854 vec![
2855 MockStreamFrame::Text("<tool_call>{\"name\":\"search_web\",\"arguments\":{\"query\":\"first\"}}</tool_call>\n<tool_call>{\"name\":\"search_web\",\"arguments\":{\"query\":\"second\"}}</tool_call>"),
2856 MockStreamFrame::Done,
2857 ],
2858 )
2859 .await;
2860
2861 assert_eq!(response.status(), StatusCode::OK);
2862 let body = json_body(response).await;
2863 assert_eq!(body["choices"][0]["finish_reason"], "tool_calls");
2864 assert!(body["choices"][0]["message"]["content"].is_null());
2865 let tool_calls = body["choices"][0]["message"]["tool_calls"]
2866 .as_array()
2867 .expect("tool_calls should be an array");
2868 assert_eq!(tool_calls.len(), 2);
2869 assert_eq!(tool_calls[0]["function"]["name"], "search_web");
2870 assert_eq!(
2871 tool_calls[0]["function"]["arguments"],
2872 r#"{"query":"first"}"#
2873 );
2874 assert_eq!(
2875 tool_calls[1]["function"]["arguments"],
2876 r#"{"query":"second"}"#
2877 );
2878 assert_ne!(tool_calls[0]["id"], tool_calls[1]["id"]);
2879 }
2880
2881 #[tokio::test]
2882 async fn chat_route_tool_mode_leaves_normal_text_unaffected() {
2883 let response = streaming_chat_response(
2884 "chat-route-tool-normal-text",
2885 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":true,"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object"}}}]}"#,
2886 vec![
2887 MockStreamFrame::Text("Hello without tools"),
2888 MockStreamFrame::Finish("stop"),
2889 MockStreamFrame::Done,
2890 ],
2891 )
2892 .await;
2893
2894 assert_eq!(response.status(), StatusCode::OK);
2895 let body = response_body(response).await;
2896 let chunks = sse_json_chunks(&body);
2897 assert_eq!(chunks[0]["choices"][0]["delta"]["role"], "assistant");
2898 assert_eq!(streamed_content(&chunks), "Hello without tools");
2899 assert!(streamed_tool_call_deltas(&chunks).is_empty());
2900 }
2901
2902 #[tokio::test]
2903 async fn chat_route_treats_marker_like_non_protocol_text_as_normal_text() {
2904 let response = streaming_chat_response(
2905 "chat-route-tool-marker-like-text",
2906 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":true,"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object"}}}]}"#,
2907 vec![
2908 MockStreamFrame::Text("<tool_cal>{not actually a marker}"),
2909 MockStreamFrame::Finish("stop"),
2910 MockStreamFrame::Done,
2911 ],
2912 )
2913 .await;
2914
2915 assert_eq!(response.status(), StatusCode::OK);
2916 let body = response_body(response).await;
2917 let chunks = sse_json_chunks(&body);
2918 assert_eq!(
2919 streamed_content(&chunks),
2920 "<tool_cal>{not actually a marker}"
2921 );
2922 assert!(streamed_tool_call_deltas(&chunks).is_empty());
2923 }
2924
2925 #[tokio::test]
2926 async fn chat_route_retries_invalid_tool_call_and_returns_success() {
2927 let response = chat_response_sequence(
2928 "chat-route-tool-retry-success",
2929 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"search"}],"stream":false,"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object","properties":{"query":{"type":"string"}},"required":["query"]}}}]}"#,
2930 vec![
2931 vec![
2932 MockStreamFrame::Text("<tool_call>{\"name\":\"unknown\",\"arguments\":{\"query\":\"example\"}}</tool_call>"),
2933 MockStreamFrame::Done,
2934 ],
2935 vec![
2936 MockStreamFrame::Text("<tool_call>{\"name\":\"search_web\",\"arguments\":{\"query\":\"example\"}}</tool_call>"),
2937 MockStreamFrame::Done,
2938 ],
2939 ],
2940 )
2941 .await;
2942
2943 assert_eq!(response.status(), StatusCode::OK);
2944 assert_eq!(
2945 response.headers().get(HEADER_PROXY_TOOL_RETRIES).unwrap(),
2946 "1"
2947 );
2948 let body = json_body(response).await;
2949 assert_eq!(body["choices"][0]["finish_reason"], "tool_calls");
2950 assert_eq!(
2951 body["choices"][0]["message"]["tool_calls"][0]["function"]["name"],
2952 "search_web"
2953 );
2954 }
2955
2956 #[tokio::test]
2957 async fn chat_route_returns_retry_failure_error_shape() {
2958 let response = chat_response(
2959 "chat-route-tool-retry-failure",
2960 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"search"}],"stream":false,"tools":[{"type":"function","function":{"name":"search_web","parameters":{"type":"object","properties":{"query":{"type":"string"}},"required":["query"]}}}]}"#,
2961 vec![
2962 MockStreamFrame::Text("<tool_call>{\"name\":\"unknown\",\"arguments\":{}}</tool_call>"),
2963 MockStreamFrame::Done,
2964 ],
2965 )
2966 .await;
2967
2968 assert_eq!(response.status(), StatusCode::BAD_GATEWAY);
2969 assert_eq!(
2970 response.headers().get(HEADER_PROXY_ERROR_CODE).unwrap(),
2971 "invalid_tool_call"
2972 );
2973 let body = json_body(response).await;
2974 assert_eq!(body["error"]["type"], "proxy_tool_call_error");
2975 assert_eq!(body["error"]["code"], "invalid_tool_call");
2976 assert_eq!(body["error"]["details"]["max_retries"], 2);
2977 assert!(
2978 body["error"]["details"]["last_validation_error"]
2979 .as_str()
2980 .unwrap()
2981 .contains("unknown tool name")
2982 );
2983 }
2984
2985 #[tokio::test]
2986 async fn chat_route_non_streaming_fails_closed_on_upstream_error_response() {
2987 let response = chat_response_with_upstream_status(
2988 "chat-route-non-streaming-upstream-error",
2989 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":false}"#,
2990 StatusCode::INTERNAL_SERVER_ERROR,
2991 )
2992 .await;
2993
2994 assert_eq!(response.status(), StatusCode::BAD_GATEWAY);
2995 assert_eq!(
2996 response.headers().get(HEADER_PROXY_ERROR_CODE).unwrap(),
2997 "upstream_status_error"
2998 );
2999 let body = error_body(response).await;
3000 assert_eq!(body.error.kind, "proxy_upstream_error");
3001 assert_eq!(body.error.code, "upstream_status_error");
3002 }
3003
3004 #[tokio::test]
3005 async fn chat_route_non_streaming_fails_closed_on_malformed_upstream_payload() {
3006 let response = chat_response(
3007 "chat-route-non-streaming-malformed",
3008 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":false}"#,
3009 vec![MockStreamFrame::Raw("data: {\"choices\":\"bad\"}\n\n")],
3010 )
3011 .await;
3012
3013 assert_eq!(response.status(), StatusCode::BAD_GATEWAY);
3014 assert_eq!(
3015 response.headers().get(HEADER_PROXY_ERROR_CODE).unwrap(),
3016 "upstream_malformed_response"
3017 );
3018 let body = error_body(response).await;
3019 assert_eq!(body.error.kind, "proxy_upstream_error");
3020 assert_eq!(body.error.code, "upstream_malformed_response");
3021 }
3022
3023 #[tokio::test]
3024 async fn chat_route_non_streaming_fails_closed_on_missing_encrypted_content() {
3025 let response = chat_response(
3026 "chat-route-non-streaming-missing-content",
3027 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":false}"#,
3028 vec![MockStreamFrame::Finish("stop"), MockStreamFrame::Done],
3029 )
3030 .await;
3031
3032 assert_eq!(response.status(), StatusCode::BAD_GATEWAY);
3033 assert_eq!(
3034 response.headers().get(HEADER_PROXY_ERROR_CODE).unwrap(),
3035 "e2ee_response_decryption_failed"
3036 );
3037 let body = error_body(response).await;
3038 assert_eq!(body.error.kind, "proxy_e2ee_error");
3039 assert_eq!(body.error.code, "e2ee_response_decryption_failed");
3040 }
3041
3042 #[tokio::test]
3043 async fn chat_route_non_streaming_fails_closed_on_decryption_failure() {
3044 let response = chat_response(
3045 "chat-route-non-streaming-decryption-failure",
3046 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":false}"#,
3047 vec![MockStreamFrame::TextForWrongRecipient(" secret"), MockStreamFrame::Done],
3048 )
3049 .await;
3050
3051 assert_eq!(response.status(), StatusCode::BAD_GATEWAY);
3052 assert_eq!(
3053 response.headers().get(HEADER_PROXY_ERROR_CODE).unwrap(),
3054 "e2ee_response_decryption_failed"
3055 );
3056 let body = error_body(response).await;
3057 assert_eq!(body.error.kind, "proxy_e2ee_error");
3058 assert_eq!(body.error.code, "e2ee_response_decryption_failed");
3059 }
3060
3061 #[tokio::test]
3062 async fn chat_route_non_streaming_passes_through_usage_when_available() {
3063 let response = chat_response(
3064 "chat-route-non-streaming-usage",
3065 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":false}"#,
3066 vec![
3067 MockStreamFrame::Text("Hello"),
3068 MockStreamFrame::Finish("stop"),
3069 MockStreamFrame::Usage,
3070 MockStreamFrame::Done,
3071 ],
3072 )
3073 .await;
3074
3075 assert_eq!(response.status(), StatusCode::OK);
3076 let body = json_body(response).await;
3077 assert_eq!(body["choices"][0]["message"]["content"], "Hello");
3078 assert_eq!(body["usage"]["prompt_tokens"], 1);
3079 assert_eq!(body["usage"]["completion_tokens"], 2);
3080 assert_eq!(body["usage"]["total_tokens"], 3);
3081 }
3082
3083 #[tokio::test]
3084 async fn chat_route_fails_closed_on_upstream_stream_error_event() {
3085 let response = streaming_chat_response(
3086 "chat-route-upstream-error",
3087 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":true}"#,
3088 vec![MockStreamFrame::Error("model failed")],
3089 )
3090 .await;
3091
3092 assert_stream_body_fails(response).await;
3093 }
3094
3095 #[tokio::test]
3096 async fn chat_route_fails_closed_on_malformed_upstream_event() {
3097 let response = streaming_chat_response(
3098 "chat-route-malformed-event",
3099 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":true}"#,
3100 vec![MockStreamFrame::Raw("data: {\"choices\":\n\n")],
3101 )
3102 .await;
3103
3104 assert_stream_body_fails(response).await;
3105 }
3106
3107 #[tokio::test]
3108 async fn chat_route_fails_closed_on_decryption_failure_mid_stream() {
3109 let response = streaming_chat_response(
3110 "chat-route-decryption-failure",
3111 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":true}"#,
3112 vec![
3113 MockStreamFrame::Text("Hello"),
3114 MockStreamFrame::TextForWrongRecipient(" secret"),
3115 MockStreamFrame::Done,
3116 ],
3117 )
3118 .await;
3119
3120 assert_stream_body_fails(response).await;
3121 }
3122
3123 #[tokio::test]
3124 async fn chat_route_synthesizes_final_finish_chunk_before_done_when_needed() {
3125 let response = streaming_chat_response(
3126 "chat-route-final-done",
3127 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":true}"#,
3128 vec![MockStreamFrame::Text("Hello"), MockStreamFrame::Done],
3129 )
3130 .await;
3131
3132 assert_eq!(response.status(), StatusCode::OK);
3133 let body = response_body(response).await;
3134 let data = sse_data(&body);
3135 assert_eq!(data.len(), 3);
3136 let final_chunk: Value = serde_json::from_str(data[1]).expect("final chunk should be JSON");
3137 assert_eq!(final_chunk["choices"][0]["delta"], json!({}));
3138 assert_eq!(final_chunk["choices"][0]["finish_reason"], "stop");
3139 assert_eq!(data[2], "[DONE]");
3140 }
3141
3142 #[tokio::test]
3143 async fn chat_route_attestation_failure_prevents_request_construction() {
3144 let model_public_key = ProxyInstanceKey::generate().public_key_hex().to_owned();
3145 let base_url = spawn_attestation_server(model_public_key, false).await;
3146 let app = router_with_venice_client(
3147 chat_config_with_basic_test_attestation(),
3148 test_venice_client_for_base_url(base_url),
3149 );
3150
3151 let response = app
3152 .oneshot(
3153 Request::builder()
3154 .method(Method::POST)
3155 .uri("/v1/chat/completions")
3156 .header("content-type", "application/json")
3157 .header(HEADER_PROXY_SESSION_ID, "chat-route-attestation-failure")
3158 .body(Body::from(
3159 r#"{"model":"e2ee-test","messages":[{"role":"user","content":"hello"}],"stream":false}"#,
3160 ))
3161 .expect("request should build"),
3162 )
3163 .await
3164 .expect("request should complete");
3165
3166 assert_eq!(response.status(), StatusCode::BAD_GATEWAY);
3167 assert_eq!(
3168 response.headers().get(HEADER_PROXY_ERROR_CODE).unwrap(),
3169 "attestation_upstream_not_verified"
3170 );
3171 let body = error_body(response).await;
3172 assert_eq!(body.error.kind, "proxy_attestation_error");
3173 assert_eq!(body.error.code, "attestation_upstream_not_verified");
3174 }
3175
3176 #[tokio::test]
3177 async fn unknown_route_returns_openai_style_not_found() {
3178 let response = test_app()
3179 .oneshot(
3180 Request::builder()
3181 .uri("/v1/unknown")
3182 .body(Body::empty())
3183 .expect("request should build"),
3184 )
3185 .await
3186 .expect("request should complete");
3187
3188 assert_eq!(response.status(), StatusCode::NOT_FOUND);
3189 assert_eq!(
3190 response.headers().get(HEADER_PROXY_ERROR_CODE).unwrap(),
3191 "not_found"
3192 );
3193 let body = error_body(response).await;
3194 assert_eq!(body.error.kind, "invalid_request_error");
3195 assert_eq!(body.error.code, "not_found");
3196 }
3197
3198 #[tokio::test]
3199 async fn unsupported_method_returns_openai_style_method_error() {
3200 let response = test_app()
3201 .oneshot(
3202 Request::builder()
3203 .method(Method::POST)
3204 .uri("/v1/models")
3205 .body(Body::empty())
3206 .expect("request should build"),
3207 )
3208 .await
3209 .expect("request should complete");
3210
3211 assert_eq!(response.status(), StatusCode::METHOD_NOT_ALLOWED);
3212 assert_eq!(
3213 response.headers().get(HEADER_PROXY_ERROR_CODE).unwrap(),
3214 "method_not_allowed"
3215 );
3216 let body = error_body(response).await;
3217 assert_eq!(body.error.kind, "invalid_request_error");
3218 assert_eq!(body.error.code, "method_not_allowed");
3219 }
3220
3221 #[tokio::test]
3222 async fn malformed_chat_json_uses_axum_extractor_rejection() {
3223 let response = test_app()
3224 .oneshot(
3225 Request::builder()
3226 .method(Method::POST)
3227 .uri("/v1/chat/completions")
3228 .header("content-type", "application/json")
3229 .body(Body::from("{"))
3230 .expect("request should build"),
3231 )
3232 .await
3233 .expect("request should complete");
3234
3235 assert_eq!(response.status(), StatusCode::BAD_REQUEST);
3236 assert!(response.headers().get(HEADER_PROXY_ERROR_CODE).is_none());
3237 }
3238
3239 #[tokio::test]
3240 async fn non_object_chat_json_returns_structured_invalid_request() {
3241 let response = test_app()
3242 .oneshot(
3243 Request::builder()
3244 .method(Method::POST)
3245 .uri("/v1/chat/completions")
3246 .header("content-type", "application/json")
3247 .body(Body::from("[]"))
3248 .expect("request should build"),
3249 )
3250 .await
3251 .expect("request should complete");
3252
3253 assert_eq!(response.status(), StatusCode::BAD_REQUEST);
3254 assert_eq!(
3255 response.headers().get(HEADER_PROXY_ERROR_CODE).unwrap(),
3256 "invalid_request"
3257 );
3258 let body = error_body(response).await;
3259 assert_eq!(body.error.kind, "invalid_request_error");
3260 assert_eq!(body.error.code, "invalid_request");
3261 }
3262
3263 #[derive(Debug, Clone)]
3264 enum MockStreamFrame {
3265 Role,
3266 NullContent,
3267 EmptyContent,
3268 Text(&'static str),
3269 Reasoning(&'static str),
3270 TextForWrongRecipient(&'static str),
3271 Finish(&'static str),
3272 Usage,
3273 Done,
3274 Error(&'static str),
3275 Raw(&'static str),
3276 }
3277
3278 async fn streaming_chat_response(
3279 session_id: &'static str,
3280 request_body: &'static str,
3281 frames: Vec<MockStreamFrame>,
3282 ) -> Response {
3283 chat_response(session_id, request_body, frames).await
3284 }
3285
3286 async fn chat_response(
3287 session_id: &'static str,
3288 request_body: &'static str,
3289 frames: Vec<MockStreamFrame>,
3290 ) -> Response {
3291 let model_public_key = ProxyInstanceKey::generate().public_key_hex().to_owned();
3292 let base_url = spawn_streaming_venice_server(model_public_key, true, frames).await;
3293 request_chat(session_id, request_body, base_url).await
3294 }
3295
3296 async fn chat_response_sequence(
3297 session_id: &'static str,
3298 request_body: &'static str,
3299 attempts: Vec<Vec<MockStreamFrame>>,
3300 ) -> Response {
3301 let model_public_key = ProxyInstanceKey::generate().public_key_hex().to_owned();
3302 let base_url =
3303 spawn_streaming_venice_server_sequence(model_public_key, true, attempts).await;
3304 request_chat(session_id, request_body, base_url).await
3305 }
3306
3307 async fn chat_response_with_upstream_status(
3308 session_id: &'static str,
3309 request_body: &'static str,
3310 upstream_status: StatusCode,
3311 ) -> Response {
3312 let model_public_key = ProxyInstanceKey::generate().public_key_hex().to_owned();
3313 let base_url =
3314 spawn_venice_server_with_chat_status(model_public_key, upstream_status).await;
3315 request_chat(session_id, request_body, base_url).await
3316 }
3317
3318 async fn request_chat(
3319 session_id: &'static str,
3320 request_body: &'static str,
3321 base_url: String,
3322 ) -> Response {
3323 request_chat_with_config(
3324 chat_config_with_basic_test_attestation(),
3325 session_id,
3326 request_body,
3327 base_url,
3328 )
3329 .await
3330 }
3331
3332 async fn request_chat_with_config(
3333 config: ProxyConfig,
3334 session_id: &'static str,
3335 request_body: &'static str,
3336 base_url: String,
3337 ) -> Response {
3338 let app = router_with_venice_client(config, test_venice_client_for_base_url(base_url));
3339
3340 app.oneshot(
3341 Request::builder()
3342 .method(Method::POST)
3343 .uri("/v1/chat/completions")
3344 .header("content-type", "application/json")
3345 .header(HEADER_PROXY_SESSION_ID, session_id)
3346 .body(Body::from(request_body))
3347 .expect("request should build"),
3348 )
3349 .await
3350 .expect("request should complete")
3351 }
3352
3353 async fn json_body(response: Response) -> Value {
3354 let bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
3355 .await
3356 .expect("response body should buffer");
3357 serde_json::from_slice(&bytes).expect("response should be JSON")
3358 }
3359
3360 async fn response_body(response: Response) -> String {
3361 let bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
3362 .await
3363 .expect("response body should buffer");
3364 String::from_utf8(bytes.to_vec()).expect("response body should be UTF-8")
3365 }
3366
3367 async fn assert_stream_body_fails(response: Response) {
3368 assert_eq!(response.status(), StatusCode::OK);
3369 let result = axum::body::to_bytes(response.into_body(), usize::MAX).await;
3370 assert!(
3371 result.is_err(),
3372 "stream body should fail closed instead of completing successfully"
3373 );
3374 }
3375
3376 fn sse_data(body: &str) -> Vec<&str> {
3377 body.lines()
3378 .filter_map(|line| line.strip_prefix("data: "))
3379 .collect()
3380 }
3381
3382 fn sse_json_chunks(body: &str) -> Vec<Value> {
3384 let data = sse_data(body);
3385 assert_eq!(data.last().copied(), Some("[DONE]"));
3386 data[..data.len() - 1]
3387 .iter()
3388 .map(|chunk| serde_json::from_str(chunk).expect("SSE chunk should be JSON"))
3389 .collect()
3390 }
3391
3392 fn streamed_content(chunks: &[Value]) -> String {
3394 chunks
3395 .iter()
3396 .filter_map(|chunk| chunk["choices"][0]["delta"]["content"].as_str())
3397 .collect()
3398 }
3399
3400 fn streamed_tool_call_deltas(chunks: &[Value]) -> Vec<&Value> {
3402 chunks
3403 .iter()
3404 .filter_map(|chunk| chunk["choices"][0]["delta"]["tool_calls"].as_array())
3405 .flatten()
3406 .collect()
3407 }
3408
3409 fn streamed_tool_call_arguments(chunks: &[Value], index: u64) -> String {
3411 streamed_tool_call_deltas(chunks)
3412 .iter()
3413 .filter(|tool_call| tool_call["index"] == json!(index))
3414 .filter_map(|tool_call| tool_call["function"]["arguments"].as_str())
3415 .collect()
3416 }
3417
3418 async fn spawn_streaming_venice_server(
3419 model_public_key: String,
3420 verified: bool,
3421 frames: Vec<MockStreamFrame>,
3422 ) -> String {
3423 spawn_streaming_venice_server_sequence(model_public_key, verified, vec![frames]).await
3424 }
3425
3426 async fn spawn_streaming_venice_server_sequence(
3427 model_public_key: String,
3428 verified: bool,
3429 attempts: Vec<Vec<MockStreamFrame>>,
3430 ) -> String {
3431 let chat_attempts = Arc::new(Mutex::new(VecDeque::from(attempts)));
3432 let attestation_key = model_public_key.clone();
3433 let app = Router::new()
3434 .route(
3435 "/api/v1/tee/attestation",
3436 get(move |Query(query): Query<HashMap<String, String>>| {
3437 let model_public_key = attestation_key.clone();
3438 async move {
3439 Json(json!({
3440 "api_version": "aci/1",
3441 "attestation": {
3442 "tee_type": "tdx",
3443 "evidence": {}
3444 },
3445 "verified": verified,
3446 "nonce": query.get("nonce").cloned().unwrap_or_default(),
3447 "model": query.get("model").cloned().unwrap_or_default(),
3448 "tee_provider": "phala",
3449 "signing_public_key": model_public_key,
3450 }))
3451 }
3452 }),
3453 )
3454 .route(
3455 "/api/v1/chat/completions",
3456 post(move |headers: HeaderMap, Json(body): Json<Value>| {
3457 let chat_attempts = chat_attempts.clone();
3458 async move {
3459 let Some(client_public_key) = headers
3460 .get(crate::venice::HEADER_VENICE_TEE_CLIENT_PUB_KEY)
3461 .and_then(|value| value.to_str().ok())
3462 else {
3463 return (
3464 StatusCode::BAD_REQUEST,
3465 [("content-type", "text/plain")],
3466 "missing client key".to_owned(),
3467 );
3468 };
3469 if body.get("stream").and_then(Value::as_bool) != Some(true) {
3470 return (
3471 StatusCode::BAD_REQUEST,
3472 [("content-type", "text/plain")],
3473 "upstream request must stream".to_owned(),
3474 );
3475 }
3476 let messages = body.get("messages").and_then(Value::as_array);
3477 if messages.is_none_or(|messages| {
3478 messages.is_empty()
3479 || !messages.iter().all(|message| {
3480 message.get("role").and_then(Value::as_str).is_some()
3481 && message
3482 .get("content")
3483 .and_then(Value::as_str)
3484 .is_some_and(|content| {
3485 !content.is_empty()
3486 && content
3487 .chars()
3488 .all(|ch| ch.is_ascii_hexdigit())
3489 })
3490 })
3491 }) {
3492 return (
3493 StatusCode::BAD_REQUEST,
3494 [("content-type", "text/plain")],
3495 "messages must be encrypted message objects".to_owned(),
3496 );
3497 }
3498
3499 let frames = {
3500 let mut attempts = chat_attempts
3501 .lock()
3502 .expect("mock chat attempts mutex should not be poisoned");
3503 if attempts.len() > 1 {
3504 attempts.pop_front().expect("attempts length checked above")
3505 } else {
3506 attempts.front().cloned().unwrap_or_default()
3507 }
3508 };
3509
3510 (
3511 StatusCode::OK,
3512 [("content-type", "text/event-stream")],
3513 render_mock_sse(&frames, client_public_key),
3514 )
3515 }
3516 }),
3517 );
3518 let listener = TcpListener::bind(("127.0.0.1", 0))
3519 .await
3520 .expect("mock Venice listener should bind");
3521 let addr = listener
3522 .local_addr()
3523 .expect("mock Venice listener should have local address");
3524
3525 tokio::spawn(async move {
3526 axum::serve(listener, app)
3527 .await
3528 .expect("mock Venice server should run");
3529 });
3530
3531 format!("http://{addr}/api/v1")
3532 }
3533
3534 async fn spawn_venice_server_with_chat_status(
3535 model_public_key: String,
3536 upstream_status: StatusCode,
3537 ) -> String {
3538 let attestation_key = model_public_key.clone();
3539 let app = Router::new()
3540 .route(
3541 "/api/v1/tee/attestation",
3542 get(move |Query(query): Query<HashMap<String, String>>| {
3543 let model_public_key = attestation_key.clone();
3544 async move {
3545 Json(json!({
3546 "api_version": "aci/1",
3547 "attestation": {
3548 "tee_type": "tdx",
3549 "evidence": {}
3550 },
3551 "verified": true,
3552 "nonce": query.get("nonce").cloned().unwrap_or_default(),
3553 "model": query.get("model").cloned().unwrap_or_default(),
3554 "tee_provider": "phala",
3555 "signing_public_key": model_public_key,
3556 }))
3557 }
3558 }),
3559 )
3560 .route(
3561 "/api/v1/chat/completions",
3562 post(move || async move { upstream_status }),
3563 );
3564 let listener = TcpListener::bind(("127.0.0.1", 0))
3565 .await
3566 .expect("mock Venice listener should bind");
3567 let addr = listener
3568 .local_addr()
3569 .expect("mock Venice listener should have local address");
3570
3571 tokio::spawn(async move {
3572 axum::serve(listener, app)
3573 .await
3574 .expect("mock Venice server should run");
3575 });
3576
3577 format!("http://{addr}/api/v1")
3578 }
3579
3580 fn render_mock_sse(frames: &[MockStreamFrame], client_public_key: &str) -> String {
3581 let codec = E2eeCodec::default();
3582 let mut output = String::new();
3583 for frame in frames {
3584 match frame {
3585 MockStreamFrame::Role => {
3586 output.push_str(&format!("data: {}\n\n", upstream_role_chunk()));
3587 }
3588 MockStreamFrame::NullContent => {
3589 output.push_str(&format!("data: {}\n\n", upstream_null_content_chunk()));
3590 }
3591 MockStreamFrame::EmptyContent => {
3592 output.push_str(&format!(
3593 "data: {}\n\n",
3594 upstream_content_chunk(String::new())
3595 ));
3596 }
3597 MockStreamFrame::Text(content) => {
3598 let encrypted = codec
3599 .encrypt_content(content, client_public_key)
3600 .expect("mock content should encrypt")
3601 .into_hex();
3602 output.push_str(&format!("data: {}\n\n", upstream_content_chunk(encrypted)));
3603 }
3604 MockStreamFrame::Reasoning(content) => {
3605 let encrypted = codec
3606 .encrypt_content(content, client_public_key)
3607 .expect("mock reasoning content should encrypt")
3608 .into_hex();
3609 output.push_str(&format!(
3610 "data: {}\n\n",
3611 upstream_reasoning_content_chunk(encrypted)
3612 ));
3613 }
3614 MockStreamFrame::TextForWrongRecipient(content) => {
3615 let wrong_key = ProxyInstanceKey::generate();
3616 let encrypted = codec
3617 .encrypt_content(content, wrong_key.public_key_hex())
3618 .expect("mock content should encrypt")
3619 .into_hex();
3620 output.push_str(&format!("data: {}\n\n", upstream_content_chunk(encrypted)));
3621 }
3622 MockStreamFrame::Finish(reason) => {
3623 output.push_str(&format!("data: {}\n\n", upstream_finish_chunk(reason)));
3624 }
3625 MockStreamFrame::Usage => {
3626 output.push_str(&format!("data: {}\n\n", upstream_usage_chunk()));
3627 }
3628 MockStreamFrame::Done => output.push_str("data: [DONE]\n\n"),
3629 MockStreamFrame::Error(message) => {
3630 output.push_str(&format!(
3631 "event: error\ndata: {}\n\n",
3632 json!({ "message": message })
3633 ));
3634 }
3635 MockStreamFrame::Raw(raw) => output.push_str(raw),
3636 }
3637 }
3638 output
3639 }
3640
3641 fn upstream_role_chunk() -> Value {
3642 json!({
3643 "id": "chatcmpl-upstream-test",
3644 "object": "chat.completion.chunk",
3645 "created": 1_717_171_717,
3646 "model": "e2ee-test",
3647 "choices": [{
3648 "index": 0,
3649 "delta": { "role": "assistant" },
3650 "finish_reason": null,
3651 }],
3652 })
3653 }
3654
3655 fn upstream_content_chunk(encrypted_content: String) -> Value {
3656 json!({
3657 "id": "chatcmpl-upstream-test",
3658 "object": "chat.completion.chunk",
3659 "created": 1_717_171_717,
3660 "model": "e2ee-test",
3661 "choices": [{
3662 "index": 0,
3663 "delta": { "content": encrypted_content },
3664 "finish_reason": null,
3665 }],
3666 })
3667 }
3668
3669 fn upstream_reasoning_content_chunk(encrypted_content: String) -> Value {
3670 json!({
3671 "id": "chatcmpl-upstream-test",
3672 "object": "chat.completion.chunk",
3673 "created": 1_717_171_717,
3674 "model": "e2ee-test",
3675 "choices": [{
3676 "index": 0,
3677 "delta": { "reasoning_content": encrypted_content },
3678 "finish_reason": null,
3679 }],
3680 })
3681 }
3682
3683 fn upstream_null_content_chunk() -> Value {
3684 json!({
3685 "id": "chatcmpl-upstream-test",
3686 "object": "chat.completion.chunk",
3687 "created": 1_717_171_717,
3688 "model": "e2ee-test",
3689 "choices": [{
3690 "index": 0,
3691 "delta": { "content": Value::Null },
3692 "finish_reason": null,
3693 }],
3694 })
3695 }
3696
3697 fn upstream_finish_chunk(reason: &str) -> Value {
3698 json!({
3699 "id": "chatcmpl-upstream-test",
3700 "object": "chat.completion.chunk",
3701 "created": 1_717_171_717,
3702 "model": "e2ee-test",
3703 "choices": [{
3704 "index": 0,
3705 "delta": {},
3706 "finish_reason": reason,
3707 }],
3708 })
3709 }
3710
3711 fn upstream_usage_chunk() -> Value {
3712 json!({
3713 "id": "chatcmpl-upstream-test",
3714 "object": "chat.completion.chunk",
3715 "created": 1_717_171_717,
3716 "model": "e2ee-test",
3717 "choices": [],
3718 "usage": {
3719 "prompt_tokens": 1,
3720 "completion_tokens": 2,
3721 "total_tokens": 3,
3722 },
3723 })
3724 }
3725
3726 async fn spawn_attestation_server(model_public_key: String, verified: bool) -> String {
3727 let app = Router::new().route(
3728 "/api/v1/tee/attestation",
3729 get(move |Query(query): Query<HashMap<String, String>>| {
3730 let model_public_key = model_public_key.clone();
3731 async move {
3732 Json(json!({
3733 "api_version": "aci/1",
3734 "attestation": {
3735 "tee_type": "tdx",
3736 "evidence": {}
3737 },
3738 "verified": verified,
3739 "nonce": query.get("nonce").cloned().unwrap_or_default(),
3740 "model": query.get("model").cloned().unwrap_or_default(),
3741 "tee_provider": "phala",
3742 "signing_public_key": model_public_key,
3743 }))
3744 }
3745 }),
3746 );
3747 let listener = TcpListener::bind(("127.0.0.1", 0))
3748 .await
3749 .expect("mock attestation listener should bind");
3750 let addr = listener
3751 .local_addr()
3752 .expect("mock attestation listener should have local address");
3753
3754 tokio::spawn(async move {
3755 axum::serve(listener, app)
3756 .await
3757 .expect("mock attestation server should run");
3758 });
3759
3760 format!("http://{addr}/api/v1")
3761 }
3762
3763 #[test]
3764 fn metadata_header_helper_only_emits_safe_config_headers_by_default() {
3765 let config = ProxyConfig::default();
3766 let metadata = ProxyMetadataHeaders::from_config(&config);
3767 let mut headers = HeaderMap::new();
3768
3769 metadata.apply(&mut headers);
3770
3771 assert_eq!(
3772 headers.get(HEADER_PROXY_ATTESTATION_MODE).unwrap(),
3773 "independent"
3774 );
3775 assert_eq!(headers.get(HEADER_PROXY_TOOL_MODE).unwrap(), "emulated");
3776 assert!(headers.get(HEADER_PROXY_E2EE).is_none());
3777 assert!(headers.get(HEADER_PROXY_KEY_BINDING).is_none());
3778 }
3779}