1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
use super::*;

#[derive(Default, PartialEq, Eq, Debug)]
pub struct Event {
    pub id: Option<String>,
    pub event: String,
    pub data: String,
}

struct SSEBodyStream {
    body: hyper::Body,
    events: Vec<Event>,
    buf: Vec<u8>,
}

impl SSEBodyStream {
    fn new(body: hyper::Body) -> Self {
        Self {
            body,
            events: Vec::new(),
            buf: Vec::new(),
        }
    }
}

impl Stream for SSEBodyStream {
    type Item = Event;
    type Error = error::Error;

    fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
        if !self.events.is_empty() {
            task::current().notify();
            return Ok(Async::Ready(self.events.pop()));
        }

        match try_ready!(self.body.poll()) {
            None => Ok(Async::Ready(None)),
            Some(chunk) => {
                let mut buf = Vec::new();
                std::mem::swap(&mut buf, &mut self.buf);

                buf.extend_from_slice(&chunk);

                let (mut events, next_buf) = match parse::parse_sse_chunks(buf) {
                    Ok(tup) => tup,
                    Err(e) => {
                        bail!(error::ErrorKind::Protocol(e));
                    }
                };

                self.buf = next_buf;
                events.reverse();
                self.events = events;

                if let Some(ev) = self.events.pop() {
                    Ok(Async::Ready(Some(ev)))
                } else {
                    Ok(Async::NotReady)
                }
            }
        }
    }
}

pub struct SSEStream<C: hyper::client::Connect> {
    url: hyper::Uri,
    client: hyper::Client<C>,

    fut_req: Option<Box<Future<Item = hyper::Response, Error = hyper::Error>>>,
    inner: Option<SSEBodyStream>,

    last_event_id: Option<String>,
}

impl<C: hyper::client::Connect> SSEStream<C> {
    pub fn new(url: hyper::Uri, client: hyper::Client<C>) -> Self {
        Self {
            url,
            client: client.clone(),

            fut_req: None,
            inner: None,

            last_event_id: None,
        }
    }
}

impl<C: hyper::client::Connect> Stream for SSEStream<C> {
    type Item = Event;
    type Error = error::Error;

    fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
        if let Some(mut fut_req) = self.fut_req.take() {
            match fut_req.poll() {
                Err(_e) => {
                    error!("failed to connect, retry: {:?}", _e);
                    // fallthrough
                }
                Ok(Async::NotReady) => {
                    self.fut_req = Some(fut_req);
                    return Ok(Async::NotReady);
                }
                Ok(Async::Ready(resp)) => {
                    info!("sse stream connected: {}", self.url);
                    self.inner = Some(SSEBodyStream::new(resp.body()));
                }
            }
        }

        if let Some(ref mut s) = self.inner {
            match s.poll() {
                Err(_e) => {
                    error!("failed to read body, trying to reconnect: {:?}", _e);
                    // fallthrough
                }
                Ok(Async::NotReady) => return Ok(Async::NotReady),
                Ok(Async::Ready(None)) => {
                    // server drops connection, try to reconnect
                    // fallthrough
                }
                Ok(Async::Ready(Some(ev))) => {
                    if let Some(ref event_id) = ev.id {
                        self.last_event_id = Some(event_id.clone());
                    }

                    return Ok(Async::Ready(Some(ev)));
                }
            }
        }

        // retry case
        self.inner = None;
        info!("trying to connect: {}", self.url);
        let mut req = Request::new(hyper::Method::Get, self.url.clone());

        // set LastEventId
        if let Some(ref last_event_id) = self.last_event_id {
            let headers = req.headers_mut();
            headers.set(LastEventId(last_event_id.clone()));
        }

        let client = self.client.clone();
        let req = tokio_timer::Delay::new(Instant::now() + Duration::from_millis(100))
            .then(move |_| client.request(req));

        self.fut_req = Some(Box::new(req));
        self.poll()
    }
}