Skip to main content

swink_agent_eval/prompt/
mod.rs

1//! Prompt templates and rendering infrastructure for judge-backed evaluators.
2
3pub mod templates;
4
5use std::collections::{HashMap, HashSet};
6use std::sync::Arc;
7
8use minijinja::{Environment, ErrorKind, UndefinedBehavior};
9use serde::Serialize;
10use thiserror::Error;
11
12use crate::types::{EvalCase, FewShotExample, Invocation};
13
14/// Stable list of every built-in template version registered by [`PromptTemplateRegistry::builtin`].
15///
16/// Exposed so downstream tests (and users) can assert the built-in surface
17/// without hard-coding the list in multiple places (T054).
18pub const BUILTIN_TEMPLATE_VERSIONS: &[&str] = &[
19    // Quality family (T049)
20    "helpfulness_v0",
21    "correctness_v0",
22    "conciseness_v0",
23    "coherence_v0",
24    "response_relevance_v0",
25    "hallucination_v0",
26    "faithfulness_v0",
27    "plan_adherence_v0",
28    "laziness_v0",
29    "goal_success_rate_v0",
30    // Safety family (T050)
31    "harmfulness_v0",
32    "toxicity_v0",
33    "fairness_v0",
34    "pii_leakage_v0",
35    "prompt_injection_v0",
36    "code_injection_v0",
37    // RAG family (T051)
38    "rag_groundedness_v0",
39    "rag_retrieval_relevance_v0",
40    "rag_helpfulness_v0",
41    // Agent family (T052)
42    "trajectory_accuracy_v0",
43    "trajectory_accuracy_with_ref_v0",
44    "task_completion_v0",
45    "user_satisfaction_v0",
46    "agent_tone_v0",
47    "knowledge_retention_v0",
48    "language_detection_v0",
49    "perceived_error_v0",
50    "interactions_v0",
51    // Code family (T053)
52    "code_llm_judge_v0",
53    // Multimodal family (T053)
54    "image_safety_v0",
55];
56
57/// Versioned prompt template consumed by judge-backed evaluators.
58pub trait JudgePromptTemplate: Send + Sync {
59    /// Stable version identifier, for example `correctness_v0`.
60    fn version(&self) -> &str;
61
62    /// Render the prompt for a single evaluator dispatch.
63    fn render(&self, ctx: &PromptContext) -> Result<String, PromptError>;
64
65    /// Evaluator family this template belongs to.
66    fn family(&self) -> PromptFamily;
67}
68
69/// Prompt families with judge-backed templates.
70#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
71pub enum PromptFamily {
72    Quality,
73    Safety,
74    Rag,
75    Agent,
76    Structured,
77    Code,
78    Multimodal,
79}
80
81/// Data made available to prompt templates.
82#[derive(Debug, Clone, Serialize)]
83pub struct PromptContext {
84    pub case: Arc<EvalCase>,
85    pub invocation: Arc<Invocation>,
86    pub few_shot_examples: Vec<FewShotExample>,
87    pub custom: HashMap<String, serde_json::Value>,
88}
89
90impl PromptContext {
91    /// Construct a context with no few-shot examples or custom variables.
92    pub fn new(case: Arc<EvalCase>, invocation: Arc<Invocation>) -> Self {
93        Self {
94            case,
95            invocation,
96            few_shot_examples: Vec::new(),
97            custom: HashMap::new(),
98        }
99    }
100
101    /// Add few-shot examples to the render context.
102    #[must_use]
103    pub fn with_few_shot_examples(mut self, examples: Vec<FewShotExample>) -> Self {
104        self.few_shot_examples = examples;
105        self
106    }
107
108    /// Add custom template variables under the `custom` namespace.
109    #[must_use]
110    pub fn with_custom(mut self, custom: HashMap<String, serde_json::Value>) -> Self {
111        self.custom.extend(custom);
112        self
113    }
114}
115
116/// Prompt-template construction and rendering errors.
117#[derive(Debug, Error)]
118pub enum PromptError {
119    #[error("missing prompt variable: {name}")]
120    MissingVariable { name: String },
121    #[error("duplicate prompt template version: {version}")]
122    DuplicateTemplate { version: String },
123    #[error("prompt render error: {0}")]
124    RenderError(String),
125}
126
127/// Registry keyed by template version.
128#[derive(Clone, Default)]
129pub struct PromptTemplateRegistry {
130    templates: HashMap<String, Arc<dyn JudgePromptTemplate>>,
131}
132
133impl PromptTemplateRegistry {
134    /// Registry seeded with every built-in template authored under
135    /// `prompt::templates::{quality,safety,rag,agent,code,multimodal}` (T054).
136    ///
137    /// Built-in sources are validated at development time; any registration
138    /// failure here is a programming error and surfaces as a panic rather
139    /// than being silently swallowed.
140    pub fn builtin() -> Self {
141        let mut registry = Self::default();
142        for family in [
143            templates::quality::builtins,
144            templates::safety::builtins,
145            templates::rag::builtins,
146            templates::agent::builtins,
147            templates::code::builtins,
148            templates::multimodal::builtins,
149        ] {
150            let entries = family().expect("built-in prompt template failed to compile");
151            for template in entries {
152                registry
153                    .register(template)
154                    .expect("duplicate built-in prompt template version");
155            }
156        }
157        registry
158    }
159
160    /// Iterate over every registered template version identifier.
161    pub fn versions(&self) -> impl Iterator<Item = &str> {
162        self.templates.keys().map(String::as_str)
163    }
164
165    /// Current number of registered templates.
166    #[must_use]
167    pub fn len(&self) -> usize {
168        self.templates.len()
169    }
170
171    /// Whether the registry currently holds zero templates.
172    #[must_use]
173    pub fn is_empty(&self) -> bool {
174        self.templates.is_empty()
175    }
176
177    /// Look up a template by version identifier.
178    pub fn get(&self, version: &str) -> Option<Arc<dyn JudgePromptTemplate>> {
179        self.templates.get(version).cloned()
180    }
181
182    /// Register a template, rejecting duplicate version identifiers.
183    pub fn register(&mut self, template: Arc<dyn JudgePromptTemplate>) -> Result<(), PromptError> {
184        let version = template.version().to_string();
185        if self.templates.contains_key(&version) {
186            return Err(PromptError::DuplicateTemplate { version });
187        }
188        self.templates.insert(version, template);
189        Ok(())
190    }
191}
192
193/// MiniJinja-backed implementation of [`JudgePromptTemplate`].
194#[derive(Debug, Clone)]
195pub struct MinijinjaTemplate {
196    version: String,
197    family: PromptFamily,
198    source: String,
199    undeclared: HashSet<String>,
200}
201
202impl MinijinjaTemplate {
203    /// Compile and validate a MiniJinja prompt template.
204    ///
205    /// Templates may reference the root variables `case`, `invocation`,
206    /// `few_shot_examples`, and `custom`. Any other root variable is rejected
207    /// at construction time.
208    pub fn new(
209        version: impl Into<String>,
210        family: PromptFamily,
211        source: impl Into<String>,
212    ) -> Result<Self, PromptError> {
213        let version = version.into();
214        let source = source.into();
215        let mut env = strict_environment();
216        env.add_template_owned(version.clone(), source.clone())
217            .map_err(|err| PromptError::RenderError(err.to_string()))?;
218        let template = env
219            .get_template(&version)
220            .map_err(|err| PromptError::RenderError(err.to_string()))?;
221        let undeclared = template.undeclared_variables(false);
222        if let Some(name) = undeclared
223            .iter()
224            .find(|name| !ALLOWED_ROOT_VARIABLES.contains(&name.as_str()))
225        {
226            return Err(PromptError::MissingVariable { name: name.clone() });
227        }
228
229        Ok(Self {
230            version,
231            family,
232            source,
233            undeclared,
234        })
235    }
236
237    /// Root variables discovered while compiling the template.
238    pub fn variables(&self) -> &HashSet<String> {
239        &self.undeclared
240    }
241}
242
243impl JudgePromptTemplate for MinijinjaTemplate {
244    fn version(&self) -> &str {
245        &self.version
246    }
247
248    fn render(&self, ctx: &PromptContext) -> Result<String, PromptError> {
249        let mut env = strict_environment();
250        env.add_template_owned(self.version.clone(), self.source.clone())
251            .map_err(|err| render_error(&err))?;
252        let template = env
253            .get_template(&self.version)
254            .map_err(|err| render_error(&err))?;
255        template.render(ctx).map_err(|err| render_error(&err))
256    }
257
258    fn family(&self) -> PromptFamily {
259        self.family
260    }
261}
262
263const ALLOWED_ROOT_VARIABLES: &[&str] = &["case", "invocation", "few_shot_examples", "custom"];
264
265fn strict_environment() -> Environment<'static> {
266    let mut env = Environment::new();
267    env.set_undefined_behavior(UndefinedBehavior::Strict);
268    env
269}
270
271fn render_error(err: &minijinja::Error) -> PromptError {
272    if err.kind() == ErrorKind::UndefinedError {
273        return PromptError::MissingVariable {
274            name: err.to_string(),
275        };
276    }
277    PromptError::RenderError(err.to_string())
278}