1use std::{
3    collections::{BTreeSet, HashMap, HashSet},
4    hash::Hash,
5    io::Write,
6    mem,
7    sync::Arc,
8};
9
10use anyhow::Result;
11use rand::{rngs::OsRng, Rng};
12use tokio::sync::OwnedMutexGuard;
13
14use super::{
15    session::{Builder, Session, SessionInfo},
16    types::{Message, SessionID},
17};
18use crate::common::{
19    crypto::signature,
20    namespace::Namespace,
21    sgx::{EnclaveIdentity, QuotePolicy},
22    time::insecure_posix_time,
23};
24
25pub type SharedSession<PeerID> = Arc<tokio::sync::Mutex<MultiplexedSession<PeerID>>>;
27
28pub type SessionByTimeKey<PeerID> = (i64, PeerID, SessionID);
30
31#[derive(Debug, thiserror::Error)]
33pub enum Error {
34    #[error("max concurrent sessions reached")]
35    MaxConcurrentSessions,
36}
37
38pub struct MultiplexedSession<PeerID> {
40    peer_id: PeerID,
42    session_id: SessionID,
44    inner: Session,
46}
47
48impl<PeerID> MultiplexedSession<PeerID> {
49    pub fn get_peer_id(&self) -> &PeerID {
51        &self.peer_id
52    }
53
54    pub fn set_peer_id(&mut self, peer_id: PeerID) {
56        self.peer_id = peer_id;
57    }
58
59    pub fn get_session_id(&self) -> &SessionID {
61        &self.session_id
62    }
63
64    pub fn info(&self) -> Option<Arc<SessionInfo>> {
66        self.inner.session_info()
67    }
68
69    pub fn is_closed(&self) -> bool {
71        self.inner.is_closed()
72    }
73
74    pub async fn process_data<W: Write>(
76        &mut self,
77        data: &[u8],
78        writer: W,
79    ) -> Result<Option<Message>> {
80        self.inner.process_data(data, writer).await
81    }
82
83    pub fn write_message<W: Write>(&mut self, msg: Message, mut writer: W) -> Result<()> {
85        self.inner.write_message(msg, &mut writer)
86    }
87
88    pub fn get_remote_node(&self) -> Result<signature::PublicKey> {
90        self.inner.get_remote_node()
91    }
92
93    pub fn set_remote_node(&mut self, node: signature::PublicKey) -> Result<()> {
95        self.inner.set_remote_node(node)
96    }
97
98    pub fn is_connected(&self) -> bool {
101        self.inner.is_connected()
102    }
103
104    pub fn is_unauthenticated(&self) -> bool {
107        self.inner.is_unauthenticated()
108    }
109
110    pub fn close(&mut self) {
115        self.inner.close()
116    }
117}
118
119pub struct SessionMeta<PeerID: Clone + Ord + Hash> {
121    peer_id: PeerID,
123    session_id: SessionID,
125    last_access_time: i64,
127    inner: SharedSession<PeerID>,
129}
130
131impl<PeerID> SessionMeta<PeerID>
132where
133    PeerID: Clone + Ord + Hash,
134{
135    fn by_time_key(&self) -> SessionByTimeKey<PeerID> {
137        (self.last_access_time, self.peer_id.clone(), self.session_id)
138    }
139}
140
141pub struct Sessions<PeerID: Clone + Ord + Hash> {
143    builder: Builder,
145    max_sessions: usize,
147    max_sessions_per_peer: usize,
149    stale_session_timeout: i64,
151
152    by_peer: HashMap<PeerID, HashMap<SessionID, SessionMeta<PeerID>>>,
154    by_idle_time: BTreeSet<SessionByTimeKey<PeerID>>,
156}
157
158impl<PeerID> Sessions<PeerID>
159where
160    PeerID: Clone + Ord + Hash,
161{
162    pub fn new(
164        builder: Builder,
165        max_sessions: usize,
166        max_sessions_per_peer: usize,
167        stale_session_timeout: i64,
168    ) -> Self {
169        Self {
170            builder,
171            max_sessions,
172            max_sessions_per_peer,
173            stale_session_timeout,
174            by_peer: HashMap::new(),
175            by_idle_time: BTreeSet::new(),
176        }
177    }
178
179    pub fn set_builder(&mut self, builder: Builder) {
181        self.builder = builder;
182    }
183
184    pub fn update_enclaves(
187        &mut self,
188        enclaves: Option<HashSet<EnclaveIdentity>>,
189    ) -> Vec<SharedSession<PeerID>> {
190        if self.builder.get_remote_enclaves() == &enclaves {
191            return vec![];
192        }
193
194        self.builder = mem::take(&mut self.builder).remote_enclaves(enclaves);
195        self.drain()
196    }
197
198    pub fn update_quote_policy(&mut self, policy: QuotePolicy) -> Vec<SharedSession<PeerID>> {
201        let policy = Some(Arc::new(policy));
202        if self.builder.get_quote_policy() == &policy {
203            return vec![];
204        }
205
206        self.builder = mem::take(&mut self.builder).quote_policy(policy);
207        self.drain()
208    }
209
210    pub fn update_runtime_id(&mut self, id: Option<Namespace>) -> Vec<SharedSession<PeerID>> {
213        if self.builder.get_remote_runtime_id() == &id {
214            return vec![];
215        }
216
217        self.builder = mem::take(&mut self.builder).remote_runtime_id(id);
218        self.drain()
219    }
220
221    pub fn create_responder(
223        &mut self,
224        peer_id: PeerID,
225        session_id: SessionID,
226    ) -> MultiplexedSession<PeerID> {
227        if self.builder.get_quote_policy().is_none() {
229            let policy = self
230                .builder
231                .get_local_identity()
232                .as_ref()
233                .and_then(|id| id.quote_policy());
234
235            self.builder = mem::take(&mut self.builder).quote_policy(policy);
236        }
237
238        MultiplexedSession {
239            peer_id: peer_id.clone(),
240            session_id,
241            inner: self.builder.clone().build_responder(),
242        }
243    }
244
245    pub fn create_initiator(&self, peer_id: PeerID) -> MultiplexedSession<PeerID> {
247        let session_id = SessionID::random();
248
249        MultiplexedSession {
250            peer_id: peer_id.clone(),
251            session_id,
252            inner: self.builder.clone().build_initiator(),
253        }
254    }
255
256    pub fn get(
258        &mut self,
259        peer_id: &PeerID,
260        session_id: &SessionID,
261    ) -> Option<SharedSession<PeerID>> {
262        let sessions = self.by_peer.get_mut(peer_id)?;
264
265        let session = sessions.get_mut(session_id)?;
267
268        Self::update_access_time(session, &mut self.by_idle_time);
269
270        Some(session.inner.clone())
271    }
272
273    pub fn find(&mut self, peer_ids: &[PeerID]) -> Option<SharedSession<PeerID>> {
276        match peer_ids.is_empty() {
277            true => self.find_any(),
278            false => self.find_one(peer_ids),
279        }
280    }
281
282    pub fn find_any(&mut self) -> Option<SharedSession<PeerID>> {
284        if self.by_idle_time.is_empty() {
285            return None;
286        }
287
288        for (_, peer_id, session_id) in self.by_idle_time.iter() {
290            let session = self
291                .by_peer
292                .get_mut(peer_id)
293                .unwrap()
294                .get_mut(session_id)
295                .unwrap();
296
297            if session.inner.clone().try_lock_owned().is_ok() {
298                Self::update_access_time(session, &mut self.by_idle_time);
299                return Some(session.inner.clone());
300            }
301        }
302
303        let n = OsRng.gen_range(0..self.by_idle_time.len());
305        let (_, peer_id, session_id) = self.by_idle_time.iter().nth(n).unwrap();
306        let session = self
307            .by_peer
308            .get_mut(peer_id)
309            .unwrap()
310            .get_mut(session_id)
311            .unwrap();
312
313        Self::update_access_time(session, &mut self.by_idle_time);
314
315        Some(session.inner.clone())
316    }
317
318    pub fn find_one(&mut self, peer_ids: &[PeerID]) -> Option<SharedSession<PeerID>> {
320        let mut all_sessions = vec![];
321
322        for peer_id in peer_ids.iter() {
323            let sessions = self.by_peer.get_mut(peer_id)?;
324
325            let session = sessions
327                .values_mut()
328                .filter(|s| s.inner.clone().try_lock_owned().is_ok())
329                .min_by_key(|s| s.last_access_time);
330
331            if let Some(session) = session {
332                Self::update_access_time(session, &mut self.by_idle_time);
333                return Some(session.inner.clone());
334            }
335
336            for session in sessions.values() {
337                all_sessions.push((session.peer_id.clone(), session.session_id));
338            }
339        }
340
341        if all_sessions.is_empty() {
342            return None;
343        }
344
345        let n = OsRng.gen_range(0..all_sessions.len());
347        let (peer_id, session_id) = all_sessions.get(n).unwrap();
348        let session = self
349            .by_peer
350            .get_mut(peer_id)
351            .unwrap()
352            .get_mut(session_id)
353            .unwrap();
354
355        Self::update_access_time(session, &mut self.by_idle_time);
356
357        Some(session.inner.clone())
358    }
359
360    pub fn remove_for(
362        &mut self,
363        peer_id: &PeerID,
364        now: i64,
365    ) -> Result<Option<OwnedMutexGuard<MultiplexedSession<PeerID>>>, Error> {
366        if let Some(session) = self.remove_from(peer_id)? {
367            return Ok(Some(session));
368        }
369        self.remove_one(now)
370    }
371
372    pub fn remove_from(
376        &mut self,
377        peer_id: &PeerID,
378    ) -> Result<Option<OwnedMutexGuard<MultiplexedSession<PeerID>>>, Error> {
379        let sessions = match self.by_peer.get_mut(peer_id) {
381            Some(sessions) => sessions,
382            None => return Ok(None),
383        };
384
385        if sessions.len() < self.max_sessions_per_peer
388            && self.by_idle_time.len() < self.max_sessions
389        {
390            return Ok(None);
391        }
392
393        let remove_session = sessions
395            .iter()
396            .min_by_key(|(_, s)| {
397                if let Ok(_inner) = s.inner.try_lock() {
398                    s.last_access_time
399                } else {
400                    i64::MAX }
402            })
403            .map(|(_, s)| s.inner.clone())
404            .ok_or(Error::MaxConcurrentSessions)?;
405
406        let session = match remove_session.try_lock_owned() {
407            Ok(inner) => inner,
408            Err(_) => return Err(Error::MaxConcurrentSessions), };
410
411        self.remove(&session);
412
413        Ok(Some(session))
414    }
415
416    pub fn remove_one(
419        &mut self,
420        now: i64,
421    ) -> Result<Option<OwnedMutexGuard<MultiplexedSession<PeerID>>>, Error> {
422        if self.by_idle_time.len() < self.max_sessions {
424            return Ok(None);
425        }
426
427        let mut remove_session: Option<OwnedMutexGuard<MultiplexedSession<PeerID>>> = None;
429
430        for (last_process_frame_time, peer_id, session_id) in self.by_idle_time.iter() {
431            if now.saturating_sub(*last_process_frame_time) < self.stale_session_timeout {
432                return Err(Error::MaxConcurrentSessions);
434            }
435
436            if let Some(sessions) = self.by_peer.get(peer_id) {
438                if let Some(session) = sessions.get(session_id) {
439                    if let Ok(session) = session.inner.clone().try_lock_owned() {
440                        remove_session = Some(session);
441                        break;
442                    }
443                }
444            }
445        }
446
447        let session = match remove_session {
449            Some(session) => session,
450            None => return Err(Error::MaxConcurrentSessions), };
452
453        self.remove(&session);
454
455        Ok(Some(session))
456    }
457
458    pub fn add(
460        &mut self,
461        session: MultiplexedSession<PeerID>,
462        now: i64,
463    ) -> Result<SharedSession<PeerID>, Error> {
464        if self.by_idle_time.len() >= self.max_sessions {
465            return Err(Error::MaxConcurrentSessions);
466        }
467
468        let sessions = self.by_peer.entry(session.peer_id.clone()).or_default();
469        if sessions.len() >= self.max_sessions_per_peer {
470            return Err(Error::MaxConcurrentSessions);
471        }
472
473        let peer_id = session.peer_id.clone();
474        let session_id = session.session_id;
475
476        let session = SessionMeta {
477            inner: Arc::new(tokio::sync::Mutex::new(session)),
478            peer_id,
479            session_id,
480            last_access_time: now,
481        };
482        let inner = session.inner.clone();
483
484        self.by_idle_time.insert(session.by_time_key());
485        sessions.insert(session.session_id, session);
486
487        Ok(inner)
488    }
489
490    pub fn remove(&mut self, session: &OwnedMutexGuard<MultiplexedSession<PeerID>>) {
492        let sessions = self.by_peer.get_mut(&session.peer_id).unwrap();
493        let session_meta = sessions.get(&session.session_id).unwrap();
494        let key = session_meta.by_time_key();
495        sessions.remove(&session.session_id);
496        self.by_idle_time.remove(&key);
497
498        if sessions.is_empty() {
500            self.by_peer.remove(&session.peer_id);
501        }
502    }
503
504    pub fn drain(&mut self) -> Vec<SharedSession<PeerID>> {
506        self.by_idle_time.clear();
507
508        let mut all_sessions = vec![];
509        for (_, mut sessions) in self.by_peer.drain() {
510            for (_, session) in sessions.drain() {
511                all_sessions.push(session.inner);
512            }
513        }
514
515        all_sessions
516    }
517
518    fn update_access_time(
519        session: &mut SessionMeta<PeerID>,
520        by_idle_time: &mut BTreeSet<SessionByTimeKey<PeerID>>,
521    ) {
522        by_idle_time.remove(&session.by_time_key());
524
525        session.last_access_time = insecure_posix_time();
527        by_idle_time.insert(session.by_time_key());
528    }
529
530    #[cfg(test)]
532    fn session_count(&self) -> usize {
533        self.by_idle_time.len()
534    }
535
536    #[cfg(test)]
538    fn peer_count(&self) -> usize {
539        self.by_peer.len()
540    }
541}
542
543#[cfg(test)]
544mod test {
545    use crate::enclave_rpc::{session::Builder, types::SessionID};
546
547    use super::{Error, Sessions};
548
549    fn ids() -> (Vec<Vec<u8>>, Vec<SessionID>) {
550        let peer_ids: Vec<Vec<u8>> = (1..8).map(|x| vec![x]).collect();
551        let session_ids: Vec<SessionID> = (1..8).map(|_| SessionID::random()).collect();
552
553        (peer_ids, session_ids)
554    }
555
556    #[test]
557    fn test_add() {
558        let (peer_ids, session_ids) = ids();
559        let mut sessions = Sessions::new(Builder::default(), 4, 2, 60);
560
561        let test_vector = vec![
562            (&peer_ids[0], &session_ids[0], 1, 1, true),
563            (&peer_ids[0], &session_ids[1], 2, 1, true), (&peer_ids[0], &session_ids[2], 2, 1, false), (&peer_ids[1], &session_ids[0], 3, 2, true), (&peer_ids[2], &session_ids[2], 4, 3, true), (&peer_ids[3], &session_ids[3], 4, 3, false), ];
569
570        let now = 0;
571        for (peer_id, session_id, num_sessions, num_peers, created) in test_vector {
572            let session = sessions.create_responder(peer_id.clone(), session_id.clone());
573            let res = sessions.add(session, now);
574            match created {
575                true => {
576                    assert!(res.is_ok(), "session should be created");
577                    let s = res.unwrap();
578                    let s_owned = s.try_lock().unwrap();
579                    assert_eq!(&s_owned.peer_id, peer_id);
580                    assert_eq!(&s_owned.session_id, session_id);
581                }
582                false => {
583                    assert!(res.is_err(), "session should not be created");
584                    assert!(matches!(res, Err(Error::MaxConcurrentSessions)));
585                }
586            };
587            assert_eq!(sessions.session_count(), num_sessions);
588            assert_eq!(sessions.peer_count(), num_peers);
589        }
590    }
591
592    #[test]
593    fn test_get() {
594        let (peer_ids, session_ids) = ids();
595        let mut sessions = Sessions::new(Builder::default(), 8, 2, 60);
596
597        let test_vector = vec![
598            (&peer_ids[0], &session_ids[0], true),
599            (&peer_ids[0], &session_ids[1], false), (&peer_ids[1], &session_ids[0], false), (&peer_ids[1], &session_ids[1], false), ];
603
604        let now = 0;
605        for (peer_id, session_id, create) in test_vector {
606            if create {
607                let session = sessions.create_responder(peer_id.clone(), session_id.clone());
608                let _ = sessions.add(session, now);
609            }
610
611            let maybe_s = sessions.get(peer_id, session_id);
612            match create {
613                true => {
614                    assert!(maybe_s.is_some(), "session should exist");
615                    let s = maybe_s.unwrap();
616                    let s_owned = s.try_lock_owned().unwrap();
617                    assert_eq!(&s_owned.peer_id, peer_id);
618                    assert_eq!(&s_owned.session_id, session_id);
619                }
620                false => assert!(maybe_s.is_none(), "session should not exist"),
621            }
622        }
623    }
624
625    #[test]
626    fn test_find_any() {
627        let (peer_ids, session_ids) = ids();
628        let mut sessions = Sessions::new(Builder::default(), 8, 2, 60);
629
630        let test_vector = vec![
631            (&peer_ids[0], &session_ids[0]),
632            (&peer_ids[0], &session_ids[1]),
633            (&peer_ids[1], &session_ids[2]),
634        ];
635
636        let maybe_s = sessions.find_any();
638        assert!(maybe_s.is_none(), "session should not be found");
639
640        let mut now = 0;
641        for (peer_id, session_id) in test_vector {
642            let session = sessions.create_responder(peer_id.clone(), session_id.clone());
643            let _ = sessions.add(session, now);
644            now += 1
645        }
646
647        let maybe_s = sessions.find_any();
649        assert!(maybe_s.is_some(), "session should be found");
650        let s = maybe_s.unwrap();
651        let s1_owned = s.try_lock_owned().unwrap(); assert_eq!(&s1_owned.peer_id, &peer_ids[0]);
653        assert_eq!(&s1_owned.session_id, &session_ids[0]);
654
655        let maybe_s = sessions.find_any();
657        assert!(maybe_s.is_some(), "session should be found");
658        let s = maybe_s.unwrap();
659        let s2_owned = s.try_lock_owned().unwrap(); assert_eq!(&s2_owned.peer_id, &peer_ids[0]);
661        assert_eq!(&s2_owned.session_id, &session_ids[1]); let maybe_s = sessions.find_any();
665        assert!(maybe_s.is_some(), "session should be found");
666        let s = maybe_s.unwrap();
667        let s3_owned = s.try_lock_owned().unwrap(); assert_eq!(&s3_owned.peer_id, &peer_ids[1]);
669        assert_eq!(&s3_owned.session_id, &session_ids[2]); let maybe_s = sessions.find_any();
673        assert!(maybe_s.is_some(), "session should be found");
674        let s = maybe_s.unwrap();
675        let res = s.try_lock_owned(); assert!(res.is_err(), "session should be in use");
677
678        drop(s2_owned);
680
681        let maybe_s = sessions.find_any();
683        assert!(maybe_s.is_some(), "session should be found");
684        let s = maybe_s.unwrap();
685        let s_owned = s.try_lock_owned().unwrap(); assert_eq!(&s_owned.peer_id, &peer_ids[0]);
687        assert_eq!(&s_owned.session_id, &session_ids[1]);
688    }
689
690    #[test]
691    fn test_find_one() {
692        let (peer_ids, session_ids) = ids();
693        let mut sessions = Sessions::new(Builder::default(), 8, 2, 60);
694
695        let test_vector = vec![
696            (&peer_ids[2], &session_ids[0]), (&peer_ids[0], &session_ids[0]),
698            (&peer_ids[3], &session_ids[1]), (&peer_ids[0], &session_ids[1]),
700            (&peer_ids[3], &session_ids[2]), (&peer_ids[1], &session_ids[2]),
702            (&peer_ids[2], &session_ids[2]), ];
704
705        let maybe_s = sessions.find_one(&peer_ids[0..2]);
707        assert!(maybe_s.is_none(), "session should not be found");
708
709        let mut now = 0;
710        for (peer_id, session_id) in test_vector {
711            let session = sessions.create_responder(peer_id.clone(), session_id.clone());
712            let _ = sessions.add(session, now);
713            now += 1
714        }
715
716        let maybe_s = sessions.find_one(&peer_ids[4..]);
718        assert!(maybe_s.is_none(), "session should not be found");
719
720        let maybe_s = sessions.find_one(&peer_ids[0..2]);
722        assert!(maybe_s.is_some(), "session should be found");
723        let s = maybe_s.unwrap();
724        let s1_owned = s.try_lock_owned().unwrap(); assert_eq!(&s1_owned.peer_id, &peer_ids[0]);
726        assert_eq!(&s1_owned.session_id, &session_ids[0]);
727
728        let maybe_s = sessions.find_one(&peer_ids[0..2]);
730        assert!(maybe_s.is_some(), "session should be found");
731        let s = maybe_s.unwrap();
732        let s2_owned = s.try_lock_owned().unwrap(); assert_eq!(&s2_owned.peer_id, &peer_ids[0]);
734        assert_eq!(&s2_owned.session_id, &session_ids[1]); let maybe_s = sessions.find_one(&peer_ids[0..2]);
738        assert!(maybe_s.is_some(), "session should be found");
739        let s = maybe_s.unwrap();
740        let s3_owned = s.try_lock_owned().unwrap(); assert_eq!(&s3_owned.peer_id, &peer_ids[1]);
742        assert_eq!(&s3_owned.session_id, &session_ids[2]); let maybe_s = sessions.find_one(&peer_ids[0..2]);
746        assert!(maybe_s.is_some(), "session should be found");
747        let s = maybe_s.unwrap();
748        let res = s.try_lock_owned(); assert!(res.is_err(), "session should be in use");
750
751        drop(s2_owned);
753
754        let maybe_s = sessions.find_one(&peer_ids[0..2]);
756        assert!(maybe_s.is_some(), "session should be found");
757        let s = maybe_s.unwrap();
758        let s_owned = s.try_lock_owned().unwrap(); assert_eq!(&s_owned.peer_id, &peer_ids[0]);
760        assert_eq!(&s_owned.session_id, &session_ids[1]);
761    }
762
763    #[test]
764    fn test_remove_from() {
765        let (peer_ids, session_ids) = ids();
766        let mut sessions = Sessions::new(Builder::default(), 4, 2, 60);
767
768        let test_vector = vec![
769            (&peer_ids[0], &session_ids[0]),
770            (&peer_ids[1], &session_ids[1]),
771            (&peer_ids[2], &session_ids[2]),
772            (&peer_ids[2], &session_ids[3]), ];
775
776        let mut now = 0;
777        for (peer_id, session_id) in test_vector.clone() {
778            let session = sessions.create_responder(peer_id.clone(), session_id.clone());
779            let _ = sessions.add(session, now);
780            now += 1;
781        }
782
783        let res = sessions.remove_from(&peer_ids[3]);
786        assert!(res.is_ok(), "remove_from should succeed");
787        let maybe_s_owned = res.unwrap();
788        assert!(maybe_s_owned.is_none(), "no sessions should be removed");
789        assert_eq!(sessions.session_count(), 4);
790        assert_eq!(sessions.peer_count(), 3);
791
792        let res = sessions.remove_from(&peer_ids[0]);
798        assert!(res.is_ok(), "remove_from should succeed");
799        let maybe_s_owned = res.unwrap();
800        assert!(maybe_s_owned.is_some(), "one session should be removed");
801        let s_owned = maybe_s_owned.unwrap();
802        assert_eq!(&s_owned.peer_id, &peer_ids[0]);
803        assert_eq!(&s_owned.session_id, &session_ids[0]);
804        assert_eq!(sessions.session_count(), 3);
805        assert_eq!(sessions.peer_count(), 2);
806
807        for peer_id in vec![&peer_ids[0], &peer_ids[1]] {
810            let res = sessions.remove_from(peer_id);
811            assert!(res.is_ok(), "remove_from should succeed");
812            let maybe_s_owned = res.unwrap();
813            assert!(maybe_s_owned.is_none(), "no sessions should be removed");
814            assert_eq!(sessions.session_count(), 3);
815            assert_eq!(sessions.peer_count(), 2);
816        }
817
818        let res = sessions.remove_from(&peer_ids[2]);
821        assert!(res.is_ok(), "remove_from should succeed");
822        let maybe_s_owned = res.unwrap();
823        assert!(maybe_s_owned.is_some(), "one session should be removed");
824        let s_owned = maybe_s_owned.unwrap();
825        assert_eq!(&s_owned.peer_id, &peer_ids[2]);
826        assert_eq!(&s_owned.session_id, &session_ids[2]);
827        assert_eq!(sessions.session_count(), 2);
828        assert_eq!(sessions.peer_count(), 2);
829    }
830
831    #[test]
832    fn test_remove_one() {
833        let (peer_ids, session_ids) = ids();
834        let mut sessions = Sessions::new(Builder::default(), 4, 2, 60);
835
836        let test_vector = vec![
837            (&peer_ids[0], &session_ids[0]),
838            (&peer_ids[1], &session_ids[1]),
839            (&peer_ids[2], &session_ids[2]),
840            (&peer_ids[2], &session_ids[3]), ];
842
843        let mut now = 0;
844        for (peer_id, session_id) in test_vector.clone() {
845            let session = sessions.create_responder(peer_id.clone(), session_id.clone());
846            let _ = sessions.add(session, now);
847            now += 1;
848        }
849
850        now += 60 - 4 - 1;
852
853        let res = sessions.remove_one(now);
855        assert!(res.is_err(), "remove_one should fail");
856        assert!(matches!(res, Err(Error::MaxConcurrentSessions)));
857        assert_eq!(sessions.session_count(), 4);
858        assert_eq!(sessions.peer_count(), 3);
859
860        now += 1;
862
863        let res = sessions.remove_one(now);
866        assert!(res.is_ok(), "remove_one should succeed");
867        let maybe_s_owned = res.unwrap();
868        assert!(maybe_s_owned.is_some(), "one session should be removed");
869        let s_owned = maybe_s_owned.unwrap();
870        assert_eq!(&s_owned.peer_id, &peer_ids[0]);
871        assert_eq!(&s_owned.session_id, &session_ids[0]);
872        assert_eq!(sessions.session_count(), 3);
873        assert_eq!(sessions.peer_count(), 2);
874
875        now += 100;
877
878        let res = sessions.remove_one(now);
881        assert!(res.is_ok(), "remove_one should succeed");
882        let maybe_s_owned = res.unwrap();
883        assert!(maybe_s_owned.is_none(), "no sessions should be removed");
884        assert_eq!(sessions.session_count(), 3);
885        assert_eq!(sessions.peer_count(), 2);
886    }
887
888    #[test]
889    fn test_remove() {
890        let (peer_ids, session_ids) = ids();
891        let mut sessions = Sessions::new(Builder::default(), 8, 2, 60);
892
893        let test_vector = vec![
894            (&peer_ids[0], &session_ids[0], 3, 2),
895            (&peer_ids[1], &session_ids[1], 2, 1),
896            (&peer_ids[2], &session_ids[2], 1, 1),
897            (&peer_ids[2], &session_ids[3], 0, 0),
898        ];
899
900        let now = 0;
901        for (peer_id, session_id, _, _) in test_vector.clone() {
902            let session = sessions.create_responder(peer_id.clone(), session_id.clone());
903            let _ = sessions.add(session, now);
904        }
905
906        for (peer_id, session_id, num_sessions, num_peers) in test_vector {
907            let maybe_s = sessions.get(peer_id, session_id);
908            assert!(maybe_s.is_some(), "session should exist");
909            let s = maybe_s.unwrap();
910            let s_owned = s.try_lock_owned().unwrap();
911
912            sessions.remove(&s_owned);
913            assert_eq!(sessions.session_count(), num_sessions);
914            assert_eq!(sessions.peer_count(), num_peers);
915        }
916    }
917
918    #[test]
919    fn test_clear() {
920        let (peer_ids, session_ids) = ids();
921        let mut sessions = Sessions::new(Builder::default(), 8, 2, 60);
922
923        let test_vector = vec![
924            (&peer_ids[0], &session_ids[0]),
925            (&peer_ids[1], &session_ids[1]),
926            (&peer_ids[2], &session_ids[2]),
927            (&peer_ids[2], &session_ids[3]),
928        ];
929
930        let now = 0;
931        for (peer_id, session_id) in test_vector.clone() {
932            let session = sessions.create_responder(peer_id.clone(), session_id.clone());
933            let _ = sessions.add(session, now);
934        }
935
936        let removed_sessions = sessions.drain();
937        assert_eq!(removed_sessions.len(), 4);
938        assert_eq!(sessions.session_count(), 0);
939        assert_eq!(sessions.peer_count(), 0);
940    }
941}