oasis_core_runtime/enclave_rpc/
demux.rs

1//! Session demultiplexer.
2use std::{io::Write, sync::Mutex};
3
4use thiserror::Error;
5use tokio::sync::OwnedMutexGuard;
6
7use super::{
8    session::Builder,
9    sessions::{self, MultiplexedSession, Sessions},
10    types::{Frame, Message, SessionID},
11};
12use crate::common::time::insecure_posix_time;
13
14/// Demultiplexer error.
15#[derive(Error, Debug)]
16pub enum Error {
17    #[error("malformed payload: {0}")]
18    MalformedPayload(#[from] cbor::DecodeError),
19    #[error("malformed request method")]
20    MalformedRequestMethod,
21    #[error("sessions error: {0}")]
22    SessionsError(#[from] sessions::Error),
23    #[error("{0}")]
24    Other(#[from] anyhow::Error),
25}
26
27impl Error {
28    fn code(&self) -> u32 {
29        match self {
30            Error::MalformedPayload(_) => 1,
31            Error::MalformedRequestMethod => 2,
32            Error::SessionsError(_) => 3,
33            Error::Other(_) => 4,
34        }
35    }
36}
37
38impl From<Error> for crate::types::Error {
39    fn from(e: Error) -> Self {
40        Self {
41            module: "demux".to_string(),
42            code: e.code(),
43            message: e.to_string(),
44        }
45    }
46}
47
48/// Session demultiplexer.
49pub struct Demux {
50    sessions: Mutex<Sessions<Vec<u8>>>,
51}
52
53impl Demux {
54    /// Create new session demultiplexer.
55    pub fn new(
56        builder: Builder,
57        max_sessions: usize,
58        max_sessions_per_peer: usize,
59        stale_session_timeout: i64,
60    ) -> Self {
61        Self {
62            sessions: Mutex::new(Sessions::new(
63                builder,
64                max_sessions,
65                max_sessions_per_peer,
66                stale_session_timeout,
67            )),
68        }
69    }
70
71    /// Set the session builder to use.
72    pub fn set_session_builder(&self, builder: Builder) {
73        let mut sessions = self.sessions.lock().unwrap();
74        sessions.set_builder(builder);
75    }
76
77    async fn get_or_create_session(
78        &self,
79        peer_id: Vec<u8>,
80        session_id: SessionID,
81    ) -> Result<OwnedMutexGuard<MultiplexedSession<Vec<u8>>>, Error> {
82        let session = {
83            let mut sessions = self.sessions.lock().unwrap();
84            match sessions.get(&peer_id, &session_id) {
85                Some(session) => session,
86                None => {
87                    let now = insecure_posix_time();
88                    let _ = sessions.remove_for(&peer_id, now)?;
89                    let session = sessions.create_responder(peer_id, session_id);
90                    sessions
91                        .add(session, now)
92                        .expect("there should be space for the new session")
93                }
94            }
95        };
96
97        Ok(session.lock_owned().await)
98    }
99
100    /// Process a frame, returning the locked session guard and decoded message.
101    ///
102    /// Any data that needs to be transmitted back to the peer is written to the passed writer.
103    pub async fn process_frame<W: Write>(
104        &self,
105        peer_id: Vec<u8>,
106        data: Vec<u8>,
107        writer: W,
108    ) -> Result<
109        (
110            OwnedMutexGuard<MultiplexedSession<Vec<u8>>>,
111            Option<Message>,
112        ),
113        Error,
114    > {
115        // Decode frame.
116        let frame: Frame = cbor::from_slice(&data)?;
117        // Get the existing session or create a new one.
118        let mut session = self.get_or_create_session(peer_id, frame.session).await?;
119        // Process session data.
120        match session.process_data(&frame.payload, writer).await {
121            Ok(msg) => {
122                if let Some(Message::Request(ref req)) = msg {
123                    // Make sure that the untrusted_plaintext matches the request's method.
124                    if frame.untrusted_plaintext != req.method {
125                        return Err(Error::MalformedRequestMethod);
126                    }
127                }
128
129                Ok((session, msg))
130            }
131            Err(err) => {
132                // In case the session was closed, remove the session.
133                if session.is_closed() {
134                    let mut sessions = self.sessions.lock().unwrap();
135                    sessions.remove(&session);
136                }
137                Err(Error::Other(err))
138            }
139        }
140    }
141
142    /// Closes the given session.
143    ///
144    /// Any data that needs to be transmitted back to the peer is written to the passed writer.
145    pub fn close<W: Write>(
146        &self,
147        mut session: OwnedMutexGuard<MultiplexedSession<Vec<u8>>>,
148        writer: W,
149    ) -> Result<(), Error> {
150        let mut sessions = self.sessions.lock().unwrap();
151        sessions.remove(&session);
152
153        session.write_message(Message::Close, writer)?;
154        Ok(())
155    }
156
157    /// Resets all open sessions.
158    pub fn reset(&self) {
159        let mut sessions = self.sessions.lock().unwrap();
160        let _ = sessions.drain();
161    }
162}