1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
use std::marker::PhantomData; use std::net::SocketAddr; use std::{io, mem}; use bytes::Bytes; use futures::sink::Send; use futures::{Async, Future, Poll, Sink, Stream}; use http::{header, Method, Request}; use http_codec::client::HttpCodec; use serde::de::DeserializeOwned; use serde::ser::Serialize; use serde_json; use tokio_codec::{Decoder, Framed}; use tokio_tcp::{ConnectFuture, TcpStream}; use error::Error; pub struct Connection { stream: Framed<TcpStream, HttpCodec>, auth: &'static str, path: &'static str, } impl Connection { pub fn connect(addr: &SocketAddr, auth: &'static str, path: &'static str) -> ConnectionFuture { ConnectionFuture { inner: TcpStream::connect(addr), auth, path, } } pub fn query<T, U>(self, query: &str, bind_vars: &T) -> Result<ResponseFuture<U>, Error> where T: Serialize, U: DeserializeOwned, { let body = serde_json::to_vec(&Query { query, bind_vars })?; let req = Request::builder() .uri(self.path) .method(Method::POST) .header(header::AUTHORIZATION, self.auth) .header(header::CONTENT_TYPE, "application/json") .header(header::CONTENT_LENGTH, body.len()) .body(Bytes::from(body)) .unwrap(); let Connection { stream, auth, path } = self; let fut = ResponseFuture { state: State::Sending(stream.send(req)), auth, path, phantom: PhantomData, }; Ok(fut) } } pub struct ConnectionFuture { inner: ConnectFuture, auth: &'static str, path: &'static str, } impl Future for ConnectionFuture { type Item = Connection; type Error = Error; fn poll(&mut self) -> Poll<Connection, Error> { let stream = try_ready!(self.inner.poll()); Ok(Async::Ready(Connection { stream: HttpCodec::new().framed(stream), auth: self.auth, path: self.path, })) } } #[derive(Debug, Serialize)] #[serde(rename_all = "camelCase")] struct Query<'a, T: 'a + Serialize> { query: &'a str, bind_vars: &'a T, } enum State { Sending(Send<Framed<TcpStream, HttpCodec>>), Receiving(Framed<TcpStream, HttpCodec>), Complete, } pub struct ResponseFuture<T: DeserializeOwned> { state: State, auth: &'static str, path: &'static str, phantom: PhantomData<T>, } impl<T: DeserializeOwned> Future for ResponseFuture<T> { type Item = (Connection, T); type Error = Error; fn poll(&mut self) -> Poll<(Connection, T), Error> { use self::State::*; loop { match mem::replace(&mut self.state, Complete) { Sending(mut fut) => { let stream = try_ready!(fut.poll()); self.state = Receiving(stream); } Receiving(mut stream) => match stream.poll()? { Async::Ready(Some(res)) => { let status = res.status(); if !status.is_success() { return Err(Error::StatusCode(status.as_u16())); } let r = serde_json::from_slice(&res.into_body())?; let conn = Connection { stream, auth: self.auth, path: self.path, }; return Ok(Async::Ready((conn, r))); } Async::Ready(None) => { return Err(Error::IoError(io::Error::new( io::ErrorKind::Other, "arangodb server closed connection before responding", ))); } Async::NotReady => { self.state = Receiving(stream); return Ok(Async::NotReady); } }, Complete => unreachable!(), } } } }