use core::{
cmp::PartialEq,
ops::{
Add,
Neg,
Mul,
}
};
use subtle::{
Choice,
ConditionallySelectable,
ConstantTimeEq,
};
use crate::{
Error,
Result,
constants::COMPRESSED_Y_LENGTH,
field::{
FieldImplementation,
FieldElement,
},
scalar::Scalar,
};
#[derive(Clone,Copy,Debug,Default)]
pub struct CurvePoint (
[FieldElement; 4]
);
#[derive(Clone,Copy,Debug,Default)]
pub struct CompressedY(
pub [u8; 32])
;
impl From<&[u8; COMPRESSED_Y_LENGTH]> for CompressedY {
fn from(bytes: &[u8; COMPRESSED_Y_LENGTH]) -> CompressedY {
CompressedY(bytes.clone())
}
}
impl CompressedY {
pub fn decompressed(&self) -> Result<CurvePoint> {
#![allow(non_snake_case)]
let Y = FieldElement::from_bytes_unchecked(self.as_bytes());
let Z = FieldElement::ONE;
let Y_squared = Y.squared();
let u = &Y_squared - &Z;
let v = &(&Y_squared * &FieldElement::D) + &Z;
let v2 = v.squared();
let v4 = v2.squared();
let v7 = &(&v4 * &v2) * &v;
let t = &v7 * &u;
let mut X = &(&(&t.pow2523() * &u) * &v2) * &v;
let chk = &X.squared() * &v;
if chk != u {
X = &X * &FieldElement::I;
}
let chk = &X.squared() * &v;
if chk != u {
return Err(Error::PublicKeyBytesInvalid);
}
if X.parity() != (self.0[31] >> 7) {
X = -&X;
}
let T = &X * &Y;
Ok(CurvePoint([X, Y, Z, T]))
}
}
impl CurvePoint {
pub fn basepoint() -> CurvePoint {
CurvePoint([
FieldElement::BASEPOINT_X,
FieldElement::BASEPOINT_Y,
FieldElement::ONE,
&FieldElement::BASEPOINT_X * &FieldElement::BASEPOINT_Y,
])
}
pub fn neutral_element() -> CurvePoint {
CurvePoint([
FieldElement::ZERO,
FieldElement::ONE,
FieldElement::ONE,
FieldElement::ZERO,
])
}
pub fn compressed(&self) -> CompressedY {
let z_inverse = &self.0[2].inverse();
let x = &self.0[0] * &z_inverse;
let y = &self.0[1] * &z_inverse;
let mut r = y.to_bytes();
r[31] ^= x.parity() << 7;
CompressedY(r)
}
}
impl<'a, 'b> Add<&'b CurvePoint> for &'a CurvePoint {
type Output = CurvePoint;
fn add(self, other: &'b CurvePoint) -> Self::Output {
let p = &self.0;
let q = &other.0;
let a = &p[1] - &p[0];
let t = &q[1] - &q[0];
let a = &a * &t;
let b = &p[0] + &p[1];
let t = &q[0] + &q[1];
let b = &b * &t;
let c = &p[3] * &q[3];
let c = &c * &FieldElement::D2;
let d = &p[2] * &q[2];
let d = &d + &d;
let e = &b - &a;
let f = &d - &c;
let g = &d + &c;
let h = &b + &a;
let coordinates = [
&e * &f,
&h * &g,
&g * &f,
&e * &h,
];
CurvePoint(coordinates)
}
}
impl<'a> Neg for &'a CurvePoint {
type Output = CurvePoint;
fn neg(self) -> CurvePoint {
let p = &self.0;
CurvePoint([-&p[0], p[1], p[2], -&p[3]])
}
}
impl<'a, 'b> Mul<&'b CurvePoint> for &'a Scalar {
type Output = CurvePoint;
fn mul(self, point: &'b CurvePoint) -> CurvePoint {
let mut p = CurvePoint([
FieldElement::ZERO.clone(),
FieldElement::ONE.clone(),
FieldElement::ONE.clone(),
FieldElement::ZERO.clone(),
]);
let mut q = point.clone();
let scalar = self;
for i in (0..=255).rev() {
let b = (((scalar.0[i / 8] >> (i & 7)) & 1) as u8).into();
CurvePoint::conditional_swap(&mut p, &mut q, b);
q = &q + &p;
p = &p + &p;
CurvePoint::conditional_swap(&mut p, &mut q, b);
}
p
}
}
impl ConditionallySelectable for CurvePoint {
fn conditional_select(p: &Self, q: &Self, choice: Choice) -> Self {
let mut selection = Self::default();
for (i, (pi, qi)) in p.0.iter().zip(q.0.iter()).enumerate() {
selection.0[i] = FieldElement::conditional_select(pi, qi, choice);
}
selection
}
fn conditional_swap(p: &mut Self, q: &mut Self, choice: Choice) {
for (pi, qi) in p.0.iter_mut().zip(q.0.iter_mut()) {
FieldElement::conditional_swap(pi, qi, choice);
}
}
}
impl CompressedY {
pub fn as_bytes(&self) -> &[u8; 32] {
&self.0
}
pub fn to_bytes(&self) -> [u8; 32] {
self.0
}
}
impl ConstantTimeEq for CompressedY {
fn ct_eq(&self, other: &Self) -> Choice {
self.0.ct_eq(&other.0)
}
}
impl PartialEq for CompressedY {
fn eq(&self, other: &Self) -> bool {
bool::from(self.ct_eq(other))
}
}
impl ConstantTimeEq for CurvePoint {
fn ct_eq(&self, other: &Self) -> Choice {
let self_compressed = self.compressed();
let other_compressed = other.compressed();
self_compressed.ct_eq(&other_compressed)
}
}
impl PartialEq for CurvePoint {
fn eq(&self, other: &Self) -> bool {
bool::from(self.ct_eq(other))
}
}
#[cfg(test)]
mod tests {
use super::CurvePoint;
use crate::Scalar;
#[test]
fn test_neutral_is_neutral() {
let n = 42;
let s = Scalar::from_bytes(&[
n, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
]);
let ne = CurvePoint::neutral_element();
assert_eq!(ne, &s * &ne);
}
#[test]
fn test_addition_vs_multiplication() {
let p = CurvePoint::basepoint();
let p_plus_p = &p + &p;
let two = Scalar::from_bytes(&[
2u8, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
]);
let two_times_p = &two * &p;
assert_eq!(p_plus_p, two_times_p);
}
#[test]
fn test_some_more() {
let n = 37;
let s = Scalar::from_bytes(&[
n, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
]);
let bp = CurvePoint::basepoint();
let ne = CurvePoint::neutral_element();
let a = (1..=n).fold(ne, |partial_sum, _| &partial_sum + &bp);
let b = &s * &bp;
assert_eq!(a, b);
}
#[test]
fn test_negation() {
let bp = CurvePoint::basepoint();
let minus_bp = -&bp;
let maybe_neutral = &bp + &minus_bp;
assert_eq!(maybe_neutral, CurvePoint::neutral_element());
}
}