Skip to main content

sapient_tokenizers/
chat.rs

1//! Chat template rendering using Jinja2 via `minijinja`.
2//!
3//! Renders HuggingFace chat templates (stored in `tokenizer_config.json`)
4//! exactly as the Python `transformers` library does.
5//!
6//! Supports: Llama 3 / ChatML / Mistral / Phi / Gemma / Qwen / Zephyr templates.
7
8use std::path::Path;
9
10use anyhow::{Context, Result};
11use minijinja::Environment;
12use serde::{Deserialize, Serialize};
13
14// ── ChatRole ──────────────────────────────────────────────────────────────────
15
16#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
17#[serde(rename_all = "lowercase")]
18pub enum ChatRole {
19    System,
20    User,
21    Assistant,
22    Tool,
23}
24
25impl std::fmt::Display for ChatRole {
26    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27        match self {
28            ChatRole::System => f.write_str("system"),
29            ChatRole::User => f.write_str("user"),
30            ChatRole::Assistant => f.write_str("assistant"),
31            ChatRole::Tool => f.write_str("tool"),
32        }
33    }
34}
35
36// ── ChatMessage ───────────────────────────────────────────────────────────────
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct ChatMessage {
40    pub role: ChatRole,
41    pub content: String,
42}
43
44impl ChatMessage {
45    pub fn system(content: impl Into<String>) -> Self {
46        Self {
47            role: ChatRole::System,
48            content: content.into(),
49        }
50    }
51    pub fn user(content: impl Into<String>) -> Self {
52        Self {
53            role: ChatRole::User,
54            content: content.into(),
55        }
56    }
57    pub fn assistant(content: impl Into<String>) -> Self {
58        Self {
59            role: ChatRole::Assistant,
60            content: content.into(),
61        }
62    }
63}
64
65// ── ChatTemplate ──────────────────────────────────────────────────────────────
66
67/// Renders a conversation to a tokenizable string using the model's
68/// Jinja2 chat template (from `tokenizer_config.json`).
69pub struct ChatTemplate {
70    template_src: String,
71}
72
73impl ChatTemplate {
74    /// Load from a `tokenizer_config.json` file.
75    pub fn from_tokenizer_config(path: &Path) -> Result<Self> {
76        let text = std::fs::read_to_string(path).context("Failed to read tokenizer_config.json")?;
77        let config: serde_json::Value =
78            serde_json::from_str(&text).context("Invalid tokenizer_config.json")?;
79
80        let template_src = config["chat_template"]
81            .as_str()
82            .context("No chat_template found in tokenizer_config.json")?
83            .to_owned();
84
85        Ok(Self { template_src })
86    }
87
88    /// Build with a raw Jinja2 template string.
89    pub fn from_template(template: impl Into<String>) -> Self {
90        Self {
91            template_src: template.into(),
92        }
93    }
94
95    /// Render a list of chat messages to a prompt string.
96    ///
97    /// Set `add_generation_prompt = true` to append the assistant turn header
98    /// (so the model knows to generate).
99    pub fn render(&self, messages: &[ChatMessage], add_generation_prompt: bool) -> Result<String> {
100        let mut env = Environment::new();
101
102        // Register the template.
103        env.add_template("chat", &self.template_src)
104            .map_err(|e| anyhow::anyhow!("Template parse error: {e}"))?;
105
106        let tmpl = env
107            .get_template("chat")
108            .map_err(|e| anyhow::anyhow!("Template load error: {e}"))?;
109
110        // Build context — same variable names as HF Python.
111        let messages_val: Vec<serde_json::Value> = messages
112            .iter()
113            .map(|m| {
114                serde_json::json!({
115                    "role": m.role.to_string(),
116                    "content": m.content,
117                })
118            })
119            .collect();
120
121        let ctx = serde_json::json!({
122            "messages": messages_val,
123            "add_generation_prompt": add_generation_prompt,
124            "bos_token": "<s>",
125            "eos_token": "</s>",
126        });
127
128        tmpl.render(ctx)
129            .map_err(|e| anyhow::anyhow!("Template render error: {e}"))
130    }
131}
132
133/// Built-in templates for common models (fallback when tokenizer_config.json
134/// doesn't contain a chat_template field).
135pub mod builtin {
136    /// ChatML format — used by Phi-3, Qwen, Mistral-Instruct variants.
137    pub const CHATML: &str = concat!(
138        "{% for message in messages %}",
139        "<|im_start|>{{ message['role'] }}\n{{ message['content'] }}<|im_end|>\n",
140        "{% endfor %}",
141        "{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}",
142    );
143
144    /// Llama 3 Instruct format.
145    pub const LLAMA3: &str = concat!(
146        "<|begin_of_text|>",
147        "{% for message in messages %}",
148        "<|start_header_id|>{{ message['role'] }}<|end_header_id|>\n\n",
149        "{{ message['content'] }}<|eot_id|>",
150        "{% endfor %}",
151        "{% if add_generation_prompt %}",
152        "<|start_header_id|>assistant<|end_header_id|>\n\n",
153        "{% endif %}",
154    );
155
156    /// Llama 2 / Mistral instruct format.
157    pub const LLAMA2: &str = concat!(
158        "{% if messages[0]['role'] == 'system' %}",
159        "{{ '[INST] <<SYS>>\n' + messages[0]['content'] + '\n<</SYS>>\n\n' }}",
160        "{% set messages = messages[1:] %}",
161        "{% endif %}",
162        "{% for message in messages %}",
163        "{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}",
164        "{% elif message['role'] == 'assistant' %}{{ message['content'] + '</s>' }}",
165        "{% endif %}",
166        "{% endfor %}",
167    );
168
169    /// Gemma instruct format.
170    pub const GEMMA: &str = concat!(
171        "{% for message in messages %}",
172        "{% if message['role'] == 'user' %}<start_of_turn>user\n{{ message['content'] }}<end_of_turn>\n",
173        "{% elif message['role'] == 'assistant' %}<start_of_turn>model\n{{ message['content'] }}<end_of_turn>\n",
174        "{% endif %}",
175        "{% endfor %}",
176        "{% if add_generation_prompt %}<start_of_turn>model\n{% endif %}",
177    );
178
179    /// Zephyr / TinyLlama chat format.
180    pub const ZEPHYR: &str = concat!(
181        "{% for message in messages %}",
182        "{% if message['role'] == 'system' %}",
183        "<|system|>\n{{ message['content'] }}</s>\n",
184        "{% elif message['role'] == 'user' %}",
185        "<|user|>\n{{ message['content'] }}</s>\n",
186        "{% elif message['role'] == 'assistant' %}",
187        "<|assistant|>\n{{ message['content'] }}</s>\n",
188        "{% endif %}",
189        "{% endfor %}",
190        "{% if add_generation_prompt %}<|assistant|>\n{% endif %}",
191    );
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197
198    #[test]
199    fn chatml_render() {
200        let tmpl = ChatTemplate::from_template(builtin::CHATML);
201        let messages = vec![
202            ChatMessage::system("You are a helpful assistant."),
203            ChatMessage::user("Hello!"),
204        ];
205        let out = tmpl.render(&messages, true).unwrap();
206        assert!(out.contains("<|im_start|>system"));
207        assert!(out.contains("<|im_start|>user"));
208        assert!(out.contains("<|im_start|>assistant"));
209    }
210
211    #[test]
212    fn llama3_render() {
213        let tmpl = ChatTemplate::from_template(builtin::LLAMA3);
214        let messages = vec![ChatMessage::user("What is 2+2?")];
215        let out = tmpl.render(&messages, true).unwrap();
216        assert!(out.contains("<|begin_of_text|>"));
217        assert!(out.contains("<|start_header_id|>user<|end_header_id|>"));
218    }
219}