oasis_core_runtime/enclave_rpc/
sessions.rs

1//! Session demultiplexer.
2use std::{
3    collections::{BTreeSet, HashMap, HashSet},
4    hash::Hash,
5    io::Write,
6    mem,
7    sync::Arc,
8};
9
10use anyhow::Result;
11use rand::{rngs::OsRng, Rng};
12use tokio::sync::OwnedMutexGuard;
13
14use super::{
15    session::{Builder, Session, SessionInfo},
16    types::{Message, SessionID},
17};
18use crate::common::{
19    crypto::signature,
20    namespace::Namespace,
21    sgx::{EnclaveIdentity, QuotePolicy},
22    time::insecure_posix_time,
23};
24
25/// Shared pointer to a multiplexed session.
26pub type SharedSession<PeerID> = Arc<tokio::sync::Mutex<MultiplexedSession<PeerID>>>;
27
28/// Key for use in the by-idle-time index.
29pub type SessionByTimeKey<PeerID> = (i64, PeerID, SessionID);
30
31/// Sessions error.
32#[derive(Debug, thiserror::Error)]
33pub enum Error {
34    #[error("max concurrent sessions reached")]
35    MaxConcurrentSessions,
36}
37
38/// A multiplexed session.
39pub struct MultiplexedSession<PeerID> {
40    /// Peer identifier (needed for resolution when only given the shared pointer).
41    peer_id: PeerID,
42    /// Session identifier (needed for resolution when only given the shared pointer).
43    session_id: SessionID,
44    /// The actual session.
45    inner: Session,
46}
47
48impl<PeerID> MultiplexedSession<PeerID> {
49    /// Return the session's peer ID.
50    pub fn get_peer_id(&self) -> &PeerID {
51        &self.peer_id
52    }
53
54    /// Set the session's peer ID.
55    pub fn set_peer_id(&mut self, peer_id: PeerID) {
56        self.peer_id = peer_id;
57    }
58
59    /// Return the session ID.
60    pub fn get_session_id(&self) -> &SessionID {
61        &self.session_id
62    }
63
64    /// Session information.
65    pub fn info(&self) -> Option<Arc<SessionInfo>> {
66        self.inner.session_info()
67    }
68
69    /// Whether the session is in closed state.
70    pub fn is_closed(&self) -> bool {
71        self.inner.is_closed()
72    }
73
74    /// Process incoming session data.
75    pub async fn process_data<W: Write>(
76        &mut self,
77        data: &[u8],
78        writer: W,
79    ) -> Result<Option<Message>> {
80        self.inner.process_data(data, writer).await
81    }
82
83    /// Write message to session and generate a response.
84    pub fn write_message<W: Write>(&mut self, msg: Message, mut writer: W) -> Result<()> {
85        self.inner.write_message(msg, &mut writer)
86    }
87
88    /// Return remote node identifier.
89    pub fn get_remote_node(&self) -> Result<signature::PublicKey> {
90        self.inner.get_remote_node()
91    }
92
93    /// Set the remote node identifier.
94    pub fn set_remote_node(&mut self, node: signature::PublicKey) -> Result<()> {
95        self.inner.set_remote_node(node)
96    }
97
98    /// Whether the session handshake has completed and the session
99    /// is in transport mode.
100    pub fn is_connected(&self) -> bool {
101        self.inner.is_connected()
102    }
103
104    /// Whether the session is in unauthenticated transport state. In this state the session can
105    /// only be used to transmit a close notification.
106    pub fn is_unauthenticated(&self) -> bool {
107        self.inner.is_unauthenticated()
108    }
109
110    /// Mark the session as closed.
111    ///
112    /// After the session is closed it can no longer be used to transmit
113    /// or receive messages and any such use will result in an error.
114    pub fn close(&mut self) {
115        self.inner.close()
116    }
117}
118
119/// Structure used for session accounting.
120pub struct SessionMeta<PeerID: Clone + Ord + Hash> {
121    /// Peer identifier.
122    peer_id: PeerID,
123    /// Session identifier.
124    session_id: SessionID,
125    /// Timestamp when the session was last accessed.
126    last_access_time: i64,
127    /// The shared session pointer that needs to be locked for access.
128    inner: SharedSession<PeerID>,
129}
130
131impl<PeerID> SessionMeta<PeerID>
132where
133    PeerID: Clone + Ord + Hash,
134{
135    /// Key for ordering in the by-idle-time index.
136    fn by_time_key(&self) -> SessionByTimeKey<PeerID> {
137        (self.last_access_time, self.peer_id.clone(), self.session_id)
138    }
139}
140
141/// Session indices and management operations.
142pub struct Sessions<PeerID: Clone + Ord + Hash> {
143    /// Session builder.
144    builder: Builder,
145    /// Maximum number of sessions.
146    max_sessions: usize,
147    /// Maximum number of sessions per peer.
148    max_sessions_per_peer: usize,
149    /// Stale session timeout (in seconds).
150    stale_session_timeout: i64,
151
152    /// A map of sessions for each peer.
153    by_peer: HashMap<PeerID, HashMap<SessionID, SessionMeta<PeerID>>>,
154    /// A set of all sessions, ordered by idle time.
155    by_idle_time: BTreeSet<SessionByTimeKey<PeerID>>,
156}
157
158impl<PeerID> Sessions<PeerID>
159where
160    PeerID: Clone + Ord + Hash,
161{
162    /// Create a new session management instance.
163    pub fn new(
164        builder: Builder,
165        max_sessions: usize,
166        max_sessions_per_peer: usize,
167        stale_session_timeout: i64,
168    ) -> Self {
169        Self {
170            builder,
171            max_sessions,
172            max_sessions_per_peer,
173            stale_session_timeout,
174            by_peer: HashMap::new(),
175            by_idle_time: BTreeSet::new(),
176        }
177    }
178
179    /// Set the session builder to use.
180    pub fn set_builder(&mut self, builder: Builder) {
181        self.builder = builder;
182    }
183
184    /// Update remote enclave identity verification in the session builder
185    /// and clear all sessions if the identity has changed.
186    pub fn update_enclaves(
187        &mut self,
188        enclaves: Option<HashSet<EnclaveIdentity>>,
189    ) -> Vec<SharedSession<PeerID>> {
190        if self.builder.get_remote_enclaves() == &enclaves {
191            return vec![];
192        }
193
194        self.builder = mem::take(&mut self.builder).remote_enclaves(enclaves);
195        self.drain()
196    }
197
198    /// Update quote policy used for remote quote verification in the session builder
199    /// and clear all sessions if the policy has changed.
200    pub fn update_quote_policy(&mut self, policy: QuotePolicy) -> Vec<SharedSession<PeerID>> {
201        let policy = Some(Arc::new(policy));
202        if self.builder.get_quote_policy() == &policy {
203            return vec![];
204        }
205
206        self.builder = mem::take(&mut self.builder).quote_policy(policy);
207        self.drain()
208    }
209
210    /// Update remote runtime ID for node identity verification in the session builder
211    /// and clear all sessions if the runtime ID has changed.
212    pub fn update_runtime_id(&mut self, id: Option<Namespace>) -> Vec<SharedSession<PeerID>> {
213        if self.builder.get_remote_runtime_id() == &id {
214            return vec![];
215        }
216
217        self.builder = mem::take(&mut self.builder).remote_runtime_id(id);
218        self.drain()
219    }
220
221    /// Create a new multiplexed responder session.
222    pub fn create_responder(
223        &mut self,
224        peer_id: PeerID,
225        session_id: SessionID,
226    ) -> MultiplexedSession<PeerID> {
227        // If no quote policy is set, use the local one.
228        if self.builder.get_quote_policy().is_none() {
229            let policy = self
230                .builder
231                .get_local_identity()
232                .as_ref()
233                .and_then(|id| id.quote_policy());
234
235            self.builder = mem::take(&mut self.builder).quote_policy(policy);
236        }
237
238        MultiplexedSession {
239            peer_id: peer_id.clone(),
240            session_id,
241            inner: self.builder.clone().build_responder(),
242        }
243    }
244
245    /// Create a new multiplexed initiator session.
246    pub fn create_initiator(&self, peer_id: PeerID) -> MultiplexedSession<PeerID> {
247        let session_id = SessionID::random();
248
249        MultiplexedSession {
250            peer_id: peer_id.clone(),
251            session_id,
252            inner: self.builder.clone().build_initiator(),
253        }
254    }
255
256    /// Fetch an existing session given its identifier.
257    pub fn get(
258        &mut self,
259        peer_id: &PeerID,
260        session_id: &SessionID,
261    ) -> Option<SharedSession<PeerID>> {
262        // Check if peer exists.
263        let sessions = match self.by_peer.get_mut(peer_id) {
264            Some(sessions) => sessions,
265            None => return None,
266        };
267
268        // Check if the session exists. If so, return it.
269        let session = match sessions.get_mut(session_id) {
270            Some(session) => session,
271            None => return None,
272        };
273
274        Self::update_access_time(session, &mut self.by_idle_time);
275
276        Some(session.inner.clone())
277    }
278
279    /// Fetch an existing session from one of the given peers. If no peers
280    /// are provided, a session from any peer will be returned.
281    pub fn find(&mut self, peer_ids: &[PeerID]) -> Option<SharedSession<PeerID>> {
282        match peer_ids.is_empty() {
283            true => self.find_any(),
284            false => self.find_one(peer_ids),
285        }
286    }
287
288    /// Fetch an existing session from any peer.
289    pub fn find_any(&mut self) -> Option<SharedSession<PeerID>> {
290        if self.by_idle_time.is_empty() {
291            return None;
292        }
293
294        // Check if there is a session that is not currently in use.
295        for (_, peer_id, session_id) in self.by_idle_time.iter() {
296            let session = self
297                .by_peer
298                .get_mut(peer_id)
299                .unwrap()
300                .get_mut(session_id)
301                .unwrap();
302
303            if session.inner.clone().try_lock_owned().is_ok() {
304                Self::update_access_time(session, &mut self.by_idle_time);
305                return Some(session.inner.clone());
306            }
307        }
308
309        // If all sessions are in use, return a random one.
310        let n = OsRng.gen_range(0..self.by_idle_time.len());
311        let (_, peer_id, session_id) = self.by_idle_time.iter().nth(n).unwrap();
312        let session = self
313            .by_peer
314            .get_mut(peer_id)
315            .unwrap()
316            .get_mut(session_id)
317            .unwrap();
318
319        Self::update_access_time(session, &mut self.by_idle_time);
320
321        Some(session.inner.clone())
322    }
323
324    /// Fetch an existing session from one of the given peers.
325    pub fn find_one(&mut self, peer_ids: &[PeerID]) -> Option<SharedSession<PeerID>> {
326        let mut all_sessions = vec![];
327
328        for peer_id in peer_ids.iter() {
329            let sessions = match self.by_peer.get_mut(peer_id) {
330                Some(sessions) => sessions,
331                None => return None,
332            };
333
334            // Check if peer has a session that is not currently in use.
335            let session = sessions
336                .values_mut()
337                .filter(|s| s.inner.clone().try_lock_owned().is_ok())
338                .min_by_key(|s| s.last_access_time);
339
340            if let Some(session) = session {
341                Self::update_access_time(session, &mut self.by_idle_time);
342                return Some(session.inner.clone());
343            }
344
345            for session in sessions.values() {
346                all_sessions.push((session.peer_id.clone(), session.session_id));
347            }
348        }
349
350        if all_sessions.is_empty() {
351            return None;
352        }
353
354        // If all sessions are in use, return a random one.
355        let n = OsRng.gen_range(0..all_sessions.len());
356        let (peer_id, session_id) = all_sessions.get(n).unwrap();
357        let session = self
358            .by_peer
359            .get_mut(peer_id)
360            .unwrap()
361            .get_mut(session_id)
362            .unwrap();
363
364        Self::update_access_time(session, &mut self.by_idle_time);
365
366        Some(session.inner.clone())
367    }
368
369    /// Remove one session to free up a slot for the given peer.
370    pub fn remove_for(
371        &mut self,
372        peer_id: &PeerID,
373        now: i64,
374    ) -> Result<Option<OwnedMutexGuard<MultiplexedSession<PeerID>>>, Error> {
375        if let Some(session) = self.remove_from(peer_id)? {
376            return Ok(Some(session));
377        }
378        self.remove_one(now)
379    }
380
381    /// Remove one existing session from the given peer if the peer has reached
382    /// the maximum number of sessions or if the total number of sessions exceeds
383    /// the global session limit.
384    pub fn remove_from(
385        &mut self,
386        peer_id: &PeerID,
387    ) -> Result<Option<OwnedMutexGuard<MultiplexedSession<PeerID>>>, Error> {
388        // Check if peer exists.
389        let sessions = match self.by_peer.get_mut(peer_id) {
390            Some(sessions) => sessions,
391            None => return Ok(None),
392        };
393
394        // Check if the peer has max sessions or if no more sessions are available globally.
395        // If so, remove the oldest or return an error.
396        if sessions.len() < self.max_sessions_per_peer
397            && self.by_idle_time.len() < self.max_sessions
398        {
399            return Ok(None);
400        }
401
402        // Force close the oldest idle session.
403        let remove_session = sessions
404            .iter()
405            .min_by_key(|(_, s)| {
406                if let Ok(_inner) = s.inner.try_lock() {
407                    s.last_access_time
408                } else {
409                    i64::MAX // Session is currently in use.
410                }
411            })
412            .map(|(_, s)| s.inner.clone())
413            .ok_or(Error::MaxConcurrentSessions)?;
414
415        let session = match remove_session.try_lock_owned() {
416            Ok(inner) => inner,
417            Err(_) => return Err(Error::MaxConcurrentSessions), // All sessions are in use.
418        };
419
420        self.remove(&session);
421
422        Ok(Some(session))
423    }
424
425    /// Remove one stale session if the total number of sessions exceeds
426    /// the global session limit.
427    pub fn remove_one(
428        &mut self,
429        now: i64,
430    ) -> Result<Option<OwnedMutexGuard<MultiplexedSession<PeerID>>>, Error> {
431        // Check if there are too many sessions. If so, remove one or return an error.
432        if self.by_idle_time.len() < self.max_sessions {
433            return Ok(None);
434        }
435
436        // Attempt to prune stale sessions, starting with the oldest ones.
437        let mut remove_session: Option<OwnedMutexGuard<MultiplexedSession<PeerID>>> = None;
438
439        for (last_process_frame_time, peer_id, session_id) in self.by_idle_time.iter() {
440            if now.saturating_sub(*last_process_frame_time) < self.stale_session_timeout {
441                // This is the oldest session, all next ones will be more fresh.
442                return Err(Error::MaxConcurrentSessions);
443            }
444
445            // Fetch session and attempt to lock it.
446            if let Some(sessions) = self.by_peer.get(peer_id) {
447                if let Some(session) = sessions.get(session_id) {
448                    if let Ok(session) = session.inner.clone().try_lock_owned() {
449                        remove_session = Some(session);
450                        break;
451                    }
452                }
453            }
454        }
455
456        // Check if we found a session that can be removed.
457        let session = match remove_session {
458            Some(session) => session,
459            None => return Err(Error::MaxConcurrentSessions), // All stale sessions are in use.
460        };
461
462        self.remove(&session);
463
464        Ok(Some(session))
465    }
466
467    /// Add a session if there is an available spot.
468    pub fn add(
469        &mut self,
470        session: MultiplexedSession<PeerID>,
471        now: i64,
472    ) -> Result<SharedSession<PeerID>, Error> {
473        if self.by_idle_time.len() >= self.max_sessions {
474            return Err(Error::MaxConcurrentSessions);
475        }
476
477        let sessions = self.by_peer.entry(session.peer_id.clone()).or_default();
478        if sessions.len() >= self.max_sessions_per_peer {
479            return Err(Error::MaxConcurrentSessions);
480        }
481
482        let peer_id = session.peer_id.clone();
483        let session_id = session.session_id;
484
485        let session = SessionMeta {
486            inner: Arc::new(tokio::sync::Mutex::new(session)),
487            peer_id,
488            session_id,
489            last_access_time: now,
490        };
491        let inner = session.inner.clone();
492
493        self.by_idle_time.insert(session.by_time_key());
494        sessions.insert(session.session_id, session);
495
496        Ok(inner)
497    }
498
499    /// Remove a session that must be currently owned by the caller.
500    pub fn remove(&mut self, session: &OwnedMutexGuard<MultiplexedSession<PeerID>>) {
501        let sessions = self.by_peer.get_mut(&session.peer_id).unwrap();
502        let session_meta = sessions.get(&session.session_id).unwrap();
503        let key = session_meta.by_time_key();
504        sessions.remove(&session.session_id);
505        self.by_idle_time.remove(&key);
506
507        // If peer doesn't have any more sessions, remove the peer.
508        if sessions.is_empty() {
509            self.by_peer.remove(&session.peer_id);
510        }
511    }
512
513    /// Removes and returns all sessions.
514    pub fn drain(&mut self) -> Vec<SharedSession<PeerID>> {
515        self.by_idle_time.clear();
516
517        let mut all_sessions = vec![];
518        for (_, mut sessions) in self.by_peer.drain() {
519            for (_, session) in sessions.drain() {
520                all_sessions.push(session.inner);
521            }
522        }
523
524        all_sessions
525    }
526
527    fn update_access_time(
528        session: &mut SessionMeta<PeerID>,
529        by_idle_time: &mut BTreeSet<SessionByTimeKey<PeerID>>,
530    ) {
531        // Remove old idle time.
532        by_idle_time.remove(&session.by_time_key());
533
534        // Update idle time.
535        session.last_access_time = insecure_posix_time();
536        by_idle_time.insert(session.by_time_key());
537    }
538
539    /// Number of all sessions.
540    #[cfg(test)]
541    fn session_count(&self) -> usize {
542        self.by_idle_time.len()
543    }
544
545    /// Number of all peers.
546    #[cfg(test)]
547    fn peer_count(&self) -> usize {
548        self.by_peer.len()
549    }
550}
551
552#[cfg(test)]
553mod test {
554    use crate::enclave_rpc::{session::Builder, types::SessionID};
555
556    use super::{Error, Sessions};
557
558    fn ids() -> (Vec<Vec<u8>>, Vec<SessionID>) {
559        let peer_ids: Vec<Vec<u8>> = (1..8).map(|x| vec![x]).collect();
560        let session_ids: Vec<SessionID> = (1..8).map(|_| SessionID::random()).collect();
561
562        (peer_ids, session_ids)
563    }
564
565    #[test]
566    fn test_add() {
567        let (peer_ids, session_ids) = ids();
568        let mut sessions = Sessions::new(Builder::default(), 4, 2, 60);
569
570        let test_vector = vec![
571            (&peer_ids[0], &session_ids[0], 1, 1, true),
572            (&peer_ids[0], &session_ids[1], 2, 1, true), // Different session ID.
573            (&peer_ids[0], &session_ids[2], 2, 1, false), // Too many sessions per peer.
574            (&peer_ids[1], &session_ids[0], 3, 2, true), // Different peer ID.
575            (&peer_ids[2], &session_ids[2], 4, 3, true), // Different peer ID and session ID.
576            (&peer_ids[3], &session_ids[3], 4, 3, false), // Too many sessions.
577        ];
578
579        let now = 0;
580        for (peer_id, session_id, num_sessions, num_peers, created) in test_vector {
581            let session = sessions.create_responder(peer_id.clone(), session_id.clone());
582            let res = sessions.add(session, now);
583            match created {
584                true => {
585                    assert!(res.is_ok(), "session should be created");
586                    let s = res.unwrap();
587                    let s_owned = s.try_lock().unwrap();
588                    assert_eq!(&s_owned.peer_id, peer_id);
589                    assert_eq!(&s_owned.session_id, session_id);
590                }
591                false => {
592                    assert!(res.is_err(), "session should not be created");
593                    assert!(matches!(res, Err(Error::MaxConcurrentSessions)));
594                }
595            };
596            assert_eq!(sessions.session_count(), num_sessions);
597            assert_eq!(sessions.peer_count(), num_peers);
598        }
599    }
600
601    #[test]
602    fn test_get() {
603        let (peer_ids, session_ids) = ids();
604        let mut sessions = Sessions::new(Builder::default(), 8, 2, 60);
605
606        let test_vector = vec![
607            (&peer_ids[0], &session_ids[0], true),
608            (&peer_ids[0], &session_ids[1], false), // Different peer ID.
609            (&peer_ids[1], &session_ids[0], false), // Different session ID.
610            (&peer_ids[1], &session_ids[1], false), // Different peer ID and session ID.
611        ];
612
613        let now = 0;
614        for (peer_id, session_id, create) in test_vector {
615            if create {
616                let session = sessions.create_responder(peer_id.clone(), session_id.clone());
617                let _ = sessions.add(session, now);
618            }
619
620            let maybe_s = sessions.get(peer_id, session_id);
621            match create {
622                true => {
623                    assert!(maybe_s.is_some(), "session should exist");
624                    let s = maybe_s.unwrap();
625                    let s_owned = s.try_lock_owned().unwrap();
626                    assert_eq!(&s_owned.peer_id, peer_id);
627                    assert_eq!(&s_owned.session_id, session_id);
628                }
629                false => assert!(maybe_s.is_none(), "session should not exist"),
630            }
631        }
632    }
633
634    #[test]
635    fn test_find_any() {
636        let (peer_ids, session_ids) = ids();
637        let mut sessions = Sessions::new(Builder::default(), 8, 2, 60);
638
639        let test_vector = vec![
640            (&peer_ids[0], &session_ids[0]),
641            (&peer_ids[0], &session_ids[1]),
642            (&peer_ids[1], &session_ids[2]),
643        ];
644
645        // No sessions.
646        let maybe_s = sessions.find_any();
647        assert!(maybe_s.is_none(), "session should not be found");
648
649        let mut now = 0;
650        for (peer_id, session_id) in test_vector {
651            let session = sessions.create_responder(peer_id.clone(), session_id.clone());
652            let _ = sessions.add(session, now);
653            now += 1
654        }
655
656        // No sessions in use.
657        let maybe_s = sessions.find_any();
658        assert!(maybe_s.is_some(), "session should be found");
659        let s = maybe_s.unwrap();
660        let s1_owned = s.try_lock_owned().unwrap(); // Session now in use.
661        assert_eq!(&s1_owned.peer_id, &peer_ids[0]);
662        assert_eq!(&s1_owned.session_id, &session_ids[0]);
663
664        // One session in use.
665        let maybe_s = sessions.find_any();
666        assert!(maybe_s.is_some(), "session should be found");
667        let s = maybe_s.unwrap();
668        let s2_owned = s.try_lock_owned().unwrap(); // Session now in use.
669        assert_eq!(&s2_owned.peer_id, &peer_ids[0]);
670        assert_eq!(&s2_owned.session_id, &session_ids[1]); // Different session found.
671
672        // Two sessions in use.
673        let maybe_s = sessions.find_any();
674        assert!(maybe_s.is_some(), "session should be found");
675        let s = maybe_s.unwrap();
676        let s3_owned = s.try_lock_owned().unwrap(); // Session now in use.
677        assert_eq!(&s3_owned.peer_id, &peer_ids[1]);
678        assert_eq!(&s3_owned.session_id, &session_ids[2]); // Different session found.
679
680        // All sessions in use.
681        let maybe_s = sessions.find_any();
682        assert!(maybe_s.is_some(), "session should be found");
683        let s = maybe_s.unwrap();
684        let res = s.try_lock_owned(); // Session now in use.
685        assert!(res.is_err(), "session should be in use");
686
687        // Free one session.
688        drop(s2_owned);
689
690        // Two sessions in use.
691        let maybe_s = sessions.find_any();
692        assert!(maybe_s.is_some(), "session should be found");
693        let s = maybe_s.unwrap();
694        let s_owned = s.try_lock_owned().unwrap(); // Session now in use.
695        assert_eq!(&s_owned.peer_id, &peer_ids[0]);
696        assert_eq!(&s_owned.session_id, &session_ids[1]);
697    }
698
699    #[test]
700    fn test_find_one() {
701        let (peer_ids, session_ids) = ids();
702        let mut sessions = Sessions::new(Builder::default(), 8, 2, 60);
703
704        let test_vector = vec![
705            (&peer_ids[2], &session_ids[0]), // Incorrect peer.
706            (&peer_ids[0], &session_ids[0]),
707            (&peer_ids[3], &session_ids[1]), // Incorrect peer.
708            (&peer_ids[0], &session_ids[1]),
709            (&peer_ids[3], &session_ids[2]), // Incorrect peer.
710            (&peer_ids[1], &session_ids[2]),
711            (&peer_ids[2], &session_ids[2]), // Incorrect peer.
712        ];
713
714        // No sessions.
715        let maybe_s = sessions.find_one(&peer_ids[0..2]);
716        assert!(maybe_s.is_none(), "session should not be found");
717
718        let mut now = 0;
719        for (peer_id, session_id) in test_vector {
720            let session = sessions.create_responder(peer_id.clone(), session_id.clone());
721            let _ = sessions.add(session, now);
722            now += 1
723        }
724
725        // Peers without sessions.
726        let maybe_s = sessions.find_one(&peer_ids[4..]);
727        assert!(maybe_s.is_none(), "session should not be found");
728
729        // No sessions in use.
730        let maybe_s = sessions.find_one(&peer_ids[0..2]);
731        assert!(maybe_s.is_some(), "session should be found");
732        let s = maybe_s.unwrap();
733        let s1_owned = s.try_lock_owned().unwrap(); // Session now in use.
734        assert_eq!(&s1_owned.peer_id, &peer_ids[0]);
735        assert_eq!(&s1_owned.session_id, &session_ids[0]);
736
737        // One session in use.
738        let maybe_s = sessions.find_one(&peer_ids[0..2]);
739        assert!(maybe_s.is_some(), "session should be found");
740        let s = maybe_s.unwrap();
741        let s2_owned = s.try_lock_owned().unwrap(); // Session now in use.
742        assert_eq!(&s2_owned.peer_id, &peer_ids[0]);
743        assert_eq!(&s2_owned.session_id, &session_ids[1]); // Different session found.
744
745        // Two sessions in use.
746        let maybe_s = sessions.find_one(&peer_ids[0..2]);
747        assert!(maybe_s.is_some(), "session should be found");
748        let s = maybe_s.unwrap();
749        let s3_owned = s.try_lock_owned().unwrap(); // Session now in use.
750        assert_eq!(&s3_owned.peer_id, &peer_ids[1]);
751        assert_eq!(&s3_owned.session_id, &session_ids[2]); // Different session found.
752
753        // All sessions in use.
754        let maybe_s = sessions.find_one(&peer_ids[0..2]);
755        assert!(maybe_s.is_some(), "session should be found");
756        let s = maybe_s.unwrap();
757        let res = s.try_lock_owned(); // Session now in use.
758        assert!(res.is_err(), "session should be in use");
759
760        // Free one session.
761        drop(s2_owned);
762
763        // Two sessions in use.
764        let maybe_s = sessions.find_one(&peer_ids[0..2]);
765        assert!(maybe_s.is_some(), "session should be found");
766        let s = maybe_s.unwrap();
767        let s_owned = s.try_lock_owned().unwrap(); // Session now in use.
768        assert_eq!(&s_owned.peer_id, &peer_ids[0]);
769        assert_eq!(&s_owned.session_id, &session_ids[1]);
770    }
771
772    #[test]
773    fn test_remove_from() {
774        let (peer_ids, session_ids) = ids();
775        let mut sessions = Sessions::new(Builder::default(), 4, 2, 60);
776
777        let test_vector = vec![
778            (&peer_ids[0], &session_ids[0]),
779            (&peer_ids[1], &session_ids[1]),
780            (&peer_ids[2], &session_ids[2]),
781            (&peer_ids[2], &session_ids[3]), // Max sessions per peer reached.
782                                             // Max sessions reached.
783        ];
784
785        let mut now = 0;
786        for (peer_id, session_id) in test_vector.clone() {
787            let session = sessions.create_responder(peer_id.clone(), session_id.clone());
788            let _ = sessions.add(session, now);
789            now += 1;
790        }
791
792        // Removing one session from an unknown peer should have no effect,
793        // even if all global session slots are occupied.
794        let res = sessions.remove_from(&peer_ids[3]);
795        assert!(res.is_ok(), "remove_from should succeed");
796        let maybe_s_owned = res.unwrap();
797        assert!(maybe_s_owned.is_none(), "no sessions should be removed");
798        assert_eq!(sessions.session_count(), 4);
799        assert_eq!(sessions.peer_count(), 3);
800
801        // Removing one session for one of the existing peers should work
802        // as it should force evict an old session.
803        // Note that each peer has 2 available slots, but globally there are
804        // only 4 slots so if global slots are full this should trigger peer
805        // session eviction.
806        let res = sessions.remove_from(&peer_ids[0]);
807        assert!(res.is_ok(), "remove_from should succeed");
808        let maybe_s_owned = res.unwrap();
809        assert!(maybe_s_owned.is_some(), "one session should be removed");
810        let s_owned = maybe_s_owned.unwrap();
811        assert_eq!(&s_owned.peer_id, &peer_ids[0]);
812        assert_eq!(&s_owned.session_id, &session_ids[0]);
813        assert_eq!(sessions.session_count(), 3);
814        assert_eq!(sessions.peer_count(), 2);
815
816        // Removing another session should fail as one global session slot
817        // is available.
818        for peer_id in vec![&peer_ids[0], &peer_ids[1]] {
819            let res = sessions.remove_from(peer_id);
820            assert!(res.is_ok(), "remove_from should succeed");
821            let maybe_s_owned = res.unwrap();
822            assert!(maybe_s_owned.is_none(), "no sessions should be removed");
823            assert_eq!(sessions.session_count(), 3);
824            assert_eq!(sessions.peer_count(), 2);
825        }
826
827        // Removing one session from a peer with max sessions should succeed
828        // even if one global slot is available.
829        let res = sessions.remove_from(&peer_ids[2]);
830        assert!(res.is_ok(), "remove_from should succeed");
831        let maybe_s_owned = res.unwrap();
832        assert!(maybe_s_owned.is_some(), "one session should be removed");
833        let s_owned = maybe_s_owned.unwrap();
834        assert_eq!(&s_owned.peer_id, &peer_ids[2]);
835        assert_eq!(&s_owned.session_id, &session_ids[2]);
836        assert_eq!(sessions.session_count(), 2);
837        assert_eq!(sessions.peer_count(), 2);
838    }
839
840    #[test]
841    fn test_remove_one() {
842        let (peer_ids, session_ids) = ids();
843        let mut sessions = Sessions::new(Builder::default(), 4, 2, 60);
844
845        let test_vector = vec![
846            (&peer_ids[0], &session_ids[0]),
847            (&peer_ids[1], &session_ids[1]),
848            (&peer_ids[2], &session_ids[2]),
849            (&peer_ids[2], &session_ids[3]), // Max sessions reached.
850        ];
851
852        let mut now = 0;
853        for (peer_id, session_id) in test_vector.clone() {
854            let session = sessions.create_responder(peer_id.clone(), session_id.clone());
855            let _ = sessions.add(session, now);
856            now += 1;
857        }
858
859        // Forward time (stale_session_timeout - test_vector.len() - 1).
860        now += 60 - 4 - 1;
861
862        // Removing one session should fail as there are none stale sessions.
863        let res = sessions.remove_one(now);
864        assert!(res.is_err(), "remove_one should fail");
865        assert!(matches!(res, Err(Error::MaxConcurrentSessions)));
866        assert_eq!(sessions.session_count(), 4);
867        assert_eq!(sessions.peer_count(), 3);
868
869        // Forward time.
870        now += 1;
871
872        // Removing one session should succeed as no session slots
873        // are available and there is one stale session.
874        let res = sessions.remove_one(now);
875        assert!(res.is_ok(), "remove_one should succeed");
876        let maybe_s_owned = res.unwrap();
877        assert!(maybe_s_owned.is_some(), "one session should be removed");
878        let s_owned = maybe_s_owned.unwrap();
879        assert_eq!(&s_owned.peer_id, &peer_ids[0]);
880        assert_eq!(&s_owned.session_id, &session_ids[0]);
881        assert_eq!(sessions.session_count(), 3);
882        assert_eq!(sessions.peer_count(), 2);
883
884        // Forward time.
885        now += 100;
886
887        // Removing one session should fail even though there are stale sessions
888        // because there is one session slot available.
889        let res = sessions.remove_one(now);
890        assert!(res.is_ok(), "remove_one should succeed");
891        let maybe_s_owned = res.unwrap();
892        assert!(maybe_s_owned.is_none(), "no sessions should be removed");
893        assert_eq!(sessions.session_count(), 3);
894        assert_eq!(sessions.peer_count(), 2);
895    }
896
897    #[test]
898    fn test_remove() {
899        let (peer_ids, session_ids) = ids();
900        let mut sessions = Sessions::new(Builder::default(), 8, 2, 60);
901
902        let test_vector = vec![
903            (&peer_ids[0], &session_ids[0], 3, 2),
904            (&peer_ids[1], &session_ids[1], 2, 1),
905            (&peer_ids[2], &session_ids[2], 1, 1),
906            (&peer_ids[2], &session_ids[3], 0, 0),
907        ];
908
909        let now = 0;
910        for (peer_id, session_id, _, _) in test_vector.clone() {
911            let session = sessions.create_responder(peer_id.clone(), session_id.clone());
912            let _ = sessions.add(session, now);
913        }
914
915        for (peer_id, session_id, num_sessions, num_peers) in test_vector {
916            let maybe_s = sessions.get(peer_id, session_id);
917            assert!(maybe_s.is_some(), "session should exist");
918            let s = maybe_s.unwrap();
919            let s_owned = s.try_lock_owned().unwrap();
920
921            sessions.remove(&s_owned);
922            assert_eq!(sessions.session_count(), num_sessions);
923            assert_eq!(sessions.peer_count(), num_peers);
924        }
925    }
926
927    #[test]
928    fn test_clear() {
929        let (peer_ids, session_ids) = ids();
930        let mut sessions = Sessions::new(Builder::default(), 8, 2, 60);
931
932        let test_vector = vec![
933            (&peer_ids[0], &session_ids[0]),
934            (&peer_ids[1], &session_ids[1]),
935            (&peer_ids[2], &session_ids[2]),
936            (&peer_ids[2], &session_ids[3]),
937        ];
938
939        let now = 0;
940        for (peer_id, session_id) in test_vector.clone() {
941            let session = sessions.create_responder(peer_id.clone(), session_id.clone());
942            let _ = sessions.add(session, now);
943        }
944
945        let removed_sessions = sessions.drain();
946        assert_eq!(removed_sessions.len(), 4);
947        assert_eq!(sessions.session_count(), 0);
948        assert_eq!(sessions.peer_count(), 0);
949    }
950}