Skip to main content

tt_shared/
messages.rs

1//! OpenAI-compatible request/response shapes. Canonical wire format across all
2//! providers — adapters translate to/from provider-native formats.
3
4use std::collections::HashMap;
5
6use serde::{Deserialize, Serialize};
7
8use crate::Usage;
9
10// ---------------------------------------------------------------------------
11// tt_extras cache-control types (Fix B / §2.7)
12// ---------------------------------------------------------------------------
13
14/// Cache behaviour requested by the caller via `tt_extras.cache`.
15///
16/// Absent (no `cache` key in `tt_extras`) is treated as [`CacheMode::Normal`].
17///
18/// JSON shape:
19/// ```json
20/// { "cache": { "mode": "bypass", "ttl_secs": 3600 } }
21/// ```
22#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
23#[serde(rename_all = "lowercase")]
24pub enum CacheMode {
25    /// Normal read-write caching (default when key absent).
26    #[default]
27    Normal,
28    /// Skip lookup AND insert — always hit the provider, never populate cache.
29    Bypass,
30    /// Skip lookup, but DO insert (force-refresh stale entry).
31    Refresh,
32    /// Do lookup, but never insert (read-only cache consumer).
33    #[serde(rename = "read-only")]
34    ReadOnly,
35}
36
37/// Typed cache-control extracted from `tt_extras`.
38#[derive(Debug, Clone, Default, Serialize, Deserialize)]
39pub struct CacheControlConfig {
40    /// Requested cache behaviour.
41    #[serde(default)]
42    pub mode: CacheMode,
43    /// Override TTL for cache inserts. `None` = use the gateway default.
44    #[serde(default, skip_serializing_if = "Option::is_none")]
45    pub ttl_secs: Option<u64>,
46}
47
48/// Parse [`CacheControlConfig`] from a request's `tt_extras` map.
49///
50/// Returns `None` when `tt_extras` does not contain a `"cache"` key.
51/// Returns the default config (normal mode, no TTL override) when the key is
52/// present but the value fails to deserialize — so a malformed field degrades
53/// gracefully rather than hard-failing.
54pub fn parse_cache_control(
55    extras: &HashMap<String, serde_json::Value>,
56) -> Option<CacheControlConfig> {
57    let val = extras.get("cache")?;
58    match serde_json::from_value::<CacheControlConfig>(val.clone()) {
59        Ok(cfg) => Some(cfg),
60        Err(e) => {
61            // Log at warn level so operators can see bad payloads; fall back
62            // to normal (don't block the request).
63            tracing::warn!(
64                error = %e,
65                "tt_extras.cache deserialization failed — treating as normal"
66            );
67            Some(CacheControlConfig::default())
68        }
69    }
70}
71
72#[cfg(test)]
73mod cache_control_tests {
74    use super::*;
75
76    fn extras(json: &str) -> HashMap<String, serde_json::Value> {
77        serde_json::from_str(json).unwrap()
78    }
79
80    #[test]
81    fn no_cache_key_returns_none() {
82        assert!(parse_cache_control(&extras("{}")).is_none());
83    }
84
85    #[test]
86    fn bypass_mode_parsed() {
87        let cfg = parse_cache_control(&extras(r#"{"cache":{"mode":"bypass"}}"#)).unwrap();
88        assert_eq!(cfg.mode, CacheMode::Bypass);
89        assert!(cfg.ttl_secs.is_none());
90    }
91
92    #[test]
93    fn refresh_mode_with_ttl() {
94        let cfg = parse_cache_control(&extras(r#"{"cache":{"mode":"refresh","ttl_secs":3600}}"#))
95            .unwrap();
96        assert_eq!(cfg.mode, CacheMode::Refresh);
97        assert_eq!(cfg.ttl_secs, Some(3600));
98    }
99
100    #[test]
101    fn read_only_mode() {
102        let cfg = parse_cache_control(&extras(r#"{"cache":{"mode":"read-only"}}"#)).unwrap();
103        assert_eq!(cfg.mode, CacheMode::ReadOnly);
104    }
105
106    #[test]
107    fn absent_mode_defaults_to_normal() {
108        let cfg = parse_cache_control(&extras(r#"{"cache":{}}"#)).unwrap();
109        assert_eq!(cfg.mode, CacheMode::Normal);
110    }
111
112    #[test]
113    fn malformed_value_falls_back_to_default() {
114        let cfg = parse_cache_control(&extras(r#"{"cache":"not-an-object"}"#)).unwrap();
115        assert_eq!(cfg.mode, CacheMode::Normal);
116    }
117}
118
119#[derive(Debug, Clone, Default, Serialize, Deserialize)]
120pub struct ChatCompletionRequest {
121    pub model: String,
122    pub messages: Vec<Message>,
123
124    #[serde(default, skip_serializing_if = "Option::is_none")]
125    pub temperature: Option<f32>,
126    #[serde(default, skip_serializing_if = "Option::is_none")]
127    pub top_p: Option<f32>,
128    #[serde(default, skip_serializing_if = "Option::is_none")]
129    pub max_tokens: Option<u32>,
130    /// Newer OpenAI spend cap for reasoning models (`o3`, `o4-mini`, …). When a
131    /// client sets this it MUST be honored end-to-end — dropping it silently
132    /// removes the caller's output ceiling and changes spend.
133    #[serde(default, skip_serializing_if = "Option::is_none")]
134    pub max_completion_tokens: Option<u32>,
135    #[serde(default, skip_serializing_if = "std::ops::Not::not")]
136    pub stream: bool,
137    /// OpenAI `stream_options` (e.g. `{ "include_usage": true }`). Kept as an
138    /// opaque value so the full object passes through to OpenAI-shaped upstreams
139    /// unchanged.
140    #[serde(default, skip_serializing_if = "Option::is_none")]
141    pub stream_options: Option<serde_json::Value>,
142    #[serde(default, skip_serializing_if = "Vec::is_empty")]
143    pub tools: Vec<Tool>,
144    #[serde(default, skip_serializing_if = "Option::is_none")]
145    pub tool_choice: Option<ToolChoice>,
146    #[serde(default, skip_serializing_if = "Option::is_none")]
147    pub response_format: Option<ResponseFormat>,
148    #[serde(default, skip_serializing_if = "Vec::is_empty")]
149    pub stop: Vec<String>,
150    #[serde(default, skip_serializing_if = "Option::is_none")]
151    pub presence_penalty: Option<f32>,
152    #[serde(default, skip_serializing_if = "Option::is_none")]
153    pub frequency_penalty: Option<f32>,
154    #[serde(default, skip_serializing_if = "Option::is_none")]
155    pub n: Option<u32>,
156    #[serde(default, skip_serializing_if = "Option::is_none")]
157    pub seed: Option<i64>,
158    #[serde(default, skip_serializing_if = "Option::is_none")]
159    pub user: Option<String>,
160    /// Whether the model may emit tool calls in parallel (OpenAI
161    /// `parallel_tool_calls`). Forwarded to OpenAI-shaped upstreams.
162    #[serde(default, skip_serializing_if = "Option::is_none")]
163    pub parallel_tool_calls: Option<bool>,
164    /// Reasoning-effort hint for reasoning models (`"low"`/`"medium"`/`"high"`).
165    /// Materially changes output and cost, so it must reach the upstream when
166    /// supported.
167    #[serde(default, skip_serializing_if = "Option::is_none")]
168    pub reasoning_effort: Option<String>,
169
170    /// TokenTrimmer-internal extras (cache config, route hints, etc.) that are
171    /// stripped before forwarding to the provider.
172    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
173    pub tt_extras: HashMap<String, serde_json::Value>,
174
175    /// Genuinely-unknown / newer OpenAI fields not modeled above. Captured via
176    /// `#[serde(flatten)]` so they passthrough to OpenAI-shaped upstreams
177    /// instead of being silently dropped on deserialize. Never includes the
178    /// named fields above (serde consumes those first).
179    #[serde(flatten, default, skip_serializing_if = "HashMap::is_empty")]
180    pub extra: HashMap<String, serde_json::Value>,
181}
182
183#[derive(Debug, Clone, Serialize, Deserialize)]
184#[serde(tag = "role", rename_all = "lowercase")]
185pub enum Message {
186    System {
187        content: MessageContent,
188    },
189    User {
190        content: MessageContent,
191        #[serde(default, skip_serializing_if = "Option::is_none")]
192        name: Option<String>,
193    },
194    Assistant {
195        #[serde(default, skip_serializing_if = "Option::is_none")]
196        content: Option<MessageContent>,
197        #[serde(default, skip_serializing_if = "Vec::is_empty")]
198        tool_calls: Vec<ToolCall>,
199        #[serde(default, skip_serializing_if = "Option::is_none")]
200        name: Option<String>,
201    },
202    Tool {
203        content: MessageContent,
204        tool_call_id: String,
205    },
206}
207
208#[derive(Debug, Clone, Serialize, Deserialize)]
209#[serde(untagged)]
210pub enum MessageContent {
211    Text(String),
212    Parts(Vec<ContentPart>),
213}
214
215#[derive(Debug, Clone, Serialize, Deserialize)]
216#[serde(tag = "type", rename_all = "snake_case")]
217pub enum ContentPart {
218    Text { text: String },
219    ImageUrl { image_url: ImageUrl },
220    InputAudio { input_audio: InputAudio },
221}
222
223#[derive(Debug, Clone, Serialize, Deserialize)]
224pub struct ImageUrl {
225    pub url: String,
226    #[serde(default, skip_serializing_if = "Option::is_none")]
227    pub detail: Option<String>,
228}
229
230#[derive(Debug, Clone, Serialize, Deserialize)]
231pub struct InputAudio {
232    pub data: String,
233    pub format: String,
234}
235
236#[derive(Debug, Clone, Serialize, Deserialize)]
237pub struct Tool {
238    #[serde(rename = "type")]
239    pub r#type: String,
240    pub function: ToolFunction,
241}
242
243#[derive(Debug, Clone, Serialize, Deserialize)]
244pub struct ToolFunction {
245    pub name: String,
246    #[serde(default, skip_serializing_if = "Option::is_none")]
247    pub description: Option<String>,
248    pub parameters: serde_json::Value,
249}
250
251#[derive(Debug, Clone, Serialize, Deserialize)]
252#[serde(untagged)]
253pub enum ToolChoice {
254    Auto(String),
255    Specific {
256        #[serde(rename = "type")]
257        r#type: String,
258        function: ToolChoiceFunction,
259    },
260}
261
262impl ToolChoice {
263    /// Let the model decide whether to call a tool (`"auto"`).
264    #[must_use]
265    pub fn auto() -> Self {
266        ToolChoice::Auto("auto".to_string())
267    }
268
269    /// Forbid tool calls — force a plain text answer (`"none"`).
270    #[must_use]
271    pub fn none() -> Self {
272        ToolChoice::Auto("none".to_string())
273    }
274
275    /// Require the model to call some tool (`"required"`).
276    #[must_use]
277    pub fn required() -> Self {
278        ToolChoice::Auto("required".to_string())
279    }
280
281    /// Require the model to call a specific named function.
282    #[must_use]
283    pub fn function(name: impl Into<String>) -> Self {
284        ToolChoice::Specific {
285            r#type: "function".to_string(),
286            function: ToolChoiceFunction { name: name.into() },
287        }
288    }
289}
290
291#[derive(Debug, Clone, Serialize, Deserialize)]
292pub struct ToolChoiceFunction {
293    pub name: String,
294}
295
296#[derive(Debug, Clone, Serialize, Deserialize)]
297pub struct ToolCall {
298    pub id: String,
299    #[serde(rename = "type")]
300    pub r#type: String,
301    pub function: ToolCallFunction,
302}
303
304#[derive(Debug, Clone, Serialize, Deserialize)]
305pub struct ToolCallFunction {
306    pub name: String,
307    /// Stringified JSON arguments — OpenAI convention.
308    pub arguments: String,
309}
310
311#[derive(Debug, Clone, Serialize, Deserialize)]
312pub struct ResponseFormat {
313    #[serde(rename = "type")]
314    pub r#type: String,
315    #[serde(default, skip_serializing_if = "Option::is_none")]
316    pub json_schema: Option<serde_json::Value>,
317}
318
319#[derive(Debug, Clone, Serialize, Deserialize)]
320pub struct ChatCompletionResponse {
321    pub id: String,
322    pub object: String,
323    pub created: i64,
324    pub model: String,
325    pub choices: Vec<Choice>,
326    pub usage: Usage,
327}
328
329#[derive(Debug, Clone, Serialize, Deserialize)]
330pub struct Choice {
331    pub index: u32,
332    pub message: Message,
333    pub finish_reason: Option<String>,
334}
335
336/// One SSE event from a streaming chat completion.
337#[derive(Debug, Clone, Serialize, Deserialize)]
338pub struct ChatCompletionChunk {
339    pub id: String,
340    pub object: String,
341    pub created: i64,
342    pub model: String,
343    pub choices: Vec<ChunkChoice>,
344    #[serde(default, skip_serializing_if = "Option::is_none")]
345    pub usage: Option<Usage>,
346    /// Genuinely-unknown / newer provider chunk fields (e.g.
347    /// `system_fingerprint`, `service_tier`). Captured via `#[serde(flatten)]`
348    /// so upstream SSE chunks round-trip through the gateway unchanged rather
349    /// than being silently dropped.
350    #[serde(flatten, default, skip_serializing_if = "HashMap::is_empty")]
351    pub extra: HashMap<String, serde_json::Value>,
352}
353
354#[derive(Debug, Clone, Serialize, Deserialize)]
355pub struct ChunkChoice {
356    pub index: u32,
357    pub delta: ChunkDelta,
358    pub finish_reason: Option<String>,
359    /// Unknown per-choice fields (e.g. `logprobs`) preserved verbatim.
360    #[serde(flatten, default, skip_serializing_if = "HashMap::is_empty")]
361    pub extra: HashMap<String, serde_json::Value>,
362}
363
364#[derive(Debug, Clone, Default, Serialize, Deserialize)]
365pub struct ChunkDelta {
366    #[serde(default, skip_serializing_if = "Option::is_none")]
367    pub role: Option<String>,
368    #[serde(default, skip_serializing_if = "Option::is_none")]
369    pub content: Option<String>,
370    #[serde(default, skip_serializing_if = "Vec::is_empty")]
371    pub tool_calls: Vec<ToolCall>,
372    /// Unknown per-delta fields (e.g. `refusal`) preserved verbatim.
373    #[serde(flatten, default, skip_serializing_if = "HashMap::is_empty")]
374    pub extra: HashMap<String, serde_json::Value>,
375}
376
377#[derive(Debug, Clone, Serialize, Deserialize)]
378pub struct EmbeddingsRequest {
379    pub model: String,
380    pub input: EmbeddingInput,
381    #[serde(default, skip_serializing_if = "Option::is_none")]
382    pub dimensions: Option<u32>,
383    #[serde(default, skip_serializing_if = "Option::is_none")]
384    pub encoding_format: Option<String>,
385}
386
387#[derive(Debug, Clone, Serialize, Deserialize)]
388#[serde(untagged)]
389pub enum EmbeddingInput {
390    Single(String),
391    Batch(Vec<String>),
392}
393
394#[derive(Debug, Clone, Serialize, Deserialize)]
395pub struct EmbeddingsResponse {
396    pub object: String,
397    pub data: Vec<EmbeddingData>,
398    pub model: String,
399    pub usage: Usage,
400}
401
402#[derive(Debug, Clone, Serialize, Deserialize)]
403pub struct EmbeddingData {
404    pub object: String,
405    pub index: u32,
406    pub embedding: Vec<f32>,
407}
408
409/// Parse a base64 `data:` URL into `(media_type, base64_payload)`.
410///
411/// Returns `None` for non-`data:` URLs, non-base64 data URLs, or a malformed/
412/// empty media type. Provider adapters use this to forward inline image bytes
413/// as the provider's native base64 image part instead of mistakenly sending the
414/// whole `data:` URI as a *remote* URL reference (which the upstream rejects).
415#[must_use]
416pub fn parse_data_url(url: &str) -> Option<(String, String)> {
417    let rest = url.strip_prefix("data:")?;
418    let (meta, data) = rest.split_once(',')?;
419    // Only base64 payloads are supported (the canonical image transport).
420    let media_with_params = meta.strip_suffix(";base64")?;
421    // Drop any RFC-2397 media-type parameters (e.g. `;charset=utf-8`) — providers
422    // expect a bare MIME type like `image/png` in the base64 image part.
423    let media_type = media_with_params.split(';').next().unwrap_or("");
424    if media_type.is_empty() || data.is_empty() {
425        return None;
426    }
427    Some((media_type.to_string(), data.to_string()))
428}
429
430#[cfg(test)]
431mod embeddings_default_tests {
432    use super::*;
433
434    #[test]
435    fn chat_request_default_is_empty() {
436        let r = ChatCompletionRequest::default();
437        assert_eq!(r.model, "");
438        assert!(r.messages.is_empty());
439        assert!(!r.stream);
440        assert!(r.tools.is_empty());
441        assert!(r.max_tokens.is_none());
442    }
443
444    #[test]
445    fn typed_compat_fields_roundtrip() {
446        let json = serde_json::json!({
447            "model": "o3",
448            "messages": [{ "role": "user", "content": "hi" }],
449            "max_completion_tokens": 4096,
450            "stream_options": { "include_usage": true },
451            "parallel_tool_calls": false,
452            "reasoning_effort": "high",
453        });
454        let req: ChatCompletionRequest = serde_json::from_value(json).unwrap();
455        assert_eq!(req.max_completion_tokens, Some(4096));
456        assert_eq!(req.parallel_tool_calls, Some(false));
457        assert_eq!(req.reasoning_effort.as_deref(), Some("high"));
458        assert_eq!(
459            req.stream_options,
460            Some(serde_json::json!({ "include_usage": true }))
461        );
462        // The flatten map must NOT capture the named fields.
463        assert!(req.extra.is_empty());
464
465        let out = serde_json::to_value(&req).unwrap();
466        assert_eq!(out["max_completion_tokens"], 4096);
467        assert_eq!(
468            out["stream_options"],
469            serde_json::json!({"include_usage": true})
470        );
471        assert_eq!(out["parallel_tool_calls"], false);
472        assert_eq!(out["reasoning_effort"], "high");
473    }
474
475    #[test]
476    fn unknown_fields_passthrough_via_flatten() {
477        // A genuinely-unknown / newer OpenAI field must survive deserialize and
478        // re-serialize verbatim rather than being silently dropped.
479        let json = serde_json::json!({
480            "model": "gpt-4o",
481            "messages": [{ "role": "user", "content": "hi" }],
482            "logprobs": true,
483            "top_logprobs": 5,
484            "service_tier": "auto",
485        });
486        let req: ChatCompletionRequest = serde_json::from_value(json.clone()).unwrap();
487        assert_eq!(req.extra.get("logprobs"), Some(&serde_json::json!(true)));
488        assert_eq!(req.extra.get("top_logprobs"), Some(&serde_json::json!(5)));
489        assert_eq!(
490            req.extra.get("service_tier"),
491            Some(&serde_json::json!("auto"))
492        );
493
494        let out = serde_json::to_value(&req).unwrap();
495        assert_eq!(out["logprobs"], true);
496        assert_eq!(out["top_logprobs"], 5);
497        assert_eq!(out["service_tier"], "auto");
498    }
499
500    #[test]
501    fn streaming_chunk_unknown_fields_passthrough() {
502        // Unknown / newer provider fields on a streaming chunk (and on its nested
503        // choice/delta) must survive deserialize and re-serialize verbatim rather
504        // than being silently dropped on the round-trip passthrough.
505        let json = serde_json::json!({
506            "id": "chatcmpl-1",
507            "object": "chat.completion.chunk",
508            "created": 1716598234,
509            "model": "gpt-4o",
510            "system_fingerprint": "fp_abc123",
511            "choices": [{
512                "index": 0,
513                "delta": { "content": "hi", "refusal": null },
514                "finish_reason": null,
515                "logprobs": { "content": [] }
516            }]
517        });
518        let chunk: ChatCompletionChunk = serde_json::from_value(json).unwrap();
519        assert_eq!(
520            chunk.extra.get("system_fingerprint"),
521            Some(&serde_json::json!("fp_abc123"))
522        );
523        assert_eq!(
524            chunk.choices[0].extra.get("logprobs"),
525            Some(&serde_json::json!({ "content": [] }))
526        );
527        assert_eq!(
528            chunk.choices[0].delta.extra.get("refusal"),
529            Some(&serde_json::Value::Null)
530        );
531
532        let out = serde_json::to_value(&chunk).unwrap();
533        assert_eq!(out["system_fingerprint"], "fp_abc123");
534        assert_eq!(
535            out["choices"][0]["logprobs"],
536            serde_json::json!({ "content": [] })
537        );
538        assert_eq!(
539            out["choices"][0]["delta"]["refusal"],
540            serde_json::Value::Null
541        );
542    }
543
544    #[test]
545    fn parse_data_url_extracts_media_type_and_payload() {
546        assert_eq!(
547            parse_data_url("data:image/png;base64,iVBORw0KGgo="),
548            Some(("image/png".to_string(), "iVBORw0KGgo=".to_string()))
549        );
550        // Non-data URLs and non-base64 / malformed data URLs return None.
551        assert_eq!(parse_data_url("https://example.com/cat.png"), None);
552        assert_eq!(parse_data_url("data:image/png,notbase64"), None);
553        assert_eq!(parse_data_url("data:;base64,abc"), None);
554        assert_eq!(parse_data_url("data:image/png;base64,"), None);
555        // Media-type parameters are stripped to a bare MIME type.
556        assert_eq!(
557            parse_data_url("data:image/png;charset=utf-8;base64,iVBORw0KGgo="),
558            Some(("image/png".to_string(), "iVBORw0KGgo=".to_string()))
559        );
560    }
561
562    #[test]
563    fn tool_choice_constructors_serialize_to_the_wire_form() {
564        // The string variants stay an untagged bare string …
565        assert_eq!(
566            serde_json::to_value(ToolChoice::auto()).unwrap(),
567            serde_json::json!("auto")
568        );
569        assert_eq!(
570            serde_json::to_value(ToolChoice::none()).unwrap(),
571            serde_json::json!("none")
572        );
573        assert_eq!(
574            serde_json::to_value(ToolChoice::required()).unwrap(),
575            serde_json::json!("required")
576        );
577        // … and `function(name)` produces the object form.
578        assert_eq!(
579            serde_json::to_value(ToolChoice::function("get_weather")).unwrap(),
580            serde_json::json!({ "type": "function", "function": { "name": "get_weather" } })
581        );
582    }
583}