1use std::{
3 collections::{BTreeSet, HashMap, HashSet},
4 hash::Hash,
5 io::Write,
6 mem,
7 sync::Arc,
8};
9
10use anyhow::Result;
11use rand::{rngs::OsRng, Rng};
12use tokio::sync::OwnedMutexGuard;
13
14use super::{
15 session::{Builder, Session, SessionInfo},
16 types::{Message, SessionID},
17};
18use crate::common::{
19 crypto::signature,
20 namespace::Namespace,
21 sgx::{EnclaveIdentity, QuotePolicy},
22 time::insecure_posix_time,
23};
24
25pub type SharedSession<PeerID> = Arc<tokio::sync::Mutex<MultiplexedSession<PeerID>>>;
27
28pub type SessionByTimeKey<PeerID> = (i64, PeerID, SessionID);
30
31#[derive(Debug, thiserror::Error)]
33pub enum Error {
34 #[error("max concurrent sessions reached")]
35 MaxConcurrentSessions,
36}
37
38pub struct MultiplexedSession<PeerID> {
40 peer_id: PeerID,
42 session_id: SessionID,
44 inner: Session,
46}
47
48impl<PeerID> MultiplexedSession<PeerID> {
49 pub fn get_peer_id(&self) -> &PeerID {
51 &self.peer_id
52 }
53
54 pub fn set_peer_id(&mut self, peer_id: PeerID) {
56 self.peer_id = peer_id;
57 }
58
59 pub fn get_session_id(&self) -> &SessionID {
61 &self.session_id
62 }
63
64 pub fn info(&self) -> Option<Arc<SessionInfo>> {
66 self.inner.session_info()
67 }
68
69 pub fn is_closed(&self) -> bool {
71 self.inner.is_closed()
72 }
73
74 pub async fn process_data<W: Write>(
76 &mut self,
77 data: &[u8],
78 writer: W,
79 ) -> Result<Option<Message>> {
80 self.inner.process_data(data, writer).await
81 }
82
83 pub fn write_message<W: Write>(&mut self, msg: Message, mut writer: W) -> Result<()> {
85 self.inner.write_message(msg, &mut writer)
86 }
87
88 pub fn get_remote_node(&self) -> Result<signature::PublicKey> {
90 self.inner.get_remote_node()
91 }
92
93 pub fn set_remote_node(&mut self, node: signature::PublicKey) -> Result<()> {
95 self.inner.set_remote_node(node)
96 }
97
98 pub fn is_connected(&self) -> bool {
101 self.inner.is_connected()
102 }
103
104 pub fn is_unauthenticated(&self) -> bool {
107 self.inner.is_unauthenticated()
108 }
109
110 pub fn close(&mut self) {
115 self.inner.close()
116 }
117}
118
119pub struct SessionMeta<PeerID: Clone + Ord + Hash> {
121 peer_id: PeerID,
123 session_id: SessionID,
125 last_access_time: i64,
127 inner: SharedSession<PeerID>,
129}
130
131impl<PeerID> SessionMeta<PeerID>
132where
133 PeerID: Clone + Ord + Hash,
134{
135 fn by_time_key(&self) -> SessionByTimeKey<PeerID> {
137 (self.last_access_time, self.peer_id.clone(), self.session_id)
138 }
139}
140
141pub struct Sessions<PeerID: Clone + Ord + Hash> {
143 builder: Builder,
145 max_sessions: usize,
147 max_sessions_per_peer: usize,
149 stale_session_timeout: i64,
151
152 by_peer: HashMap<PeerID, HashMap<SessionID, SessionMeta<PeerID>>>,
154 by_idle_time: BTreeSet<SessionByTimeKey<PeerID>>,
156}
157
158impl<PeerID> Sessions<PeerID>
159where
160 PeerID: Clone + Ord + Hash,
161{
162 pub fn new(
164 builder: Builder,
165 max_sessions: usize,
166 max_sessions_per_peer: usize,
167 stale_session_timeout: i64,
168 ) -> Self {
169 Self {
170 builder,
171 max_sessions,
172 max_sessions_per_peer,
173 stale_session_timeout,
174 by_peer: HashMap::new(),
175 by_idle_time: BTreeSet::new(),
176 }
177 }
178
179 pub fn set_builder(&mut self, builder: Builder) {
181 self.builder = builder;
182 }
183
184 pub fn update_enclaves(
187 &mut self,
188 enclaves: Option<HashSet<EnclaveIdentity>>,
189 ) -> Vec<SharedSession<PeerID>> {
190 if self.builder.get_remote_enclaves() == &enclaves {
191 return vec![];
192 }
193
194 self.builder = mem::take(&mut self.builder).remote_enclaves(enclaves);
195 self.drain()
196 }
197
198 pub fn update_quote_policy(&mut self, policy: QuotePolicy) -> Vec<SharedSession<PeerID>> {
201 let policy = Some(Arc::new(policy));
202 if self.builder.get_quote_policy() == &policy {
203 return vec![];
204 }
205
206 self.builder = mem::take(&mut self.builder).quote_policy(policy);
207 self.drain()
208 }
209
210 pub fn update_runtime_id(&mut self, id: Option<Namespace>) -> Vec<SharedSession<PeerID>> {
213 if self.builder.get_remote_runtime_id() == &id {
214 return vec![];
215 }
216
217 self.builder = mem::take(&mut self.builder).remote_runtime_id(id);
218 self.drain()
219 }
220
221 pub fn create_responder(
223 &mut self,
224 peer_id: PeerID,
225 session_id: SessionID,
226 ) -> MultiplexedSession<PeerID> {
227 if self.builder.get_quote_policy().is_none() {
229 let policy = self
230 .builder
231 .get_local_identity()
232 .as_ref()
233 .and_then(|id| id.quote_policy());
234
235 self.builder = mem::take(&mut self.builder).quote_policy(policy);
236 }
237
238 MultiplexedSession {
239 peer_id: peer_id.clone(),
240 session_id,
241 inner: self.builder.clone().build_responder(),
242 }
243 }
244
245 pub fn create_initiator(&self, peer_id: PeerID) -> MultiplexedSession<PeerID> {
247 let session_id = SessionID::random();
248
249 MultiplexedSession {
250 peer_id: peer_id.clone(),
251 session_id,
252 inner: self.builder.clone().build_initiator(),
253 }
254 }
255
256 pub fn get(
258 &mut self,
259 peer_id: &PeerID,
260 session_id: &SessionID,
261 ) -> Option<SharedSession<PeerID>> {
262 let sessions = match self.by_peer.get_mut(peer_id) {
264 Some(sessions) => sessions,
265 None => return None,
266 };
267
268 let session = match sessions.get_mut(session_id) {
270 Some(session) => session,
271 None => return None,
272 };
273
274 Self::update_access_time(session, &mut self.by_idle_time);
275
276 Some(session.inner.clone())
277 }
278
279 pub fn find(&mut self, peer_ids: &[PeerID]) -> Option<SharedSession<PeerID>> {
282 match peer_ids.is_empty() {
283 true => self.find_any(),
284 false => self.find_one(peer_ids),
285 }
286 }
287
288 pub fn find_any(&mut self) -> Option<SharedSession<PeerID>> {
290 if self.by_idle_time.is_empty() {
291 return None;
292 }
293
294 for (_, peer_id, session_id) in self.by_idle_time.iter() {
296 let session = self
297 .by_peer
298 .get_mut(peer_id)
299 .unwrap()
300 .get_mut(session_id)
301 .unwrap();
302
303 if session.inner.clone().try_lock_owned().is_ok() {
304 Self::update_access_time(session, &mut self.by_idle_time);
305 return Some(session.inner.clone());
306 }
307 }
308
309 let n = OsRng.gen_range(0..self.by_idle_time.len());
311 let (_, peer_id, session_id) = self.by_idle_time.iter().nth(n).unwrap();
312 let session = self
313 .by_peer
314 .get_mut(peer_id)
315 .unwrap()
316 .get_mut(session_id)
317 .unwrap();
318
319 Self::update_access_time(session, &mut self.by_idle_time);
320
321 Some(session.inner.clone())
322 }
323
324 pub fn find_one(&mut self, peer_ids: &[PeerID]) -> Option<SharedSession<PeerID>> {
326 let mut all_sessions = vec![];
327
328 for peer_id in peer_ids.iter() {
329 let sessions = match self.by_peer.get_mut(peer_id) {
330 Some(sessions) => sessions,
331 None => return None,
332 };
333
334 let session = sessions
336 .values_mut()
337 .filter(|s| s.inner.clone().try_lock_owned().is_ok())
338 .min_by_key(|s| s.last_access_time);
339
340 if let Some(session) = session {
341 Self::update_access_time(session, &mut self.by_idle_time);
342 return Some(session.inner.clone());
343 }
344
345 for session in sessions.values() {
346 all_sessions.push((session.peer_id.clone(), session.session_id));
347 }
348 }
349
350 if all_sessions.is_empty() {
351 return None;
352 }
353
354 let n = OsRng.gen_range(0..all_sessions.len());
356 let (peer_id, session_id) = all_sessions.get(n).unwrap();
357 let session = self
358 .by_peer
359 .get_mut(peer_id)
360 .unwrap()
361 .get_mut(session_id)
362 .unwrap();
363
364 Self::update_access_time(session, &mut self.by_idle_time);
365
366 Some(session.inner.clone())
367 }
368
369 pub fn remove_for(
371 &mut self,
372 peer_id: &PeerID,
373 now: i64,
374 ) -> Result<Option<OwnedMutexGuard<MultiplexedSession<PeerID>>>, Error> {
375 if let Some(session) = self.remove_from(peer_id)? {
376 return Ok(Some(session));
377 }
378 self.remove_one(now)
379 }
380
381 pub fn remove_from(
385 &mut self,
386 peer_id: &PeerID,
387 ) -> Result<Option<OwnedMutexGuard<MultiplexedSession<PeerID>>>, Error> {
388 let sessions = match self.by_peer.get_mut(peer_id) {
390 Some(sessions) => sessions,
391 None => return Ok(None),
392 };
393
394 if sessions.len() < self.max_sessions_per_peer
397 && self.by_idle_time.len() < self.max_sessions
398 {
399 return Ok(None);
400 }
401
402 let remove_session = sessions
404 .iter()
405 .min_by_key(|(_, s)| {
406 if let Ok(_inner) = s.inner.try_lock() {
407 s.last_access_time
408 } else {
409 i64::MAX }
411 })
412 .map(|(_, s)| s.inner.clone())
413 .ok_or(Error::MaxConcurrentSessions)?;
414
415 let session = match remove_session.try_lock_owned() {
416 Ok(inner) => inner,
417 Err(_) => return Err(Error::MaxConcurrentSessions), };
419
420 self.remove(&session);
421
422 Ok(Some(session))
423 }
424
425 pub fn remove_one(
428 &mut self,
429 now: i64,
430 ) -> Result<Option<OwnedMutexGuard<MultiplexedSession<PeerID>>>, Error> {
431 if self.by_idle_time.len() < self.max_sessions {
433 return Ok(None);
434 }
435
436 let mut remove_session: Option<OwnedMutexGuard<MultiplexedSession<PeerID>>> = None;
438
439 for (last_process_frame_time, peer_id, session_id) in self.by_idle_time.iter() {
440 if now.saturating_sub(*last_process_frame_time) < self.stale_session_timeout {
441 return Err(Error::MaxConcurrentSessions);
443 }
444
445 if let Some(sessions) = self.by_peer.get(peer_id) {
447 if let Some(session) = sessions.get(session_id) {
448 if let Ok(session) = session.inner.clone().try_lock_owned() {
449 remove_session = Some(session);
450 break;
451 }
452 }
453 }
454 }
455
456 let session = match remove_session {
458 Some(session) => session,
459 None => return Err(Error::MaxConcurrentSessions), };
461
462 self.remove(&session);
463
464 Ok(Some(session))
465 }
466
467 pub fn add(
469 &mut self,
470 session: MultiplexedSession<PeerID>,
471 now: i64,
472 ) -> Result<SharedSession<PeerID>, Error> {
473 if self.by_idle_time.len() >= self.max_sessions {
474 return Err(Error::MaxConcurrentSessions);
475 }
476
477 let sessions = self.by_peer.entry(session.peer_id.clone()).or_default();
478 if sessions.len() >= self.max_sessions_per_peer {
479 return Err(Error::MaxConcurrentSessions);
480 }
481
482 let peer_id = session.peer_id.clone();
483 let session_id = session.session_id;
484
485 let session = SessionMeta {
486 inner: Arc::new(tokio::sync::Mutex::new(session)),
487 peer_id,
488 session_id,
489 last_access_time: now,
490 };
491 let inner = session.inner.clone();
492
493 self.by_idle_time.insert(session.by_time_key());
494 sessions.insert(session.session_id, session);
495
496 Ok(inner)
497 }
498
499 pub fn remove(&mut self, session: &OwnedMutexGuard<MultiplexedSession<PeerID>>) {
501 let sessions = self.by_peer.get_mut(&session.peer_id).unwrap();
502 let session_meta = sessions.get(&session.session_id).unwrap();
503 let key = session_meta.by_time_key();
504 sessions.remove(&session.session_id);
505 self.by_idle_time.remove(&key);
506
507 if sessions.is_empty() {
509 self.by_peer.remove(&session.peer_id);
510 }
511 }
512
513 pub fn drain(&mut self) -> Vec<SharedSession<PeerID>> {
515 self.by_idle_time.clear();
516
517 let mut all_sessions = vec![];
518 for (_, mut sessions) in self.by_peer.drain() {
519 for (_, session) in sessions.drain() {
520 all_sessions.push(session.inner);
521 }
522 }
523
524 all_sessions
525 }
526
527 fn update_access_time(
528 session: &mut SessionMeta<PeerID>,
529 by_idle_time: &mut BTreeSet<SessionByTimeKey<PeerID>>,
530 ) {
531 by_idle_time.remove(&session.by_time_key());
533
534 session.last_access_time = insecure_posix_time();
536 by_idle_time.insert(session.by_time_key());
537 }
538
539 #[cfg(test)]
541 fn session_count(&self) -> usize {
542 self.by_idle_time.len()
543 }
544
545 #[cfg(test)]
547 fn peer_count(&self) -> usize {
548 self.by_peer.len()
549 }
550}
551
552#[cfg(test)]
553mod test {
554 use crate::enclave_rpc::{session::Builder, types::SessionID};
555
556 use super::{Error, Sessions};
557
558 fn ids() -> (Vec<Vec<u8>>, Vec<SessionID>) {
559 let peer_ids: Vec<Vec<u8>> = (1..8).map(|x| vec![x]).collect();
560 let session_ids: Vec<SessionID> = (1..8).map(|_| SessionID::random()).collect();
561
562 (peer_ids, session_ids)
563 }
564
565 #[test]
566 fn test_add() {
567 let (peer_ids, session_ids) = ids();
568 let mut sessions = Sessions::new(Builder::default(), 4, 2, 60);
569
570 let test_vector = vec![
571 (&peer_ids[0], &session_ids[0], 1, 1, true),
572 (&peer_ids[0], &session_ids[1], 2, 1, true), (&peer_ids[0], &session_ids[2], 2, 1, false), (&peer_ids[1], &session_ids[0], 3, 2, true), (&peer_ids[2], &session_ids[2], 4, 3, true), (&peer_ids[3], &session_ids[3], 4, 3, false), ];
578
579 let now = 0;
580 for (peer_id, session_id, num_sessions, num_peers, created) in test_vector {
581 let session = sessions.create_responder(peer_id.clone(), session_id.clone());
582 let res = sessions.add(session, now);
583 match created {
584 true => {
585 assert!(res.is_ok(), "session should be created");
586 let s = res.unwrap();
587 let s_owned = s.try_lock().unwrap();
588 assert_eq!(&s_owned.peer_id, peer_id);
589 assert_eq!(&s_owned.session_id, session_id);
590 }
591 false => {
592 assert!(res.is_err(), "session should not be created");
593 assert!(matches!(res, Err(Error::MaxConcurrentSessions)));
594 }
595 };
596 assert_eq!(sessions.session_count(), num_sessions);
597 assert_eq!(sessions.peer_count(), num_peers);
598 }
599 }
600
601 #[test]
602 fn test_get() {
603 let (peer_ids, session_ids) = ids();
604 let mut sessions = Sessions::new(Builder::default(), 8, 2, 60);
605
606 let test_vector = vec![
607 (&peer_ids[0], &session_ids[0], true),
608 (&peer_ids[0], &session_ids[1], false), (&peer_ids[1], &session_ids[0], false), (&peer_ids[1], &session_ids[1], false), ];
612
613 let now = 0;
614 for (peer_id, session_id, create) in test_vector {
615 if create {
616 let session = sessions.create_responder(peer_id.clone(), session_id.clone());
617 let _ = sessions.add(session, now);
618 }
619
620 let maybe_s = sessions.get(peer_id, session_id);
621 match create {
622 true => {
623 assert!(maybe_s.is_some(), "session should exist");
624 let s = maybe_s.unwrap();
625 let s_owned = s.try_lock_owned().unwrap();
626 assert_eq!(&s_owned.peer_id, peer_id);
627 assert_eq!(&s_owned.session_id, session_id);
628 }
629 false => assert!(maybe_s.is_none(), "session should not exist"),
630 }
631 }
632 }
633
634 #[test]
635 fn test_find_any() {
636 let (peer_ids, session_ids) = ids();
637 let mut sessions = Sessions::new(Builder::default(), 8, 2, 60);
638
639 let test_vector = vec![
640 (&peer_ids[0], &session_ids[0]),
641 (&peer_ids[0], &session_ids[1]),
642 (&peer_ids[1], &session_ids[2]),
643 ];
644
645 let maybe_s = sessions.find_any();
647 assert!(maybe_s.is_none(), "session should not be found");
648
649 let mut now = 0;
650 for (peer_id, session_id) in test_vector {
651 let session = sessions.create_responder(peer_id.clone(), session_id.clone());
652 let _ = sessions.add(session, now);
653 now += 1
654 }
655
656 let maybe_s = sessions.find_any();
658 assert!(maybe_s.is_some(), "session should be found");
659 let s = maybe_s.unwrap();
660 let s1_owned = s.try_lock_owned().unwrap(); assert_eq!(&s1_owned.peer_id, &peer_ids[0]);
662 assert_eq!(&s1_owned.session_id, &session_ids[0]);
663
664 let maybe_s = sessions.find_any();
666 assert!(maybe_s.is_some(), "session should be found");
667 let s = maybe_s.unwrap();
668 let s2_owned = s.try_lock_owned().unwrap(); assert_eq!(&s2_owned.peer_id, &peer_ids[0]);
670 assert_eq!(&s2_owned.session_id, &session_ids[1]); let maybe_s = sessions.find_any();
674 assert!(maybe_s.is_some(), "session should be found");
675 let s = maybe_s.unwrap();
676 let s3_owned = s.try_lock_owned().unwrap(); assert_eq!(&s3_owned.peer_id, &peer_ids[1]);
678 assert_eq!(&s3_owned.session_id, &session_ids[2]); let maybe_s = sessions.find_any();
682 assert!(maybe_s.is_some(), "session should be found");
683 let s = maybe_s.unwrap();
684 let res = s.try_lock_owned(); assert!(res.is_err(), "session should be in use");
686
687 drop(s2_owned);
689
690 let maybe_s = sessions.find_any();
692 assert!(maybe_s.is_some(), "session should be found");
693 let s = maybe_s.unwrap();
694 let s_owned = s.try_lock_owned().unwrap(); assert_eq!(&s_owned.peer_id, &peer_ids[0]);
696 assert_eq!(&s_owned.session_id, &session_ids[1]);
697 }
698
699 #[test]
700 fn test_find_one() {
701 let (peer_ids, session_ids) = ids();
702 let mut sessions = Sessions::new(Builder::default(), 8, 2, 60);
703
704 let test_vector = vec![
705 (&peer_ids[2], &session_ids[0]), (&peer_ids[0], &session_ids[0]),
707 (&peer_ids[3], &session_ids[1]), (&peer_ids[0], &session_ids[1]),
709 (&peer_ids[3], &session_ids[2]), (&peer_ids[1], &session_ids[2]),
711 (&peer_ids[2], &session_ids[2]), ];
713
714 let maybe_s = sessions.find_one(&peer_ids[0..2]);
716 assert!(maybe_s.is_none(), "session should not be found");
717
718 let mut now = 0;
719 for (peer_id, session_id) in test_vector {
720 let session = sessions.create_responder(peer_id.clone(), session_id.clone());
721 let _ = sessions.add(session, now);
722 now += 1
723 }
724
725 let maybe_s = sessions.find_one(&peer_ids[4..]);
727 assert!(maybe_s.is_none(), "session should not be found");
728
729 let maybe_s = sessions.find_one(&peer_ids[0..2]);
731 assert!(maybe_s.is_some(), "session should be found");
732 let s = maybe_s.unwrap();
733 let s1_owned = s.try_lock_owned().unwrap(); assert_eq!(&s1_owned.peer_id, &peer_ids[0]);
735 assert_eq!(&s1_owned.session_id, &session_ids[0]);
736
737 let maybe_s = sessions.find_one(&peer_ids[0..2]);
739 assert!(maybe_s.is_some(), "session should be found");
740 let s = maybe_s.unwrap();
741 let s2_owned = s.try_lock_owned().unwrap(); assert_eq!(&s2_owned.peer_id, &peer_ids[0]);
743 assert_eq!(&s2_owned.session_id, &session_ids[1]); let maybe_s = sessions.find_one(&peer_ids[0..2]);
747 assert!(maybe_s.is_some(), "session should be found");
748 let s = maybe_s.unwrap();
749 let s3_owned = s.try_lock_owned().unwrap(); assert_eq!(&s3_owned.peer_id, &peer_ids[1]);
751 assert_eq!(&s3_owned.session_id, &session_ids[2]); let maybe_s = sessions.find_one(&peer_ids[0..2]);
755 assert!(maybe_s.is_some(), "session should be found");
756 let s = maybe_s.unwrap();
757 let res = s.try_lock_owned(); assert!(res.is_err(), "session should be in use");
759
760 drop(s2_owned);
762
763 let maybe_s = sessions.find_one(&peer_ids[0..2]);
765 assert!(maybe_s.is_some(), "session should be found");
766 let s = maybe_s.unwrap();
767 let s_owned = s.try_lock_owned().unwrap(); assert_eq!(&s_owned.peer_id, &peer_ids[0]);
769 assert_eq!(&s_owned.session_id, &session_ids[1]);
770 }
771
772 #[test]
773 fn test_remove_from() {
774 let (peer_ids, session_ids) = ids();
775 let mut sessions = Sessions::new(Builder::default(), 4, 2, 60);
776
777 let test_vector = vec![
778 (&peer_ids[0], &session_ids[0]),
779 (&peer_ids[1], &session_ids[1]),
780 (&peer_ids[2], &session_ids[2]),
781 (&peer_ids[2], &session_ids[3]), ];
784
785 let mut now = 0;
786 for (peer_id, session_id) in test_vector.clone() {
787 let session = sessions.create_responder(peer_id.clone(), session_id.clone());
788 let _ = sessions.add(session, now);
789 now += 1;
790 }
791
792 let res = sessions.remove_from(&peer_ids[3]);
795 assert!(res.is_ok(), "remove_from should succeed");
796 let maybe_s_owned = res.unwrap();
797 assert!(maybe_s_owned.is_none(), "no sessions should be removed");
798 assert_eq!(sessions.session_count(), 4);
799 assert_eq!(sessions.peer_count(), 3);
800
801 let res = sessions.remove_from(&peer_ids[0]);
807 assert!(res.is_ok(), "remove_from should succeed");
808 let maybe_s_owned = res.unwrap();
809 assert!(maybe_s_owned.is_some(), "one session should be removed");
810 let s_owned = maybe_s_owned.unwrap();
811 assert_eq!(&s_owned.peer_id, &peer_ids[0]);
812 assert_eq!(&s_owned.session_id, &session_ids[0]);
813 assert_eq!(sessions.session_count(), 3);
814 assert_eq!(sessions.peer_count(), 2);
815
816 for peer_id in vec![&peer_ids[0], &peer_ids[1]] {
819 let res = sessions.remove_from(peer_id);
820 assert!(res.is_ok(), "remove_from should succeed");
821 let maybe_s_owned = res.unwrap();
822 assert!(maybe_s_owned.is_none(), "no sessions should be removed");
823 assert_eq!(sessions.session_count(), 3);
824 assert_eq!(sessions.peer_count(), 2);
825 }
826
827 let res = sessions.remove_from(&peer_ids[2]);
830 assert!(res.is_ok(), "remove_from should succeed");
831 let maybe_s_owned = res.unwrap();
832 assert!(maybe_s_owned.is_some(), "one session should be removed");
833 let s_owned = maybe_s_owned.unwrap();
834 assert_eq!(&s_owned.peer_id, &peer_ids[2]);
835 assert_eq!(&s_owned.session_id, &session_ids[2]);
836 assert_eq!(sessions.session_count(), 2);
837 assert_eq!(sessions.peer_count(), 2);
838 }
839
840 #[test]
841 fn test_remove_one() {
842 let (peer_ids, session_ids) = ids();
843 let mut sessions = Sessions::new(Builder::default(), 4, 2, 60);
844
845 let test_vector = vec![
846 (&peer_ids[0], &session_ids[0]),
847 (&peer_ids[1], &session_ids[1]),
848 (&peer_ids[2], &session_ids[2]),
849 (&peer_ids[2], &session_ids[3]), ];
851
852 let mut now = 0;
853 for (peer_id, session_id) in test_vector.clone() {
854 let session = sessions.create_responder(peer_id.clone(), session_id.clone());
855 let _ = sessions.add(session, now);
856 now += 1;
857 }
858
859 now += 60 - 4 - 1;
861
862 let res = sessions.remove_one(now);
864 assert!(res.is_err(), "remove_one should fail");
865 assert!(matches!(res, Err(Error::MaxConcurrentSessions)));
866 assert_eq!(sessions.session_count(), 4);
867 assert_eq!(sessions.peer_count(), 3);
868
869 now += 1;
871
872 let res = sessions.remove_one(now);
875 assert!(res.is_ok(), "remove_one should succeed");
876 let maybe_s_owned = res.unwrap();
877 assert!(maybe_s_owned.is_some(), "one session should be removed");
878 let s_owned = maybe_s_owned.unwrap();
879 assert_eq!(&s_owned.peer_id, &peer_ids[0]);
880 assert_eq!(&s_owned.session_id, &session_ids[0]);
881 assert_eq!(sessions.session_count(), 3);
882 assert_eq!(sessions.peer_count(), 2);
883
884 now += 100;
886
887 let res = sessions.remove_one(now);
890 assert!(res.is_ok(), "remove_one should succeed");
891 let maybe_s_owned = res.unwrap();
892 assert!(maybe_s_owned.is_none(), "no sessions should be removed");
893 assert_eq!(sessions.session_count(), 3);
894 assert_eq!(sessions.peer_count(), 2);
895 }
896
897 #[test]
898 fn test_remove() {
899 let (peer_ids, session_ids) = ids();
900 let mut sessions = Sessions::new(Builder::default(), 8, 2, 60);
901
902 let test_vector = vec![
903 (&peer_ids[0], &session_ids[0], 3, 2),
904 (&peer_ids[1], &session_ids[1], 2, 1),
905 (&peer_ids[2], &session_ids[2], 1, 1),
906 (&peer_ids[2], &session_ids[3], 0, 0),
907 ];
908
909 let now = 0;
910 for (peer_id, session_id, _, _) in test_vector.clone() {
911 let session = sessions.create_responder(peer_id.clone(), session_id.clone());
912 let _ = sessions.add(session, now);
913 }
914
915 for (peer_id, session_id, num_sessions, num_peers) in test_vector {
916 let maybe_s = sessions.get(peer_id, session_id);
917 assert!(maybe_s.is_some(), "session should exist");
918 let s = maybe_s.unwrap();
919 let s_owned = s.try_lock_owned().unwrap();
920
921 sessions.remove(&s_owned);
922 assert_eq!(sessions.session_count(), num_sessions);
923 assert_eq!(sessions.peer_count(), num_peers);
924 }
925 }
926
927 #[test]
928 fn test_clear() {
929 let (peer_ids, session_ids) = ids();
930 let mut sessions = Sessions::new(Builder::default(), 8, 2, 60);
931
932 let test_vector = vec![
933 (&peer_ids[0], &session_ids[0]),
934 (&peer_ids[1], &session_ids[1]),
935 (&peer_ids[2], &session_ids[2]),
936 (&peer_ids[2], &session_ids[3]),
937 ];
938
939 let now = 0;
940 for (peer_id, session_id) in test_vector.clone() {
941 let session = sessions.create_responder(peer_id.clone(), session_id.clone());
942 let _ = sessions.add(session, now);
943 }
944
945 let removed_sessions = sessions.drain();
946 assert_eq!(removed_sessions.len(), 4);
947 assert_eq!(sessions.session_count(), 0);
948 assert_eq!(sessions.peer_count(), 0);
949 }
950}