1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
//! Client-side LZ4 twist extension implementation.
use {PMLZ4, RSV2, SWE};
use lz4_compress::{compress, decompress};
use slog::Logger;
use std::io;
use twist::client::BaseFrame;
use twist::extension::{Header, PerMessage};
use util;

#[derive(Default)]
/// The lz4 server-side extension configuration.
pub struct Lz4 {
    /// Is this extension enabled?
    enabled: bool,
    /// slog stdout `Logger`
    stdout: Option<Logger>,
    /// slog stderr `Logger`
    stderr: Option<Logger>,
}

impl Lz4 {
    /// Set the `enabled` flag.
    pub fn set_enabled(&mut self, enabled: bool) -> &mut Lz4 {
        self.enabled = enabled;
        self
    }

    /// Add a stdout slog `Logger` to this protocol.
    pub fn stdout(&mut self, logger: Logger) -> &mut Lz4 {
        let stdout = logger.new(o!("extension" => "lz4", "module" => "client"));
        self.stdout = Some(stdout);
        self
    }

    /// Add a stderr slog `Logger` to this protocol.
    pub fn stderr(&mut self, logger: Logger) -> &mut Lz4 {
        let stderr = logger.new(o!("extension" => "lz4", "module" => "client"));
        self.stderr = Some(stderr);
        self
    }
}

impl Header for Lz4 {
    fn from_header(&mut self, header: &str) -> Result<(), io::Error> {
        try_trace!(self.stdout, "from_header");
        if header.contains(PMLZ4) {
            try_trace!(self.stdout, "permessage-lz4 is enabled");
            self.enabled = true;
        } else {
            try_trace!(self.stdout, "permessage lz4 is disabled");
            self.enabled = false;
        }
        Ok(())
    }

    fn into_header(&mut self) -> io::Result<Option<String>> {
        try_trace!(self.stdout, "into_header");
        if self.enabled {
            let mut resp = String::new();
            resp.push_str(SWE);
            resp.push_str(PMLZ4);
            Ok(Some(resp))
        } else {
            Ok(None)
        }
    }
}

impl PerMessage for Lz4 {
    fn enabled(&self) -> bool {
        try_trace!(self.stdout, "enabled");
        self.enabled
    }

    fn reserve_rsv(&self, reserved: u8) -> Result<u8, io::Error> {
        try_trace!(self.stdout, "reserve_rsv");
        if self.enabled {
            if reserved & RSV2 == 0 {
                Ok(reserved | RSV2)
            } else {
                try_error!(self.stderr, "rsv2 bit is already reserved");
                Err(util::other("rsv2 bit is already reserved"))
            }
        } else {
            Ok(reserved)
        }
    }

    fn decode(&self, frame: &mut BaseFrame) -> Result<(), io::Error> {
        try_trace!(self.stdout, "decode");
        if frame.rsv2() {
            let (len, decompressed) = match decompress(frame.application_data()) {
                Ok(decompressed) => {
                    try_trace!(self.stdout, "full\n{}", util::as_hex(&decompressed));
                    (decompressed.len() as u64, decompressed)
                }
                Err(e) => {
                    try_error!(self.stderr, "{}", e);
                    return Err(util::other("unable to decompress app data"));
                }
            };

            frame.set_payload_length(len);
            frame.set_application_data(decompressed);
        }
        Ok(())
    }

    fn encode(&self, frame: &mut BaseFrame) -> Result<(), io::Error> {
        try_trace!(self.stdout, "encode");
        if frame.rsv2() {
            let compressed = compress(frame.application_data());
            try_trace!(self.stdout, "compressed\n{}", util::as_hex(&compressed));
            frame.set_payload_length(compressed.len() as u64);
            frame.set_application_data(compressed);
        }
        Ok(())
    }
}