use std::{
io,
pin::Pin,
task::{ready, Context, Poll},
};
use anyhow::{bail, Result};
use bytes::Bytes;
use iroh_base::{NodeId, SecretKey};
use n0_future::{time::Duration, Sink, Stream};
#[cfg(not(wasm_browser))]
use tokio_util::codec::Framed;
use tracing::debug;
use super::KeyCache;
use crate::protos::relay::{ClientInfo, Frame, MAX_PACKET_SIZE, PROTOCOL_VERSION};
#[cfg(not(wasm_browser))]
use crate::{
client::streams::{MaybeTlsStream, MaybeTlsStreamChained, ProxyStream},
protos::relay::RelayCodec,
};
#[derive(Debug, thiserror::Error)]
pub enum ConnSendError {
#[error("IO error")]
Io(#[from] io::Error),
#[error("Protocol error")]
Protocol(&'static str),
}
#[cfg(wasm_browser)]
impl From<ws_stream_wasm::WsErr> for ConnSendError {
fn from(err: ws_stream_wasm::WsErr) -> Self {
use std::io::ErrorKind::*;
use ws_stream_wasm::WsErr::*;
let kind = match err {
ConnectionNotOpen => NotConnected,
ReasonStringToLong | InvalidCloseCode { .. } | InvalidUrl { .. } => InvalidInput,
UnknownDataType | InvalidEncoding => InvalidData,
ConnectionFailed { .. } => ConnectionReset,
_ => Other,
};
Self::Io(std::io::Error::new(kind, err.to_string()))
}
}
#[cfg(not(wasm_browser))]
impl From<tokio_websockets::Error> for ConnSendError {
fn from(err: tokio_websockets::Error) -> Self {
let io_err = match err {
tokio_websockets::Error::Io(io_err) => io_err,
_ => std::io::Error::new(std::io::ErrorKind::Other, err.to_string()),
};
Self::Io(io_err)
}
}
#[derive(derive_more::Debug)]
pub(crate) enum Conn {
#[cfg(not(wasm_browser))]
Relay {
#[debug("Framed<MaybeTlsStreamChained, RelayCodec>")]
conn: Framed<MaybeTlsStreamChained, RelayCodec>,
},
#[cfg(not(wasm_browser))]
Ws {
#[debug("WebSocketStream<MaybeTlsStream<ProxyStream>>")]
conn: tokio_websockets::WebSocketStream<MaybeTlsStream<ProxyStream>>,
key_cache: KeyCache,
},
#[cfg(wasm_browser)]
WsBrowser {
#[debug("WebSocketStream")]
conn: ws_stream_wasm::WsStream,
key_cache: KeyCache,
},
}
impl Conn {
#[cfg(wasm_browser)]
pub(crate) async fn new_ws_browser(
conn: ws_stream_wasm::WsStream,
key_cache: KeyCache,
secret_key: &SecretKey,
) -> Result<Self> {
let mut conn = Self::WsBrowser { conn, key_cache };
server_handshake(&mut conn, secret_key).await?;
Ok(conn)
}
#[cfg(not(wasm_browser))]
pub(crate) async fn new_relay(
conn: MaybeTlsStreamChained,
key_cache: KeyCache,
secret_key: &SecretKey,
) -> Result<Self> {
let conn = Framed::new(conn, RelayCodec::new(key_cache));
let mut conn = Self::Relay { conn };
server_handshake(&mut conn, secret_key).await?;
Ok(conn)
}
#[cfg(not(wasm_browser))]
pub(crate) async fn new_ws(
conn: tokio_websockets::WebSocketStream<MaybeTlsStream<ProxyStream>>,
key_cache: KeyCache,
secret_key: &SecretKey,
) -> Result<Self> {
let mut conn = Self::Ws { conn, key_cache };
server_handshake(&mut conn, secret_key).await?;
Ok(conn)
}
}
async fn server_handshake(writer: &mut Conn, secret_key: &SecretKey) -> Result<()> {
debug!("server_handshake: started");
let client_info = ClientInfo {
version: PROTOCOL_VERSION,
};
debug!("server_handshake: sending client_key: {:?}", &client_info);
crate::protos::relay::send_client_key(&mut *writer, secret_key, &client_info).await?;
debug!("server_handshake: done");
Ok(())
}
impl Stream for Conn {
type Item = Result<ReceivedMessage>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match *self {
#[cfg(not(wasm_browser))]
Self::Relay { ref mut conn } => match ready!(Pin::new(conn).poll_next(cx)) {
Some(Ok(frame)) => {
let message = ReceivedMessage::try_from(frame);
Poll::Ready(Some(message))
}
Some(Err(err)) => Poll::Ready(Some(Err(err))),
None => Poll::Ready(None),
},
#[cfg(not(wasm_browser))]
Self::Ws {
ref mut conn,
ref key_cache,
} => match ready!(Pin::new(conn).poll_next(cx)) {
Some(Ok(msg)) => {
if msg.is_close() {
return Poll::Ready(None);
}
if !msg.is_binary() {
tracing::warn!(
?msg,
"Got websocket message of unsupported type, skipping."
);
return Poll::Pending;
}
let frame = Frame::decode_from_ws_msg(msg.into_payload().into(), key_cache)?;
Poll::Ready(Some(ReceivedMessage::try_from(frame)))
}
Some(Err(e)) => Poll::Ready(Some(Err(e.into()))),
None => Poll::Ready(None),
},
#[cfg(wasm_browser)]
Self::WsBrowser {
ref mut conn,
ref key_cache,
} => match ready!(Pin::new(conn).poll_next(cx)) {
Some(ws_stream_wasm::WsMessage::Binary(vec)) => {
let frame = Frame::decode_from_ws_msg(Bytes::from(vec), key_cache)?;
Poll::Ready(Some(ReceivedMessage::try_from(frame)))
}
Some(msg) => {
tracing::warn!(?msg, "Got websocket message of unsupported type, skipping.");
Poll::Pending
}
None => Poll::Ready(None),
},
}
}
}
impl Sink<Frame> for Conn {
type Error = ConnSendError;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match *self {
#[cfg(not(wasm_browser))]
Self::Relay { ref mut conn } => Pin::new(conn).poll_ready(cx).map_err(Into::into),
#[cfg(not(wasm_browser))]
Self::Ws { ref mut conn, .. } => Pin::new(conn).poll_ready(cx).map_err(Into::into),
#[cfg(wasm_browser)]
Self::WsBrowser { ref mut conn, .. } => {
Pin::new(conn).poll_ready(cx).map_err(Into::into)
}
}
}
fn start_send(mut self: Pin<&mut Self>, frame: Frame) -> Result<(), Self::Error> {
if let Frame::SendPacket { dst_key: _, packet } = &frame {
if packet.len() > MAX_PACKET_SIZE {
return Err(ConnSendError::Protocol("Packet exceeds MAX_PACKET_SIZE"));
}
}
match *self {
#[cfg(not(wasm_browser))]
Self::Relay { ref mut conn } => Pin::new(conn).start_send(frame).map_err(Into::into),
#[cfg(not(wasm_browser))]
Self::Ws { ref mut conn, .. } => Pin::new(conn)
.start_send(tokio_websockets::Message::binary(
tokio_websockets::Payload::from(frame.encode_for_ws_msg()),
))
.map_err(Into::into),
#[cfg(wasm_browser)]
Self::WsBrowser { ref mut conn, .. } => Pin::new(conn)
.start_send(ws_stream_wasm::WsMessage::Binary(frame.encode_for_ws_msg()))
.map_err(Into::into),
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match *self {
#[cfg(not(wasm_browser))]
Self::Relay { ref mut conn } => Pin::new(conn).poll_flush(cx).map_err(Into::into),
#[cfg(not(wasm_browser))]
Self::Ws { ref mut conn, .. } => Pin::new(conn).poll_flush(cx).map_err(Into::into),
#[cfg(wasm_browser)]
Self::WsBrowser { ref mut conn, .. } => {
Pin::new(conn).poll_flush(cx).map_err(Into::into)
}
}
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match *self {
#[cfg(not(wasm_browser))]
Self::Relay { ref mut conn } => Pin::new(conn).poll_close(cx).map_err(Into::into),
#[cfg(not(wasm_browser))]
Self::Ws { ref mut conn, .. } => Pin::new(conn).poll_flush(cx).map_err(Into::into),
#[cfg(wasm_browser)]
Self::WsBrowser { ref mut conn, .. } => {
Pin::new(conn).poll_close(cx).map_err(Into::into)
}
}
}
}
impl Sink<SendMessage> for Conn {
type Error = ConnSendError;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match *self {
#[cfg(not(wasm_browser))]
Self::Relay { ref mut conn } => Pin::new(conn).poll_ready(cx).map_err(Into::into),
#[cfg(not(wasm_browser))]
Self::Ws { ref mut conn, .. } => Pin::new(conn).poll_ready(cx).map_err(Into::into),
#[cfg(wasm_browser)]
Self::WsBrowser { ref mut conn, .. } => {
Pin::new(conn).poll_ready(cx).map_err(Into::into)
}
}
}
fn start_send(mut self: Pin<&mut Self>, item: SendMessage) -> Result<(), Self::Error> {
if let SendMessage::SendPacket(_, bytes) = &item {
if bytes.len() > MAX_PACKET_SIZE {
return Err(ConnSendError::Protocol("Packet exceeds MAX_PACKET_SIZE"));
}
}
let frame = Frame::from(item);
match *self {
#[cfg(not(wasm_browser))]
Self::Relay { ref mut conn } => Pin::new(conn).start_send(frame).map_err(Into::into),
#[cfg(not(wasm_browser))]
Self::Ws { ref mut conn, .. } => Pin::new(conn)
.start_send(tokio_websockets::Message::binary(
tokio_websockets::Payload::from(frame.encode_for_ws_msg()),
))
.map_err(Into::into),
#[cfg(wasm_browser)]
Self::WsBrowser { ref mut conn, .. } => Pin::new(conn)
.start_send(ws_stream_wasm::WsMessage::Binary(frame.encode_for_ws_msg()))
.map_err(Into::into),
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match *self {
#[cfg(not(wasm_browser))]
Self::Relay { ref mut conn } => Pin::new(conn).poll_flush(cx).map_err(Into::into),
#[cfg(not(wasm_browser))]
Self::Ws { ref mut conn, .. } => Pin::new(conn).poll_flush(cx).map_err(Into::into),
#[cfg(wasm_browser)]
Self::WsBrowser { ref mut conn, .. } => {
Pin::new(conn).poll_flush(cx).map_err(Into::into)
}
}
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match *self {
#[cfg(not(wasm_browser))]
Self::Relay { ref mut conn } => Pin::new(conn).poll_close(cx).map_err(Into::into),
#[cfg(not(wasm_browser))]
Self::Ws { ref mut conn, .. } => Pin::new(conn).poll_close(cx).map_err(Into::into),
#[cfg(wasm_browser)]
Self::WsBrowser { ref mut conn, .. } => {
Pin::new(conn).poll_close(cx).map_err(Into::into)
}
}
}
}
#[derive(derive_more::Debug, Clone)]
pub enum ReceivedMessage {
ReceivedPacket {
remote_node_id: NodeId,
#[debug(skip)]
data: Bytes, },
NodeGone(NodeId),
Ping([u8; 8]),
Pong([u8; 8]),
KeepAlive,
Health {
problem: Option<String>,
},
ServerRestarting {
reconnect_in: Duration,
try_for: Duration,
},
}
impl TryFrom<Frame> for ReceivedMessage {
type Error = anyhow::Error;
fn try_from(frame: Frame) -> std::result::Result<Self, Self::Error> {
match frame {
Frame::KeepAlive => {
Ok(ReceivedMessage::KeepAlive)
}
Frame::NodeGone { node_id } => Ok(ReceivedMessage::NodeGone(node_id)),
Frame::RecvPacket { src_key, content } => {
let packet = ReceivedMessage::ReceivedPacket {
remote_node_id: src_key,
data: content,
};
Ok(packet)
}
Frame::Ping { data } => Ok(ReceivedMessage::Ping(data)),
Frame::Pong { data } => Ok(ReceivedMessage::Pong(data)),
Frame::Health { problem } => {
let problem = std::str::from_utf8(&problem)?.to_owned();
let problem = Some(problem);
Ok(ReceivedMessage::Health { problem })
}
Frame::Restarting {
reconnect_in,
try_for,
} => {
let reconnect_in = Duration::from_millis(reconnect_in as u64);
let try_for = Duration::from_millis(try_for as u64);
Ok(ReceivedMessage::ServerRestarting {
reconnect_in,
try_for,
})
}
_ => bail!("unexpected packet: {:?}", frame.typ()),
}
}
}
#[derive(Debug)]
pub enum SendMessage {
SendPacket(NodeId, Bytes),
Ping([u8; 8]),
Pong([u8; 8]),
}
impl From<SendMessage> for Frame {
fn from(source: SendMessage) -> Self {
match source {
SendMessage::SendPacket(dst_key, packet) => Frame::SendPacket { dst_key, packet },
SendMessage::Ping(data) => Frame::Ping { data },
SendMessage::Pong(data) => Frame::Pong { data },
}
}
}