1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
//! Checks SSL certificate expiration.
//!
//! This crate will try to connect a remote server and check SSL certificate expiration.
//!
//! Example:
//!
//! ```rust
//! use ssl_expiration::SslExpiration;
//!
//! let expiration = SslExpiration::from_domain_name("google.com").unwrap();
//! if expiration.is_expired() {
//!     // do something if SSL certificate expired
//! }
//! ```

extern crate foreign_types_shared;
extern crate openssl;
extern crate openssl_sys;
#[macro_use]
extern crate error_chain;

use std::os::raw::c_int;
use std::net::{TcpStream, ToSocketAddrs};
use std::error::Error;
use openssl::ssl::{Ssl, SslContext, SslMethod, SslVerifyMode};
use openssl::asn1::Asn1Time;
use openssl_sys::ASN1_TIME;
use foreign_types_shared::{ForeignType,ForeignTypeRef};
use error::Result;


extern "C" {
    fn ASN1_TIME_diff(pday: *mut c_int,
                      psec: *mut c_int,
                      from: *const ASN1_TIME,
                      to: *const ASN1_TIME);
}


pub struct SslExpiration(c_int);


impl SslExpiration {
    /// Creates new SslExpiration from domain name.
    ///
    /// This function will use HTTPS port (443) to check SSL certificate.
    pub fn from_domain_name(domain: &str) -> Result<SslExpiration> {
        SslExpiration::from_addr(format!("{}:443", domain))
    }

    /// Creates new SslExpiration from SocketAddr.
    pub fn from_addr<A: ToSocketAddrs>(addr: A) -> Result<SslExpiration> {
        let context = {
            let mut context = SslContext::builder(SslMethod::tls())?;
            context.set_verify(SslVerifyMode::empty());
            context.build()
        };
        let connector = Ssl::new(&context)?;
        let stream = TcpStream::connect(addr)?;
        let stream = connector.connect(stream)
            .map_err(|e| error::ErrorKind::HandshakeError(e.description().to_owned()))?;
        let cert = stream.ssl()
            .peer_certificate()
            .ok_or("Certificate not found")?;

        let now = Asn1Time::days_from_now(0)?;

        let (mut pday, mut psec) = (0, 0);
        unsafe {
            let ptr_pday: *mut c_int = &mut pday;
            let ptr_psec: *mut c_int = &mut psec;
            ASN1_TIME_diff(ptr_pday,
                           ptr_psec,
                           now.as_ptr(),
                           cert.not_after().as_ptr());
        }

        Ok(SslExpiration(pday * 24 * 60 * 60 - psec))
    }

    /// How many seconds until SSL certificate expires.
    ///
    /// This function will return minus if SSL certificate is already expired.
    pub fn secs(&self) -> i32 {
        self.0
    }

    /// How many days until SSL certificate expires
    ///
    /// This function will return minus if SSL certificate is already expired.
    pub fn days(&self) -> i32 {
        self.0 / 60 / 60 / 24
    }

    /// Returns true if SSL certificate is expired
    pub fn is_expired(&self) -> bool {
        self.0 < 0
    }
}



pub mod error {
    use std::io;
    use openssl;

    error_chain! {
        foreign_links {
            OpenSslErrorStack(openssl::error::ErrorStack);
            IoError(io::Error);
        }
        errors {
            HandshakeError(e: String) {
                description("HandshakeError")
                display("HandshakeError: {}", e)
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    #[test]
    fn test_ssl_expiration() {
        assert!(!SslExpiration::from_domain_name("google.com").unwrap().is_expired());
        assert!(SslExpiration::from_domain_name("expired.identrustssl.com").unwrap().is_expired());
    }
}