swarm_engine_eval/environments/
maze.rs1use std::collections::HashMap;
43use std::sync::RwLock;
44
45use serde::{Deserialize, Serialize};
46
47use swarm_engine_core::agent::WorkResult;
48use swarm_engine_core::environment::Environment;
49use swarm_engine_core::types::{Action, WorkerId};
50
51#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
57pub struct Position {
58 pub x: i32,
59 pub y: i32,
60}
61
62impl Position {
63 pub fn new(x: i32, y: i32) -> Self {
64 Self { x, y }
65 }
66
67 pub fn moved(&self, direction: &str) -> Self {
69 match direction.to_lowercase().as_str() {
70 "north" | "n" | "up" => Self::new(self.x, self.y - 1),
71 "south" | "s" | "down" => Self::new(self.x, self.y + 1),
72 "east" | "e" | "right" => Self::new(self.x + 1, self.y),
73 "west" | "w" | "left" => Self::new(self.x - 1, self.y),
74 _ => *self,
75 }
76 }
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
81pub enum Cell {
82 Wall,
84 Floor,
86 Start,
88 Goal,
90}
91
92impl Cell {
93 pub fn is_passable(&self) -> bool {
95 matches!(self, Cell::Floor | Cell::Start | Cell::Goal)
96 }
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct MazeMap {
102 pub grid: Vec<Vec<Cell>>,
104 pub width: usize,
106 pub height: usize,
108 pub start: Position,
110 pub goal: Position,
112}
113
114impl MazeMap {
115 pub fn parse(s: &str) -> Self {
122 let mut grid = Vec::new();
123 let mut start = Position::new(0, 0);
124 let mut goal = Position::new(0, 0);
125
126 for (y, line) in s.lines().enumerate() {
127 let trimmed = line.trim();
128 if trimmed.is_empty() {
129 continue;
130 }
131
132 let mut row = Vec::new();
133 for (x, ch) in trimmed.chars().enumerate() {
134 let cell = match ch {
135 '#' => Cell::Wall,
136 '.' => Cell::Floor,
137 'S' => {
138 start = Position::new(x as i32, y as i32);
139 Cell::Start
140 }
141 'G' => {
142 goal = Position::new(x as i32, y as i32);
143 Cell::Goal
144 }
145 _ => Cell::Floor,
146 };
147 row.push(cell);
148 }
149 grid.push(row);
150 }
151
152 let height = grid.len();
153 let width = grid.first().map(|r| r.len()).unwrap_or(0);
154
155 let mut actual_y = 0;
157 for line in s.lines() {
158 let trimmed = line.trim();
159 if trimmed.is_empty() {
160 continue;
161 }
162 for (x, ch) in trimmed.chars().enumerate() {
163 if ch == 'S' {
164 start = Position::new(x as i32, actual_y);
165 } else if ch == 'G' {
166 goal = Position::new(x as i32, actual_y);
167 }
168 }
169 actual_y += 1;
170 }
171
172 Self {
173 grid,
174 width,
175 height,
176 start,
177 goal,
178 }
179 }
180
181 pub fn get(&self, pos: Position) -> Option<Cell> {
183 if pos.x < 0 || pos.y < 0 {
184 return None;
185 }
186 self.grid
187 .get(pos.y as usize)
188 .and_then(|row| row.get(pos.x as usize))
189 .copied()
190 }
191
192 pub fn is_passable(&self, pos: Position) -> bool {
194 self.get(pos).map(|c| c.is_passable()).unwrap_or(false)
195 }
196
197 pub fn is_goal(&self, pos: Position) -> bool {
199 self.get(pos) == Some(Cell::Goal)
200 }
201}
202
203#[derive(Debug)]
209struct MazeState {
210 agents: HashMap<WorkerId, Position>,
212 reached_goal: Vec<WorkerId>,
214}
215
216pub struct MazeEnvironment {
222 map: MazeMap,
224 state: RwLock<MazeState>,
226 worker_count: usize,
228}
229
230impl MazeEnvironment {
231 pub fn new(map: MazeMap, worker_count: usize) -> Self {
233 let mut agents = HashMap::new();
234
235 for i in 0..worker_count {
237 agents.insert(WorkerId(i), map.start);
238 }
239
240 Self {
241 map,
242 state: RwLock::new(MazeState {
243 agents,
244 reached_goal: Vec::new(),
245 }),
246 worker_count,
247 }
248 }
249
250 pub fn from_str(map_str: &str, worker_count: usize) -> Self {
252 let map = MazeMap::parse(map_str);
253 Self::new(map, worker_count)
254 }
255
256 fn handle_move(&self, worker_id: WorkerId, action: &Action) -> WorkResult {
261 let direction = action
262 .params
263 .args
264 .get("target")
265 .map(|s| s.as_str())
266 .unwrap_or("north");
267
268 let mut state = self.state.write().unwrap();
269
270 let current_pos = match state.agents.get(&worker_id) {
271 Some(pos) => *pos,
272 None => return WorkResult::env_failure("Worker not found in maze"),
273 };
274
275 let new_pos = current_pos.moved(direction);
276
277 if !self.map.is_passable(new_pos) {
279 return WorkResult::env_failure(format!(
280 "Cannot move {}: wall or out of bounds",
281 direction
282 ));
283 }
284
285 state.agents.insert(worker_id, new_pos);
287
288 if self.map.is_goal(new_pos) {
290 if !state.reached_goal.contains(&worker_id) {
291 state.reached_goal.push(worker_id);
292 }
293
294 let all_reached = state.reached_goal.len() >= self.worker_count;
296 if all_reached {
297 return WorkResult::done_success(format!(
298 "Moved {} to goal! All workers reached goal!",
299 direction
300 ));
301 } else {
302 return WorkResult::env_success(format!("Moved {} to goal!", direction));
303 }
304 }
305
306 WorkResult::env_success(format!(
307 "Moved {} to ({}, {})",
308 direction, new_pos.x, new_pos.y
309 ))
310 }
311
312 fn handle_look(&self, worker_id: WorkerId) -> WorkResult {
314 let state = self.state.read().unwrap();
315
316 let current_pos = match state.agents.get(&worker_id) {
317 Some(pos) => *pos,
318 None => return WorkResult::env_failure("Worker not found in maze"),
319 };
320
321 let mut surroundings = HashMap::new();
323 for (dir, offset) in &[
324 ("north", (0, -1)),
325 ("south", (0, 1)),
326 ("east", (1, 0)),
327 ("west", (-1, 0)),
328 ] {
329 let check_pos = Position::new(current_pos.x + offset.0, current_pos.y + offset.1);
330 let cell_info = match self.map.get(check_pos) {
331 Some(Cell::Wall) => "wall",
332 Some(Cell::Floor) => "floor",
333 Some(Cell::Start) => "start",
334 Some(Cell::Goal) => "goal",
335 None => "void",
336 };
337 surroundings.insert(*dir, cell_info);
338 }
339
340 let data = serde_json::json!({
341 "position": { "x": current_pos.x, "y": current_pos.y },
342 "surroundings": surroundings,
343 "at_goal": self.map.is_goal(current_pos),
344 });
345
346 WorkResult::env_success_structured(data)
347 }
348
349 fn handle_wait(&self, _worker_id: WorkerId) -> WorkResult {
350 WorkResult::env_success("Waiting...")
351 }
352}
353
354impl Environment for MazeEnvironment {
355 fn step(&self, worker_id: WorkerId, action: &Action) -> WorkResult {
356 match action.name.as_str() {
357 "Move" => self.handle_move(worker_id, action),
358 "Look" => self.handle_look(worker_id),
359 "Wait" => self.handle_wait(worker_id),
360 _ => WorkResult::unsupported(&action.name),
361 }
362 }
363
364 fn reset(&self) {
365 let mut state = self.state.write().unwrap();
366
367 state.agents.clear();
369 for i in 0..self.worker_count {
370 state.agents.insert(WorkerId(i), self.map.start);
371 }
372 state.reached_goal.clear();
373 }
374
375 fn name(&self) -> &str {
376 "MazeEnvironment"
377 }
378}
379
380#[cfg(test)]
385mod tests {
386 use super::*;
387
388 const SIMPLE_MAZE: &str = "
389 #####
390 #S..#
391 #.#.#
392 #..G#
393 #####
394 ";
395
396 #[test]
397 fn test_maze_map_from_str() {
398 let map = MazeMap::parse(SIMPLE_MAZE);
399 assert_eq!(map.width, 5);
400 assert_eq!(map.height, 5);
401 assert_eq!(map.start, Position::new(1, 1));
402 assert_eq!(map.goal, Position::new(3, 3));
403 }
404
405 #[test]
406 fn test_maze_passable() {
407 let map = MazeMap::parse(SIMPLE_MAZE);
408 assert!(map.is_passable(Position::new(1, 1))); assert!(map.is_passable(Position::new(2, 1))); assert!(!map.is_passable(Position::new(0, 0))); assert!(!map.is_passable(Position::new(2, 2))); }
413
414 fn is_success(result: &WorkResult) -> bool {
416 match result {
417 WorkResult::Acted { action_result, .. } => action_result.success,
418 WorkResult::Done { success, .. } => *success,
419 _ => false,
420 }
421 }
422
423 #[test]
424 fn test_maze_environment_move() {
425 let env = MazeEnvironment::from_str(SIMPLE_MAZE, 1);
426 let worker = WorkerId(0);
427
428 let action = Action {
430 name: "Move".to_string(),
431 params: swarm_engine_core::types::ActionParams {
432 target: None,
433 args: [("target".to_string(), "east".to_string())]
434 .into_iter()
435 .collect(),
436 data: vec![],
437 },
438 };
439
440 let result = env.step(worker, &action);
441 assert!(is_success(&result));
442
443 let state = env.state.read().unwrap();
445 assert_eq!(state.agents.get(&worker), Some(&Position::new(2, 1)));
446 }
447
448 #[test]
449 fn test_maze_environment_wall_collision() {
450 let env = MazeEnvironment::from_str(SIMPLE_MAZE, 1);
451 let worker = WorkerId(0);
452
453 let action = Action {
455 name: "Move".to_string(),
456 params: swarm_engine_core::types::ActionParams {
457 target: None,
458 args: [("target".to_string(), "north".to_string())]
459 .into_iter()
460 .collect(),
461 data: vec![],
462 },
463 };
464
465 let result = env.step(worker, &action);
466 assert!(!is_success(&result)); }
468
469 #[test]
470 fn test_maze_environment_look() {
471 let env = MazeEnvironment::from_str(SIMPLE_MAZE, 1);
472 let worker = WorkerId(0);
473
474 let action = Action {
475 name: "Look".to_string(),
476 params: Default::default(),
477 };
478
479 let result = env.step(worker, &action);
480 assert!(is_success(&result));
481 if let WorkResult::Acted { action_result, .. } = result {
483 assert!(action_result.output.is_some());
484 } else {
485 panic!("Expected WorkResult::Acted");
486 }
487 }
488}