Skip to main content

tt_shared/
messages.rs

1//! OpenAI-compatible request/response shapes. Canonical wire format across all
2//! providers — adapters translate to/from provider-native formats.
3
4use std::collections::HashMap;
5
6use serde::{Deserialize, Serialize};
7
8use crate::Usage;
9
10// ---------------------------------------------------------------------------
11// tt_extras cache-control types (Fix B / §2.7)
12// ---------------------------------------------------------------------------
13
14/// Cache behaviour requested by the caller via `tt_extras.cache`.
15///
16/// Absent (no `cache` key in `tt_extras`) is treated as [`CacheMode::Normal`].
17///
18/// JSON shape:
19/// ```json
20/// { "cache": { "mode": "bypass", "ttl_secs": 3600 } }
21/// ```
22#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
23#[serde(rename_all = "lowercase")]
24pub enum CacheMode {
25    /// Normal read-write caching (default when key absent).
26    #[default]
27    Normal,
28    /// Skip lookup AND insert — always hit the provider, never populate cache.
29    Bypass,
30    /// Skip lookup, but DO insert (force-refresh stale entry).
31    Refresh,
32    /// Do lookup, but never insert (read-only cache consumer).
33    #[serde(rename = "read-only")]
34    ReadOnly,
35}
36
37/// Typed cache-control extracted from `tt_extras`.
38#[derive(Debug, Clone, Default, Serialize, Deserialize)]
39pub struct CacheControlConfig {
40    /// Requested cache behaviour.
41    #[serde(default)]
42    pub mode: CacheMode,
43    /// Override TTL for cache inserts. `None` = use the gateway default.
44    #[serde(default, skip_serializing_if = "Option::is_none")]
45    pub ttl_secs: Option<u64>,
46}
47
48/// Parse [`CacheControlConfig`] from a request's `tt_extras` map.
49///
50/// Returns `None` when `tt_extras` does not contain a `"cache"` key.
51/// Returns the default config (normal mode, no TTL override) when the key is
52/// present but the value fails to deserialize — so a malformed field degrades
53/// gracefully rather than hard-failing.
54pub fn parse_cache_control(
55    extras: &HashMap<String, serde_json::Value>,
56) -> Option<CacheControlConfig> {
57    let val = extras.get("cache")?;
58    match serde_json::from_value::<CacheControlConfig>(val.clone()) {
59        Ok(cfg) => Some(cfg),
60        Err(e) => {
61            // Log at warn level so operators can see bad payloads; fall back
62            // to normal (don't block the request).
63            tracing::warn!(
64                error = %e,
65                "tt_extras.cache deserialization failed — treating as normal"
66            );
67            Some(CacheControlConfig::default())
68        }
69    }
70}
71
72#[cfg(test)]
73mod cache_control_tests {
74    use super::*;
75
76    fn extras(json: &str) -> HashMap<String, serde_json::Value> {
77        serde_json::from_str(json).unwrap()
78    }
79
80    #[test]
81    fn no_cache_key_returns_none() {
82        assert!(parse_cache_control(&extras("{}")).is_none());
83    }
84
85    #[test]
86    fn bypass_mode_parsed() {
87        let cfg = parse_cache_control(&extras(r#"{"cache":{"mode":"bypass"}}"#)).unwrap();
88        assert_eq!(cfg.mode, CacheMode::Bypass);
89        assert!(cfg.ttl_secs.is_none());
90    }
91
92    #[test]
93    fn refresh_mode_with_ttl() {
94        let cfg = parse_cache_control(&extras(r#"{"cache":{"mode":"refresh","ttl_secs":3600}}"#))
95            .unwrap();
96        assert_eq!(cfg.mode, CacheMode::Refresh);
97        assert_eq!(cfg.ttl_secs, Some(3600));
98    }
99
100    #[test]
101    fn read_only_mode() {
102        let cfg = parse_cache_control(&extras(r#"{"cache":{"mode":"read-only"}}"#)).unwrap();
103        assert_eq!(cfg.mode, CacheMode::ReadOnly);
104    }
105
106    #[test]
107    fn absent_mode_defaults_to_normal() {
108        let cfg = parse_cache_control(&extras(r#"{"cache":{}}"#)).unwrap();
109        assert_eq!(cfg.mode, CacheMode::Normal);
110    }
111
112    #[test]
113    fn malformed_value_falls_back_to_default() {
114        let cfg = parse_cache_control(&extras(r#"{"cache":"not-an-object"}"#)).unwrap();
115        assert_eq!(cfg.mode, CacheMode::Normal);
116    }
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct ChatCompletionRequest {
121    pub model: String,
122    pub messages: Vec<Message>,
123
124    #[serde(default, skip_serializing_if = "Option::is_none")]
125    pub temperature: Option<f32>,
126    #[serde(default, skip_serializing_if = "Option::is_none")]
127    pub top_p: Option<f32>,
128    #[serde(default, skip_serializing_if = "Option::is_none")]
129    pub max_tokens: Option<u32>,
130    #[serde(default, skip_serializing_if = "std::ops::Not::not")]
131    pub stream: bool,
132    #[serde(default, skip_serializing_if = "Vec::is_empty")]
133    pub tools: Vec<Tool>,
134    #[serde(default, skip_serializing_if = "Option::is_none")]
135    pub tool_choice: Option<ToolChoice>,
136    #[serde(default, skip_serializing_if = "Option::is_none")]
137    pub response_format: Option<ResponseFormat>,
138    #[serde(default, skip_serializing_if = "Vec::is_empty")]
139    pub stop: Vec<String>,
140    #[serde(default, skip_serializing_if = "Option::is_none")]
141    pub presence_penalty: Option<f32>,
142    #[serde(default, skip_serializing_if = "Option::is_none")]
143    pub frequency_penalty: Option<f32>,
144    #[serde(default, skip_serializing_if = "Option::is_none")]
145    pub n: Option<u32>,
146    #[serde(default, skip_serializing_if = "Option::is_none")]
147    pub seed: Option<i64>,
148    #[serde(default, skip_serializing_if = "Option::is_none")]
149    pub user: Option<String>,
150
151    /// TokenTrimmer-internal extras (cache config, route hints, etc.) that are
152    /// stripped before forwarding to the provider.
153    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
154    pub tt_extras: HashMap<String, serde_json::Value>,
155}
156
157#[derive(Debug, Clone, Serialize, Deserialize)]
158#[serde(tag = "role", rename_all = "lowercase")]
159pub enum Message {
160    System {
161        content: MessageContent,
162    },
163    User {
164        content: MessageContent,
165        #[serde(default, skip_serializing_if = "Option::is_none")]
166        name: Option<String>,
167    },
168    Assistant {
169        #[serde(default, skip_serializing_if = "Option::is_none")]
170        content: Option<MessageContent>,
171        #[serde(default, skip_serializing_if = "Vec::is_empty")]
172        tool_calls: Vec<ToolCall>,
173        #[serde(default, skip_serializing_if = "Option::is_none")]
174        name: Option<String>,
175    },
176    Tool {
177        content: MessageContent,
178        tool_call_id: String,
179    },
180}
181
182#[derive(Debug, Clone, Serialize, Deserialize)]
183#[serde(untagged)]
184pub enum MessageContent {
185    Text(String),
186    Parts(Vec<ContentPart>),
187}
188
189#[derive(Debug, Clone, Serialize, Deserialize)]
190#[serde(tag = "type", rename_all = "snake_case")]
191pub enum ContentPart {
192    Text { text: String },
193    ImageUrl { image_url: ImageUrl },
194    InputAudio { input_audio: InputAudio },
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize)]
198pub struct ImageUrl {
199    pub url: String,
200    #[serde(default, skip_serializing_if = "Option::is_none")]
201    pub detail: Option<String>,
202}
203
204#[derive(Debug, Clone, Serialize, Deserialize)]
205pub struct InputAudio {
206    pub data: String,
207    pub format: String,
208}
209
210#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct Tool {
212    #[serde(rename = "type")]
213    pub r#type: String,
214    pub function: ToolFunction,
215}
216
217#[derive(Debug, Clone, Serialize, Deserialize)]
218pub struct ToolFunction {
219    pub name: String,
220    #[serde(default, skip_serializing_if = "Option::is_none")]
221    pub description: Option<String>,
222    pub parameters: serde_json::Value,
223}
224
225#[derive(Debug, Clone, Serialize, Deserialize)]
226#[serde(untagged)]
227pub enum ToolChoice {
228    Auto(String),
229    Specific {
230        #[serde(rename = "type")]
231        r#type: String,
232        function: ToolChoiceFunction,
233    },
234}
235
236#[derive(Debug, Clone, Serialize, Deserialize)]
237pub struct ToolChoiceFunction {
238    pub name: String,
239}
240
241#[derive(Debug, Clone, Serialize, Deserialize)]
242pub struct ToolCall {
243    pub id: String,
244    #[serde(rename = "type")]
245    pub r#type: String,
246    pub function: ToolCallFunction,
247}
248
249#[derive(Debug, Clone, Serialize, Deserialize)]
250pub struct ToolCallFunction {
251    pub name: String,
252    /// Stringified JSON arguments — OpenAI convention.
253    pub arguments: String,
254}
255
256#[derive(Debug, Clone, Serialize, Deserialize)]
257pub struct ResponseFormat {
258    #[serde(rename = "type")]
259    pub r#type: String,
260    #[serde(default, skip_serializing_if = "Option::is_none")]
261    pub json_schema: Option<serde_json::Value>,
262}
263
264#[derive(Debug, Clone, Serialize, Deserialize)]
265pub struct ChatCompletionResponse {
266    pub id: String,
267    pub object: String,
268    pub created: i64,
269    pub model: String,
270    pub choices: Vec<Choice>,
271    pub usage: Usage,
272}
273
274#[derive(Debug, Clone, Serialize, Deserialize)]
275pub struct Choice {
276    pub index: u32,
277    pub message: Message,
278    pub finish_reason: Option<String>,
279}
280
281/// One SSE event from a streaming chat completion.
282#[derive(Debug, Clone, Serialize, Deserialize)]
283pub struct ChatCompletionChunk {
284    pub id: String,
285    pub object: String,
286    pub created: i64,
287    pub model: String,
288    pub choices: Vec<ChunkChoice>,
289    #[serde(default, skip_serializing_if = "Option::is_none")]
290    pub usage: Option<Usage>,
291}
292
293#[derive(Debug, Clone, Serialize, Deserialize)]
294pub struct ChunkChoice {
295    pub index: u32,
296    pub delta: ChunkDelta,
297    pub finish_reason: Option<String>,
298}
299
300#[derive(Debug, Clone, Default, Serialize, Deserialize)]
301pub struct ChunkDelta {
302    #[serde(default, skip_serializing_if = "Option::is_none")]
303    pub role: Option<String>,
304    #[serde(default, skip_serializing_if = "Option::is_none")]
305    pub content: Option<String>,
306    #[serde(default, skip_serializing_if = "Vec::is_empty")]
307    pub tool_calls: Vec<ToolCall>,
308}
309
310#[derive(Debug, Clone, Serialize, Deserialize)]
311pub struct EmbeddingsRequest {
312    pub model: String,
313    pub input: EmbeddingInput,
314    #[serde(default, skip_serializing_if = "Option::is_none")]
315    pub dimensions: Option<u32>,
316    #[serde(default, skip_serializing_if = "Option::is_none")]
317    pub encoding_format: Option<String>,
318}
319
320#[derive(Debug, Clone, Serialize, Deserialize)]
321#[serde(untagged)]
322pub enum EmbeddingInput {
323    Single(String),
324    Batch(Vec<String>),
325}
326
327#[derive(Debug, Clone, Serialize, Deserialize)]
328pub struct EmbeddingsResponse {
329    pub object: String,
330    pub data: Vec<EmbeddingData>,
331    pub model: String,
332    pub usage: Usage,
333}
334
335#[derive(Debug, Clone, Serialize, Deserialize)]
336pub struct EmbeddingData {
337    pub object: String,
338    pub index: u32,
339    pub embedding: Vec<f32>,
340}