Skip to main content

use_geode/
geode.rs

1use std::collections::HashMap;
2
3use crate::error::GeodeError;
4
5/// Finite hyper-Catalan type vectors of the form `[m2, m3, m4, ...]`.
6#[derive(Debug, Clone, PartialEq, Eq, Hash)]
7pub struct TypeVector {
8    values: Vec<u64>,
9}
10
11impl TypeVector {
12    /// Creates a validated finite type vector.
13    ///
14    /// The vector must contain at least one component, and the total face count
15    /// must still fit in `u64` so `face_count()` remains exact.
16    ///
17    /// # Errors
18    ///
19    /// Returns [`GeodeError::EmptyTypeVector`] when `values` is empty.
20    ///
21    /// Returns [`GeodeError::InvalidInput`] when the total face count would no
22    /// longer fit in `u64`.
23    pub fn new(values: Vec<u64>) -> Result<Self, GeodeError> {
24        if values.is_empty() {
25            return Err(GeodeError::EmptyTypeVector);
26        }
27
28        values
29            .iter()
30            .try_fold(0_u64, |total, value| total.checked_add(*value))
31            .ok_or(GeodeError::InvalidInput)?;
32
33        Ok(Self { values })
34    }
35
36    /// Returns the underlying component slice.
37    #[must_use]
38    pub fn values(&self) -> &[u64] {
39        &self.values
40    }
41
42    /// Returns the current dimension of the finite type vector.
43    #[must_use]
44    pub const fn dimension(&self) -> usize {
45        self.values.len()
46    }
47
48    /// Returns the total face count.
49    #[must_use]
50    pub fn face_count(&self) -> u64 {
51        face_count(self)
52    }
53
54    /// Returns whether all components are zero.
55    #[must_use]
56    pub fn is_zero(&self) -> bool {
57        self.values.iter().all(|value| *value == 0)
58    }
59
60    /// Returns a copy with one checked increment at `index`.
61    ///
62    /// # Errors
63    ///
64    /// Returns [`GeodeError::IndexOutOfBounds`] when `index` is not present.
65    ///
66    /// Returns [`GeodeError::ArithmeticOverflow`] when incrementing the chosen
67    /// component would overflow `u64`.
68    ///
69    /// Returns [`GeodeError::InvalidInput`] when the resulting total face count
70    /// no longer fits in `u64`.
71    pub fn incremented(&self, index: usize) -> Result<Self, GeodeError> {
72        let mut values = self.values.clone();
73        let value = values.get_mut(index).ok_or(GeodeError::IndexOutOfBounds)?;
74        *value = value.checked_add(1).ok_or(GeodeError::ArithmeticOverflow)?;
75        Self::new(values)
76    }
77
78    /// Returns a copy with one decrement at `index`, or `None` when already zero.
79    ///
80    /// # Errors
81    ///
82    /// Returns [`GeodeError::IndexOutOfBounds`] when `index` is not present.
83    ///
84    /// Returns [`GeodeError::InvalidInput`] when the resulting total face count
85    /// no longer fits in `u64`.
86    pub fn decremented(&self, index: usize) -> Result<Option<Self>, GeodeError> {
87        let mut values = self.values.clone();
88        let value = values.get_mut(index).ok_or(GeodeError::IndexOutOfBounds)?;
89
90        if *value == 0 {
91            return Ok(None);
92        }
93
94        *value -= 1;
95        Self::new(values).map(Some)
96    }
97
98    /// Removes trailing zeroes while keeping at least one component.
99    #[must_use]
100    pub fn trimmed(&self) -> Self {
101        let mut values = self.values.clone();
102
103        while values.len() > 1 && values.last() == Some(&0) {
104            values.pop();
105        }
106
107        Self { values }
108    }
109}
110
111/// Returns the total face count `F = m2 + m3 + m4 + ...`.
112#[must_use]
113pub fn face_count(m: &TypeVector) -> u64 {
114    m.values.iter().copied().sum()
115}
116
117/// Returns `E = 1 + 2m2 + 3m3 + 4m4 + ...` using checked `u128` arithmetic.
118///
119/// # Errors
120///
121/// Returns [`GeodeError::ArithmeticOverflow`] when any intermediate weighted
122/// sum exceeds `u128`.
123pub fn polygon_edge_count(m: &TypeVector) -> Result<u128, GeodeError> {
124    weighted_sum_with_offset(m, 1, 2)
125}
126
127/// Returns `V = 2 + m2 + 2m3 + 3m4 + ...` using checked `u128` arithmetic.
128///
129/// # Errors
130///
131/// Returns [`GeodeError::ArithmeticOverflow`] when any intermediate weighted
132/// sum exceeds `u128`.
133pub fn polygon_vertex_count(m: &TypeVector) -> Result<u128, GeodeError> {
134    weighted_sum_with_offset(m, 2, 1)
135}
136
137/// Returns `n!` using checked `u128` arithmetic.
138///
139/// # Errors
140///
141/// Returns [`GeodeError::ArithmeticOverflow`] when the factorial no longer
142/// fits in `u128`.
143pub fn checked_factorial(n: u64) -> Result<u128, GeodeError> {
144    let mut result = 1_u128;
145    let mut factor = 2_u64;
146
147    while factor <= n {
148        result = result
149            .checked_mul(u128::from(factor))
150            .ok_or(GeodeError::ArithmeticOverflow)?;
151        factor += 1;
152    }
153
154    Ok(result)
155}
156
157/// Returns the checked product of each factorial in `values`.
158///
159/// # Errors
160///
161/// Returns [`GeodeError::ArithmeticOverflow`] when any constituent factorial or
162/// the final product no longer fits in `u128`.
163pub fn checked_product_factorials(values: &[u64]) -> Result<u128, GeodeError> {
164    let mut product = 1_u128;
165
166    for value in values {
167        product = product
168            .checked_mul(checked_factorial(*value)?)
169            .ok_or(GeodeError::ArithmeticOverflow)?;
170    }
171
172    Ok(product)
173}
174
175/// Returns an exact integer quotient.
176///
177/// # Errors
178///
179/// Returns [`GeodeError::DivisionNotExact`] when `denominator == 0` or the
180/// division leaves a non-zero remainder.
181pub const fn exact_divide(numerator: u128, denominator: u128) -> Result<u128, GeodeError> {
182    if denominator == 0 || !numerator.is_multiple_of(denominator) {
183        return Err(GeodeError::DivisionNotExact);
184    }
185
186    Ok(numerator / denominator)
187}
188
189/// Returns the finite-type hyper-Catalan coefficient for `m`.
190///
191/// # Errors
192///
193/// Returns [`GeodeError::ArithmeticOverflow`] when structural counts or
194/// factorials exceed the exact integer range.
195///
196/// Returns [`GeodeError::DivisionNotExact`] when the closed form does not divide
197/// exactly in `u128` arithmetic.
198pub fn hyper_catalan(m: &TypeVector) -> Result<u128, GeodeError> {
199    let edge_count = polygon_edge_count(m)?;
200    let vertex_count = polygon_vertex_count(m)?;
201    let numerator = checked_factorial(
202        u64::try_from(edge_count - 1).map_err(|_| GeodeError::ArithmeticOverflow)?,
203    )?;
204    let vertex_factorial = checked_factorial(
205        u64::try_from(vertex_count - 1).map_err(|_| GeodeError::ArithmeticOverflow)?,
206    )?;
207    let face_factorials = checked_product_factorials(m.values())?;
208    let denominator = vertex_factorial
209        .checked_mul(face_factorials)
210        .ok_or(GeodeError::ArithmeticOverflow)?;
211
212    exact_divide(numerator, denominator)
213}
214
215/// Returns the Geode coefficient `G[m]` for small exact inputs.
216///
217/// This direct version uses the defining recurrence and is intended only for
218/// small vectors.
219///
220/// # Errors
221///
222/// Returns [`GeodeError::ArithmeticOverflow`] when the underlying hyper-Catalan
223/// count overflows or the recurrence would subtract past zero.
224///
225/// Returns [`GeodeError::DivisionNotExact`] when the underlying hyper-Catalan
226/// closed form does not divide exactly in `u128` arithmetic.
227pub fn geode(m: &TypeVector) -> Result<u128, GeodeError> {
228    geode_impl(m, None)
229}
230
231/// Returns the Geode coefficient `G[m]` using internal memoization.
232///
233/// # Errors
234///
235/// Returns the same errors as [`geode`].
236pub fn geode_memoized(m: &TypeVector) -> Result<u128, GeodeError> {
237    let mut memo = HashMap::new();
238    geode_impl(m, Some(&mut memo))
239}
240
241/// Returns the one-dimensional Catalan value that matches `[n]`.
242///
243/// # Errors
244///
245/// Returns [`GeodeError::ArithmeticOverflow`] when the corresponding Catalan
246/// number no longer fits in `u128`.
247pub fn catalan_from_geode_dimension(n: u64) -> Result<u128, GeodeError> {
248    use_catalan::catalan(n).map_err(|_| GeodeError::ArithmeticOverflow)
249}
250
251/// Returns the Geode coefficient along the first axis `[n]`.
252///
253/// # Errors
254///
255/// Returns any validation or arithmetic error produced while constructing the
256/// axis vector or evaluating [`geode`].
257pub fn geode_on_first_axis(n: u64) -> Result<u128, GeodeError> {
258    geode(&TypeVector::new(vec![n])?)
259}
260
261/// Returns the diagonal Geode coefficient for `[n, n]`.
262///
263/// # Errors
264///
265/// Returns any validation or arithmetic error produced while constructing the
266/// diagonal vector or evaluating [`geode`].
267pub fn diagonal_geode_2d(n: u64) -> Result<u128, GeodeError> {
268    geode(&TypeVector::new(vec![n, n])?)
269}
270
271/// Returns the diagonal Geode coefficient for `[n, n, n]`.
272///
273/// # Errors
274///
275/// Returns any validation or arithmetic error produced while constructing the
276/// diagonal vector or evaluating [`geode`].
277pub fn diagonal_geode_3d(n: u64) -> Result<u128, GeodeError> {
278    geode(&TypeVector::new(vec![n, n, n])?)
279}
280
281/// Returns the diagonal Geode coefficient for `[n, n, n, n]`.
282///
283/// # Errors
284///
285/// Returns any validation or arithmetic error produced while constructing the
286/// diagonal vector or evaluating [`geode`].
287pub fn diagonal_geode_4d(n: u64) -> Result<u128, GeodeError> {
288    geode(&TypeVector::new(vec![n, n, n, n])?)
289}
290
291fn weighted_sum_with_offset(
292    m: &TypeVector,
293    constant_term: u128,
294    first_weight: u128,
295) -> Result<u128, GeodeError> {
296    let mut total = constant_term;
297
298    for (index, value) in m.values.iter().enumerate() {
299        let index_weight = u128::try_from(index).map_err(|_| GeodeError::ArithmeticOverflow)?;
300        let weight = first_weight
301            .checked_add(index_weight)
302            .ok_or(GeodeError::ArithmeticOverflow)?;
303        let contribution = weight
304            .checked_mul(u128::from(*value))
305            .ok_or(GeodeError::ArithmeticOverflow)?;
306        total = total
307            .checked_add(contribution)
308            .ok_or(GeodeError::ArithmeticOverflow)?;
309    }
310
311    Ok(total)
312}
313
314fn geode_impl(
315    m: &TypeVector,
316    mut memo: Option<&mut HashMap<Vec<u64>, u128>>,
317) -> Result<u128, GeodeError> {
318    if let Some(cache) = &mut memo
319        && let Some(value) = cache.get(m.values())
320    {
321        return Ok(*value);
322    }
323
324    let incremented = m.incremented(0)?;
325    let mut result = hyper_catalan(&incremented)?;
326
327    for index in 1..m.dimension() {
328        if m.values()[index] == 0 {
329            continue;
330        }
331
332        let shifted = incremented
333            .decremented(index)?
334            .ok_or(GeodeError::ArithmeticOverflow)?;
335        let term = geode_impl(&shifted.trimmed(), memo.as_deref_mut())?;
336        result = result
337            .checked_sub(term)
338            .ok_or(GeodeError::ArithmeticOverflow)?;
339    }
340
341    if let Some(cache) = memo {
342        cache.insert(m.values.clone(), result);
343    }
344
345    Ok(result)
346}
347
348#[cfg(test)]
349mod tests {
350    use use_catalan::catalan;
351
352    use super::{
353        GeodeError, TypeVector, catalan_from_geode_dimension, checked_factorial,
354        checked_product_factorials, diagonal_geode_2d, diagonal_geode_3d, diagonal_geode_4d,
355        exact_divide, face_count, geode, geode_memoized, geode_on_first_axis, hyper_catalan,
356        polygon_edge_count, polygon_vertex_count,
357    };
358
359    #[test]
360    fn rejects_empty_type_vectors() {
361        assert_eq!(TypeVector::new(vec![]), Err(GeodeError::EmptyTypeVector));
362    }
363
364    #[test]
365    fn accepts_zero_type_vectors() -> Result<(), GeodeError> {
366        let vector = TypeVector::new(vec![0])?;
367
368        assert_eq!(vector.values(), &[0]);
369        assert_eq!(vector.dimension(), 1);
370        assert!(vector.is_zero());
371
372        Ok(())
373    }
374
375    #[test]
376    fn preserves_dimension_and_face_counts() -> Result<(), GeodeError> {
377        let vector = TypeVector::new(vec![2, 1, 0])?;
378
379        assert_eq!(vector.dimension(), 3);
380        assert_eq!(vector.face_count(), 3);
381        assert_eq!(face_count(&vector), 3);
382        assert_eq!(vector.trimmed().dimension(), 2);
383
384        Ok(())
385    }
386
387    #[test]
388    fn computes_polygon_edge_counts() -> Result<(), GeodeError> {
389        assert_eq!(polygon_edge_count(&TypeVector::new(vec![0])?)?, 1);
390        assert_eq!(polygon_edge_count(&TypeVector::new(vec![2, 1])?)?, 8);
391
392        Ok(())
393    }
394
395    #[test]
396    fn computes_polygon_vertex_counts() -> Result<(), GeodeError> {
397        assert_eq!(polygon_vertex_count(&TypeVector::new(vec![0])?)?, 2);
398        assert_eq!(polygon_vertex_count(&TypeVector::new(vec![2, 1])?)?, 6);
399
400        Ok(())
401    }
402
403    #[test]
404    fn increments_and_decrements_components() -> Result<(), GeodeError> {
405        let vector = TypeVector::new(vec![1, 2, 0])?;
406        let decremented = vector.decremented(1)?;
407
408        assert_eq!(vector.incremented(1)?.values(), &[1, 3, 0]);
409        assert_eq!(decremented, Some(TypeVector::new(vec![1, 1, 0])?));
410
411        Ok(())
412    }
413
414    #[test]
415    fn decrementing_zero_component_returns_none() -> Result<(), GeodeError> {
416        let vector = TypeVector::new(vec![1, 0])?;
417
418        assert_eq!(vector.decremented(1)?, None);
419
420        Ok(())
421    }
422
423    #[test]
424    fn reports_invalid_indices() -> Result<(), GeodeError> {
425        let vector = TypeVector::new(vec![1, 2])?;
426
427        assert_eq!(vector.incremented(2), Err(GeodeError::IndexOutOfBounds));
428        assert_eq!(vector.decremented(3), Err(GeodeError::IndexOutOfBounds));
429
430        Ok(())
431    }
432
433    #[test]
434    fn computes_checked_factorials() {
435        assert_eq!(checked_factorial(0), Ok(1));
436        assert_eq!(checked_factorial(1), Ok(1));
437        assert_eq!(checked_factorial(5), Ok(120));
438    }
439
440    #[test]
441    fn computes_factorial_products() {
442        assert_eq!(checked_product_factorials(&[0, 1, 3]), Ok(6));
443        assert_eq!(checked_product_factorials(&[2, 2]), Ok(4));
444    }
445
446    #[test]
447    fn performs_exact_division() {
448        assert_eq!(exact_divide(12, 3), Ok(4));
449        assert_eq!(exact_divide(120, 10), Ok(12));
450    }
451
452    #[test]
453    fn rejects_non_exact_division() {
454        assert_eq!(exact_divide(10, 3), Err(GeodeError::DivisionNotExact));
455        assert_eq!(exact_divide(10, 0), Err(GeodeError::DivisionNotExact));
456    }
457
458    #[test]
459    fn one_dimensional_hyper_catalan_matches_catalan() -> Result<(), GeodeError> {
460        for n in 0_u64..=5 {
461            let vector = TypeVector::new(vec![n])?;
462            let catalan_value = catalan(n).map_err(|_| GeodeError::ArithmeticOverflow)?;
463
464            assert_eq!(hyper_catalan(&vector)?, catalan_value);
465            assert_eq!(catalan_from_geode_dimension(n)?, catalan_value);
466        }
467
468        Ok(())
469    }
470
471    #[test]
472    fn computes_small_multidimensional_hyper_catalan_values() -> Result<(), GeodeError> {
473        assert_eq!(hyper_catalan(&TypeVector::new(vec![1, 1])?)?, 5);
474        assert_eq!(hyper_catalan(&TypeVector::new(vec![2, 1])?)?, 21);
475        assert_eq!(hyper_catalan(&TypeVector::new(vec![1, 0, 1])?)?, 6);
476
477        Ok(())
478    }
479
480    #[test]
481    fn computes_small_geode_values() -> Result<(), GeodeError> {
482        assert_eq!(geode(&TypeVector::new(vec![0])?)?, 1);
483        assert_eq!(geode_on_first_axis(1)?, 2);
484        assert_eq!(geode(&TypeVector::new(vec![1, 0])?)?, 2);
485        assert_eq!(geode(&TypeVector::new(vec![0, 1])?)?, 3);
486        assert_eq!(geode(&TypeVector::new(vec![1, 1])?)?, 16);
487
488        Ok(())
489    }
490
491    #[test]
492    fn memoized_geode_matches_direct_computation() -> Result<(), GeodeError> {
493        let vectors = [
494            TypeVector::new(vec![0])?,
495            TypeVector::new(vec![1])?,
496            TypeVector::new(vec![2])?,
497            TypeVector::new(vec![0, 1])?,
498            TypeVector::new(vec![1, 1])?,
499            TypeVector::new(vec![2, 1])?,
500            TypeVector::new(vec![1, 0, 1])?,
501        ];
502
503        for vector in vectors {
504            assert_eq!(geode_memoized(&vector)?, geode(&vector)?);
505        }
506
507        Ok(())
508    }
509
510    #[test]
511    fn diagonal_helpers_are_deterministic() -> Result<(), GeodeError> {
512        assert_eq!(diagonal_geode_2d(0)?, geode(&TypeVector::new(vec![0, 0])?)?);
513        assert_eq!(diagonal_geode_2d(1)?, geode(&TypeVector::new(vec![1, 1])?)?);
514        assert_eq!(
515            diagonal_geode_3d(1)?,
516            geode(&TypeVector::new(vec![1, 1, 1])?)?
517        );
518        assert_eq!(
519            diagonal_geode_4d(1)?,
520            geode(&TypeVector::new(vec![1, 1, 1, 1])?)?
521        );
522
523        Ok(())
524    }
525
526    #[test]
527    fn reports_overflow_where_practical() {
528        assert_eq!(checked_factorial(35), Err(GeodeError::ArithmeticOverflow));
529        assert_eq!(
530            TypeVector::new(vec![u64::MAX, 1]),
531            Err(GeodeError::InvalidInput)
532        );
533    }
534}