use std::{
collections::{BTreeSet, HashMap},
io::Write,
sync::{Arc, Mutex},
};
use thiserror::Error;
use tokio::sync::OwnedMutexGuard;
use super::{
session::{Builder, Session, SessionInfo},
types::{Frame, Message, SessionID},
};
use crate::common::time::insecure_posix_time;
#[derive(Error, Debug)]
pub enum Error {
#[error("malformed payload: {0}")]
MalformedPayload(#[from] cbor::DecodeError),
#[error("malformed request method")]
MalformedRequestMethod,
#[error("max concurrent sessions reached")]
MaxConcurrentSessions,
#[error("{0}")]
Other(#[from] anyhow::Error),
}
impl Error {
fn code(&self) -> u32 {
match self {
Error::MalformedPayload(_) => 1,
Error::MalformedRequestMethod => 2,
Error::MaxConcurrentSessions => 3,
Error::Other(_) => 4,
}
}
}
impl From<Error> for crate::types::Error {
fn from(e: Error) -> Self {
Self {
module: "demux".to_string(),
code: e.code(),
message: e.to_string(),
}
}
}
type PeerID = Vec<u8>;
type SharedSession = Arc<tokio::sync::Mutex<MultiplexedSession>>;
type SessionByTimeKey = (i64, PeerID, SessionID);
struct SessionMeta {
peer_id: PeerID,
session_id: SessionID,
last_access_time: i64,
inner: SharedSession,
}
impl SessionMeta {
fn by_time_key(&self) -> SessionByTimeKey {
(self.last_access_time, self.peer_id.clone(), self.session_id)
}
}
struct Sessions {
builder: Builder,
max_sessions: usize,
max_sessions_per_peer: usize,
stale_session_timeout: i64,
by_peer: HashMap<PeerID, HashMap<SessionID, SessionMeta>>,
by_idle_time: BTreeSet<SessionByTimeKey>,
}
impl Sessions {
fn new(
builder: Builder,
max_sessions: usize,
max_sessions_per_peer: usize,
stale_session_timeout: i64,
) -> Self {
Self {
builder,
max_sessions,
max_sessions_per_peer,
stale_session_timeout,
by_peer: HashMap::new(),
by_idle_time: BTreeSet::new(),
}
}
fn create_session(
mut builder: Builder,
peer_id: PeerID,
session_id: SessionID,
now: i64,
) -> SessionMeta {
if builder.get_quote_policy().is_none() {
let policy = builder
.get_local_identity()
.as_ref()
.and_then(|id| id.quote_policy());
builder = builder.quote_policy(policy);
}
SessionMeta {
inner: Arc::new(tokio::sync::Mutex::new(MultiplexedSession {
peer_id: peer_id.clone(),
session_id,
inner: builder.build_responder(),
})),
peer_id,
session_id,
last_access_time: now,
}
}
fn get_or_create(
&mut self,
peer_id: PeerID,
session_id: SessionID,
) -> Result<(SharedSession, bool), Error> {
let now = insecure_posix_time();
if let Some(sessions) = self.by_peer.get_mut(&peer_id) {
if let Some(session) = sessions.get_mut(&session_id) {
self.by_idle_time.remove(&session.by_time_key());
session.last_access_time = now;
self.by_idle_time.insert(session.by_time_key());
return Ok((session.inner.clone(), false));
}
if sessions.len() >= self.max_sessions_per_peer
|| self.by_idle_time.len() >= self.max_sessions
{
let inner = sessions
.iter()
.min_by_key(|(_, s)| {
if let Ok(_inner) = s.inner.try_lock() {
s.last_access_time
} else {
i64::MAX }
})
.map(|(_, s)| s.inner.clone())
.ok_or(Error::MaxConcurrentSessions)?;
if let Ok(inner) = inner.try_lock_owned() {
self.remove(&inner);
} else {
return Err(Error::MaxConcurrentSessions);
}
}
}
if self.by_idle_time.len() >= self.max_sessions {
let mut remove_session: Option<OwnedMutexGuard<MultiplexedSession>> = None;
for (last_process_frame_time, peer_id, session_id) in self.by_idle_time.iter() {
if now.saturating_sub(*last_process_frame_time) < self.stale_session_timeout {
return Err(Error::MaxConcurrentSessions);
}
if let Some(sessions) = self.by_peer.get(peer_id) {
if let Some(session) = sessions.get(session_id) {
if let Ok(session) = session.inner.clone().try_lock_owned() {
remove_session = Some(session);
break;
}
}
}
}
if let Some(session) = remove_session {
self.remove(&session);
} else {
return Err(Error::MaxConcurrentSessions);
}
}
let sessions = self.by_peer.entry(peer_id.clone()).or_default();
let session = Self::create_session(self.builder.clone(), peer_id.clone(), session_id, now);
let inner = session.inner.clone();
sessions.insert(session_id, session);
self.by_idle_time.insert((now, peer_id, session_id));
Ok((inner, true))
}
fn remove(&mut self, session: &OwnedMutexGuard<MultiplexedSession>) {
let sessions = self.by_peer.get_mut(&session.peer_id).unwrap();
let session_meta = sessions.get(&session.session_id).unwrap();
let key = session_meta.by_time_key();
sessions.remove(&session.session_id);
self.by_idle_time.remove(&key);
if sessions.is_empty() {
self.by_peer.remove(&session.peer_id);
}
}
fn clear(&mut self) {
self.by_peer.clear();
self.by_idle_time.clear();
}
#[cfg(test)]
fn session_count(&self) -> usize {
self.by_idle_time.len()
}
#[cfg(test)]
fn peer_count(&self) -> usize {
self.by_peer.len()
}
}
pub struct Demux {
sessions: Mutex<Sessions>,
}
pub struct MultiplexedSession {
peer_id: PeerID,
session_id: SessionID,
inner: Session,
}
impl MultiplexedSession {
pub fn info(&self) -> Option<Arc<SessionInfo>> {
self.inner.session_info()
}
async fn process_data<W: Write>(
&mut self,
data: Vec<u8>,
writer: W,
) -> Result<Option<Message>, Error> {
Ok(self.inner.process_data(data, writer).await?)
}
pub fn write_message<W: Write>(&mut self, msg: Message, mut writer: W) -> Result<(), Error> {
Ok(self.inner.write_message(msg, &mut writer)?)
}
}
impl Demux {
pub fn new(
builder: Builder,
max_sessions: usize,
max_sessions_per_peer: usize,
stale_session_timeout: i64,
) -> Self {
Self {
sessions: Mutex::new(Sessions::new(
builder,
max_sessions,
max_sessions_per_peer,
stale_session_timeout,
)),
}
}
async fn get_or_create_session(
&self,
peer_id: PeerID,
session_id: SessionID,
) -> Result<OwnedMutexGuard<MultiplexedSession>, Error> {
let (session, _) = {
let mut sessions = self.sessions.lock().unwrap();
sessions.get_or_create(peer_id, session_id)?
};
Ok(session.lock_owned().await)
}
pub async fn process_frame<W: Write>(
&self,
peer_id: PeerID,
data: Vec<u8>,
writer: W,
) -> Result<(OwnedMutexGuard<MultiplexedSession>, Option<Message>), Error> {
let frame: Frame = cbor::from_slice(&data)?;
let mut session = self.get_or_create_session(peer_id, frame.session).await?;
match session.process_data(frame.payload, writer).await {
Ok(msg) => {
if let Some(Message::Request(ref req)) = msg {
if frame.untrusted_plaintext != req.method {
return Err(Error::MalformedRequestMethod);
}
}
Ok((session, msg))
}
Err(err) => {
if session.inner.is_closed() {
let mut sessions = self.sessions.lock().unwrap();
sessions.remove(&session);
}
Err(err)
}
}
}
pub fn close<W: Write>(
&self,
mut session: OwnedMutexGuard<MultiplexedSession>,
writer: W,
) -> Result<(), Error> {
let mut sessions = self.sessions.lock().unwrap();
sessions.remove(&session);
session.write_message(Message::Close, writer)?;
Ok(())
}
pub fn reset(&self) {
let mut sessions = self.sessions.lock().unwrap();
sessions.clear();
}
}
#[cfg(test)]
mod test {
use crate::enclave_rpc::{session::Builder, types::SessionID};
use super::{Error, Sessions};
fn ids() -> (Vec<Vec<u8>>, Vec<SessionID>) {
let peer_ids: Vec<Vec<u8>> = (1..16).map(|x| vec![x]).collect();
let session_ids: Vec<SessionID> = (1..16).map(|_| SessionID::random()).collect();
(peer_ids, session_ids)
}
#[test]
fn test_namespacing() {
let (peer_ids, session_ids) = ids();
let mut sessions = Sessions::new(Builder::default(), 16, 4, 60);
let (s1, created) = sessions
.get_or_create(peer_ids[0].clone(), session_ids[0])
.expect("get_or_create should succeed");
assert!(created, "new session should be created");
let s1_owned = s1.try_lock().unwrap();
assert_eq!(&s1_owned.peer_id, &peer_ids[0]);
assert_eq!(&s1_owned.session_id, &session_ids[0]);
drop(s1_owned);
assert_eq!(sessions.session_count(), 1);
assert_eq!(sessions.peer_count(), 1);
let (_, created) = sessions
.get_or_create(peer_ids[0].clone(), session_ids[1])
.expect("get_or_create should succeed");
assert!(created, "new session should be created");
let (_, created) = sessions
.get_or_create(peer_ids[0].clone(), session_ids[2])
.expect("get_or_create should succeed");
assert!(created, "new session should be created");
let (_, created) = sessions
.get_or_create(peer_ids[0].clone(), session_ids[3])
.expect("get_or_create should succeed");
assert!(created, "new session should be created");
assert_eq!(sessions.session_count(), 4);
assert_eq!(sessions.peer_count(), 1);
let (s1r, created) = sessions
.get_or_create(peer_ids[0].clone(), session_ids[0])
.expect("get_or_create should succeed");
assert!(!created, "session should be reused");
let s1r_owned = s1r.try_lock().unwrap();
assert_eq!(&s1r_owned.peer_id, &peer_ids[0]);
assert_eq!(&s1r_owned.session_id, &session_ids[0]);
drop(s1r_owned);
assert_eq!(sessions.session_count(), 4);
assert_eq!(sessions.peer_count(), 1);
let (s5, created) = sessions
.get_or_create(peer_ids[1].clone(), session_ids[0])
.expect("get_or_create should succeed");
assert!(created, "new session should be created due to namespacing");
let s5_owned = s5.try_lock().unwrap();
assert_eq!(&s5_owned.peer_id, &peer_ids[1]);
assert_eq!(&s5_owned.session_id, &session_ids[0]);
drop(s5_owned);
assert_eq!(sessions.session_count(), 5);
assert_eq!(sessions.peer_count(), 2);
}
#[test]
fn test_max_sessions_per_peer() {
let (peer_ids, session_ids) = ids();
let mut sessions = Sessions::new(Builder::default(), 16, 4, 60); let (_, created) = sessions
.get_or_create(peer_ids[0].clone(), session_ids[0])
.expect("get_or_create should succeed");
assert!(created, "new session should be created");
std::thread::sleep(std::time::Duration::from_millis(1100));
let (_, created) = sessions
.get_or_create(peer_ids[0].clone(), session_ids[1])
.expect("get_or_create should succeed");
assert!(created, "new session should be created");
let (_, created) = sessions
.get_or_create(peer_ids[0].clone(), session_ids[2])
.expect("get_or_create should succeed");
assert!(created, "new session should be created");
let (_, created) = sessions
.get_or_create(peer_ids[0].clone(), session_ids[3])
.expect("get_or_create should succeed");
assert!(created, "new session should be created");
assert_eq!(sessions.session_count(), 4);
assert_eq!(sessions.peer_count(), 1);
let (_, created) = sessions
.get_or_create(peer_ids[0].clone(), session_ids[4])
.expect("get_or_create should succeed");
assert!(created, "new session should be created");
assert_eq!(sessions.session_count(), 4);
assert_eq!(sessions.peer_count(), 1);
let (_, created) = sessions
.get_or_create(peer_ids[0].clone(), session_ids[1])
.expect("get_or_create should succeed");
assert!(!created, "session should be reused");
let (_, created) = sessions
.get_or_create(peer_ids[0].clone(), session_ids[2])
.expect("get_or_create should succeed");
assert!(!created, "session should be reused");
let (_, created) = sessions
.get_or_create(peer_ids[0].clone(), session_ids[3])
.expect("get_or_create should succeed");
assert!(!created, "session should be reused");
assert_eq!(sessions.session_count(), 4);
assert_eq!(sessions.peer_count(), 1);
let (_, created) = sessions
.get_or_create(peer_ids[0].clone(), session_ids[0])
.expect("get_or_create should succeed");
assert!(created, "new session should be created");
assert_eq!(sessions.session_count(), 4);
assert_eq!(sessions.peer_count(), 1);
}
#[test]
fn test_max_sessions() {
let (peer_ids, session_ids) = ids();
let mut sessions = Sessions::new(Builder::default(), 4, 4, 60);
let (_, created) = sessions
.get_or_create(peer_ids[0].clone(), session_ids[0])
.expect("get_or_create should succeed");
assert!(created, "new session should be created");
let (_, created) = sessions
.get_or_create(peer_ids[1].clone(), session_ids[1])
.expect("get_or_create should succeed");
assert!(created, "new session should be created");
let (_, created) = sessions
.get_or_create(peer_ids[2].clone(), session_ids[2])
.expect("get_or_create should succeed");
assert!(created, "new session should be created");
let (_, created) = sessions
.get_or_create(peer_ids[3].clone(), session_ids[3])
.expect("get_or_create should succeed");
assert!(created, "new session should be created");
assert_eq!(sessions.session_count(), 4);
assert_eq!(sessions.peer_count(), 4);
let res = sessions.get_or_create(peer_ids[4].clone(), session_ids[4]);
assert!(
matches!(res, Err(Error::MaxConcurrentSessions)),
"get_or_create should fail"
);
assert_eq!(sessions.session_count(), 4);
assert_eq!(sessions.peer_count(), 4);
let (_, created) = sessions
.get_or_create(peer_ids[0].clone(), session_ids[5])
.expect("get_or_create should succeed");
assert!(created, "new session should be created");
assert_eq!(sessions.session_count(), 4);
assert_eq!(sessions.peer_count(), 4);
}
#[test]
fn test_max_sessions_prune_stale() {
let (peer_ids, session_ids) = ids();
let mut sessions = Sessions::new(Builder::default(), 4, 4, 0); let (_, created) = sessions
.get_or_create(peer_ids[0].clone(), session_ids[0])
.expect("get_or_create should succeed");
assert!(created, "new session should be created");
let (_, created) = sessions
.get_or_create(peer_ids[1].clone(), session_ids[1])
.expect("get_or_create should succeed");
assert!(created, "new session should be created");
let (_, created) = sessions
.get_or_create(peer_ids[2].clone(), session_ids[2])
.expect("get_or_create should succeed");
assert!(created, "new session should be created");
let (_, created) = sessions
.get_or_create(peer_ids[3].clone(), session_ids[3])
.expect("get_or_create should succeed");
assert!(created, "new session should be created");
assert_eq!(sessions.session_count(), 4);
assert_eq!(sessions.peer_count(), 4);
let (_, created) = sessions
.get_or_create(peer_ids[4].clone(), session_ids[4])
.expect("get_or_create should succeed");
assert!(created, "new session should be created");
assert_eq!(sessions.session_count(), 4);
assert_eq!(sessions.peer_count(), 4);
}
#[test]
fn test_remove() {
let (peer_ids, session_ids) = ids();
let mut sessions = Sessions::new(Builder::default(), 16, 4, 0); let (s1, created) = sessions
.get_or_create(peer_ids[0].clone(), session_ids[0])
.expect("get_or_create should succeed");
assert!(created, "new session should be created");
let (s2, created) = sessions
.get_or_create(peer_ids[1].clone(), session_ids[1])
.expect("get_or_create should succeed");
assert!(created, "new session should be created");
let (_, created) = sessions
.get_or_create(peer_ids[1].clone(), session_ids[2])
.expect("get_or_create should succeed");
assert!(created, "new session should be created");
let (_, created) = sessions
.get_or_create(peer_ids[2].clone(), session_ids[3])
.expect("get_or_create should succeed");
assert!(created, "new session should be created");
assert_eq!(sessions.session_count(), 4);
assert_eq!(sessions.peer_count(), 3);
let s1r = s1.try_lock_owned().unwrap();
sessions.remove(&s1r);
assert_eq!(sessions.session_count(), 3);
assert_eq!(sessions.peer_count(), 2);
let s2r = s2.try_lock_owned().unwrap();
sessions.remove(&s2r);
assert_eq!(sessions.session_count(), 2);
assert_eq!(sessions.peer_count(), 2);
}
}