swink_agent_eval/simulation/
actor.rs1#![forbid(unsafe_code)]
8
9use std::sync::{
10 Arc,
11 atomic::{AtomicUsize, Ordering},
12};
13
14use crate::judge::{JudgeClient, JudgeError, JudgeVerdict};
15
16#[derive(Debug, Clone, PartialEq, Eq)]
18pub struct ActorProfile {
19 pub name: String,
20 pub traits: Vec<String>,
21 pub context: String,
22 pub goal: String,
23}
24
25impl ActorProfile {
26 #[must_use]
27 pub fn new(name: impl Into<String>, goal: impl Into<String>) -> Self {
28 Self {
29 name: name.into(),
30 traits: Vec::new(),
31 context: String::new(),
32 goal: goal.into(),
33 }
34 }
35
36 #[must_use]
38 pub fn as_system_prompt(&self) -> String {
39 let mut prompt = format!("You are {}.\nGoal: {}\n", self.name, self.goal);
40 if !self.context.is_empty() {
41 prompt.push_str("Context: ");
42 prompt.push_str(&self.context);
43 prompt.push('\n');
44 }
45 if !self.traits.is_empty() {
46 prompt.push_str("Traits: ");
47 prompt.push_str(&self.traits.join(", "));
48 prompt.push('\n');
49 }
50 prompt
51 }
52}
53
54#[derive(Debug, Clone, PartialEq, Eq)]
56pub struct ActorTurn {
57 pub message: String,
58 pub goal_completed: Option<String>,
60}
61
62pub struct ActorSimulator {
64 profile: ActorProfile,
65 judge: Arc<dyn JudgeClient>,
66 model_id: String,
67 greeting_pool: Vec<String>,
68 max_turns: u32,
69 goal_completion_signal: Option<String>,
70 greeting_cursor: AtomicUsize,
71}
72
73impl ActorSimulator {
74 pub const DEFAULT_MAX_TURNS: u32 = 10;
76
77 #[must_use]
78 pub fn new(
79 profile: ActorProfile,
80 judge: Arc<dyn JudgeClient>,
81 model_id: impl Into<String>,
82 ) -> Self {
83 Self {
84 profile,
85 judge,
86 model_id: model_id.into(),
87 greeting_pool: vec!["Hello.".to_string()],
88 max_turns: Self::DEFAULT_MAX_TURNS,
89 goal_completion_signal: None,
90 greeting_cursor: AtomicUsize::new(0),
91 }
92 }
93
94 #[must_use]
96 pub fn with_greeting_pool(mut self, pool: Vec<String>) -> Self {
97 self.greeting_pool = if pool.is_empty() {
98 vec!["Hello.".to_string()]
99 } else {
100 pool
101 };
102 self
103 }
104
105 #[must_use]
106 pub const fn with_max_turns(mut self, max_turns: u32) -> Self {
107 self.max_turns = max_turns;
108 self
109 }
110
111 #[must_use]
112 pub fn with_goal_completion_signal(mut self, signal: impl Into<String>) -> Self {
113 self.goal_completion_signal = Some(signal.into());
114 self
115 }
116
117 #[must_use]
118 pub const fn profile(&self) -> &ActorProfile {
119 &self.profile
120 }
121
122 #[must_use]
123 pub const fn max_turns(&self) -> u32 {
124 self.max_turns
125 }
126
127 #[must_use]
128 pub fn goal_completion_signal(&self) -> Option<&str> {
129 self.goal_completion_signal.as_deref()
130 }
131
132 #[must_use]
133 pub fn model_id(&self) -> &str {
134 &self.model_id
135 }
136
137 pub fn greeting(&self) -> ActorTurn {
139 let idx = self.greeting_cursor.fetch_add(1, Ordering::Relaxed);
140 let message = self.greeting_pool[idx % self.greeting_pool.len()].clone();
141 ActorTurn {
142 message,
143 goal_completed: None,
144 }
145 }
146
147 pub async fn next_turn(&self, assistant_message: &str) -> Result<ActorTurn, JudgeError> {
149 let prompt = self.render_prompt(assistant_message);
150 let verdict = self.judge.judge(&prompt).await?;
151 Ok(self.turn_from_verdict(verdict))
152 }
153
154 fn render_prompt(&self, assistant_message: &str) -> String {
155 let mut prompt = self.profile.as_system_prompt();
156 prompt.push_str("Assistant said: ");
157 prompt.push_str(assistant_message);
158 prompt.push('\n');
159 prompt.push_str("Reply with your next message. ");
160 if let Some(signal) = &self.goal_completion_signal {
161 prompt.push_str(&format!(
162 "If the goal is complete, reply with label `{signal}`."
163 ));
164 }
165 prompt
166 }
167
168 fn turn_from_verdict(&self, verdict: JudgeVerdict) -> ActorTurn {
169 let goal_completed = match (&verdict.label, &self.goal_completion_signal) {
170 (Some(label), Some(signal)) if label == signal => Some(signal.clone()),
171 _ => None,
172 };
173 ActorTurn {
174 message: verdict.reason.unwrap_or_else(|| "…".to_string()),
175 goal_completed,
176 }
177 }
178}
179
180impl std::fmt::Debug for ActorSimulator {
181 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182 f.debug_struct("ActorSimulator")
183 .field("profile", &self.profile)
184 .field("max_turns", &self.max_turns)
185 .field("goal_completion_signal", &self.goal_completion_signal)
186 .finish_non_exhaustive()
187 }
188}