simplicity_derive/
lib.rs

1extern crate proc_macro;
2#[macro_use]
3extern crate quote;
4
5use fnv::FnvHashMap;
6use proc_macro::TokenStream;
7use syn::{Ident, Token};
8use syn::parse::{Parse, ParseStream, Result};
9use itertools::Itertools;
10use permutator::Combination;
11use std::{collections::HashSet, fmt::{self, Display, Formatter}};
12use std::iter::{once, repeat};
13use proc_macro2::TokenStream as TokenStream2;
14
15struct InHypersphere {
16    /// The list to index on
17    list: Ident,
18    /// The indexing function
19    index_fn: Ident,
20    /// The list of indexes
21    indexes: Vec<Ident>,
22}
23
24impl Parse for InHypersphere {
25    fn parse(input: ParseStream) -> Result<Self> {
26        let list: Ident = input.parse()?;
27        input.parse::<Token![,]>()?;
28        let index_fn: Ident = input.parse()?;
29        input.parse::<Token![,]>()?;
30        let indexes = input.parse_terminated::<Ident, Token![,]>(Ident::parse)?;
31        
32        Ok(InHypersphere {
33            list,
34            index_fn,
35            indexes: indexes.into_iter().collect()
36        })
37    }
38}
39
40/// Sub-determinant of the original matrix.
41/// Row the last is implicity included.
42/// Column the last (the column of 1's) is implicity included.
43#[derive(Clone, Debug, Default, PartialEq, Eq, Hash, PartialOrd, Ord)]
44struct Determinant {
45    rows: Vec<usize>,
46    cols: Vec<usize>,
47}
48
49impl Determinant {
50    fn new(rows: Vec<usize>, cols: Vec<usize>) -> Self {
51        Self { rows, cols }
52    }
53
54    fn nonzero(self, zero_dets: &mut HashSet<Determinant>) -> Option<Self> {
55        if zero_dets.contains(&self) {
56            return None;
57        };
58
59        // Smaller determinants
60        for i in 1..self.cols.len() {
61            if self.rows.combination(i).any(|combo_r|
62                self.cols.combination(i).all(|combo_c|
63                    zero_dets.contains(&Determinant::new(combo_r.iter().copied().copied().collect(),
64                        combo_c.into_iter().copied().collect()))))
65            {
66                // Determinant is 0 because a whole row/rows of subdeterminants are 0
67                return None;
68            }
69        }
70
71        Some(self)
72    }
73
74    /// To be called after prepare_dets_for_cases
75    fn vector_tokens(&self, points: &[Ident]) -> TokenStream2 {
76        let dim = points.len() - 2;
77        let mut cols = self.cols.clone();
78
79        // Magnitude column; replace with missing coordinates
80        if cols[cols.len() - 1] == dim {
81            cols.pop();
82            cols.extend((0..dim).filter(|i| !self.cols.contains(i)));
83        }
84
85        let vector = format_ident!("Vector{}", cols.len());
86
87        self.rows.iter().map(|r| {
88            let point = &points[*r];
89            let mut coords = cols.iter().map(|c| {
90                quote! { #point[#c], }
91            }).collect::<Vec<_>>();
92
93            if coords.len() > 1 {
94                let coords = coords.into_iter().collect::<TokenStream2>();
95                quote! { nalgebra::#vector::new(#coords), }
96            } else {
97                coords.pop().unwrap()
98            }
99        }).collect()
100    }
101
102    fn to_grid(&self, indexes: &[Ident]) -> Vec<String> {
103        let coords = "xyzw".chars().collect::<Vec<_>>();
104        let mut lines = vec![];
105        for row in self.rows.iter().copied().chain(once(indexes.len() - 1)) {
106            let mut line = "│ ".to_string();
107
108            for col in self.cols.iter().copied().chain(once(indexes.len() - 1)) {
109                if col == indexes.len() - 1 {
110                    line += "1 ";
111                } else if col == indexes.len() - 2 {
112                    line += &(0..indexes.len() - 2).map(|i| format!("{}{}²", indexes[row], coords[i])).join("+");
113                    line += "  ";
114                } else {
115                    line += &format!("{}{}  ", indexes[row], coords[col]);
116                }
117            }
118
119            lines.push(line + "│");
120        }
121
122        //let pad = repeat(" ").take(lines[0].chars().count() - 2).collect::<String>();
123        //lines.insert(0, format!("│{}│", pad));
124        //lines.push(format!("│{}│", pad));
125        lines
126    }
127}
128
129#[derive(Clone, Debug, PartialEq, Eq, Hash)]
130struct Term {
131    const_mult: i32,
132    /// Says location of term to multiply by.
133    var_mult: Option<[usize; 2]>,
134    det: Determinant,
135}
136
137impl Term {
138    fn new(const_mult: i32, var_mult: Option<[usize; 2]>, det: Determinant) -> Self {
139        Self { const_mult, var_mult, det }
140    }
141
142    fn nonzero(mut self, zero_dets: &mut HashSet<Determinant>) -> Option<Self> {
143        if let Some(det) = std::mem::take(&mut self.det).nonzero(zero_dets) {
144            self.det = det;
145            Some(self)
146        } else {
147            None
148        }
149    }
150
151    fn to_grid(&self, indexes: &[Ident]) -> Vec<String> {
152        let coords = "xyzw".chars().collect::<Vec<_>>();
153        let mut lines = self.det.to_grid(indexes);
154
155        let mut coeff = if self.const_mult >= 0 {"+ "} else {"- "}.to_owned();
156        if self.const_mult.abs() != 1 {
157            coeff += &self.const_mult.abs().to_string();
158        }
159        if let Some([r, c]) = self.var_mult {
160            coeff += &format!("{}{}", indexes[r], coords[c]);
161        }
162
163        let mid = (lines.len() - 1) / 2;
164        let pad = repeat(" ").take(coeff.chars().count()).collect::<String>();
165        lines[mid] = coeff + &lines[mid];
166        for (i, line) in lines.iter_mut().enumerate() {
167            if i != mid {
168                *line = pad.clone() + line;
169            }
170        }
171        
172        lines
173    }
174}
175
176#[derive(Clone, Debug, Default)]
177struct TermSum {
178    terms: Vec<Term>,
179}
180
181impl TermSum {
182    fn new() -> Self {
183        Self::default()
184    }
185
186    fn without_zero_dets(mut self, dim: usize, zero_dets: &mut HashSet<Determinant>) -> Option<Self> {
187        self.terms = self.terms.into_iter().flat_map(|t| t.nonzero(zero_dets)).collect::<Vec<_>>();
188
189        if self.terms.len() == 1 && self.terms[0].var_mult.is_none() {
190            let det = &self.terms[0].det;
191            zero_dets.insert(det.clone());
192
193            // Special case: coordinates equal, so the magnitudes do as well.
194            if det.cols.len() == 1 && (0..dim).all(|i| 
195                zero_dets.contains(&Determinant::new(vec![det.rows[0]], vec![i])))
196            {
197                zero_dets.insert(Determinant::new(vec![det.rows[0]], vec![dim]));
198            }
199        }
200
201        if self.terms.is_empty() { None } else { Some(self) }
202    }
203
204    fn prepare_dets_for_cases(&mut self, dim: usize) {
205        for term in &mut self.terms {
206            // For convenience of including the last point
207            term.det.rows.push(dim + 1);
208
209            if term.const_mult < 0 {
210                term.const_mult *= -1;
211                let n = term.det.rows.len();
212                term.det.rows.swap(n - 2, n - 1);
213            }
214        }
215    }
216
217    fn case(mut self, points: &[Ident]) -> TokenStream2 {
218        let coords = "xyzw".chars().collect::<Vec<_>>();
219        let dim = points.len() - 2;
220        self.prepare_dets_for_cases(dim);
221
222        if self.terms.len() == 1 && self.terms[0].det.cols.len() == dim + 1 {
223            assert_eq!(self.terms[0].const_mult, 1);
224            assert_eq!(self.terms[0].var_mult, None);
225
226            if dim == 2 {
227                let [i, j, k, l] = [&points[0], &points[1], &points[2], &points[3]];
228                quote! {
229                    let val = rg::in_circle(#i, #j, #k, #l);
230                    if val != 0.0 {
231                        return (val > 0.0) != odd;
232                    }
233                }
234            } else if dim == 3 {
235                let [i, j, k, l, m] = [&points[0], &points[1], &points[2], &points[3], &points[4]];
236                quote! {
237                    let val = rg::in_sphere(#i, #j, #k, #l, #m);
238                    if val != 0.0 {
239                        return (val > 0.0) != odd;
240                    }
241                }
242            } else {
243                panic!("Unsupported # of dimensions: {}", dim)
244            }
245        } else if self.terms.len() == 1 && self.terms[0].det.cols.last() == Some(&dim) {
246            assert_eq!(self.terms[0].const_mult, 1);
247            assert_eq!(self.terms[0].var_mult, None);
248
249            let det = self.terms[0].det.vector_tokens(points);
250            let func = if self.terms[0].det.cols.len() == 1 {
251                format_ident!("magnitude_cmp_{}d", dim)
252            } else {
253                format_ident!(
254                    "sign_det_{}{}",
255                    coords[..self.terms[0].det.cols.len() - 1].iter().map(|c| c.to_string() + "_").join(""),
256                    coords[..dim].iter().map(|c| c.to_string() + "2").join(""),
257                )
258            };
259            quote! {
260                let val = rg::#func(#det);
261                if val != 0.0 {
262                    return (val > 0.0) != odd;
263                }
264            }
265        } else if self.terms.len() == 1 {
266            assert_eq!(self.terms[0].const_mult, 1);
267            assert_eq!(self.terms[0].var_mult, None);
268            
269            if self.terms[0].det.cols.len() == 0 {
270                quote! { !odd }
271            } else if self.terms[0].det.cols.len() == 1 {
272                let coord = self.terms[0].det.cols[0];
273                let p1 = &points[self.terms[0].det.rows[0]];
274                let p2 = &points[self.terms[0].det.rows[1]];
275                quote! {
276                    if #p1[#coord] != #p2[#coord] {
277                        return (#p1[#coord] > #p2[#coord]) != odd;
278                    }
279                }
280            } else {
281                let det = self.terms[0].det.vector_tokens(points);
282                let func = format_ident!("orient_{}d", self.terms[0].det.cols.len());
283                quote! {
284                    let val = rg::#func(#det);
285                    if val != 0.0 {
286                        return (val > 0.0) != odd;
287                    }
288                }
289            }
290        } else if self.terms.len() == 2 && self.terms[0].var_mult.is_none() {
291            assert_eq!(self.terms[0].const_mult, 1);
292            assert_eq!(*self.terms[0].det.cols.last().unwrap(), dim);
293            assert_eq!(self.terms[1].const_mult, 2);
294            assert!(self.terms[1].var_mult.is_some());
295            assert_ne!(*self.terms[1].det.cols.last().unwrap(), dim);
296
297            let det1 = self.terms[0].det.vector_tokens(points);
298            let det2 = self.terms[1].det.vector_tokens(points);
299            let mult = &points[self.terms[1].var_mult.unwrap()[0]];
300            let mult_coord = self.terms[1].var_mult.unwrap()[1];
301            let func = format_ident!(
302                "sign_det_{}{}_plus_2x_det_{}",
303                coords[..self.terms[0].det.cols.len() - 1].iter().map(|c| c.to_string() + "_").join(""),
304                coords[..dim].iter().map(|c| c.to_string() + "2").join(""),
305                coords[..self.terms[1].det.cols.len()].iter().join("_"),
306            );
307            quote! { 
308                let val = rg::#func(#det1 #mult[#mult_coord], #det2);
309                if val != 0.0 {
310                    return (val > 0.0) != odd;
311                }
312            }
313        } else if self.terms.len() == 2 {
314            assert_eq!(self.terms[0].const_mult, 2);
315            assert_ne!(self.terms[0].det.cols.last(), Some(&dim));
316            assert_eq!(self.terms[1].const_mult, 2);
317            assert!(self.terms[1].var_mult.is_some());
318            assert_eq!(self.terms[0].det, self.terms[1].det);
319
320            let mult1 = &points[self.terms[0].var_mult.unwrap()[0]];
321            let mult1_coord = self.terms[0].var_mult.unwrap()[1];
322            let mult2 = &points[self.terms[1].var_mult.unwrap()[0]];
323            let mult2_coord = self.terms[1].var_mult.unwrap()[1];
324            
325            let inner = if self.terms[0].det.cols.len() == 0 {
326                quote! { return negate == odd; }
327            } else if self.terms[0].det.cols.len() == 1 {
328                let coord = self.terms[0].det.cols[0];
329                let p1 = &points[self.terms[0].det.rows[0]];
330                let p2 = &points[self.terms[0].det.rows[1]];
331                quote! {
332                    if #p1[#coord] != #p2[#coord] {
333                        return (#p1[#coord] > #p2[#coord]) != (negate != odd);
334                    }
335                }
336            } else {
337                let det = self.terms[0].det.vector_tokens(points);
338                let func = format_ident!("orient_{}d", self.terms[0].det.cols.len());
339                quote! {
340                    let val = rg::#func(#det);
341                    if val != 0.0 {
342                        return (val > 0.0) != (negate != odd);
343                    }
344                }
345            };
346
347            quote! {
348                if #mult1[#mult1_coord] != -#mult2[#mult2_coord] {
349                    let negate = #mult1[#mult1_coord] < -#mult2[#mult2_coord];
350                    #inner
351                }
352            }
353        } else {
354            panic!("Unsupported determinant: {}", self.to_grid(points).join("\n"))
355        }
356    }
357
358    fn to_grid(&self, indexes: &[Ident]) -> Vec<String> {
359        let mut lines = self.terms[0].to_grid(indexes);
360        for term in &self.terms[1..] {
361            for (i, line) in term.to_grid(indexes).into_iter().enumerate() {
362                lines[i] += &format!(" {}", line);
363            }
364        }
365        lines
366    }
367}
368
369/// An ε-factor, represented as an exponent of ε.
370#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
371struct EFactor(u64);
372
373impl EFactor {
374    fn new(dim: usize, coords: impl IntoIterator<Item = [usize; 2]>) -> Self {
375        Self(coords.into_iter().map(|[r, c]| 3u64.pow((dim * r + dim - 1 - c) as u32)).sum())
376    }
377
378    fn to_repr(mut self, indexes: &[Ident]) -> String {
379        let coords = "xyzw".chars().collect::<Vec<_>>();
380        let mut res = String::new();
381
382        for index in indexes {
383            for c in 0..indexes.len() - 2 {
384                let rem = self.0 % 3;
385                self.0 /= 3;
386
387                if rem > 0 {
388                    if !res.is_empty() {
389                        res += "·";
390                    }
391                    res += &format!("ε{}{}", index, coords[indexes.len() - 3 - c]);
392                }
393                if rem == 2 {
394                    res += "²";
395                }
396            }
397        }
398
399        res
400    }
401}
402
403impl Display for EFactor {
404    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
405        let mut res = String::new();
406        let mut num = self.0;
407
408        while num > 0 {
409            res += &(num % 3).to_string();
410            num /= 3;
411        }
412
413        res = res.chars().rev().collect();
414        f.pad(&res)
415    }
416}
417
418fn terms(dim: usize) -> Vec<(EFactor, Term)> {
419    let mut terms = vec![];
420
421    // The biggest relevant ε-factor.
422    let big_e = EFactor::new(dim, (0..dim - 1).map(|i| [i, i]).chain(vec![[dim - 1, dim - 1], [dim - 1, dim - 1], [dim, dim - 1]]));
423
424    let all = (0..=dim).collect::<Vec<_>>();
425
426    // General term
427    terms.push((EFactor::new(dim, vec![]), Term::new(1, None, Determinant::new(all.clone(), all.clone()))));
428
429    // Degenerate terms
430    let mut rows = all.clone();
431    let mut cols = all.clone();
432    let mut e_factors = vec![];
433    for i in 1..=dim + 1 {
434        let mut remove = vec![0; 2 * i];
435
436        while remove[0] <= dim - (i - 1) {
437            // Trying not to have a million allocations here
438            rows.clear();
439            rows.extend(all.iter().copied());
440            cols.clear();
441            cols.extend(all.iter().copied());
442            e_factors.clear();
443
444            let mut mult = 1;
445            for rc in remove.chunks_exact(2) {
446                let er = rows.remove(rc[0]);
447                let ec = cols.remove(rc[1]);
448                if (er + ec) % 2 == 1 {
449                    mult *= -1;
450                }
451                e_factors.push([er, ec]);
452            }
453
454            let det = Determinant::new(rows.clone(), cols.clone());
455
456            // Column dim is the magnitude column, so do special things with it.
457            // For example, (x + εx)² + (y + εy)² expands to
458            // (x² + y²) + εx·2x + εx² + εy·2y + εy²
459            if let Some(mag_r) = e_factors.iter().position(|[_, c]| *c == dim).map(|i| e_factors.remove(i)[0]) {
460                for j in 0..dim {
461                    let factor = EFactor::new(dim, e_factors.iter().copied().chain(once([mag_r, j])));
462                    if factor <= big_e {
463                        terms.push((factor, Term::new(mult * 2, Some([mag_r, j]), det.clone())));
464                    }
465
466                    let factor = EFactor::new(dim, e_factors.iter().copied().chain(repeat([mag_r, j]).take(2)));
467                    if factor <= big_e {
468                        terms.push((factor, Term::new(mult, None, det.clone())));
469                    }
470                }
471            } else {
472                let factor = EFactor::new(dim, e_factors.drain(..));
473                if factor <= big_e {
474                    terms.push((factor, Term::new(mult, None, det)));
475                }
476            }
477
478            // Count in base factorial to iterate through permutations
479            // Row index shouldn't decrease so permutations aren't repeated.
480            let mut j = 2 * i - 1;
481            while {
482                remove[j] += 1;
483                if j % 2 == 0 && remove[j] <= dim - (i - 1) {
484                    let row = remove[j];
485                    for n in remove[j + 2..].iter_mut().step_by(2) {
486                        *n = row;
487                    }
488                }
489
490                remove[j] > dim - if j % 2 == 0 {i - 1} else {j / 2} && j > 0
491            } {
492                if j % 2 == 0 {
493                    let row = remove[j - 2];
494                    for n in remove[j..].iter_mut().step_by(2) {
495                        *n = row;
496                    }
497                } else {
498                    remove[j] = 0;
499                };
500
501                j -= 1;
502            }
503        }
504    }
505
506    terms
507}
508
509// Ordered by ε-factor exponent
510fn term_sums(dim: usize) -> Vec<(EFactor, TermSum)> {
511    let mut sums = FnvHashMap::default();
512
513    for (e, term) in terms(dim) {
514        sums.entry(e).or_insert(TermSum::new()).terms.push(term);
515    }
516
517    let mut sums = sums.into_iter().collect::<Vec<_>>();
518    sums.sort_by_key(|(e, _)| *e);
519    sums
520}
521
522fn fn_body(h: InHypersphere, sums: Vec<(EFactor, TermSum)>) -> TokenStream2 {
523    let list = h.list;
524    let index_fn = h.index_fn;
525    let dim = h.indexes.len() - 2;
526
527    let sorted = format_ident!("sorted_{}", h.indexes.len());
528    let index_seq = h.indexes.iter().map(|index| quote!{#index,}).collect::<TokenStream2>();
529
530    let points = h.indexes.iter().map(|index| format_ident!("p{}", index)).collect::<Vec<_>>();
531    let indexing_seq = h.indexes.iter().zip(points.iter()).map(|(index, point)| quote! {
532        let #point = #index_fn(#list, #index);
533    }).collect::<TokenStream2>();
534
535    let mut zero_dets = HashSet::new();
536    let cases = sums.into_iter()
537        .flat_map(|(e, sum)| sum.without_zero_dets(dim, &mut zero_dets).map(|sum| (e, sum)))
538        .map(|(_, sum)| sum.case(&points))
539        .collect::<TokenStream2>();
540
541    let tokens = quote! { 
542        let ([#index_seq], odd) = #sorted([#index_seq]);
543
544        #indexing_seq
545
546        #cases
547    };
548
549    tokens
550}
551
552#[proc_macro]
553pub fn generate_in_hypersphere(input: TokenStream) -> TokenStream {
554    let h = syn::parse_macro_input!(input as InHypersphere);
555
556    let sums = term_sums(h.indexes.len() - 2);
557    //let mut msg = "Cases:\n```".to_owned();
558
559    //let mut zero_dets = HashSet::new();
560    //for (e, sum) in &sums {
561    //    msg += &format!("{}:\n", e.to_repr(&h.indexes));
562
563    //    if let Some(sum) = sum.clone().without_zero_dets(h.indexes.len() - 2, &mut zero_dets) {
564    //        msg += &format!("{}\n", sum.to_grid(&h.indexes).into_iter().join("\n"));
565    //    } else {
566    //        msg += "Impossible!\n";
567    //    }
568    //    msg += "\n";
569    //}
570    //msg += "```";
571
572    //let ident = quote::format_ident!("__test_macro_{}", h.indexes.len() - 2);
573    //let stream = msg.split('\n').map(|line| quote! {
574    //    #[doc = #line]
575    //}).chain(once(quote! {
576    //    pub fn #ident() {}
577    //})).collect::<TokenStream2>();
578
579    TokenStream::from(fn_body(h, sums))
580}