1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
use base64::{self, URL_SAFE_NO_PAD};
use error::WebPushError;
use hyper::Uri;
use openssl::hash::MessageDigest;
use openssl::pkey::PKey;
use openssl::sign::Signer as SslSigner;
use serde_json;
use serde_json::{Number, Value};
use std::collections::BTreeMap;
use time;
use vapid::VapidKey;

lazy_static! {
    static ref JWT_HEADERS: String = base64::encode_config(
        &serde_json::to_string(&json!({
            "typ": "JWT",
            "alg": "ES256"
        }))
        .unwrap(),
        URL_SAFE_NO_PAD
    );
}

/// A struct representing a VAPID signature. Should be generated using the
/// [VapidSignatureBuilder](struct.VapidSignatureBuilder.html).
#[derive(Debug)]
pub struct VapidSignature {
    /// The signature
    pub auth_t: String,
    /// The public key
    pub auth_k: String,
}

impl<'a> Into<String> for &'a VapidSignature {
    fn into(self) -> String {
        format!("WebPush {}", self.auth_t)
    }
}

pub struct VapidSigner {}

impl VapidSigner {
    /// Create a signature with a given key. Sets the default audience from the
    /// endpoint host and sets the expiry in twelve hours. Values can be
    /// overwritten by adding the `aud` and `exp` claims.
    pub fn sign(
        key: VapidKey,
        endpoint: &Uri,
        mut claims: BTreeMap<&str, Value>,
    ) -> Result<VapidSignature, WebPushError> {
        if !claims.contains_key("aud") {
            let audience = format!(
                "{}://{}",
                endpoint.scheme_part().unwrap(),
                endpoint.host().unwrap()
            );
            claims.insert("aud", Value::String(audience));
        }

        if !claims.contains_key("exp") {
            let expiry = time::now_utc() + time::Duration::hours(12);
            let number = Number::from(expiry.to_timespec().sec);
            claims.insert("exp", Value::Number(number));
        }

        let signing_input = format!(
            "{}.{}",
            *JWT_HEADERS,
            base64::encode_config(&serde_json::to_string(&claims)?, URL_SAFE_NO_PAD)
        );

        let public_key = key.public_key();
        let auth_k = base64::encode_config(&public_key, URL_SAFE_NO_PAD);
        let pkey = PKey::from_ec_key(key.0)?;

        let mut signer = SslSigner::new(MessageDigest::sha256(), &pkey)?;
        signer.update(signing_input.as_bytes())?;

        let signature = signer.sign_to_vec()?;

        let r_off: usize = 3;
        let r_len = signature[r_off] as usize;
        let s_off: usize = r_off + r_len + 2;
        let s_len = signature[s_off] as usize;

        let mut r_val = &signature[(r_off + 1)..(r_off + 1 + r_len)];
        let mut s_val = &signature[(s_off + 1)..(s_off + 1 + s_len)];

        if r_len == 33 && r_val[0] == 0 {
            r_val = &r_val[1..];
        }

        if s_len == 33 && s_val[0] == 0 {
            s_val = &s_val[1..];
        }

        let mut sigval: Vec<u8> = Vec::with_capacity(64);
        sigval.extend(r_val);
        sigval.extend(s_val);

        trace!("Public key: {}", auth_k);

        let auth_t = format!(
            "{}.{}",
            signing_input,
            base64::encode_config(&sigval, URL_SAFE_NO_PAD)
        );

        Ok(VapidSignature { auth_t, auth_k })
    }
}

#[cfg(test)]
mod tests {
    use vapid::VapidSignature;

    #[test]
    fn test_vapid_signature_aesgcm_format() {
        let vapid_signature = &VapidSignature {
            auth_t: String::from("foo"),
            auth_k: String::from("bar"),
        };

        let header_value: String = vapid_signature.into();

        assert_eq!("WebPush foo", &header_value);
    }
}