wave_function_collapse/
wave_function.rs

1use std::{collections::{HashMap, HashSet}, rc::Rc, hash::Hash, fs::File, io::BufReader, cell::RefCell};
2use serde::{Serialize, Deserialize, de::DeserializeOwned};
3use bitvec::prelude::*;
4use log::debug;
5extern crate pretty_env_logger;
6mod indexed_view;
7use crate::wave_function::collapsable_wave_function::collapsable_wave_function::CollapsableNode;
8
9use self::{collapsable_wave_function::collapsable_wave_function::CollapsableWaveFunction, indexed_view::IndexedView};
10mod probability_collection;
11mod probability_tree;
12mod probability_container;
13pub mod collapsable_wave_function;
14mod tests;
15
16/// This struct makes for housing convenient utility functions.
17pub struct NodeStateProbability;
18
19impl NodeStateProbability {
20    pub fn get_equal_probability<TNodeState: Eq + Hash + Clone + std::fmt::Debug + Ord>(node_states: &Vec<TNodeState>) -> HashMap<TNodeState, f32> {
21        let mut node_state_probability_per_node_state: HashMap<TNodeState, f32> = HashMap::new();
22
23        for node_state in node_states.into_iter() {
24            node_state_probability_per_node_state.insert(node_state.clone(), 1.0);
25        }
26
27        node_state_probability_per_node_state
28    }
29}
30
31/// This is a node in the graph of the wave function. It can be in any of the provided node states, trying to achieve the cooresponding probability, connected to other nodes as described by the node state collections.
32#[derive(Debug, Serialize, Deserialize, Clone)]
33pub struct Node<TNodeState: Eq + Hash + Clone + std::fmt::Debug + Ord> {
34    pub id: String,
35    pub node_state_collection_ids_per_neighbor_node_id: HashMap<String, Vec<String>>,
36    pub node_state_ids: Vec<TNodeState>,
37    pub node_state_ratios: Vec<f32>
38}
39
40impl<TNodeState: Eq + Hash + Clone + std::fmt::Debug + Ord> Node<TNodeState> {
41    pub fn new(id: String, node_state_ratio_per_node_state_id: HashMap<TNodeState, f32>, node_state_collection_ids_per_neighbor_node_id: HashMap<String, Vec<String>>) -> Self {
42        let mut node_state_ids: Vec<TNodeState> = Vec::new();
43        let mut node_state_ratios: Vec<f32> = Vec::new();
44        for (node_state_id, node_state_ratio) in node_state_ratio_per_node_state_id.iter() {
45            node_state_ids.push(node_state_id.clone());
46            node_state_ratios.push(*node_state_ratio);
47        }
48        
49        // sort the node_state_ids and node_state_probabilities
50        let mut sort_permutation = permutation::sort(&node_state_ids);
51        sort_permutation.apply_slice_in_place(&mut node_state_ids);
52        sort_permutation.apply_slice_in_place(&mut node_state_ratios);
53
54        Node {
55            id,
56            node_state_collection_ids_per_neighbor_node_id,
57            node_state_ids,
58            node_state_ratios
59        }
60    }
61    pub fn get_id(&self) -> String {
62        self.id.clone()
63    }
64}
65
66/// This struct represents a relationship between the state of one "original" node to another "neighbor" node, permitting only those node states for the connected neighbor if the original node is in the specific state. This defines the constraints between nodes.
67#[derive(Debug, Serialize, Deserialize, Clone)]
68pub struct NodeStateCollection<TNodeState: Eq + Hash + Clone + std::fmt::Debug + Ord> {
69    pub id: String,
70    pub node_state_id: TNodeState,
71    pub node_state_ids: Vec<TNodeState>
72}
73
74impl<TNodeState: Eq + Hash + Clone + std::fmt::Debug + Ord> NodeStateCollection<TNodeState> {
75    pub fn new(id: String, node_state_id: TNodeState, node_state_ids: Vec<TNodeState>) -> Self {
76        NodeStateCollection {
77            id,
78            node_state_id,
79            node_state_ids
80        }
81    }
82}
83
84/// This struct represents the uncollapsed definition of nodes and their relationships to other nodes.
85#[derive(Serialize, Clone, Deserialize)]
86pub struct WaveFunction<TNodeState: Eq + Hash + Clone + std::fmt::Debug + Ord> {
87    nodes: Vec<Node<TNodeState>>,
88    node_state_collections: Vec<NodeStateCollection<TNodeState>>
89}
90
91impl<TNodeState: Eq + Hash + Clone + std::fmt::Debug + Ord + Serialize + DeserializeOwned> WaveFunction<TNodeState> {
92    pub fn new(nodes: Vec<Node<TNodeState>>, node_state_collections: Vec<NodeStateCollection<TNodeState>>) -> Self {
93        WaveFunction {
94            nodes,
95            node_state_collections
96        }
97    }
98
99    pub fn get_nodes(&self) -> Vec<Node<TNodeState>> {
100        self.nodes.clone()
101    }
102
103    pub fn get_node_state_collections(&self) -> Vec<NodeStateCollection<TNodeState>> {
104        self.node_state_collections.clone()
105    }
106
107    pub fn validate(&self) -> Result<(), String> {
108        let nodes_length: usize = self.nodes.len();
109
110        let mut node_per_id: HashMap<&str, &Node<TNodeState>> = HashMap::new();
111        let mut node_ids: HashSet<&str> = HashSet::new();
112        self.nodes
113            .iter()
114            .for_each(|node: &Node<TNodeState>| {
115                node_per_id.insert(&node.id, node);
116                node_ids.insert(&node.id);
117            });
118
119        let mut node_state_collection_per_id: HashMap<&str, &NodeStateCollection<TNodeState>> = HashMap::new();
120        self.node_state_collections
121            .iter()
122            .for_each(|node_state_collection| {
123                node_state_collection_per_id.insert(&node_state_collection.id, node_state_collection);
124            });
125
126        // ensure that references neighbors are actually nodes
127        for (_, node) in node_per_id.iter() {
128            for (neighbor_node_id_string, _) in node.node_state_collection_ids_per_neighbor_node_id.iter() {
129                let neighbor_node_id: &str = neighbor_node_id_string;
130                if !node_ids.contains(neighbor_node_id) {
131                    return Err(format!("Neighbor node {neighbor_node_id} does not exist in main list of nodes."));
132                }
133            }
134        }
135
136        let mut at_least_one_node_connects_to_all_other_nodes: bool = false;
137        for node in self.nodes.iter() {
138            // ensure that all nodes connect to all other nodes
139            let mut all_traversed_node_ids: HashSet<&str> = HashSet::new();
140            let mut potential_node_ids: Vec<&str> = Vec::new();
141
142            potential_node_ids.push(&node.id);
143
144            while let Some(node_id) = potential_node_ids.pop() {
145                let node = node_per_id.get(node_id).unwrap();
146                for neighbor_node_id_string in node.node_state_collection_ids_per_neighbor_node_id.keys() {
147                    let neighbor_node_id: &str = neighbor_node_id_string;
148                    if !all_traversed_node_ids.contains(neighbor_node_id) && !potential_node_ids.contains(&neighbor_node_id) {
149                        potential_node_ids.push(neighbor_node_id);
150                    }
151                }
152                all_traversed_node_ids.insert(node_id);
153            }
154
155            let all_traversed_node_ids_length = all_traversed_node_ids.len();
156            if all_traversed_node_ids_length == nodes_length {
157                at_least_one_node_connects_to_all_other_nodes = true;
158                break;
159            }
160        }
161
162        if !at_least_one_node_connects_to_all_other_nodes {
163            return Err(String::from("Not all nodes connect together. At least one node must be able to traverse to all other nodes."));
164        }
165
166        Ok(())
167    }
168
169    pub fn get_collapsable_wave_function<'a, TCollapsableWaveFunction: CollapsableWaveFunction<'a, TNodeState>>(&'a self, random_seed: Option<u64>) -> TCollapsableWaveFunction {
170        let mut node_per_id: HashMap<&str, &Node<TNodeState>> = HashMap::new();
171        self.nodes
172            .iter()
173            .for_each(|node: &Node<TNodeState>| {
174                node_per_id.insert(&node.id, node);
175            });
176
177        let mut node_state_collection_per_id: HashMap<&str, &NodeStateCollection<TNodeState>> = HashMap::new();
178        self.node_state_collections
179            .iter()
180            .for_each(|node_state_collection| {
181                node_state_collection_per_id.insert(&node_state_collection.id, node_state_collection);
182            });
183
184        // for each neighbor node
185        //      for each possible state for this node
186        //          create a mutable bit vector
187        //          for each possible node state for the neighbor node
188        //              get if the neighbor node state is permitted by this node's possible node state
189        //              push the boolean into bit vector
190        //          push bit vector into hashmap of mask per node state per neighbor node
191
192        // neighbor_mask_mapped_view_per_node_id is equivalent to mask_per_child_neighbor_per_state_per_node
193        let mut neighbor_mask_mapped_view_per_node_id: HashMap<&str, HashMap<&TNodeState, HashMap<&str, BitVec>>> = HashMap::new();
194
195        // create, per parent neighbor, a mask for each node (as child of parent neighbor)
196        let mut mask_per_parent_state_per_parent_neighbor_per_node: HashMap<&str, HashMap<&str, HashMap<&TNodeState, BitVec>>> = HashMap::new();
197
198        // for each node
199        for child_node in self.nodes.iter() {
200
201            let mut mask_per_parent_state_per_parent_neighbor: HashMap<&str, HashMap<&TNodeState, BitVec>> = HashMap::new();
202
203            // look for each parent neighbor node
204            for parent_neighbor_node in self.nodes.iter() {
205                // if you find that this is a parent neighbor node
206                if parent_neighbor_node.node_state_collection_ids_per_neighbor_node_id.contains_key(&child_node.id) {
207
208                    debug!("constructing mask for {:?}'s child node {:?}.", parent_neighbor_node.id, child_node.id);
209
210                    let mut mask_per_parent_state: HashMap<&TNodeState, BitVec> = HashMap::new();
211
212                    // get the node state collections that this parent neighbor node forces upon this node
213                    let node_state_collection_ids: &Vec<String> = parent_neighbor_node.node_state_collection_ids_per_neighbor_node_id.get(&child_node.id).unwrap();
214                    for node_state_collection_id in node_state_collection_ids.iter() {
215                        let node_state_collection = node_state_collection_per_id.get(node_state_collection_id.as_str()).unwrap();
216                        // construct a mask for this parent neighbor's node state collection and node state for this child node
217                        let mut mask: BitVec = BitVec::new();
218                        for node_state_id in child_node.node_state_ids.iter() {
219                            // if the node state for the child is permitted by the parent neighbor node state collection
220                            mask.push(node_state_collection.node_state_ids.contains(node_state_id));
221                        }
222                        // store the mask for this child node
223                        mask_per_parent_state.insert(&node_state_collection.node_state_id, mask);
224                    }
225
226                    mask_per_parent_state_per_parent_neighbor.insert(&parent_neighbor_node.id, mask_per_parent_state);
227                }
228            }
229
230            mask_per_parent_state_per_parent_neighbor_per_node.insert(&child_node.id, mask_per_parent_state_per_parent_neighbor);
231        }
232
233        // fill the neighbor_mask_mapped_view_per_node_id now that all masks have been constructed
234        // neighbor_mask_mapped_view_per_node_id is equivalent to mask_per_child_neighbor_per_state_per_node
235        for node in self.nodes.iter() {
236
237            // for this node, find all child neighbors
238            let node_id: &str = node.id.as_str();
239
240            let mut mask_per_neighbor_per_state: HashMap<&TNodeState, HashMap<&str, BitVec>> = HashMap::new();
241
242            for (neighbor_node_id, _) in node.node_state_collection_ids_per_neighbor_node_id.iter() {
243                let neighbor_node_id: &str = neighbor_node_id;
244
245                // get the inverse hashmap of this node to its child neighbor
246                let mask_per_parent_state_per_parent_neighbor = mask_per_parent_state_per_parent_neighbor_per_node.get(neighbor_node_id).unwrap();
247                let mask_per_parent_state = mask_per_parent_state_per_parent_neighbor.get(node_id).unwrap();
248
249                for (node_state_id, mask) in mask_per_parent_state.iter() {
250                    mask_per_neighbor_per_state
251                        .entry(node_state_id)
252                        .or_insert(HashMap::new())
253                        .insert(neighbor_node_id, mask.clone());
254                }
255            }
256
257            neighbor_mask_mapped_view_per_node_id.insert(node_id, mask_per_neighbor_per_state);
258        }
259
260        let mut node_state_indexed_view_per_node_id: HashMap<&str, IndexedView<&TNodeState>> = HashMap::new();
261
262        // store all of the masks that my neighbors will be orienting so that this node can check for restrictions
263        for node in self.nodes.iter() {
264            let node_id: &str = &node.id;
265
266            //debug!("storing for node {node_id} restrictive masks into node state indexed view.");
267
268            let referenced_node_state_ids: Vec<&TNodeState> = node.node_state_ids.iter().collect();
269            let cloned_node_state_ratios: Vec<f32> = node.node_state_ratios.clone();
270
271            let node_state_indexed_view = IndexedView::new(referenced_node_state_ids, cloned_node_state_ratios);
272            //debug!("stored for node {node_id} node state indexed view {:?}", node_state_indexed_view);
273            node_state_indexed_view_per_node_id.insert(node_id, node_state_indexed_view);
274        }
275
276        let mut collapsable_nodes: Vec<Rc<RefCell<CollapsableNode<TNodeState>>>> = Vec::new();
277        let mut collapsable_node_per_id: HashMap<&str, Rc<RefCell<CollapsableNode<TNodeState>>>> = HashMap::new();
278        // contains the mask to apply to the neighbor when this node is in a specific state
279        let random_instance = if let Some(seed) = random_seed {
280            Rc::new(RefCell::new(fastrand::Rng::with_seed(seed)))
281        }
282        else {
283            Rc::new(RefCell::new(fastrand::Rng::new()))
284        };
285        for node in self.nodes.iter() {
286            let node_id: &str = node.id.as_str();
287
288            let node_state_indexed_view: IndexedView<&TNodeState> = node_state_indexed_view_per_node_id.remove(node_id).unwrap();
289            let mask_per_neighbor_per_state = neighbor_mask_mapped_view_per_node_id.remove(node_id).unwrap();
290
291            let mut collapsable_node = CollapsableNode::new(&node.id, &node.node_state_collection_ids_per_neighbor_node_id, mask_per_neighbor_per_state, node_state_indexed_view);
292
293            if random_seed.is_some() {
294                collapsable_node.randomize(&mut random_instance.borrow_mut());
295            }
296
297            collapsable_nodes.push(Rc::new(RefCell::new(collapsable_node)));
298        }
299
300        for wrapped_collapsable_node in collapsable_nodes.iter() {
301            let collapsable_node = wrapped_collapsable_node.borrow();
302            collapsable_node_per_id.insert(collapsable_node.id, wrapped_collapsable_node.clone());
303        }
304
305        for wrapped_collapsable_node in collapsable_nodes.iter() {
306            let mut collapsable_node = wrapped_collapsable_node.borrow_mut();
307            let collapsable_node_id: &str = collapsable_node.id;
308
309            if mask_per_parent_state_per_parent_neighbor_per_node.contains_key(collapsable_node_id) {
310                let mask_per_parent_state_per_parent_neighbor = mask_per_parent_state_per_parent_neighbor_per_node.get(collapsable_node_id).unwrap();
311                for parent_neighbor_node_id in mask_per_parent_state_per_parent_neighbor.keys() {
312                    collapsable_node.parent_neighbor_node_ids.push(parent_neighbor_node_id);
313                }
314                if random_seed.is_some() {
315                    random_instance.borrow_mut().shuffle(collapsable_node.parent_neighbor_node_ids.as_mut_slice());
316                }
317                else {
318                    collapsable_node.parent_neighbor_node_ids.sort();
319                }
320            }
321        }
322
323        TCollapsableWaveFunction::new(collapsable_nodes, collapsable_node_per_id, random_instance)
324    }
325
326    pub fn save_to_file(&self, file_path: &str) {
327        let serialized_self = serde_json::to_string(self).unwrap();
328        std::fs::write(file_path, serialized_self).unwrap();
329    }
330
331    pub fn load_from_file(file_path: &str) -> Self {
332        let file = File::open(file_path).unwrap();
333        let reader = BufReader::new(file);
334        let deserialized_self: WaveFunction<TNodeState> = serde_json::from_reader(reader).unwrap();
335        deserialized_self
336    }
337}