Skip to main content

terrain_forge/algorithms/
wfc.rs

1use crate::{Algorithm, Grid, Rng, Tile};
2use std::collections::{HashMap, VecDeque};
3
4#[derive(Debug, Clone)]
5pub struct WfcConfig {
6    pub floor_weight: f64,
7    pub pattern_size: usize,
8    pub enable_backtracking: bool,
9}
10
11impl Default for WfcConfig {
12    fn default() -> Self {
13        Self {
14            floor_weight: 0.4,
15            pattern_size: 3,
16            enable_backtracking: true,
17        }
18    }
19}
20
21#[derive(Debug, Clone, PartialEq, Eq, Hash)]
22pub struct Pattern {
23    tiles: Vec<Vec<Tile>>,
24}
25
26impl Pattern {
27    fn new(size: usize) -> Self {
28        Self {
29            tiles: vec![vec![Tile::Wall; size]; size],
30        }
31    }
32
33    fn from_grid(grid: &Grid<Tile>, x: usize, y: usize, size: usize) -> Option<Self> {
34        let mut tiles = vec![vec![Tile::Wall; size]; size];
35        for (dy, row) in tiles.iter_mut().enumerate() {
36            for (dx, cell) in row.iter_mut().enumerate() {
37                if let Some(tile) = grid.get((x + dx) as i32, (y + dy) as i32) {
38                    *cell = *tile;
39                } else {
40                    return None;
41                }
42            }
43        }
44        Some(Self { tiles })
45    }
46
47    fn rotated(&self) -> Self {
48        let size = self.tiles.len();
49        let mut tiles = vec![vec![Tile::Wall; size]; size];
50        for (y, row) in self.tiles.iter().enumerate() {
51            for (x, &tile) in row.iter().enumerate() {
52                tiles[x][size - 1 - y] = tile;
53            }
54        }
55        Self { tiles }
56    }
57}
58
59#[derive(Debug, Clone)]
60pub struct WfcState {
61    possibilities: Vec<Vec<Vec<usize>>>,
62    patterns: Vec<Pattern>,
63    #[allow(dead_code)]
64    constraints: HashMap<(usize, i32, i32), Vec<usize>>,
65    width: usize,
66    height: usize,
67}
68
69impl WfcState {
70    fn new(width: usize, height: usize, patterns: Vec<Pattern>) -> Self {
71        let pattern_count = patterns.len();
72        let possibilities = vec![vec![(0..pattern_count).collect(); width]; height];
73
74        Self {
75            possibilities,
76            patterns,
77            constraints: HashMap::new(),
78            width,
79            height,
80        }
81    }
82
83    fn entropy(&self, x: usize, y: usize) -> usize {
84        self.possibilities[y][x].len()
85    }
86
87    fn is_collapsed(&self, x: usize, y: usize) -> bool {
88        self.entropy(x, y) == 1
89    }
90
91    fn collapse(&mut self, x: usize, y: usize, pattern_id: usize) -> bool {
92        if !self.possibilities[y][x].contains(&pattern_id) {
93            return false;
94        }
95        self.possibilities[y][x] = vec![pattern_id];
96        true
97    }
98
99    fn propagate(&mut self) -> bool {
100        let mut queue = VecDeque::new();
101
102        // Add all collapsed cells to queue
103        for y in 0..self.height {
104            for x in 0..self.width {
105                if self.is_collapsed(x, y) {
106                    queue.push_back((x, y));
107                }
108            }
109        }
110
111        while let Some((x, y)) = queue.pop_front() {
112            let current_patterns = self.possibilities[y][x].clone();
113
114            // Check all neighbors
115            for (dx, dy) in [(-1, 0), (1, 0), (0, -1), (0, 1)] {
116                let nx = x as i32 + dx;
117                let ny = y as i32 + dy;
118
119                if nx >= 0 && ny >= 0 && (nx as usize) < self.width && (ny as usize) < self.height {
120                    let nx = nx as usize;
121                    let ny = ny as usize;
122
123                    if self.constrain_neighbor(nx, ny, &current_patterns, dx, dy) {
124                        if self.possibilities[ny][nx].is_empty() {
125                            return false; // Contradiction
126                        }
127                        queue.push_back((nx, ny));
128                    }
129                }
130            }
131        }
132
133        true
134    }
135
136    fn constrain_neighbor(
137        &mut self,
138        x: usize,
139        y: usize,
140        allowed_patterns: &[usize],
141        dx: i32,
142        dy: i32,
143    ) -> bool {
144        let mut changed = false;
145        let mut valid_patterns = Vec::new();
146
147        for &pattern_id in &self.possibilities[y][x] {
148            if self.is_compatible(pattern_id, allowed_patterns, dx, dy) {
149                valid_patterns.push(pattern_id);
150            }
151        }
152
153        if valid_patterns.len() != self.possibilities[y][x].len() {
154            self.possibilities[y][x] = valid_patterns;
155            changed = true;
156        }
157
158        changed
159    }
160
161    fn is_compatible(
162        &self,
163        pattern_id: usize,
164        neighbor_patterns: &[usize],
165        dx: i32,
166        dy: i32,
167    ) -> bool {
168        // Simplified compatibility check - patterns are compatible if they have matching edges
169        for &neighbor_id in neighbor_patterns {
170            if self.patterns_compatible(pattern_id, neighbor_id, dx, dy) {
171                return true;
172            }
173        }
174        false
175    }
176
177    fn patterns_compatible(&self, p1: usize, p2: usize, dx: i32, dy: i32) -> bool {
178        let pattern1 = &self.patterns[p1];
179        let pattern2 = &self.patterns[p2];
180        let size = pattern1.tiles.len();
181
182        // Check edge compatibility based on direction
183        match (dx, dy) {
184            (1, 0) => {
185                // p2 is to the right of p1
186                for y in 0..size {
187                    if pattern1.tiles[y][size - 1] != pattern2.tiles[y][0] {
188                        return false;
189                    }
190                }
191            }
192            (-1, 0) => {
193                // p2 is to the left of p1
194                for y in 0..size {
195                    if pattern1.tiles[y][0] != pattern2.tiles[y][size - 1] {
196                        return false;
197                    }
198                }
199            }
200            (0, 1) => {
201                // p2 is below p1
202                for x in 0..size {
203                    if pattern1.tiles[size - 1][x] != pattern2.tiles[0][x] {
204                        return false;
205                    }
206                }
207            }
208            (0, -1) => {
209                // p2 is above p1
210                for x in 0..size {
211                    if pattern1.tiles[0][x] != pattern2.tiles[size - 1][x] {
212                        return false;
213                    }
214                }
215            }
216            _ => {}
217        }
218
219        true
220    }
221}
222
223pub struct WfcPatternExtractor;
224
225impl WfcPatternExtractor {
226    pub fn extract_patterns(grid: &Grid<Tile>, pattern_size: usize) -> Vec<Pattern> {
227        let mut patterns = Vec::new();
228        let mut pattern_set = std::collections::HashSet::new();
229
230        for y in 0..=grid.height().saturating_sub(pattern_size) {
231            for x in 0..=grid.width().saturating_sub(pattern_size) {
232                if let Some(pattern) = Pattern::from_grid(grid, x, y, pattern_size) {
233                    if pattern_set.insert(pattern.clone()) {
234                        patterns.push(pattern.clone());
235                        // Add rotations
236                        let mut rotated = pattern;
237                        for _ in 0..3 {
238                            rotated = rotated.rotated();
239                            if pattern_set.insert(rotated.clone()) {
240                                patterns.push(rotated.clone());
241                            }
242                        }
243                    }
244                }
245            }
246        }
247
248        // Ensure we have at least basic patterns
249        if patterns.is_empty() {
250            let wall_pattern = Pattern::new(pattern_size);
251            let mut floor_pattern = Pattern::new(pattern_size);
252            for row in &mut floor_pattern.tiles {
253                for tile in row {
254                    *tile = Tile::Floor;
255                }
256            }
257            patterns.push(wall_pattern);
258            patterns.push(floor_pattern);
259        }
260
261        patterns
262    }
263}
264
265#[derive(Default)]
266pub struct WfcBacktracker {
267    states: Vec<WfcState>,
268}
269
270impl WfcBacktracker {
271    pub fn new() -> Self {
272        Self::default()
273    }
274
275    pub fn save_state(&mut self, state: &WfcState) {
276        self.states.push(state.clone());
277    }
278
279    pub fn backtrack(&mut self) -> Option<WfcState> {
280        self.states.pop()
281    }
282}
283
284pub struct Wfc {
285    config: WfcConfig,
286}
287
288impl Wfc {
289    pub fn new(config: WfcConfig) -> Self {
290        Self { config }
291    }
292
293    pub fn generate_with_patterns(&self, grid: &mut Grid<Tile>, patterns: Vec<Pattern>, seed: u64) {
294        let mut rng = Rng::new(seed);
295        let mut state = WfcState::new(grid.width(), grid.height(), patterns);
296        let mut backtracker = WfcBacktracker::new();
297
298        // Set border constraints
299        self.set_border_constraints(&mut state);
300
301        loop {
302            if !state.propagate() {
303                if self.config.enable_backtracking {
304                    if let Some(prev_state) = backtracker.backtrack() {
305                        state = prev_state;
306                        continue;
307                    }
308                }
309                break; // Failed to solve
310            }
311
312            // Find cell with minimum entropy > 1
313            if let Some((x, y)) = self.find_min_entropy_cell(&state) {
314                if self.config.enable_backtracking {
315                    backtracker.save_state(&state);
316                }
317
318                let pattern_id = self.choose_pattern(&state, x, y, &mut rng);
319                if !state.collapse(x, y, pattern_id) {
320                    if self.config.enable_backtracking {
321                        if let Some(prev_state) = backtracker.backtrack() {
322                            state = prev_state;
323                            continue;
324                        }
325                    }
326                    break;
327                }
328            } else {
329                break; // All cells collapsed
330            }
331        }
332
333        self.apply_to_grid(&state, grid);
334    }
335
336    fn set_border_constraints(&self, state: &mut WfcState) {
337        // Force borders to be walls by keeping only wall patterns
338        let wall_patterns: Vec<usize> = state
339            .patterns
340            .iter()
341            .enumerate()
342            .filter(|(_, p)| {
343                p.tiles
344                    .iter()
345                    .all(|row| row.iter().all(|&t| t == Tile::Wall))
346            })
347            .map(|(i, _)| i)
348            .collect();
349
350        if !wall_patterns.is_empty() {
351            for x in 0..state.width {
352                state.possibilities[0][x] = wall_patterns.clone();
353                state.possibilities[state.height - 1][x] = wall_patterns.clone();
354            }
355            for y in 0..state.height {
356                state.possibilities[y][0] = wall_patterns.clone();
357                state.possibilities[y][state.width - 1] = wall_patterns.clone();
358            }
359        }
360    }
361
362    fn find_min_entropy_cell(&self, state: &WfcState) -> Option<(usize, usize)> {
363        let mut min_entropy = usize::MAX;
364        let mut candidates = Vec::new();
365
366        for y in 0..state.height {
367            for x in 0..state.width {
368                let entropy = state.entropy(x, y);
369                if entropy > 1 {
370                    if entropy < min_entropy {
371                        min_entropy = entropy;
372                        candidates.clear();
373                    }
374                    if entropy == min_entropy {
375                        candidates.push((x, y));
376                    }
377                }
378            }
379        }
380
381        candidates.into_iter().next()
382    }
383
384    fn choose_pattern(&self, state: &WfcState, x: usize, y: usize, rng: &mut Rng) -> usize {
385        let patterns = &state.possibilities[y][x];
386        *rng.pick(patterns).unwrap_or(&0)
387    }
388
389    fn apply_to_grid(&self, state: &WfcState, grid: &mut Grid<Tile>) {
390        let pattern_size = if !state.patterns.is_empty() {
391            state.patterns[0].tiles.len()
392        } else {
393            1
394        };
395
396        for y in 0..state.height {
397            for x in 0..state.width {
398                if state.is_collapsed(x, y) {
399                    let pattern_id = state.possibilities[y][x][0];
400                    let pattern = &state.patterns[pattern_id];
401
402                    // Apply center tile of pattern
403                    let center = pattern_size / 2;
404                    let tile = pattern.tiles[center][center];
405                    grid.set(x as i32, y as i32, tile);
406                }
407            }
408        }
409    }
410}
411
412impl Default for Wfc {
413    fn default() -> Self {
414        Self::new(WfcConfig::default())
415    }
416}
417
418impl Algorithm<Tile> for Wfc {
419    fn generate(&self, grid: &mut Grid<Tile>, seed: u64) {
420        // Create basic patterns for default generation
421        let patterns = vec![
422            Pattern {
423                tiles: vec![vec![Tile::Wall; 3]; 3],
424            },
425            Pattern {
426                tiles: vec![vec![Tile::Floor; 3]; 3],
427            },
428            Pattern {
429                tiles: vec![
430                    vec![Tile::Wall, Tile::Wall, Tile::Wall],
431                    vec![Tile::Wall, Tile::Floor, Tile::Wall],
432                    vec![Tile::Wall, Tile::Wall, Tile::Wall],
433                ],
434            },
435            Pattern {
436                tiles: vec![
437                    vec![Tile::Floor, Tile::Floor, Tile::Floor],
438                    vec![Tile::Floor, Tile::Floor, Tile::Floor],
439                    vec![Tile::Wall, Tile::Wall, Tile::Wall],
440                ],
441            },
442        ];
443
444        self.generate_with_patterns(grid, patterns, seed);
445    }
446
447    fn name(&self) -> &'static str {
448        "WFC"
449    }
450}