rusty_systems/
productions.rs1use std::cell::RefCell;
4use std::hash::{Hash, Hasher};
5
6use rand::{Rng, thread_rng};
7
8use crate::{Result};
9use crate::error::{Error, ErrorKind};
10use crate::prelude::*;
11use crate::Symbol;
12
13#[derive(Debug, Copy, Clone)]
14pub enum ChanceKind {
15 Set,
17 Derived
20}
21
22#[derive(Debug, Copy, Clone)]
23pub struct Chance {
24 kind: ChanceKind,
25 chance: Option<f32>
26}
27
28impl Chance {
29 pub fn new(chance: f32) -> Self {
31 assert!(chance > 0_f32, "chance should be positive");
32 assert!(chance <= 1.0_f32, "chance should be less than or equal to 1.0");
33
34 Chance {
35 kind: ChanceKind::Set,
36 chance: Some(chance)
37 }
38 }
39
40 #[inline]
43 pub fn empty() -> Self {
44 Chance {
45 kind: ChanceKind::Derived,
46 chance: None
47 }
48 }
49
50 #[inline]
52 pub fn is_derived(&self) -> bool {
53 matches!(self.kind, ChanceKind::Derived)
54 }
55
56 #[inline]
58 pub fn is_user_set(&self) -> bool {
59 matches!(self.kind, ChanceKind::Set)
60 }
61
62 #[inline]
63 pub fn expect(&self, message: &str) -> f32 {
64 self.chance.expect(message)
65 }
66
67 #[inline]
68 pub fn unwrap(&self) -> f32 {
69 self.chance.unwrap()
70 }
71
72 #[inline]
73 pub fn unwrap_or(&self, default: f32) -> f32 {
74 self.chance.unwrap_or(default)
75 }
76}
77
78
79#[derive(Debug, Clone, PartialEq, Eq, Hash)]
80pub struct ProductionHead {
81 pre: Option<ProductionString>,
82 target: Symbol,
83 post: Option<ProductionString>
84}
85
86impl ProductionHead {
87 pub fn build(pre: Option<ProductionString>, target: Symbol, post: Option<ProductionString>) -> Result<Self> {
89 Ok(ProductionHead {
90 pre,
91 target,
92 post
93 })
94 }
95
96 #[inline]
98 pub fn target(&self) -> &Symbol {
99 &self.target
100 }
101
102 #[inline]
103 pub fn pre_context(&self) -> Option<&ProductionString> {
104 self.pre.as_ref()
105 }
106
107 #[inline]
108 pub fn post_context(&self) -> Option<&ProductionString> {
109 self.post.as_ref()
110 }
111
112 pub fn matches(&self, string: &ProductionString, index: usize) -> bool {
115 self.pre_matches(string, index) &&
116 self.post_matches(string, index) &&
117 string.symbols()
118 .get(index)
119 .map(|symbol| self.target == *symbol)
120 .unwrap_or(false)
121 }
122
123 pub fn pre_matches(&self, string: &ProductionString, index: usize) -> bool {
124 if self.pre.is_none() {
125 return true;
126 }
127
128 let left = self.pre.as_ref().unwrap();
129
130 if index == 0 {
131 return left.is_empty();
132 }
133
134 let symbols: Vec<_> = string.symbols()[0..index].iter().rev().collect();
135 if symbols.len() < left.len() {
136 return false;
137 }
138
139 return left.iter().rev().enumerate().all(|(i, t)| t == symbols[i]);
140 }
141
142 pub fn post_matches(&self, string: &ProductionString, index: usize) -> bool {
143 if self.post.is_none() {
144 return true;
145 }
146
147 let right = self.post.as_ref().unwrap();
148
149 if index == string.len() - 1 {
150 return right.is_empty();
151 }
152
153 let symbols = string.symbols()[index + 1 ..].to_vec();
154 if symbols.len() < right.len() {
155 return false;
156 }
157
158 return right.iter().enumerate().all(|(i, t)| *t == symbols[i]);
159 }
160
161}
162
163
164#[derive(Debug, Clone)]
165pub struct ProductionBody {
166 string: ProductionString,
167 chance: Chance
168}
169
170impl ProductionBody {
171 pub fn new(string: ProductionString) -> Self {
174 ProductionBody {
175 string,
176 chance: Chance::empty()
177 }
178 }
179
180 pub fn try_with_chance(chance: f32, string: ProductionString) -> Result<Self> {
183 if !(0.0..=1.0).contains(&chance) {
184 return Err(Error::new(ErrorKind::Parse, "chance should be between 0.0 and 1.0 inclusive"));
185 }
186
187 Ok(ProductionBody {
188 string,
189 chance: Chance::new(chance),
190 })
191 }
192
193 pub fn empty() -> Self {
195 ProductionBody {
196 string: ProductionString::empty(),
197 chance: Chance::empty()
198 }
199 }
200
201 #[inline]
202 pub fn is_empty(&self) -> bool {
203 self.string.is_empty()
204 }
205
206 #[inline]
207 pub fn len(&self) -> usize {
208 self.string.len()
209 }
210
211 #[inline]
212 pub fn string(&self) -> &ProductionString {
213 &self.string
214 }
215
216 #[inline]
217 pub fn chance(&self) -> &Chance {
218 &self.chance
219 }
220}
221
222
223#[derive(Debug, Clone)]
237pub struct Production {
238 head: ProductionHead,
239 body: Vec<ProductionBody>
240}
241
242impl Production {
243 pub fn new(head: ProductionHead, body: ProductionBody) -> Self {
244 Production {
245 head,
246 body: vec![body]
247 }
248 }
249
250 #[inline]
251 pub fn head(&self) -> &ProductionHead {
252 &self.head
253 }
254
255 pub fn body(&self) -> Result<&ProductionBody> {
256 if self.body.is_empty() {
257 return Err(Error::execution("Production has no bodies set"))
258 }
259
260 if self.body.len() == 1 {
262 return Ok(self.body.last().unwrap());
263 }
264
265 let total_chance : f32 = self.body.iter()
266 .map(|b| b.chance.unwrap_or(0.0))
267 .sum();
268
269 if total_chance < 0.0 {
270 return Err(Error::execution("chance should never be negative"));
271 }
272
273 if total_chance > 1.0 {
274 return Err(Error::execution("total chance of production bodies should not be greater than 1.0"));
275 }
276
277 let remaining = self.body.iter().filter(|b| b.chance.is_derived()).count();
278 let default_chance = if remaining == 0 {
279 0_f32
280 } else {
281 (1.0_f32 - total_chance) / (remaining as f32)
282 };
283
284 let mut current = 0_f32;
285 let random : f32 = thread_rng().gen_range(0.0..=1.0);
286
287 for body in &self.body {
288 current += body.chance.unwrap_or(default_chance);
289 if random < current {
290 return Ok(body);
291 }
292 }
293
294 return Ok(self.body.last().unwrap());
296 }
297
298 #[inline]
301 pub fn matches(&self, string: &ProductionString, index: usize) -> bool {
302 self.head().matches(string, index)
303 }
304
305 pub fn add_body(&mut self, body: ProductionBody) {
306 self.body.push(body);
307 }
308
309 pub fn merge(&mut self, other: Self) {
311 other.body.into_iter().for_each(|b| self.add_body(b));
312 }
313
314 pub fn all_bodies(&self) -> &Vec<ProductionBody> {
316 &self.body
317 }
318}
319
320impl PartialEq for Production {
321 fn eq(&self, other: &Self) -> bool {
322 self.head().eq(other.head())
323 }
324}
325
326impl Eq for Production { }
327
328impl Hash for Production {
329 fn hash<H: Hasher>(&self, state: &mut H) {
330 self.head.hash(state);
331 }
332}
333
334
335pub trait ProductionStore {
341 fn add_production(&self, production: Production) -> Result<Production>;
342}
343
344impl ProductionStore for RefCell<Vec<Production>> {
345 fn add_production(&self, production: Production) -> Result<Production> {
346 let mut vec = self.borrow_mut();
347 vec.push(production);
348 vec.last().cloned().ok_or_else(|| Error::general("Unable to add production"))
349 }
350}
351
352
353
354#[cfg(test)]
355mod tests {
356 use crate::parser::parse_prod_string;
357 use super::*;
358
359 #[test]
360 fn production_matches() {
361 let system = System::default();
362 let production = system.parse_production("X -> F F").unwrap();
363
364 let string = parse_prod_string("X").unwrap();
365 assert!(production.matches(&string, 0));
366
367 let production = system.parse_production("X < X -> F F").unwrap();
368 assert!(!production.matches(&string, 0));
369
370 let string = parse_prod_string("X X").unwrap();
371 assert!(!production.matches(&string, 0));
372 assert!( production.matches(&string, 1));
373
374
375 let production = system.parse_production("a b < X -> F F").unwrap();
376 let string = parse_prod_string("a b X").unwrap();
377 assert!(!production.matches(&string, 0));
378 assert!(!production.matches(&string, 1));
379 assert!( production.matches(&string, 2));
380
381
382
383 let production = system.parse_production("X > X -> F F").unwrap();
384 assert!(!production.matches(&string, 0));
385
386 let string = parse_prod_string("X X").unwrap();
387 assert!( production.matches(&string, 0));
388 assert!(!production.matches(&string, 1));
389
390
391 let production = system.parse_production("X > a b -> F F").unwrap();
392 let string = parse_prod_string("a X a b").unwrap();
393 assert!(!production.matches(&string, 0));
394 assert!( production.matches(&string, 1));
395 assert!(!production.matches(&string, 2));
396 assert!(!production.matches(&string, 3));
397
398 let system = System::default();
399 let string = parse_prod_string("G S S S X").unwrap();
400 let production = system.parse_production("G < S -> S G").unwrap();
402
403 assert!(!production.matches(&string, 0));
404 assert!( production.matches(&string, 1));
405 assert!(!production.matches(&string, 2));
406 }
407}