Skip to main content

terrain_forge/algorithms/
wfc.rs

1use crate::{Algorithm, Grid, Rng, Tile};
2use serde::{Deserialize, Serialize};
3use std::collections::{HashMap, VecDeque};
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
6/// Configuration for Wave Function Collapse generation.
7pub struct WfcConfig {
8    /// Weight for floor tiles in random collapse. Default: 0.4.
9    pub floor_weight: f64,
10    /// Size of extracted patterns (NxN). Default: 3.
11    pub pattern_size: usize,
12    /// Enable backtracking on contradiction. Default: true.
13    pub enable_backtracking: bool,
14}
15
16impl Default for WfcConfig {
17    fn default() -> Self {
18        Self {
19            floor_weight: 0.4,
20            pattern_size: 3,
21            enable_backtracking: true,
22        }
23    }
24}
25
26#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
27/// A tile pattern extracted from an example grid.
28pub struct Pattern {
29    tiles: Vec<Vec<Tile>>,
30}
31
32impl Pattern {
33    fn new(size: usize) -> Self {
34        Self {
35            tiles: vec![vec![Tile::Wall; size]; size],
36        }
37    }
38
39    fn from_grid(grid: &Grid<Tile>, x: usize, y: usize, size: usize) -> Option<Self> {
40        let mut tiles = vec![vec![Tile::Wall; size]; size];
41        for (dy, row) in tiles.iter_mut().enumerate() {
42            for (dx, cell) in row.iter_mut().enumerate() {
43                if let Some(tile) = grid.get((x + dx) as i32, (y + dy) as i32) {
44                    *cell = *tile;
45                } else {
46                    return None;
47                }
48            }
49        }
50        Some(Self { tiles })
51    }
52
53    fn rotated(&self) -> Self {
54        let size = self.tiles.len();
55        let mut tiles = vec![vec![Tile::Wall; size]; size];
56        for (y, row) in self.tiles.iter().enumerate() {
57            for (x, &tile) in row.iter().enumerate() {
58                tiles[x][size - 1 - y] = tile;
59            }
60        }
61        Self { tiles }
62    }
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
66/// Internal state of a WFC solve.
67pub struct WfcState {
68    possibilities: Vec<Vec<Vec<usize>>>,
69    patterns: Vec<Pattern>,
70    #[allow(dead_code)]
71    constraints: HashMap<(usize, i32, i32), Vec<usize>>,
72    width: usize,
73    height: usize,
74}
75
76impl WfcState {
77    fn new(width: usize, height: usize, patterns: Vec<Pattern>) -> Self {
78        let pattern_count = patterns.len();
79        let possibilities = vec![vec![(0..pattern_count).collect(); width]; height];
80
81        Self {
82            possibilities,
83            patterns,
84            constraints: HashMap::new(),
85            width,
86            height,
87        }
88    }
89
90    fn entropy(&self, x: usize, y: usize) -> usize {
91        self.possibilities[y][x].len()
92    }
93
94    fn is_collapsed(&self, x: usize, y: usize) -> bool {
95        self.entropy(x, y) == 1
96    }
97
98    fn collapse(&mut self, x: usize, y: usize, pattern_id: usize) -> bool {
99        if !self.possibilities[y][x].contains(&pattern_id) {
100            return false;
101        }
102        self.possibilities[y][x] = vec![pattern_id];
103        true
104    }
105
106    fn propagate(&mut self) -> bool {
107        let mut queue = VecDeque::new();
108
109        // Add all collapsed cells to queue
110        for y in 0..self.height {
111            for x in 0..self.width {
112                if self.is_collapsed(x, y) {
113                    queue.push_back((x, y));
114                }
115            }
116        }
117
118        while let Some((x, y)) = queue.pop_front() {
119            let current_patterns = self.possibilities[y][x].clone();
120
121            // Check all neighbors
122            for (dx, dy) in [(-1, 0), (1, 0), (0, -1), (0, 1)] {
123                let nx = x as i32 + dx;
124                let ny = y as i32 + dy;
125
126                if nx >= 0 && ny >= 0 && (nx as usize) < self.width && (ny as usize) < self.height {
127                    let nx = nx as usize;
128                    let ny = ny as usize;
129
130                    if self.constrain_neighbor(nx, ny, &current_patterns, dx, dy) {
131                        if self.possibilities[ny][nx].is_empty() {
132                            return false; // Contradiction
133                        }
134                        queue.push_back((nx, ny));
135                    }
136                }
137            }
138        }
139
140        true
141    }
142
143    fn constrain_neighbor(
144        &mut self,
145        x: usize,
146        y: usize,
147        allowed_patterns: &[usize],
148        dx: i32,
149        dy: i32,
150    ) -> bool {
151        let mut changed = false;
152        let mut valid_patterns = Vec::new();
153
154        for &pattern_id in &self.possibilities[y][x] {
155            if self.is_compatible(pattern_id, allowed_patterns, dx, dy) {
156                valid_patterns.push(pattern_id);
157            }
158        }
159
160        if valid_patterns.len() != self.possibilities[y][x].len() {
161            self.possibilities[y][x] = valid_patterns;
162            changed = true;
163        }
164
165        changed
166    }
167
168    fn is_compatible(
169        &self,
170        pattern_id: usize,
171        neighbor_patterns: &[usize],
172        dx: i32,
173        dy: i32,
174    ) -> bool {
175        // Simplified compatibility check - patterns are compatible if they have matching edges
176        for &neighbor_id in neighbor_patterns {
177            if self.patterns_compatible(pattern_id, neighbor_id, dx, dy) {
178                return true;
179            }
180        }
181        false
182    }
183
184    fn patterns_compatible(&self, p1: usize, p2: usize, dx: i32, dy: i32) -> bool {
185        let pattern1 = &self.patterns[p1];
186        let pattern2 = &self.patterns[p2];
187        let size = pattern1.tiles.len();
188
189        // Check edge compatibility based on direction
190        match (dx, dy) {
191            (1, 0) => {
192                // p2 is to the right of p1
193                for y in 0..size {
194                    if pattern1.tiles[y][size - 1] != pattern2.tiles[y][0] {
195                        return false;
196                    }
197                }
198            }
199            (-1, 0) => {
200                // p2 is to the left of p1
201                for y in 0..size {
202                    if pattern1.tiles[y][0] != pattern2.tiles[y][size - 1] {
203                        return false;
204                    }
205                }
206            }
207            (0, 1) => {
208                // p2 is below p1
209                for x in 0..size {
210                    if pattern1.tiles[size - 1][x] != pattern2.tiles[0][x] {
211                        return false;
212                    }
213                }
214            }
215            (0, -1) => {
216                // p2 is above p1
217                for x in 0..size {
218                    if pattern1.tiles[0][x] != pattern2.tiles[size - 1][x] {
219                        return false;
220                    }
221                }
222            }
223            _ => {}
224        }
225
226        true
227    }
228}
229
230/// Extracts tile patterns from example grids for WFC.
231pub struct WfcPatternExtractor;
232
233impl WfcPatternExtractor {
234    /// Extracts all unique NxN patterns (with rotations) from the grid.
235    pub fn extract_patterns(grid: &Grid<Tile>, pattern_size: usize) -> Vec<Pattern> {
236        let mut patterns = Vec::new();
237        let mut pattern_set = std::collections::HashSet::new();
238
239        for y in 0..=grid.height().saturating_sub(pattern_size) {
240            for x in 0..=grid.width().saturating_sub(pattern_size) {
241                if let Some(pattern) = Pattern::from_grid(grid, x, y, pattern_size) {
242                    if pattern_set.insert(pattern.clone()) {
243                        patterns.push(pattern.clone());
244                        // Add rotations
245                        let mut rotated = pattern;
246                        for _ in 0..3 {
247                            rotated = rotated.rotated();
248                            if pattern_set.insert(rotated.clone()) {
249                                patterns.push(rotated.clone());
250                            }
251                        }
252                    }
253                }
254            }
255        }
256
257        // Ensure we have at least basic patterns
258        if patterns.is_empty() {
259            let wall_pattern = Pattern::new(pattern_size);
260            let mut floor_pattern = Pattern::new(pattern_size);
261            for row in &mut floor_pattern.tiles {
262                for tile in row {
263                    *tile = Tile::Floor;
264                }
265            }
266            patterns.push(wall_pattern);
267            patterns.push(floor_pattern);
268        }
269
270        patterns
271    }
272}
273
274#[derive(Debug, Clone, Default)]
275/// Backtracking state manager for WFC.
276pub struct WfcBacktracker {
277    states: Vec<WfcState>,
278}
279
280impl WfcBacktracker {
281    /// Creates a new backtracker.
282    pub fn new() -> Self {
283        Self::default()
284    }
285
286    /// Saves a WFC state snapshot.
287    pub fn save_state(&mut self, state: &WfcState) {
288        self.states.push(state.clone());
289    }
290
291    /// Restores the most recent saved state.
292    pub fn backtrack(&mut self) -> Option<WfcState> {
293        self.states.pop()
294    }
295}
296
297#[derive(Debug, Clone, Serialize, Deserialize)]
298/// Wave Function Collapse terrain generator.
299pub struct Wfc {
300    config: WfcConfig,
301}
302
303impl Wfc {
304    /// Creates a new WFC generator with the given config.
305    pub fn new(config: WfcConfig) -> Self {
306        Self { config }
307    }
308
309    /// Generates terrain using pre-extracted patterns.
310    pub fn generate_with_patterns(&self, grid: &mut Grid<Tile>, patterns: Vec<Pattern>, seed: u64) {
311        let mut rng = Rng::new(seed);
312        let mut state = WfcState::new(grid.width(), grid.height(), patterns);
313        let mut backtracker = WfcBacktracker::new();
314
315        // Set border constraints
316        self.set_border_constraints(&mut state);
317
318        loop {
319            if !state.propagate() {
320                if self.config.enable_backtracking {
321                    if let Some(prev_state) = backtracker.backtrack() {
322                        state = prev_state;
323                        continue;
324                    }
325                }
326                break; // Failed to solve
327            }
328
329            // Find cell with minimum entropy > 1
330            if let Some((x, y)) = self.find_min_entropy_cell(&state) {
331                if self.config.enable_backtracking {
332                    backtracker.save_state(&state);
333                }
334
335                let pattern_id = self.choose_pattern(&state, x, y, &mut rng);
336                if !state.collapse(x, y, pattern_id) {
337                    if self.config.enable_backtracking {
338                        if let Some(prev_state) = backtracker.backtrack() {
339                            state = prev_state;
340                            continue;
341                        }
342                    }
343                    break;
344                }
345            } else {
346                break; // All cells collapsed
347            }
348        }
349
350        self.apply_to_grid(&state, grid);
351    }
352
353    fn set_border_constraints(&self, state: &mut WfcState) {
354        // Force borders to be walls by keeping only wall patterns
355        let wall_patterns: Vec<usize> = state
356            .patterns
357            .iter()
358            .enumerate()
359            .filter(|(_, p)| {
360                p.tiles
361                    .iter()
362                    .all(|row| row.iter().all(|&t| t == Tile::Wall))
363            })
364            .map(|(i, _)| i)
365            .collect();
366
367        if !wall_patterns.is_empty() {
368            for x in 0..state.width {
369                state.possibilities[0][x] = wall_patterns.clone();
370                state.possibilities[state.height - 1][x] = wall_patterns.clone();
371            }
372            for y in 0..state.height {
373                state.possibilities[y][0] = wall_patterns.clone();
374                state.possibilities[y][state.width - 1] = wall_patterns.clone();
375            }
376        }
377    }
378
379    fn find_min_entropy_cell(&self, state: &WfcState) -> Option<(usize, usize)> {
380        let mut min_entropy = usize::MAX;
381        let mut candidates = Vec::new();
382
383        for y in 0..state.height {
384            for x in 0..state.width {
385                let entropy = state.entropy(x, y);
386                if entropy > 1 {
387                    if entropy < min_entropy {
388                        min_entropy = entropy;
389                        candidates.clear();
390                    }
391                    if entropy == min_entropy {
392                        candidates.push((x, y));
393                    }
394                }
395            }
396        }
397
398        candidates.into_iter().next()
399    }
400
401    fn choose_pattern(&self, state: &WfcState, x: usize, y: usize, rng: &mut Rng) -> usize {
402        let patterns = &state.possibilities[y][x];
403        *rng.pick(patterns).unwrap_or(&0)
404    }
405
406    fn apply_to_grid(&self, state: &WfcState, grid: &mut Grid<Tile>) {
407        let pattern_size = if !state.patterns.is_empty() {
408            state.patterns[0].tiles.len()
409        } else {
410            1
411        };
412
413        for y in 0..state.height {
414            for x in 0..state.width {
415                if state.is_collapsed(x, y) {
416                    let pattern_id = state.possibilities[y][x][0];
417                    let pattern = &state.patterns[pattern_id];
418
419                    // Apply center tile of pattern
420                    let center = pattern_size / 2;
421                    let tile = pattern.tiles[center][center];
422                    grid.set(x as i32, y as i32, tile);
423                }
424            }
425        }
426    }
427}
428
429impl Default for Wfc {
430    fn default() -> Self {
431        Self::new(WfcConfig::default())
432    }
433}
434
435impl Algorithm<Tile> for Wfc {
436    fn generate(&self, grid: &mut Grid<Tile>, seed: u64) {
437        // Create basic patterns for default generation
438        let patterns = vec![
439            Pattern {
440                tiles: vec![vec![Tile::Wall; 3]; 3],
441            },
442            Pattern {
443                tiles: vec![vec![Tile::Floor; 3]; 3],
444            },
445            Pattern {
446                tiles: vec![
447                    vec![Tile::Wall, Tile::Wall, Tile::Wall],
448                    vec![Tile::Wall, Tile::Floor, Tile::Wall],
449                    vec![Tile::Wall, Tile::Wall, Tile::Wall],
450                ],
451            },
452            Pattern {
453                tiles: vec![
454                    vec![Tile::Floor, Tile::Floor, Tile::Floor],
455                    vec![Tile::Floor, Tile::Floor, Tile::Floor],
456                    vec![Tile::Wall, Tile::Wall, Tile::Wall],
457                ],
458            },
459        ];
460
461        self.generate_with_patterns(grid, patterns, seed);
462    }
463
464    fn name(&self) -> &'static str {
465        "WFC"
466    }
467}