1use std::{fmt, sync::Arc, time::Duration};
7
8use reqwest::{
9 Url,
10 header::{ACCEPT, CONTENT_TYPE},
11};
12use secrecy::ExposeSecret;
13use serde::Deserialize;
14use serde_json::Value;
15use thiserror::Error;
16
17use crate::{
18 config::{ConfigError, ProxyConfig},
19 openai::{
20 ModelCapabilities, ModelListResponse, ModelObject, VeniceModelMetadata,
21 chat::VeniceE2eeChatRequest,
22 },
23};
24
25pub const HEADER_VENICE_TEE_CLIENT_PUB_KEY: &str = "X-Venice-TEE-Client-Pub-Key";
26pub const HEADER_VENICE_TEE_MODEL_PUB_KEY: &str = "X-Venice-TEE-Model-Pub-Key";
27pub const HEADER_VENICE_TEE_SIGNING_ALGO: &str = "X-Venice-TEE-Signing-Algo";
28
29#[derive(Clone)]
31pub struct VeniceClient {
32 http: reqwest::Client,
33 base_url: Url,
34 api_key: Arc<str>,
35 request_timeout: Duration,
36}
37
38impl VeniceClient {
39 pub fn from_config(config: &ProxyConfig) -> Result<Self, VeniceClientError> {
41 let api_key = config.venice_api_key()?;
42 Self::new(
43 &config.venice.base_url,
44 api_key.expose_secret(),
45 config.venice.request_timeout,
46 )
47 }
48
49 pub fn new(
51 base_url: impl AsRef<str>,
52 api_key: impl Into<String>,
53 timeout: Duration,
54 ) -> Result<Self, VeniceClientError> {
55 let base_url = parse_base_url(base_url.as_ref())?;
56 let http = reqwest::Client::builder()
57 .connect_timeout(timeout)
58 .read_timeout(timeout)
59 .build()
60 .map_err(VeniceClientError::client_build)?;
61
62 Ok(Self {
63 http,
64 base_url,
65 api_key: Arc::from(api_key.into()),
66 request_timeout: timeout,
67 })
68 }
69
70 pub async fn list_models(&self) -> Result<ModelListResponse, VeniceClientError> {
72 let url = self.models_url()?;
73 let response = self
74 .http
75 .get(url)
76 .bearer_auth(self.api_key.as_ref())
77 .header(ACCEPT, "application/json")
78 .timeout(self.request_timeout)
79 .send()
80 .await
81 .map_err(VeniceClientError::request_failure)?;
82
83 let response = Self::check_status(response)?;
84
85 let body = response
86 .bytes()
87 .await
88 .map_err(VeniceClientError::request_failure)?;
89 parse_model_list_response(&body)
90 }
91
92 pub async fn create_chat_completion_stream(
94 &self,
95 request: &VeniceE2eeChatRequest,
96 client_public_key_hex: &str,
97 model_public_key_hex: &str,
98 ) -> Result<reqwest::Response, VeniceClientError> {
99 let url = self.chat_completions_url()?;
100 let response = self
101 .http
102 .post(url)
103 .bearer_auth(self.api_key.as_ref())
104 .header(ACCEPT, "text/event-stream")
105 .header(CONTENT_TYPE, "application/json")
106 .header(HEADER_VENICE_TEE_CLIENT_PUB_KEY, client_public_key_hex)
107 .header(HEADER_VENICE_TEE_MODEL_PUB_KEY, model_public_key_hex)
108 .header(HEADER_VENICE_TEE_SIGNING_ALGO, "ecdsa")
109 .json(request)
110 .send()
111 .await
112 .map_err(VeniceClientError::request_failure)?;
113
114 Self::check_status(response)
115 }
116
117 pub async fn fetch_attestation_evidence(
119 &self,
120 model_id: &str,
121 nonce: &str,
122 ) -> Result<Value, VeniceClientError> {
123 let url = self.attestation_url(model_id, nonce)?;
124 let response = self
125 .http
126 .get(url)
127 .bearer_auth(self.api_key.as_ref())
128 .header(ACCEPT, "application/json")
129 .timeout(self.request_timeout)
130 .send()
131 .await
132 .map_err(VeniceClientError::request_failure)?;
133
134 let response = Self::check_status(response)?;
135
136 response
137 .json::<Value>()
138 .await
139 .map_err(VeniceClientError::malformed_attestation_payload)
140 }
141
142 fn check_status(response: reqwest::Response) -> Result<reqwest::Response, VeniceClientError> {
144 let status = response.status();
145
146 if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
147 return Err(VeniceClientError::Authentication {
148 status: status.as_u16(),
149 });
150 }
151
152 if !status.is_success() {
153 return Err(VeniceClientError::UpstreamStatus {
154 status: status.as_u16(),
155 });
156 }
157
158 Ok(response)
159 }
160
161 fn models_url(&self) -> Result<Url, VeniceClientError> {
163 self.endpoint_url("models")
164 }
165
166 fn chat_completions_url(&self) -> Result<Url, VeniceClientError> {
168 self.endpoint_url("chat/completions")
169 }
170
171 fn attestation_url(&self, model_id: &str, nonce: &str) -> Result<Url, VeniceClientError> {
173 let mut url = self.endpoint_url("tee/attestation")?;
174 url.query_pairs_mut()
175 .append_pair("model", model_id)
176 .append_pair("nonce", nonce);
177
178 Ok(url)
179 }
180
181 fn endpoint_url(&self, path: &str) -> Result<Url, VeniceClientError> {
183 self.base_url
184 .join(path)
185 .map_err(|source| VeniceClientError::EndpointUrl {
186 message: source.to_string(),
187 })
188 }
189}
190
191impl fmt::Debug for VeniceClient {
192 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
194 f.debug_struct("VeniceClient")
195 .field("base_url", &self.base_url)
196 .field("api_key", &"[redacted]")
197 .finish_non_exhaustive()
198 }
199}
200
201fn parse_base_url(value: &str) -> Result<Url, VeniceClientError> {
203 let mut url = Url::parse(value).map_err(|source| VeniceClientError::InvalidBaseUrl {
204 base_url: value.to_owned(),
205 message: source.to_string(),
206 })?;
207
208 if !url.path().ends_with('/') {
209 let path = format!("{}/", url.path());
210 url.set_path(&path);
211 }
212
213 Ok(url)
214}
215
216fn parse_model_list_response(body: &[u8]) -> Result<ModelListResponse, VeniceClientError> {
218 let payload: VeniceModelListPayload =
219 serde_json::from_slice(body).map_err(VeniceClientError::malformed_payload)?;
220 Ok(payload.into_openai_model_list())
221}
222
223#[derive(Debug, Error)]
225pub enum VeniceClientError {
226 #[error(transparent)]
227 Config(#[from] ConfigError),
228 #[error("invalid Venice base URL {base_url}: {message}")]
229 InvalidBaseUrl { base_url: String, message: String },
230 #[error("failed to build Venice HTTP client: {message}")]
231 ClientBuild { message: String },
232 #[error("failed to build Venice models URL: {message}")]
233 EndpointUrl { message: String },
234 #[error("Venice upstream authentication failed with status {status}")]
235 Authentication { status: u16 },
236 #[error("Venice upstream returned status {status}")]
237 UpstreamStatus { status: u16 },
238 #[error("Venice upstream request timed out")]
239 Timeout,
240 #[error("Venice upstream request failed: {message}")]
241 Request { message: String },
242 #[error("Venice upstream returned malformed model payload: {message}")]
243 MalformedPayload { message: String },
244 #[error("Venice upstream returned malformed attestation payload: {message}")]
245 MalformedAttestationPayload { message: String },
246}
247
248impl VeniceClientError {
249 pub fn api_error_type(&self) -> &'static str {
251 match self {
252 Self::Config(_)
253 | Self::InvalidBaseUrl { .. }
254 | Self::ClientBuild { .. }
255 | Self::EndpointUrl { .. } => "proxy_configuration_error",
256 Self::Authentication { .. } => "proxy_upstream_authentication_error",
257 Self::UpstreamStatus { .. }
258 | Self::Timeout
259 | Self::Request { .. }
260 | Self::MalformedPayload { .. }
261 | Self::MalformedAttestationPayload { .. } => "proxy_upstream_error",
262 }
263 }
264
265 pub fn api_error_code(&self) -> &'static str {
267 match self {
268 Self::Config(ConfigError::MissingApiKey) => "venice_api_key_missing",
269 Self::Config(_)
270 | Self::InvalidBaseUrl { .. }
271 | Self::ClientBuild { .. }
272 | Self::EndpointUrl { .. } => "venice_client_configuration_failed",
273 Self::Authentication { .. } => "upstream_authentication_failed",
274 Self::UpstreamStatus { .. } => "upstream_status_error",
275 Self::Timeout => "upstream_timeout",
276 Self::Request { .. } => "upstream_unavailable",
277 Self::MalformedPayload { .. } | Self::MalformedAttestationPayload { .. } => {
278 "upstream_malformed_response"
279 }
280 }
281 }
282
283 fn client_build(source: reqwest::Error) -> Self {
285 Self::ClientBuild {
286 message: source.to_string(),
287 }
288 }
289
290 fn request_failure(source: reqwest::Error) -> Self {
292 if source.is_timeout() {
293 Self::Timeout
294 } else {
295 Self::Request {
296 message: source.to_string(),
297 }
298 }
299 }
300
301 fn malformed_payload(source: serde_json::Error) -> Self {
303 Self::MalformedPayload {
304 message: source.to_string(),
305 }
306 }
307
308 fn malformed_attestation_payload(source: reqwest::Error) -> Self {
310 Self::MalformedAttestationPayload {
311 message: source.to_string(),
312 }
313 }
314}
315
316#[derive(Debug, Deserialize)]
318struct VeniceModelListPayload {
319 data: Vec<VeniceModel>,
320}
321
322impl VeniceModelListPayload {
323 fn into_openai_model_list(self) -> ModelListResponse {
325 let data = self
326 .data
327 .into_iter()
328 .filter_map(VeniceModel::into_openai_model_if_supported)
329 .collect();
330
331 ModelListResponse::new(data)
332 }
333}
334
335#[derive(Debug, Deserialize)]
337struct VeniceModel {
338 id: String,
339 #[serde(default)]
340 created: Option<i64>,
341 #[serde(default)]
342 owned_by: Option<String>,
343 #[serde(rename = "type")]
344 model_type: String,
345 model_spec: VeniceModelSpec,
346}
347
348impl VeniceModel {
349 fn into_openai_model_if_supported(self) -> Option<ModelObject> {
351 let capabilities = self.model_spec.capabilities;
352 if self.model_type != "text"
353 || !capabilities.supports_e2ee
354 || !capabilities.supports_tee_attestation
355 {
356 return None;
357 }
358
359 let venice = VeniceModelMetadata::new(
360 self.id.clone(),
361 capabilities.supports_e2ee,
362 capabilities.supports_tee_attestation,
363 capabilities.supports_reasoning.unwrap_or(false),
364 capabilities.supports_reasoning_effort.unwrap_or(false),
365 );
366 let openai_capabilities = capabilities.to_openai_capabilities();
367
368 Some(ModelObject::new(
369 self.id,
370 self.created.unwrap_or(0),
371 self.owned_by.unwrap_or_else(|| "venice.ai".to_owned()),
372 openai_capabilities,
373 venice,
374 ))
375 }
376}
377
378#[derive(Debug, Deserialize)]
380struct VeniceModelSpec {
381 capabilities: VeniceCapabilities,
382}
383
384#[derive(Debug, Deserialize)]
386struct VeniceCapabilities {
387 #[serde(rename = "supportsE2EE")]
388 supports_e2ee: bool,
389 #[serde(rename = "supportsTeeAttestation")]
390 supports_tee_attestation: bool,
391 #[serde(default, rename = "supportsFunctionCalling")]
392 supports_function_calling: Option<bool>,
393 #[serde(default, rename = "supportsBuiltinTools")]
394 supports_builtin_tools: Option<bool>,
395 #[serde(default, rename = "supportsWebSearch")]
396 supports_web_search: Option<bool>,
397 #[serde(default, rename = "supportsCodeInterpreter")]
398 supports_code_interpreter: Option<bool>,
399 #[serde(default, rename = "supportsVision")]
400 supports_vision: Option<bool>,
401 #[serde(default, rename = "supportsReasoning")]
402 supports_reasoning: Option<bool>,
403 #[serde(default, rename = "supportsReasoningEffort")]
404 supports_reasoning_effort: Option<bool>,
405}
406
407impl VeniceCapabilities {
408 fn to_openai_capabilities(&self) -> ModelCapabilities {
410 let web_search = self.supports_web_search.unwrap_or(false);
411 let code_interpreter = self.supports_code_interpreter.unwrap_or(false);
412 let builtin_tools = self
413 .supports_builtin_tools
414 .unwrap_or(web_search || code_interpreter);
415
416 ModelCapabilities {
417 function_calling: self.supports_function_calling.unwrap_or(false),
418 builtin_tools,
419 web_search,
420 code_interpreter,
421 vision: self.supports_vision.unwrap_or(false),
422 reasoning: self.supports_reasoning.unwrap_or(false),
423 reasoning_effort: self.supports_reasoning_effort.unwrap_or(false),
424 }
425 }
426}
427
428#[cfg(test)]
429mod tests {
430 use super::*;
431
432 use axum::{Router, body::Body, response::IntoResponse, routing::post};
433 use tokio::net::TcpListener;
434
435 const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(10);
436
437 #[test]
438 fn maps_supported_venice_text_models_to_openai_shape() {
439 let body = br#"
440 {
441 "data": [
442 {
443 "id": "e2ee-qwen3-5-122b-a10b",
444 "created": 1727966436,
445 "owned_by": "venice.ai",
446 "type": "text",
447 "model_spec": {
448 "capabilities": {
449 "supportsE2EE": true,
450 "supportsTeeAttestation": true,
451 "supportsFunctionCalling": true,
452 "supportsBuiltinTools": true,
453 "supportsWebSearch": true,
454 "supportsCodeInterpreter": true,
455 "supportsVision": false,
456 "supportsReasoning": true,
457 "supportsReasoningEffort": true
458 }
459 }
460 },
461 {
462 "id": "non-e2ee-text",
463 "type": "text",
464 "model_spec": {
465 "capabilities": {
466 "supportsE2EE": false,
467 "supportsTeeAttestation": true
468 }
469 }
470 },
471 {
472 "id": "e2ee-image",
473 "type": "image",
474 "model_spec": {
475 "capabilities": {
476 "supportsE2EE": true,
477 "supportsTeeAttestation": true
478 }
479 }
480 }
481 ]
482 }
483 "#;
484
485 let response = parse_model_list_response(body).expect("valid model payload should parse");
486
487 assert_eq!(response.object, "list");
488 assert_eq!(response.data.len(), 1);
489 let model = &response.data[0];
490 assert_eq!(model.id, "e2ee-qwen3-5-122b-a10b");
491 assert_eq!(model.object, "model");
492 assert_eq!(model.created, 1727966436);
493 assert_eq!(model.owned_by, "venice.ai");
494 assert_eq!(model.name, "e2ee-qwen3-5-122b-a10b");
495 assert!(model.info.meta.capabilities.function_calling);
496 assert!(model.info.meta.capabilities.builtin_tools);
497 assert!(model.info.meta.capabilities.web_search);
498 assert!(model.info.meta.capabilities.code_interpreter);
499 assert!(!model.info.meta.capabilities.vision);
500 assert!(model.info.meta.capabilities.reasoning);
501 assert!(model.info.meta.capabilities.reasoning_effort);
502 assert_eq!(model.venice.id, "e2ee-qwen3-5-122b-a10b");
503 assert!(model.venice.supports_e2ee);
504 assert!(model.venice.supports_tee_attestation);
505 assert!(model.venice.supports_reasoning);
506 assert!(model.venice.supports_reasoning_effort);
507 }
508
509 #[test]
510 fn missing_optional_capability_metadata_defaults_to_false() {
511 let body = br#"
512 {
513 "data": [
514 {
515 "id": "e2ee-minimal",
516 "type": "text",
517 "model_spec": {
518 "capabilities": {
519 "supportsE2EE": true,
520 "supportsTeeAttestation": true
521 }
522 }
523 }
524 ]
525 }
526 "#;
527
528 let response =
529 parse_model_list_response(body).expect("minimal capability payload should parse");
530 let model = response
531 .data
532 .first()
533 .expect("supported model should be present");
534
535 assert_eq!(model.created, 0);
536 assert_eq!(model.owned_by, "venice.ai");
537 assert_eq!(model.info.meta.capabilities, ModelCapabilities::default());
538 }
539
540 #[test]
541 fn malformed_model_payload_is_reported() {
542 let body = br#"
543 {
544 "data": [
545 {
546 "id": "missing-required-attestation-flag",
547 "type": "text",
548 "model_spec": {
549 "capabilities": {
550 "supportsE2EE": true
551 }
552 }
553 }
554 ]
555 }
556 "#;
557
558 let error = parse_model_list_response(body).expect_err("malformed payload should fail");
559
560 assert!(matches!(error, VeniceClientError::MalformedPayload { .. }));
561 assert_eq!(error.api_error_code(), "upstream_malformed_response");
562 }
563
564 #[test]
565 fn client_debug_output_redacts_api_key() {
566 let client = VeniceClient::new(
567 "https://api.venice.ai/api/v1",
568 "super-secret-test-key",
569 DEFAULT_REQUEST_TIMEOUT,
570 )
571 .expect("client should build");
572
573 let debug = format!("{client:?}");
574 assert!(debug.contains("api.venice.ai"));
575 assert!(debug.contains("/api/v1/"));
576 assert!(debug.contains("[redacted]"));
577 assert!(!debug.contains("super-secret-test-key"));
578 }
579
580 #[tokio::test]
581 async fn chat_stream_can_outlive_configured_request_timeout_when_chunks_keep_arriving() {
582 async fn slow_streaming_chat() -> impl IntoResponse {
583 let stream = async_stream::stream! {
584 for index in 0..5 {
585 tokio::time::sleep(Duration::from_millis(20)).await;
586 yield Ok::<_, std::io::Error>(format!("data: {index}\n\n"));
587 }
588 yield Ok::<_, std::io::Error>("data: [DONE]\n\n".to_owned());
589 };
590
591 (
592 [
593 ("content-type", "text/event-stream"),
594 ("cache-control", "no-cache"),
595 ],
596 Body::from_stream(stream),
597 )
598 }
599
600 let app = Router::new().route("/api/v1/chat/completions", post(slow_streaming_chat));
601 let listener = TcpListener::bind(("127.0.0.1", 0))
602 .await
603 .expect("mock listener should bind");
604 let addr = listener.local_addr().expect("listener should have address");
605 tokio::spawn(async move {
606 axum::serve(listener, app)
607 .await
608 .expect("mock server should run");
609 });
610
611 let client = VeniceClient::new(
612 format!("http://{addr}/api/v1"),
613 "test-api-key",
614 Duration::from_millis(50),
615 )
616 .expect("client should build");
617 let request = VeniceE2eeChatRequest {
618 model: "e2ee-test".to_owned(),
619 messages: Vec::new(),
620 stream: true,
621 stream_options: crate::openai::chat::VeniceStreamOptions {
622 include_usage: false,
623 },
624 venice_parameters: crate::openai::chat::VeniceParameters::default(),
625 temperature: None,
626 top_p: None,
627 max_tokens: None,
628 max_completion_tokens: None,
629 stop: None,
630 reasoning: None,
631 reasoning_effort: None,
632 };
633
634 let mut response = client
635 .create_chat_completion_stream(&request, "client-key", "model-key")
636 .await
637 .expect("stream response headers should arrive before timeout");
638 let mut chunks = 0;
639 while let Some(_chunk) = response
640 .chunk()
641 .await
642 .expect("frequent stream chunks should not hit total timeout")
643 {
644 chunks += 1;
645 }
646
647 assert!(chunks > 1);
648 }
649}