1use std::{collections::{HashMap, HashSet}, rc::Rc, hash::Hash, fs::File, io::BufReader, cell::RefCell};
2use serde::{Serialize, Deserialize, de::DeserializeOwned};
3use bitvec::prelude::*;
4use log::debug;
5extern crate pretty_env_logger;
6mod indexed_view;
7use crate::wave_function::collapsable_wave_function::collapsable_wave_function::CollapsableNode;
8
9use self::{collapsable_wave_function::collapsable_wave_function::CollapsableWaveFunction, indexed_view::IndexedView};
10mod probability_collection;
11mod probability_tree;
12mod probability_container;
13pub mod collapsable_wave_function;
14mod tests;
15
16pub struct NodeStateProbability;
18
19impl NodeStateProbability {
20 pub fn get_equal_probability<TNodeState: Eq + Hash + Clone + std::fmt::Debug + Ord>(node_states: &Vec<TNodeState>) -> HashMap<TNodeState, f32> {
21 let mut node_state_probability_per_node_state: HashMap<TNodeState, f32> = HashMap::new();
22
23 for node_state in node_states.into_iter() {
24 node_state_probability_per_node_state.insert(node_state.clone(), 1.0);
25 }
26
27 node_state_probability_per_node_state
28 }
29}
30
31#[derive(Debug, Serialize, Deserialize, Clone)]
33pub struct Node<TNodeState: Eq + Hash + Clone + std::fmt::Debug + Ord> {
34 pub id: String,
35 pub node_state_collection_ids_per_neighbor_node_id: HashMap<String, Vec<String>>,
36 pub node_state_ids: Vec<TNodeState>,
37 pub node_state_ratios: Vec<f32>
38}
39
40impl<TNodeState: Eq + Hash + Clone + std::fmt::Debug + Ord> Node<TNodeState> {
41 pub fn new(id: String, node_state_ratio_per_node_state_id: HashMap<TNodeState, f32>, node_state_collection_ids_per_neighbor_node_id: HashMap<String, Vec<String>>) -> Self {
42 let mut node_state_ids: Vec<TNodeState> = Vec::new();
43 let mut node_state_ratios: Vec<f32> = Vec::new();
44 for (node_state_id, node_state_ratio) in node_state_ratio_per_node_state_id.iter() {
45 node_state_ids.push(node_state_id.clone());
46 node_state_ratios.push(*node_state_ratio);
47 }
48
49 let mut sort_permutation = permutation::sort(&node_state_ids);
51 sort_permutation.apply_slice_in_place(&mut node_state_ids);
52 sort_permutation.apply_slice_in_place(&mut node_state_ratios);
53
54 Node {
55 id,
56 node_state_collection_ids_per_neighbor_node_id,
57 node_state_ids,
58 node_state_ratios
59 }
60 }
61 pub fn get_id(&self) -> String {
62 self.id.clone()
63 }
64}
65
66#[derive(Debug, Serialize, Deserialize, Clone)]
68pub struct NodeStateCollection<TNodeState: Eq + Hash + Clone + std::fmt::Debug + Ord> {
69 pub id: String,
70 pub node_state_id: TNodeState,
71 pub node_state_ids: Vec<TNodeState>
72}
73
74impl<TNodeState: Eq + Hash + Clone + std::fmt::Debug + Ord> NodeStateCollection<TNodeState> {
75 pub fn new(id: String, node_state_id: TNodeState, node_state_ids: Vec<TNodeState>) -> Self {
76 NodeStateCollection {
77 id,
78 node_state_id,
79 node_state_ids
80 }
81 }
82}
83
84#[derive(Serialize, Clone, Deserialize)]
86pub struct WaveFunction<TNodeState: Eq + Hash + Clone + std::fmt::Debug + Ord> {
87 nodes: Vec<Node<TNodeState>>,
88 node_state_collections: Vec<NodeStateCollection<TNodeState>>
89}
90
91impl<TNodeState: Eq + Hash + Clone + std::fmt::Debug + Ord + Serialize + DeserializeOwned> WaveFunction<TNodeState> {
92 pub fn new(nodes: Vec<Node<TNodeState>>, node_state_collections: Vec<NodeStateCollection<TNodeState>>) -> Self {
93 WaveFunction {
94 nodes,
95 node_state_collections
96 }
97 }
98
99 pub fn get_nodes(&self) -> Vec<Node<TNodeState>> {
100 self.nodes.clone()
101 }
102
103 pub fn get_node_state_collections(&self) -> Vec<NodeStateCollection<TNodeState>> {
104 self.node_state_collections.clone()
105 }
106
107 pub fn validate(&self) -> Result<(), String> {
108 let nodes_length: usize = self.nodes.len();
109
110 let mut node_per_id: HashMap<&str, &Node<TNodeState>> = HashMap::new();
111 let mut node_ids: HashSet<&str> = HashSet::new();
112 self.nodes
113 .iter()
114 .for_each(|node: &Node<TNodeState>| {
115 node_per_id.insert(&node.id, node);
116 node_ids.insert(&node.id);
117 });
118
119 let mut node_state_collection_per_id: HashMap<&str, &NodeStateCollection<TNodeState>> = HashMap::new();
120 self.node_state_collections
121 .iter()
122 .for_each(|node_state_collection| {
123 node_state_collection_per_id.insert(&node_state_collection.id, node_state_collection);
124 });
125
126 for (_, node) in node_per_id.iter() {
128 for (neighbor_node_id_string, _) in node.node_state_collection_ids_per_neighbor_node_id.iter() {
129 let neighbor_node_id: &str = neighbor_node_id_string;
130 if !node_ids.contains(neighbor_node_id) {
131 return Err(format!("Neighbor node {neighbor_node_id} does not exist in main list of nodes."));
132 }
133 }
134 }
135
136 let mut at_least_one_node_connects_to_all_other_nodes: bool = false;
137 for node in self.nodes.iter() {
138 let mut all_traversed_node_ids: HashSet<&str> = HashSet::new();
140 let mut potential_node_ids: Vec<&str> = Vec::new();
141
142 potential_node_ids.push(&node.id);
143
144 while let Some(node_id) = potential_node_ids.pop() {
145 let node = node_per_id.get(node_id).unwrap();
146 for neighbor_node_id_string in node.node_state_collection_ids_per_neighbor_node_id.keys() {
147 let neighbor_node_id: &str = neighbor_node_id_string;
148 if !all_traversed_node_ids.contains(neighbor_node_id) && !potential_node_ids.contains(&neighbor_node_id) {
149 potential_node_ids.push(neighbor_node_id);
150 }
151 }
152 all_traversed_node_ids.insert(node_id);
153 }
154
155 let all_traversed_node_ids_length = all_traversed_node_ids.len();
156 if all_traversed_node_ids_length == nodes_length {
157 at_least_one_node_connects_to_all_other_nodes = true;
158 break;
159 }
160 }
161
162 if !at_least_one_node_connects_to_all_other_nodes {
163 return Err(String::from("Not all nodes connect together. At least one node must be able to traverse to all other nodes."));
164 }
165
166 Ok(())
167 }
168
169 pub fn get_collapsable_wave_function<'a, TCollapsableWaveFunction: CollapsableWaveFunction<'a, TNodeState>>(&'a self, random_seed: Option<u64>) -> TCollapsableWaveFunction {
170 let mut node_per_id: HashMap<&str, &Node<TNodeState>> = HashMap::new();
171 self.nodes
172 .iter()
173 .for_each(|node: &Node<TNodeState>| {
174 node_per_id.insert(&node.id, node);
175 });
176
177 let mut node_state_collection_per_id: HashMap<&str, &NodeStateCollection<TNodeState>> = HashMap::new();
178 self.node_state_collections
179 .iter()
180 .for_each(|node_state_collection| {
181 node_state_collection_per_id.insert(&node_state_collection.id, node_state_collection);
182 });
183
184 let mut neighbor_mask_mapped_view_per_node_id: HashMap<&str, HashMap<&TNodeState, HashMap<&str, BitVec>>> = HashMap::new();
194
195 let mut mask_per_parent_state_per_parent_neighbor_per_node: HashMap<&str, HashMap<&str, HashMap<&TNodeState, BitVec>>> = HashMap::new();
197
198 for child_node in self.nodes.iter() {
200
201 let mut mask_per_parent_state_per_parent_neighbor: HashMap<&str, HashMap<&TNodeState, BitVec>> = HashMap::new();
202
203 for parent_neighbor_node in self.nodes.iter() {
205 if parent_neighbor_node.node_state_collection_ids_per_neighbor_node_id.contains_key(&child_node.id) {
207
208 debug!("constructing mask for {:?}'s child node {:?}.", parent_neighbor_node.id, child_node.id);
209
210 let mut mask_per_parent_state: HashMap<&TNodeState, BitVec> = HashMap::new();
211
212 let node_state_collection_ids: &Vec<String> = parent_neighbor_node.node_state_collection_ids_per_neighbor_node_id.get(&child_node.id).unwrap();
214 for node_state_collection_id in node_state_collection_ids.iter() {
215 let node_state_collection = node_state_collection_per_id.get(node_state_collection_id.as_str()).unwrap();
216 let mut mask: BitVec = BitVec::new();
218 for node_state_id in child_node.node_state_ids.iter() {
219 mask.push(node_state_collection.node_state_ids.contains(node_state_id));
221 }
222 mask_per_parent_state.insert(&node_state_collection.node_state_id, mask);
224 }
225
226 mask_per_parent_state_per_parent_neighbor.insert(&parent_neighbor_node.id, mask_per_parent_state);
227 }
228 }
229
230 mask_per_parent_state_per_parent_neighbor_per_node.insert(&child_node.id, mask_per_parent_state_per_parent_neighbor);
231 }
232
233 for node in self.nodes.iter() {
236
237 let node_id: &str = node.id.as_str();
239
240 let mut mask_per_neighbor_per_state: HashMap<&TNodeState, HashMap<&str, BitVec>> = HashMap::new();
241
242 for (neighbor_node_id, _) in node.node_state_collection_ids_per_neighbor_node_id.iter() {
243 let neighbor_node_id: &str = neighbor_node_id;
244
245 let mask_per_parent_state_per_parent_neighbor = mask_per_parent_state_per_parent_neighbor_per_node.get(neighbor_node_id).unwrap();
247 let mask_per_parent_state = mask_per_parent_state_per_parent_neighbor.get(node_id).unwrap();
248
249 for (node_state_id, mask) in mask_per_parent_state.iter() {
250 mask_per_neighbor_per_state
251 .entry(node_state_id)
252 .or_insert(HashMap::new())
253 .insert(neighbor_node_id, mask.clone());
254 }
255 }
256
257 neighbor_mask_mapped_view_per_node_id.insert(node_id, mask_per_neighbor_per_state);
258 }
259
260 let mut node_state_indexed_view_per_node_id: HashMap<&str, IndexedView<&TNodeState>> = HashMap::new();
261
262 for node in self.nodes.iter() {
264 let node_id: &str = &node.id;
265
266 let referenced_node_state_ids: Vec<&TNodeState> = node.node_state_ids.iter().collect();
269 let cloned_node_state_ratios: Vec<f32> = node.node_state_ratios.clone();
270
271 let node_state_indexed_view = IndexedView::new(referenced_node_state_ids, cloned_node_state_ratios);
272 node_state_indexed_view_per_node_id.insert(node_id, node_state_indexed_view);
274 }
275
276 let mut collapsable_nodes: Vec<Rc<RefCell<CollapsableNode<TNodeState>>>> = Vec::new();
277 let mut collapsable_node_per_id: HashMap<&str, Rc<RefCell<CollapsableNode<TNodeState>>>> = HashMap::new();
278 let random_instance = if let Some(seed) = random_seed {
280 Rc::new(RefCell::new(fastrand::Rng::with_seed(seed)))
281 }
282 else {
283 Rc::new(RefCell::new(fastrand::Rng::new()))
284 };
285 for node in self.nodes.iter() {
286 let node_id: &str = node.id.as_str();
287
288 let node_state_indexed_view: IndexedView<&TNodeState> = node_state_indexed_view_per_node_id.remove(node_id).unwrap();
289 let mask_per_neighbor_per_state = neighbor_mask_mapped_view_per_node_id.remove(node_id).unwrap();
290
291 let mut collapsable_node = CollapsableNode::new(&node.id, &node.node_state_collection_ids_per_neighbor_node_id, mask_per_neighbor_per_state, node_state_indexed_view);
292
293 if random_seed.is_some() {
294 collapsable_node.randomize(&mut random_instance.borrow_mut());
295 }
296
297 collapsable_nodes.push(Rc::new(RefCell::new(collapsable_node)));
298 }
299
300 for wrapped_collapsable_node in collapsable_nodes.iter() {
301 let collapsable_node = wrapped_collapsable_node.borrow();
302 collapsable_node_per_id.insert(collapsable_node.id, wrapped_collapsable_node.clone());
303 }
304
305 for wrapped_collapsable_node in collapsable_nodes.iter() {
306 let mut collapsable_node = wrapped_collapsable_node.borrow_mut();
307 let collapsable_node_id: &str = collapsable_node.id;
308
309 if mask_per_parent_state_per_parent_neighbor_per_node.contains_key(collapsable_node_id) {
310 let mask_per_parent_state_per_parent_neighbor = mask_per_parent_state_per_parent_neighbor_per_node.get(collapsable_node_id).unwrap();
311 for parent_neighbor_node_id in mask_per_parent_state_per_parent_neighbor.keys() {
312 collapsable_node.parent_neighbor_node_ids.push(parent_neighbor_node_id);
313 }
314 if random_seed.is_some() {
315 random_instance.borrow_mut().shuffle(collapsable_node.parent_neighbor_node_ids.as_mut_slice());
316 }
317 else {
318 collapsable_node.parent_neighbor_node_ids.sort();
319 }
320 }
321 }
322
323 TCollapsableWaveFunction::new(collapsable_nodes, collapsable_node_per_id, random_instance)
324 }
325
326 pub fn save_to_file(&self, file_path: &str) {
327 let serialized_self = serde_json::to_string(self).unwrap();
328 std::fs::write(file_path, serialized_self).unwrap();
329 }
330
331 pub fn load_from_file(file_path: &str) -> Self {
332 let file = File::open(file_path).unwrap();
333 let reader = BufReader::new(file);
334 let deserialized_self: WaveFunction<TNodeState> = serde_json::from_reader(reader).unwrap();
335 deserialized_self
336 }
337}