use anyhow::anyhow;
use anyhow::bail;
use bytes::Bytes;
use chrono::{DateTime, Utc};
use data_encoding::HEXLOWER;
use ring::hmac;
use serde::{Deserialize, Serialize};
use serde_json;
use serde_json::{json, Value};
use std::fmt;
use uuid::Uuid;
mod time;
pub mod content;
pub use content::*;
pub struct Connection<S> {
pub socket: S,
pub mac: Option<hmac::Key>,
}
impl<S: zeromq::Socket> Connection<S> {
pub fn new(socket: S, key: &str) -> Self {
let mac = if key.is_empty() {
None
} else {
Some(hmac::Key::new(hmac::HMAC_SHA256, key.as_bytes()))
};
Connection { socket, mac }
}
}
impl<S: zeromq::SocketSend> Connection<S> {
pub async fn send(&mut self, message: JupyterMessage) -> Result<(), anyhow::Error> {
message.send(self).await?;
Ok(())
}
}
impl<S: zeromq::SocketRecv> Connection<S> {
pub async fn read(&mut self) -> Result<JupyterMessage, anyhow::Error> {
JupyterMessage::read(self).await
}
}
impl<S: zeromq::SocketSend + zeromq::SocketRecv> Connection<S> {
pub async fn single_heartbeat(&mut self) -> Result<(), anyhow::Error> {
self.socket.recv().await?;
self.socket
.send(zeromq::ZmqMessage::from(b"ping".to_vec()))
.await?;
Ok(())
}
}
#[derive(Debug)]
struct RawMessage {
zmq_identities: Vec<Bytes>,
jparts: Vec<Bytes>,
}
impl RawMessage {
pub(crate) async fn read<S: zeromq::SocketRecv>(
connection: &mut Connection<S>,
) -> Result<RawMessage, anyhow::Error> {
Self::from_multipart(connection.socket.recv().await?, connection)
}
pub(crate) fn from_multipart<S>(
multipart: zeromq::ZmqMessage,
connection: &Connection<S>,
) -> Result<RawMessage, anyhow::Error> {
let delimiter_index = multipart
.iter()
.position(|part| &part[..] == DELIMITER)
.ok_or_else(|| anyhow!("Missing delimiter"))?;
let mut parts = multipart.into_vec();
let jparts: Vec<_> = parts.drain(delimiter_index + 2..).collect();
let expected_hmac = parts.pop().ok_or_else(|| anyhow!("Missing hmac"))?;
parts.pop();
let zmq_identities = parts;
let raw_message = RawMessage {
zmq_identities,
jparts,
};
if let Some(key) = &connection.mac {
let sig = HEXLOWER.decode(&expected_hmac)?;
let mut msg = Vec::new();
for part in &raw_message.jparts {
msg.extend(part);
}
if let Err(err) = hmac::verify(key, msg.as_ref(), sig.as_ref()) {
bail!("{}", err);
}
}
Ok(raw_message)
}
async fn send<S: zeromq::SocketSend>(
self,
connection: &mut Connection<S>,
) -> Result<(), anyhow::Error> {
let hmac = if let Some(key) = &connection.mac {
let ctx = self.digest(key);
let tag = ctx.sign();
HEXLOWER.encode(tag.as_ref())
} else {
String::new()
};
let mut parts: Vec<bytes::Bytes> = Vec::new();
for part in &self.zmq_identities {
parts.push(part.to_vec().into());
}
parts.push(DELIMITER.into());
parts.push(hmac.as_bytes().to_vec().into());
for part in &self.jparts {
parts.push(part.to_vec().into());
}
let message = zeromq::ZmqMessage::try_from(parts).map_err(|err| anyhow::anyhow!(err))?;
connection.socket.send(message).await?;
Ok(())
}
fn digest(&self, mac: &hmac::Key) -> hmac::Context {
let mut hmac_ctx = hmac::Context::with_key(mac);
for part in &self.jparts {
hmac_ctx.update(part);
}
hmac_ctx
}
}
#[derive(Serialize, Clone)]
pub struct JupyterMessage {
#[serde(skip_serializing)]
zmq_identities: Vec<Bytes>,
pub header: Header,
pub parent_header: Option<Header>,
pub metadata: Value,
pub content: JupyterMessageContent,
#[serde(skip_serializing)]
pub buffers: Vec<Bytes>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Header {
pub msg_id: String,
pub username: String,
pub session: String,
pub date: DateTime<Utc>,
pub msg_type: String,
pub version: String,
}
const DELIMITER: &[u8] = b"<IDS|MSG>";
impl JupyterMessage {
pub(crate) async fn read<S: zeromq::SocketRecv>(
connection: &mut Connection<S>,
) -> Result<JupyterMessage, anyhow::Error> {
Self::from_raw_message(RawMessage::read(connection).await?)
}
fn from_raw_message(raw_message: RawMessage) -> Result<JupyterMessage, anyhow::Error> {
if raw_message.jparts.len() < 4 {
return Err(anyhow!(
"Insufficient message parts {}",
raw_message.jparts.len()
));
}
let header: Header = serde_json::from_slice(&raw_message.jparts[0])?;
let content: Value = serde_json::from_slice(&raw_message.jparts[3])?;
let content = JupyterMessageContent::from_type_and_content(&header.msg_type, content);
let content = match content {
Ok(content) => content,
Err(err) => {
return Err(anyhow!(
"Error deserializing content for msg_type `{}`: {}",
&header.msg_type,
err
));
}
};
let parent_header = serde_json::from_slice(&raw_message.jparts[1]).ok();
let message = JupyterMessage {
zmq_identities: raw_message.zmq_identities,
header,
parent_header,
metadata: serde_json::from_slice(&raw_message.jparts[2])?,
content,
buffers: if raw_message.jparts.len() > 4 {
raw_message.jparts[4..].to_vec()
} else {
vec![]
},
};
Ok(message)
}
pub fn message_type(&self) -> &str {
self.content.message_type()
}
pub fn new(content: JupyterMessageContent, parent: Option<&JupyterMessage>) -> JupyterMessage {
let header = Header {
msg_id: Uuid::new_v4().to_string(),
username: "runtimelib".to_string(),
session: Uuid::new_v4().to_string(),
date: time::utc_now(),
msg_type: content.message_type().to_owned(),
version: "5.3".to_string(),
};
JupyterMessage {
zmq_identities: parent.map_or(Vec::new(), |parent| parent.zmq_identities.clone()),
header,
parent_header: parent.map(|parent| parent.header.clone()),
metadata: json!({}),
content,
buffers: Vec::new(),
}
}
pub fn set_parent(&mut self, parent: JupyterMessage) {
self.parent_header = Some(parent.header.clone());
}
pub async fn send<S: zeromq::SocketSend>(
&self,
connection: &mut Connection<S>,
) -> Result<(), anyhow::Error> {
let mut jparts: Vec<Bytes> = vec![
serde_json::to_vec(&self.header)?.into(),
serde_json::to_vec(&self.parent_header)?.into(),
serde_json::to_vec(&self.metadata)?.into(),
serde_json::to_vec(&self.content)?.into(),
];
jparts.extend_from_slice(&self.buffers);
let raw_message = RawMessage {
zmq_identities: self.zmq_identities.clone(),
jparts,
};
raw_message.send(connection).await
}
}
impl fmt::Debug for JupyterMessage {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(
f,
"\nHeader: {}",
serde_json::to_string_pretty(&self.header).unwrap()
)?;
writeln!(
f,
"Parent header: {}",
serde_json::to_string_pretty(&self.parent_header).unwrap()
)?;
writeln!(
f,
"Metadata: {}",
serde_json::to_string_pretty(&self.metadata).unwrap()
)?;
writeln!(
f,
"Content: {}\n",
serde_json::to_string_pretty(&self.content).unwrap()
)?;
Ok(())
}
}