1pub mod templates;
4
5use std::collections::{HashMap, HashSet};
6use std::sync::Arc;
7
8use minijinja::{Environment, ErrorKind, UndefinedBehavior};
9use serde::Serialize;
10use thiserror::Error;
11
12use crate::types::{EvalCase, FewShotExample, Invocation};
13
14pub const BUILTIN_TEMPLATE_VERSIONS: &[&str] = &[
19 "helpfulness_v0",
21 "correctness_v0",
22 "conciseness_v0",
23 "coherence_v0",
24 "response_relevance_v0",
25 "hallucination_v0",
26 "faithfulness_v0",
27 "plan_adherence_v0",
28 "laziness_v0",
29 "goal_success_rate_v0",
30 "harmfulness_v0",
32 "toxicity_v0",
33 "fairness_v0",
34 "pii_leakage_v0",
35 "prompt_injection_v0",
36 "code_injection_v0",
37 "rag_groundedness_v0",
39 "rag_retrieval_relevance_v0",
40 "rag_helpfulness_v0",
41 "trajectory_accuracy_v0",
43 "trajectory_accuracy_with_ref_v0",
44 "task_completion_v0",
45 "user_satisfaction_v0",
46 "agent_tone_v0",
47 "knowledge_retention_v0",
48 "language_detection_v0",
49 "perceived_error_v0",
50 "interactions_v0",
51 "code_llm_judge_v0",
53 "image_safety_v0",
55];
56
57pub trait JudgePromptTemplate: Send + Sync {
59 fn version(&self) -> &str;
61
62 fn render(&self, ctx: &PromptContext) -> Result<String, PromptError>;
64
65 fn family(&self) -> PromptFamily;
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
71pub enum PromptFamily {
72 Quality,
73 Safety,
74 Rag,
75 Agent,
76 Structured,
77 Code,
78 Multimodal,
79}
80
81#[derive(Debug, Clone, Serialize)]
83pub struct PromptContext {
84 pub case: Arc<EvalCase>,
85 pub invocation: Arc<Invocation>,
86 pub few_shot_examples: Vec<FewShotExample>,
87 pub custom: HashMap<String, serde_json::Value>,
88}
89
90impl PromptContext {
91 pub fn new(case: Arc<EvalCase>, invocation: Arc<Invocation>) -> Self {
93 Self {
94 case,
95 invocation,
96 few_shot_examples: Vec::new(),
97 custom: HashMap::new(),
98 }
99 }
100
101 #[must_use]
103 pub fn with_few_shot_examples(mut self, examples: Vec<FewShotExample>) -> Self {
104 self.few_shot_examples = examples;
105 self
106 }
107
108 #[must_use]
110 pub fn with_custom(mut self, custom: HashMap<String, serde_json::Value>) -> Self {
111 self.custom.extend(custom);
112 self
113 }
114}
115
116#[derive(Debug, Error)]
118pub enum PromptError {
119 #[error("missing prompt variable: {name}")]
120 MissingVariable { name: String },
121 #[error("duplicate prompt template version: {version}")]
122 DuplicateTemplate { version: String },
123 #[error("prompt render error: {0}")]
124 RenderError(String),
125}
126
127#[derive(Clone, Default)]
129pub struct PromptTemplateRegistry {
130 templates: HashMap<String, Arc<dyn JudgePromptTemplate>>,
131}
132
133impl PromptTemplateRegistry {
134 pub fn builtin() -> Self {
141 let mut registry = Self::default();
142 for family in [
143 templates::quality::builtins,
144 templates::safety::builtins,
145 templates::rag::builtins,
146 templates::agent::builtins,
147 templates::code::builtins,
148 templates::multimodal::builtins,
149 ] {
150 let entries = family().expect("built-in prompt template failed to compile");
151 for template in entries {
152 registry
153 .register(template)
154 .expect("duplicate built-in prompt template version");
155 }
156 }
157 registry
158 }
159
160 pub fn versions(&self) -> impl Iterator<Item = &str> {
162 self.templates.keys().map(String::as_str)
163 }
164
165 #[must_use]
167 pub fn len(&self) -> usize {
168 self.templates.len()
169 }
170
171 #[must_use]
173 pub fn is_empty(&self) -> bool {
174 self.templates.is_empty()
175 }
176
177 pub fn get(&self, version: &str) -> Option<Arc<dyn JudgePromptTemplate>> {
179 self.templates.get(version).cloned()
180 }
181
182 pub fn register(&mut self, template: Arc<dyn JudgePromptTemplate>) -> Result<(), PromptError> {
184 let version = template.version().to_string();
185 if self.templates.contains_key(&version) {
186 return Err(PromptError::DuplicateTemplate { version });
187 }
188 self.templates.insert(version, template);
189 Ok(())
190 }
191}
192
193#[derive(Debug, Clone)]
195pub struct MinijinjaTemplate {
196 version: String,
197 family: PromptFamily,
198 source: String,
199 undeclared: HashSet<String>,
200}
201
202impl MinijinjaTemplate {
203 pub fn new(
209 version: impl Into<String>,
210 family: PromptFamily,
211 source: impl Into<String>,
212 ) -> Result<Self, PromptError> {
213 let version = version.into();
214 let source = source.into();
215 let mut env = strict_environment();
216 env.add_template_owned(version.clone(), source.clone())
217 .map_err(|err| PromptError::RenderError(err.to_string()))?;
218 let template = env
219 .get_template(&version)
220 .map_err(|err| PromptError::RenderError(err.to_string()))?;
221 let undeclared = template.undeclared_variables(false);
222 if let Some(name) = undeclared
223 .iter()
224 .find(|name| !ALLOWED_ROOT_VARIABLES.contains(&name.as_str()))
225 {
226 return Err(PromptError::MissingVariable { name: name.clone() });
227 }
228
229 Ok(Self {
230 version,
231 family,
232 source,
233 undeclared,
234 })
235 }
236
237 pub fn variables(&self) -> &HashSet<String> {
239 &self.undeclared
240 }
241}
242
243impl JudgePromptTemplate for MinijinjaTemplate {
244 fn version(&self) -> &str {
245 &self.version
246 }
247
248 fn render(&self, ctx: &PromptContext) -> Result<String, PromptError> {
249 let mut env = strict_environment();
250 env.add_template_owned(self.version.clone(), self.source.clone())
251 .map_err(|err| render_error(&err))?;
252 let template = env
253 .get_template(&self.version)
254 .map_err(|err| render_error(&err))?;
255 template.render(ctx).map_err(|err| render_error(&err))
256 }
257
258 fn family(&self) -> PromptFamily {
259 self.family
260 }
261}
262
263const ALLOWED_ROOT_VARIABLES: &[&str] = &["case", "invocation", "few_shot_examples", "custom"];
264
265fn strict_environment() -> Environment<'static> {
266 let mut env = Environment::new();
267 env.set_undefined_behavior(UndefinedBehavior::Strict);
268 env
269}
270
271fn render_error(err: &minijinja::Error) -> PromptError {
272 if err.kind() == ErrorKind::UndefinedError {
273 return PromptError::MissingVariable {
274 name: err.to_string(),
275 };
276 }
277 PromptError::RenderError(err.to_string())
278}