1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
use std::hash::Hasher;
use std::io::{self, Write};
use std::{cell::Cell, collections::hash_map::DefaultHasher};
use std::{convert::TryInto, hash::Hash};

use super::{LazyMessage, Socket};
use crate::{codec::FrameBuf, stream::Stream};

#[derive(Clone, Debug)]
enum SubscriptionTopic {
    /// An empty topic (matches everything.)
    Empty,

    /// A literal topic is any topic 8 bytes or smaller.
    ///
    /// We store the literal to avoid extrenious hashing of small prefixes.
    ///
    Literal([u8; 8]),

    /// A hashed topic is the hash of any topic larger than 8 bytes.
    ///
    /// It matches if the hash of the first `length` bytes slice matches
    /// `value`.
    ///
    Hashed { value: u64, length: u8 },
}

impl From<Stream> for Sub {
    fn from(inner: Stream) -> Self {
        Self { inner: Cell::new(inner), topics: vec![] }
    }
}

/// A ZMQ SUB socket.
pub struct Sub {
    inner: Cell<Stream>,
    topics: Vec<SubscriptionTopic>,
}

impl Sub {
    /// Subscribe to a topic.
    pub fn subscribe(&mut self, topic: &[u8]) -> io::Result<()> {
        // Note down the subscribing topic locally for prefix matching
        // when receiving (its a new block because I wanted to reuse "topic" as a name.)
        {
            let slim_topic: Result<[u8; 8], _> = topic.try_into();
            let topic_entry = match (topic.len(), slim_topic) {
                (0, _) => SubscriptionTopic::Empty,
                (_, Ok(slim)) => SubscriptionTopic::Literal(slim),
                (length, _) => {
                    let mut s = DefaultHasher::new();
                    topic.hash(&mut s);
                    let value = s.finish();
                    let length = length
                        .try_into()
                        .expect("Subscription topics can only take 255 bytes maximum");
                    SubscriptionTopic::Hashed { value, length }
                }
            };

            self.topics.push(topic_entry);
        }

        let subscribe = if false {
            // The below code is acceptable for ZMTP 3.1 but not for 3.0 (which is what we are by default.)

            let mut subscribe = vec![
                0x4, // SHORT COMMAND
                0x0, // LENGTH OF FRAME
                // subscribe tag `0xd0 | "SUBSCRIBE".len()`
                // don't ask me why there's a 0xd0 in there
                0xd9,
            ];

            subscribe.extend_from_slice("SUBSCRIBE".as_bytes());
            subscribe.extend_from_slice(topic);
            subscribe[1] = subscribe.len() as u8;

            subscribe
        } else {
            let mut subscribe = vec![0x00, 0xFF, 0x1];

            subscribe.extend_from_slice(&topic);
            subscribe[1] = 1 + topic.len() as u8;
            subscribe
        };

        self.inner
            .get_mut()
            .ensure_connected()
            .write(&subscribe)
            .map(|_| ())
    }

    /// Recieve a message that matches a subscribed topic prefix.
    #[inline]
    pub fn recv(&mut self) -> io::Result<Vec<Vec<u8>>> {
        fn topic_prefix_match(expected: &SubscriptionTopic, bytes: &[u8]) -> bool {
            match expected {
                SubscriptionTopic::Empty => true,
                SubscriptionTopic::Literal(sl) => bytes.starts_with(sl),
                SubscriptionTopic::Hashed { value, length } => {
                    let mut s = DefaultHasher::new();
                    let tail = &bytes[..(*length as usize)];
                    tail.hash(&mut s);
                    s.finish() == *value
                }
            }
        }

        let stream = self.inner.get_mut();

        loop {
            let mut stream = LazyMessage {
                stream,
                witness: false,
            }
            .fuse();
            let first_frame = stream
                .next()
                .expect("There should always be one frame in a message.")
                .unwrap();
            let frame = first_frame.as_frame().try_into_message().unwrap();

            let prefix_match = |topic| topic_prefix_match(topic, &frame.body());

            if self.topics.iter().any(prefix_match) {
                let collected = if !frame.is_last() {
                    stream.map(|frame| frame.unwrap().into()).collect()
                } else {
                    vec![first_frame.into()]
                };

                return Ok(collected);
            }
        }
    }

    /// Receive a multipart message without performing prefix checks.
    #[inline]
    pub fn recv_unchecked(&mut self) -> io::Result<Vec<Vec<u8>>> {
        <Self as Socket>::recv(self)
    }
}

impl Socket for Sub {
    fn stream(&mut self) -> &mut crate::stream::Stream {
        self.inner.get_mut()
    }
}