use std::error::Error as StdError;
use std::fmt::{self, Display, Formatter, Write};
use std::str::FromStr;
use std::time::Duration;
use futures::{Async, Future, Poll, Stream};
use http::header::{HeaderValue, CACHE_CONTROL, CONTENT_TYPE};
use hyper::Body;
use serde::Serialize;
use serde_json;
use tokio::{clock::now, timer::Delay};
use self::sealed::{
BoxedServerSentEvent, EitherServerSentEvent, SseError, SseField, SseFormat, SseWrapper,
};
use super::{header, header::MissingHeader};
use filter::One;
use reply::{ReplySealed, Response};
use {Filter, Rejection, Reply};
pub trait ServerSentEvent: SseFormat + Sized + Send + 'static {
fn into_a<B>(self) -> EitherServerSentEvent<Self, B> {
EitherServerSentEvent::A(self)
}
fn into_b<A>(self) -> EitherServerSentEvent<A, Self> {
EitherServerSentEvent::B(self)
}
fn boxed(self) -> BoxedServerSentEvent {
BoxedServerSentEvent(Box::new(self))
}
}
impl<T: SseFormat + Send + 'static> ServerSentEvent for T {}
#[allow(missing_debug_implementations)]
struct SseComment<T>(T);
pub fn comment<T>(comment: T) -> impl ServerSentEvent
where
T: Display + Send + 'static,
{
SseComment(comment)
}
impl<T: Display> SseFormat for SseComment<T> {
fn fmt_field(&self, f: &mut Formatter, k: &SseField) -> fmt::Result {
if let SseField::Comment = k {
k.fmt(f)?;
self.0.fmt(f)?;
f.write_char('\n')?;
}
Ok(())
}
}
#[allow(missing_debug_implementations)]
struct SseEvent<T>(T);
pub fn event<T>(event: T) -> impl ServerSentEvent
where
T: Display + Send + 'static,
{
SseEvent(event)
}
impl<T: Display> SseFormat for SseEvent<T> {
fn fmt_field(&self, f: &mut Formatter, k: &SseField) -> fmt::Result {
if let SseField::Event = k {
k.fmt(f)?;
self.0.fmt(f)?;
f.write_char('\n')?;
}
Ok(())
}
}
#[allow(missing_debug_implementations)]
struct SseId<T>(T);
pub fn id<T>(id: T) -> impl ServerSentEvent
where
T: Display + Send + 'static,
{
SseId(id)
}
impl<T: Display> SseFormat for SseId<T> {
fn fmt_field(&self, f: &mut Formatter, k: &SseField) -> fmt::Result {
if let SseField::Id = k {
k.fmt(f)?;
self.0.fmt(f)?;
f.write_char('\n')?;
}
Ok(())
}
}
#[allow(missing_debug_implementations)]
struct SseRetry(Duration);
pub fn retry(time: Duration) -> impl ServerSentEvent {
SseRetry(time)
}
impl SseFormat for SseRetry {
fn fmt_field(&self, f: &mut Formatter, k: &SseField) -> fmt::Result {
if let SseField::Retry = k {
k.fmt(f)?;
let secs = self.0.as_secs();
let millis = self.0.subsec_nanos() / 1_000_000;
if secs > 0 {
secs.fmt(f)?;
if millis < 10 {
f.write_str("00")?;
} else if millis < 100 {
f.write_char('0')?;
}
}
millis.fmt(f)?;
f.write_char('\n')?;
}
Ok(())
}
}
#[allow(missing_debug_implementations)]
struct SseData<T>(T);
pub fn data<T>(data: T) -> impl ServerSentEvent
where
T: Display + Send + 'static,
{
SseData(data)
}
impl<T: Display> SseFormat for SseData<T> {
fn fmt_field(&self, f: &mut Formatter, k: &SseField) -> fmt::Result {
if let SseField::Data = k {
for line in self.0.to_string().split('\n') {
k.fmt(f)?;
line.fmt(f)?;
f.write_char('\n')?;
}
}
Ok(())
}
}
#[allow(missing_debug_implementations)]
struct SseJson<T>(T);
pub fn json<T>(data: T) -> impl ServerSentEvent
where
T: Serialize + Send + 'static,
{
SseJson(data)
}
impl<T: Serialize> SseFormat for SseJson<T> {
fn fmt_field(&self, f: &mut Formatter, k: &SseField) -> fmt::Result {
if let SseField::Data = k {
k.fmt(f)?;
serde_json::to_string(&self.0)
.map_err(|error| {
error!("sse::json error {}", error);
fmt::Error
})
.and_then(|data| data.fmt(f))?;
f.write_char('\n')?;
}
Ok(())
}
}
macro_rules! tuple_fmt {
(($($t:ident),+) => ($($i:tt),+)) => {
impl<$($t),+> SseFormat for ($($t),+)
where
$($t: SseFormat,)+
{
fn fmt_field(&self, f: &mut Formatter, k: &SseField) -> fmt::Result {
$(self.$i.fmt_field(f, k)?;)+
Ok(())
}
}
};
}
tuple_fmt!((A, B) => (0, 1));
tuple_fmt!((A, B, C) => (0, 1, 2));
tuple_fmt!((A, B, C, D) => (0, 1, 2, 3));
tuple_fmt!((A, B, C, D, E) => (0, 1, 2, 3, 4));
tuple_fmt!((A, B, C, D, E, F) => (0, 1, 2, 3, 4, 5));
tuple_fmt!((A, B, C, D, E, F, G) => (0, 1, 2, 3, 4, 5, 6));
tuple_fmt!((A, B, C, D, E, F, G, H) => (0, 1, 2, 3, 4, 5, 6, 7));
pub fn last_event_id<T>() -> impl Filter<Extract = One<Option<T>>, Error = Rejection>
where
T: FromStr + Send,
{
header::header("last-event-id")
.map(Some)
.or_else(|rejection: Rejection| {
if rejection.find_cause::<MissingHeader>().is_some() {
return Ok((None,));
}
Err(rejection)
})
}
pub fn sse() -> impl Filter<Extract = One<Sse>, Error = Rejection> + Copy {
::get2()
.and(
header::exact_ignore_case("connection", "keep-alive").or_else(
|rejection: Rejection| {
if rejection.find_cause::<MissingHeader>().is_some() {
return Ok(());
}
Err(rejection)
},
),
)
.map(|| Sse)
}
pub struct Sse;
impl Sse {
pub fn reply<S>(self, event_stream: S) -> impl Reply
where
S: Stream + Send + 'static,
S::Item: ServerSentEvent,
S::Error: StdError + Send + Sync + 'static,
{
SseReply { event_stream }
}
}
impl fmt::Debug for Sse {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Sse").finish()
}
}
#[allow(missing_debug_implementations)]
struct SseReply<S> {
event_stream: S,
}
impl<S> ReplySealed for SseReply<S>
where
S: Stream + Send + 'static,
S::Item: ServerSentEvent,
S::Error: StdError + Send + Sync + 'static,
{
#[inline]
fn into_response(self) -> Response {
let body_stream = self
.event_stream
.map_err(|error| {
error!("sse stream error: {}", error);
SseError
})
.and_then(|event| SseWrapper::format(&event));
let mut res = Response::new(Body::wrap_stream(body_stream));
res.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
res.headers_mut()
.insert(CACHE_CONTROL, HeaderValue::from_static("no-cache"));
res
}
}
#[allow(missing_debug_implementations)]
struct SseKeepAlive<S> {
event_stream: S,
max_interval: Duration,
alive_timer: Delay,
}
pub fn keep<S>(
event_stream: S,
keep_interval: impl Into<Option<Duration>>,
) -> impl Stream<
Item = impl ServerSentEvent + Send + 'static,
Error = impl StdError + Send + Sync + 'static,
> + Send
+ 'static
where
S: Stream + Send + 'static,
S::Item: ServerSentEvent + Send,
S::Error: StdError + Send + Sync + 'static,
{
let max_interval = keep_interval
.into()
.unwrap_or_else(|| Duration::from_secs(15));
let alive_timer = Delay::new(now() + max_interval);
SseKeepAlive {
event_stream,
max_interval,
alive_timer,
}
}
impl<S> Stream for SseKeepAlive<S>
where
S: Stream + Send + 'static,
S::Item: ServerSentEvent,
S::Error: StdError + Send + Sync + 'static,
{
type Item = EitherServerSentEvent<S::Item, SseComment<&'static str>>;
type Error = SseError;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
match self.event_stream.poll() {
Ok(Async::NotReady) => match self.alive_timer.poll() {
Ok(Async::NotReady) => Ok(Async::NotReady),
Ok(Async::Ready(_)) => {
self.alive_timer.reset(now() + self.max_interval);
Ok(Async::Ready(Some(EitherServerSentEvent::B(SseComment("")))))
}
Err(error) => {
error!("sse::keep error: {}", error);
Err(SseError)
}
},
Ok(Async::Ready(Some(event))) => {
self.alive_timer.reset(now() + self.max_interval);
Ok(Async::Ready(Some(EitherServerSentEvent::A(event))))
}
Ok(Async::Ready(None)) => Ok(Async::Ready(None)),
Err(error) => {
error!("sse::keep error: {}", error);
Err(SseError)
}
}
}
}
mod sealed {
use super::*;
#[derive(Debug)]
pub struct SseError;
impl Display for SseError {
fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
write!(f, "sse error")
}
}
impl StdError for SseError {
fn description(&self) -> &str {
"sse error"
}
}
impl Display for SseField {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
use self::SseField::*;
f.write_str(match self {
Event => "event:",
Id => "id:",
Data => "data:",
Retry => "retry:",
Comment => ":",
})
}
}
#[allow(missing_debug_implementations)]
pub enum SseField {
Event,
Id,
Data,
Retry,
Comment,
}
pub trait SseFormat {
fn fmt_field(&self, _f: &mut Formatter, _key: &SseField) -> fmt::Result {
Ok(())
}
}
#[allow(missing_debug_implementations)]
pub struct SseWrapper<'a, T: 'a>(&'a T);
impl<'a, T> SseWrapper<'a, T>
where
T: SseFormat + 'a,
{
pub fn format(event: &'a T) -> Result<String, SseError> {
let mut buf = String::new();
buf.write_fmt(format_args!("{}", SseWrapper(event)))
.map_err(|_| SseError)?;
buf.shrink_to_fit();
Ok(buf)
}
}
impl<'a, T> Display for SseWrapper<'a, T>
where
T: SseFormat,
{
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
self.0.fmt_field(f, &SseField::Comment)?;
self.0.fmt_field(f, &SseField::Event)?;
self.0.fmt_field(f, &SseField::Data)?;
self.0.fmt_field(f, &SseField::Id)?;
self.0.fmt_field(f, &SseField::Retry)?;
f.write_char('\n')
}
}
#[allow(missing_debug_implementations)]
pub struct BoxedServerSentEvent(pub(super) Box<SseFormat + Send>);
impl SseFormat for BoxedServerSentEvent {
fn fmt_field(&self, f: &mut Formatter, k: &SseField) -> fmt::Result {
self.0.fmt_field(f, k)
}
}
#[allow(missing_debug_implementations)]
pub enum EitherServerSentEvent<A, B> {
A(A),
B(B),
}
impl<A, B> SseFormat for EitherServerSentEvent<A, B>
where
A: SseFormat,
B: SseFormat,
{
fn fmt_field(&self, f: &mut Formatter, k: &SseField) -> fmt::Result {
match self {
EitherServerSentEvent::A(a) => a.fmt_field(f, k),
EitherServerSentEvent::B(b) => b.fmt_field(f, k),
}
}
}
}