Skip to main content

zagens_core/
chat.rs

1//! Chat / message types shared between the runtime core and the TUI shell.
2//!
3//! These are the wire-format types used to communicate with LLM APIs.
4//! They live in `zagens-core` so both the TUI and future shells can depend
5//! on them without pulling in TUI-specific code.
6
7use serde::{Deserialize, Serialize};
8
9// ── Constants ─────────────────────────────────────────────────────────
10
11pub const LEGACY_DEEPSEEK_CONTEXT_WINDOW_TOKENS: u32 = 128_000;
12pub const DEEPSEEK_V4_CONTEXT_WINDOW_TOKENS: u32 = 1_000_000;
13pub const DEFAULT_COMPACTION_TOKEN_THRESHOLD: usize = 102_400;
14const COMPACTION_THRESHOLD_PERCENT: u32 = 80;
15
16// ── Message types ─────────────────────────────────────────────────────
17
18/// Request payload for sending a message to the API.
19#[derive(Debug, Serialize, Deserialize, Clone)]
20pub struct MessageRequest {
21    pub model: String,
22    pub messages: Vec<Message>,
23    pub max_tokens: u32,
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub system: Option<SystemPrompt>,
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub tools: Option<Vec<Tool>>,
28    #[serde(skip_serializing_if = "Option::is_none")]
29    pub tool_choice: Option<serde_json::Value>,
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub metadata: Option<serde_json::Value>,
32    #[serde(skip_serializing_if = "Option::is_none")]
33    pub thinking: Option<serde_json::Value>,
34    #[serde(skip_serializing_if = "Option::is_none")]
35    pub reasoning_effort: Option<String>,
36    #[serde(skip_serializing_if = "Option::is_none")]
37    pub stream: Option<bool>,
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub temperature: Option<f32>,
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub top_p: Option<f32>,
42}
43
44/// System prompt representation (plain text or structured blocks).
45#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
46#[serde(untagged)]
47pub enum SystemPrompt {
48    Text(String),
49    Blocks(Vec<SystemBlock>),
50}
51
52/// A structured system prompt block.
53#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
54pub struct SystemBlock {
55    #[serde(rename = "type")]
56    pub block_type: String,
57    pub text: String,
58    #[serde(skip_serializing_if = "Option::is_none")]
59    pub cache_control: Option<CacheControl>,
60}
61
62/// A chat message with role and content blocks.
63#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
64pub struct Message {
65    pub role: String,
66    pub content: Vec<ContentBlock>,
67}
68
69/// A single content block inside a message.
70#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
71#[serde(tag = "type")]
72pub enum ContentBlock {
73    #[serde(rename = "text")]
74    Text {
75        text: String,
76        #[serde(skip_serializing_if = "Option::is_none")]
77        cache_control: Option<CacheControl>,
78    },
79    #[serde(rename = "thinking")]
80    Thinking { thinking: String },
81    #[serde(rename = "tool_use")]
82    ToolUse {
83        id: String,
84        name: String,
85        input: serde_json::Value,
86        #[serde(skip_serializing_if = "Option::is_none")]
87        caller: Option<ToolCaller>,
88    },
89    #[serde(rename = "tool_result")]
90    ToolResult {
91        tool_use_id: String,
92        content: String,
93        #[serde(skip_serializing_if = "Option::is_none")]
94        is_error: Option<bool>,
95        #[serde(skip_serializing_if = "Option::is_none")]
96        content_blocks: Option<Vec<serde_json::Value>>,
97    },
98    #[serde(rename = "server_tool_use")]
99    ServerToolUse {
100        id: String,
101        name: String,
102        input: serde_json::Value,
103    },
104    #[serde(rename = "tool_search_tool_result")]
105    ToolSearchToolResult {
106        tool_use_id: String,
107        content: serde_json::Value,
108    },
109    #[serde(rename = "code_execution_tool_result")]
110    CodeExecutionToolResult {
111        tool_use_id: String,
112        content: serde_json::Value,
113    },
114}
115
116/// Cache control metadata for tool definitions and blocks.
117#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
118pub struct CacheControl {
119    #[serde(rename = "type")]
120    pub cache_type: String,
121}
122
123/// Metadata describing who invoked a tool call.
124#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
125pub struct ToolCaller {
126    #[serde(rename = "type")]
127    pub caller_type: String,
128    #[serde(skip_serializing_if = "Option::is_none")]
129    pub tool_id: Option<String>,
130}
131
132/// Tool definition exposed to the model.
133#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
134pub struct Tool {
135    #[serde(rename = "type", skip_serializing_if = "Option::is_none")]
136    pub tool_type: Option<String>,
137    pub name: String,
138    pub description: String,
139    pub input_schema: serde_json::Value,
140    #[serde(skip_serializing_if = "Option::is_none")]
141    pub allowed_callers: Option<Vec<String>>,
142    #[serde(skip_serializing_if = "Option::is_none")]
143    pub defer_loading: Option<bool>,
144    #[serde(skip_serializing_if = "Option::is_none")]
145    pub input_examples: Option<Vec<serde_json::Value>>,
146    #[serde(skip_serializing_if = "Option::is_none")]
147    pub strict: Option<bool>,
148    #[serde(skip_serializing_if = "Option::is_none")]
149    pub cache_control: Option<CacheControl>,
150}
151
152/// Container metadata for code-execution style server tools.
153#[derive(Debug, Serialize, Deserialize, Clone)]
154pub struct ContainerInfo {
155    pub id: String,
156    #[serde(skip_serializing_if = "Option::is_none")]
157    pub expires_at: Option<String>,
158}
159
160/// Response payload for a message request.
161#[derive(Debug, Serialize, Deserialize, Clone)]
162pub struct MessageResponse {
163    pub id: String,
164    pub r#type: String,
165    pub role: String,
166    pub content: Vec<ContentBlock>,
167    pub model: String,
168    pub stop_reason: Option<String>,
169    pub stop_sequence: Option<String>,
170    #[serde(skip_serializing_if = "Option::is_none")]
171    pub container: Option<ContainerInfo>,
172    pub usage: crate::models::Usage,
173}
174
175// ── Streaming types ───────────────────────────────────────────────────
176
177/// Streaming event types for SSE responses.
178#[allow(dead_code)]
179#[derive(Debug, Deserialize, Clone)]
180#[serde(tag = "type")]
181pub enum StreamEvent {
182    #[serde(rename = "message_start")]
183    MessageStart { message: MessageResponse },
184    #[serde(rename = "content_block_start")]
185    ContentBlockStart {
186        index: u32,
187        content_block: ContentBlockStart,
188    },
189    #[serde(rename = "content_block_delta")]
190    ContentBlockDelta { index: u32, delta: Delta },
191    #[serde(rename = "content_block_stop")]
192    ContentBlockStop { index: u32 },
193    #[serde(rename = "message_delta")]
194    MessageDelta {
195        delta: MessageDelta,
196        usage: Option<crate::models::Usage>,
197    },
198    #[serde(rename = "message_stop")]
199    MessageStop,
200    #[serde(rename = "ping")]
201    Ping,
202}
203
204/// Content block types used in streaming starts.
205#[allow(dead_code)]
206#[derive(Debug, Deserialize, Clone)]
207#[serde(tag = "type")]
208pub enum ContentBlockStart {
209    #[serde(rename = "text")]
210    Text { text: String },
211    #[serde(rename = "thinking")]
212    Thinking { thinking: String },
213    #[serde(rename = "tool_use")]
214    ToolUse {
215        id: String,
216        name: String,
217        input: serde_json::Value,
218        #[serde(skip_serializing_if = "Option::is_none")]
219        caller: Option<ToolCaller>,
220    },
221    #[serde(rename = "server_tool_use")]
222    ServerToolUse {
223        id: String,
224        name: String,
225        input: serde_json::Value,
226    },
227}
228
229/// A content delta inside a `content_block_delta` event.
230#[allow(clippy::enum_variant_names)]
231#[derive(Debug, Deserialize, Clone)]
232#[serde(tag = "type")]
233pub enum Delta {
234    #[serde(rename = "text_delta")]
235    TextDelta { text: String },
236    #[serde(rename = "thinking_delta")]
237    ThinkingDelta { thinking: String },
238    #[serde(rename = "input_json_delta")]
239    InputJsonDelta { partial_json: String },
240}
241
242/// Delta payload for message-level updates.
243#[allow(dead_code)]
244#[derive(Debug, Deserialize, Clone)]
245pub struct MessageDelta {
246    pub stop_reason: Option<String>,
247    pub stop_sequence: Option<String>,
248}
249
250// ── LLM Client trait (P2 PR3b) ───────────────────────────────────────
251
252/// Type alias for boxed stream of SSE events.
253pub type StreamEventBox = std::pin::Pin<
254    Box<dyn futures_util::Stream<Item = anyhow::Result<StreamEvent>> + Send + 'static>,
255>;
256
257/// Unified interface for LLM providers — dyn-compatible via `#[async_trait]`.
258///
259/// Implementations live in the TUI shell (`DeepSeekClient`); the core only
260/// depends on this trait so Engine/turn_loop can be provider-agnostic.
261#[async_trait::async_trait]
262pub trait LlmClient: Send + Sync {
263    fn provider_name(&self) -> &'static str;
264    fn model(&self) -> &str;
265    async fn create_message(&self, request: MessageRequest) -> anyhow::Result<MessageResponse>;
266    async fn create_message_stream(
267        &self,
268        request: MessageRequest,
269    ) -> anyhow::Result<StreamEventBox>;
270    async fn health_check(&self) -> anyhow::Result<bool> {
271        Ok(true)
272    }
273    /// DeepSeek FIM (Fill-in-the-Middle) completion. Returns an error by default;
274    /// `DeepSeekClient` overrides this.
275    async fn fim_completion(
276        &self,
277        _model: &str,
278        _prefix: &str,
279        _suffix: &str,
280        _max_tokens: u32,
281    ) -> anyhow::Result<String> {
282        Err(anyhow::anyhow!("FIM not supported by this provider"))
283    }
284}
285
286// ── Context window helpers ────────────────────────────────────────────
287
288/// Map known models to their approximate context window sizes.
289#[must_use]
290pub fn context_window_for_model(model: &str) -> Option<u32> {
291    let lower = model.to_lowercase();
292    if lower.contains("deepseek") {
293        if let Some(explicit_window) = deepseek_context_window_hint(&lower) {
294            return Some(explicit_window);
295        }
296        if lower.contains("v4") {
297            return Some(DEEPSEEK_V4_CONTEXT_WINDOW_TOKENS);
298        }
299        return Some(LEGACY_DEEPSEEK_CONTEXT_WINDOW_TOKENS);
300    }
301    if lower.contains("claude") {
302        return Some(200_000);
303    }
304    None
305}
306
307fn deepseek_context_window_hint(model_lower: &str) -> Option<u32> {
308    let bytes = model_lower.as_bytes();
309    let mut i = 0usize;
310    while i < bytes.len() {
311        if bytes[i].is_ascii_digit() {
312            let start = i;
313            while i < bytes.len() && bytes[i].is_ascii_digit() {
314                i += 1;
315            }
316            // Must be followed by 'k'/'K' with word-boundary guards.
317            if i >= bytes.len() || bytes[i] != b'k' {
318                continue;
319            }
320            let before_ok = start == 0 || !bytes[start - 1].is_ascii_alphanumeric();
321            let after_ok = i + 1 >= bytes.len() || !bytes[i + 1].is_ascii_alphanumeric();
322            if !before_ok || !after_ok {
323                continue;
324            }
325            if let Ok(kilo_tokens) = model_lower[start..i].parse::<u32>()
326                && (8..=1024).contains(&kilo_tokens)
327            {
328                return Some(kilo_tokens.saturating_mul(1000));
329            }
330        } else {
331            i += 1;
332        }
333    }
334    None
335}
336
337/// Suggested compaction token threshold for a given model.
338#[must_use]
339pub fn compaction_threshold_for_model(model: &str) -> usize {
340    let Some(window) = context_window_for_model(model) else {
341        return DEFAULT_COMPACTION_TOKEN_THRESHOLD;
342    };
343    let threshold = (u64::from(window) * u64::from(COMPACTION_THRESHOLD_PERCENT)) / 100;
344    usize::try_from(threshold).unwrap_or(DEFAULT_COMPACTION_TOKEN_THRESHOLD)
345}
346
347/// Get the default compaction threshold for the model override, or the
348/// configured default.
349#[must_use]
350pub fn compaction_threshold_for_override(override_model: Option<&str>) -> usize {
351    match override_model {
352        Some(model) => compaction_threshold_for_model(model),
353        None => DEFAULT_COMPACTION_TOKEN_THRESHOLD,
354    }
355}