1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
use crate::arithmetic::bits_256::represent::*;
use crate::common::{Bits, Group, PrimeField, Projective};

/// The projective coordinate addition
pub fn add_point<P: Projective>(lhs: P, rhs: P) -> P {
    if lhs.is_identity() {
        return rhs;
    } else if rhs.is_identity() {
        return lhs;
    }

    let s1 = lhs.get_y() * rhs.get_z();
    let s2 = rhs.get_y() * lhs.get_z();
    let u1 = lhs.get_x() * rhs.get_z();
    let u2 = rhs.get_x() * lhs.get_z();

    if u1 == u2 {
        if s1 == s2 {
            return double_point(lhs);
        } else {
            return <P as Group>::ADDITIVE_IDENTITY;
        }
    }

    let s = s1 - s2;
    let u = u1 - u2;
    let uu = u.square();
    let v = lhs.get_z() * rhs.get_z();
    let w = s.square() * v - uu * (u1 + u2);
    let uuu = uu * u;

    let mut res = <P as Default>::default();
    res.set_x(u * w);
    res.set_y(s * (u1 * uu - w) - s1 * uuu);
    res.set_z(uuu * v);
    res
}

/// The projective coordinate doubling
pub fn double_point<P: Projective>(point: P) -> P {
    if point.is_identity() || point.get_y().is_zero() {
        <P as Group>::ADDITIVE_IDENTITY
    } else {
        let xx = point.get_x().square();
        let t = xx.double() + xx;
        let u = (point.get_y() * point.get_z()).double();
        let v = (u * point.get_x() * point.get_y()).double();
        let w = t.square() - v.double();
        let uu = u.square();
        let mut res = <P as Default>::default();
        res.set_x(u * w);
        res.set_y(t * (v - w) - (uu * point.get_y().square()).double());
        res.set_z(uu * u);
        res
    }
}

pub fn scalar_point<P: Projective>(mut base: P, scalar: [u64; 4], mut identity: P) -> P {
    let bits = to_bits(scalar)
        .into_iter()
        .skip_while(|&x| x == 0)
        .collect::<Bits>();
    for &bit in bits.iter().rev() {
        if bit == 1 {
            identity += base;
        }
        base = double_point(base);
    }
    identity
}