1use std::{
3 collections::HashSet,
4 sync::{
5 atomic::{AtomicU32, Ordering},
6 Arc,
7 },
8};
9
10use futures::stream::{FuturesUnordered, StreamExt};
11use lazy_static::lazy_static;
12#[cfg(not(test))]
13use rand::{rngs::OsRng, RngCore};
14
15use thiserror::Error;
16use tokio::sync::OwnedMutexGuard;
17
18use crate::{
19 common::{
20 crypto::signature,
21 namespace::Namespace,
22 sgx::{EnclaveIdentity, QuotePolicy},
23 time::insecure_posix_time,
24 },
25 enclave_rpc::{session::Builder, types},
26 protocol::Protocol,
27};
28
29use super::{
30 sessions::{self, MultiplexedSession, Sessions, SharedSession},
31 transport::{RuntimeTransport, Transport},
32};
33
34const MAX_TRANSPORT_ERROR_RETRIES: usize = 3;
36
37lazy_static! {
38 static ref NEXT_CLIENT_ID: AtomicU32 = AtomicU32::new(RpcClient::random_client_id());
40}
41
42#[derive(Error, Debug)]
44pub enum RpcClientError {
45 #[error("call failed: {0}")]
46 CallFailed(String),
47 #[error("expected response message, received: {0:?}")]
48 ExpectedResponseMessage(types::Message),
49 #[error("expected close message, received: {0:?}")]
50 ExpectedCloseMessage(types::Message),
51 #[error("transport error")]
52 Transport,
53 #[error("unsupported RPC kind")]
54 UnsupportedRpcKind,
55 #[error("client dropped")]
56 Dropped,
57 #[error("decode error: {0}")]
58 DecodeError(#[from] cbor::DecodeError),
59 #[error("sessions error: {0}")]
60 SessionsError(#[from] sessions::Error),
61 #[error("unknown error: {0}")]
62 Unknown(#[from] anyhow::Error),
63}
64
65pub struct Response<'a, T> {
67 transport: &'a dyn Transport,
68 request_id: Option<u64>,
69 inner: Result<T, RpcClientError>,
70}
71
72impl<'a, T> Response<'a, T> {
73 pub async fn into_result_with_feedback(mut self) -> Result<T, RpcClientError> {
76 match self.inner {
77 Ok(_) => self.success().await,
78 Err(_) => self.failure().await,
79 }
80
81 self.inner
82 }
83
84 pub fn result(&self) -> &Result<T, RpcClientError> {
86 &self.inner
87 }
88
89 pub fn into_result(self) -> Result<T, RpcClientError> {
91 self.inner
92 }
93
94 pub async fn success(&mut self) {
96 self.send_peer_feedback(types::PeerFeedback::Success).await;
97 }
98
99 pub async fn failure(&mut self) {
101 self.send_peer_feedback(types::PeerFeedback::Failure).await;
102 }
103
104 pub async fn bad_peer(&mut self) {
106 self.send_peer_feedback(types::PeerFeedback::BadPeer).await;
107 }
108
109 async fn send_peer_feedback(&mut self, feedback: types::PeerFeedback) {
111 if let Some(request_id) = self.request_id.take() {
112 let _ = self
114 .transport
115 .submit_peer_feedback(request_id, feedback)
116 .await; }
118 }
119}
120
121pub struct RpcClient {
123 transport: Box<dyn Transport>,
125 sessions: tokio::sync::Mutex<Sessions<signature::PublicKey>>,
127 client_id: u32,
129 next_request_id: AtomicU32,
131}
132
133impl RpcClient {
134 fn new(
135 transport: Box<dyn Transport>,
136 builder: Builder,
137 max_sessions: usize,
138 max_sessions_per_peer: usize,
139 stale_session_timeout: i64,
140 ) -> Self {
141 let client_id = NEXT_CLIENT_ID.fetch_add(1, Ordering::SeqCst); let next_request_id = AtomicU32::new(1);
144
145 let sessions = tokio::sync::Mutex::new(Sessions::new(
146 builder,
147 max_sessions,
148 max_sessions_per_peer,
149 stale_session_timeout,
150 ));
151
152 Self {
153 transport,
154 sessions,
155 client_id,
156 next_request_id,
157 }
158 }
159
160 pub fn new_runtime(
162 protocol: Arc<Protocol>,
163 endpoint: &str,
164 builder: Builder,
165 max_sessions: usize,
166 max_sessions_per_peer: usize,
167 stale_session_timeout: i64,
168 ) -> Self {
169 let transport = Box::new(RuntimeTransport::new(protocol, endpoint));
170
171 Self::new(
172 transport,
173 builder,
174 max_sessions,
175 max_sessions_per_peer,
176 stale_session_timeout,
177 )
178 }
179
180 pub async fn update_enclaves(&self, enclaves: Option<HashSet<EnclaveIdentity>>) {
182 let sessions = {
183 let mut sessions = self.sessions.lock().await;
184 sessions.update_enclaves(enclaves)
185 };
186 self.close_all(sessions).await;
187 }
188
189 pub async fn update_quote_policy(&self, policy: QuotePolicy) {
191 let sessions = {
192 let mut sessions = self.sessions.lock().await;
193 sessions.update_quote_policy(policy)
194 };
195 self.close_all(sessions).await;
196 }
197
198 pub async fn update_runtime_id(&self, id: Option<Namespace>) {
200 let sessions = {
201 let mut sessions = self.sessions.lock().await;
202 sessions.update_runtime_id(id)
203 };
204 self.close_all(sessions).await;
205 }
206
207 pub async fn secure_call<C, O>(
209 &self,
210 method: &'static str,
211 args: C,
212 nodes: Vec<signature::PublicKey>,
213 ) -> Response<O>
214 where
215 C: cbor::Encode,
216 O: cbor::Decode + Send + 'static,
217 {
218 self.call(method, args, types::Kind::NoiseSession, nodes)
219 .await
220 }
221
222 pub async fn insecure_call<C, O>(
224 &self,
225 method: &'static str,
226 args: C,
227 nodes: Vec<signature::PublicKey>,
228 ) -> Response<O>
229 where
230 C: cbor::Encode,
231 O: cbor::Decode + Send + 'static,
232 {
233 self.call(method, args, types::Kind::InsecureQuery, nodes)
234 .await
235 }
236
237 async fn call<C, O>(
238 &self,
239 method: &'static str,
240 args: C,
241 kind: types::Kind,
242 nodes: Vec<signature::PublicKey>,
243 ) -> Response<O>
244 where
245 C: cbor::Encode,
246 O: cbor::Decode + Send + 'static,
247 {
248 let request = types::Request {
249 method: method.to_owned(),
250 args: cbor::to_value(args),
251 };
252
253 let retry_strategy = tokio_retry::strategy::ExponentialBackoff::from_millis(2)
256 .factor(25)
257 .max_delay(std::time::Duration::from_millis(250))
258 .take(MAX_TRANSPORT_ERROR_RETRIES);
259
260 let result = tokio_retry::Retry::spawn(retry_strategy, || {
261 self.execute_call(request.clone(), kind, nodes.clone())
262 })
263 .await;
264
265 let (request_id, inner) = match result {
266 Ok((request_id, response)) => match response.body {
267 types::Body::Success(value) => (
268 Some(request_id),
269 cbor::from_value(value).map_err(Into::into),
270 ),
271 types::Body::Error(error) => {
272 (Some(request_id), Err(RpcClientError::CallFailed(error)))
273 }
274 },
275 Err(err) => (None, Err(err)),
276 };
277
278 Response {
279 transport: &*self.transport,
280 request_id,
281 inner,
282 }
283 }
284
285 async fn execute_call(
286 &self,
287 request: types::Request,
288 kind: types::Kind,
289 nodes: Vec<signature::PublicKey>,
290 ) -> Result<(u64, types::Response), RpcClientError> {
291 match kind {
292 types::Kind::NoiseSession => {
293 let session = self.connect(nodes).await?;
296 let mut session = session.lock_owned().await;
297
298 let result = self.secure_call_raw(request, &mut session).await;
300
301 if result.is_err() {
305 let mut sessions = self.sessions.lock().await;
306 sessions.remove(&session);
307 }
308
309 result
310 }
311 types::Kind::InsecureQuery => {
312 self.insecure_call_raw(request, nodes).await
314 }
315 _ => Err(RpcClientError::UnsupportedRpcKind),
316 }
317 }
318
319 async fn connect(
320 &self,
321 nodes: Vec<signature::PublicKey>,
322 ) -> Result<SharedSession<signature::PublicKey>, RpcClientError> {
323 let mut session = {
325 let mut sessions = self.sessions.lock().await;
326
327 if let Some(session) = sessions.find(&nodes) {
329 return Ok(session);
330 }
331
332 let peer_id = Default::default();
334 sessions.create_initiator(peer_id)
335 };
336
337 let session_id = *session.get_session_id();
339
340 let mut buffer1 = vec![];
342 let mut buffer2 = vec![];
343
344 session
346 .process_data(&[], &mut buffer1)
347 .await
348 .expect("initiation must always succeed");
349
350 let request_id = self.next_request_id();
351 let result: Result<_, RpcClientError> = async {
352 let rsp = self
354 .transport
355 .write_noise_session(request_id, session_id, buffer1, String::new(), nodes)
356 .await
357 .map_err(|_| RpcClientError::Transport)?;
358
359 session.set_peer_id(rsp.node);
363 session
364 .set_remote_node(rsp.node)
365 .expect("remote node should not be set");
366
367 let _ = session
370 .process_data(&rsp.data, &mut buffer2)
371 .await
372 .map_err(|_| RpcClientError::Transport)?;
373
374 Ok(rsp)
375 }
376 .await;
377
378 let feedback = match result {
381 Ok(_) => types::PeerFeedback::Success,
382 Err(_) => types::PeerFeedback::Failure,
383 };
384 let _ = self
385 .transport
386 .submit_peer_feedback(request_id, feedback)
387 .await; let rsp = result?;
391
392 let request_id = self.next_request_id();
393 let result = async {
394 let rsp = self
396 .transport
397 .write_noise_session(
398 request_id,
399 session_id,
400 buffer2,
401 String::new(),
402 vec![rsp.node],
403 )
404 .await
405 .map_err(|_| RpcClientError::Transport)?;
406
407 if session.is_unauthenticated() {
408 return Err(RpcClientError::Transport);
409 }
410
411 Ok(rsp)
412 }
413 .await;
414
415 let feedback = match result {
418 Ok(_) => types::PeerFeedback::Success,
419 Err(_) => types::PeerFeedback::Failure,
420 };
421 let _ = self
422 .transport
423 .submit_peer_feedback(request_id, feedback)
424 .await; if let Err(err) = result {
428 let session = Arc::new(tokio::sync::Mutex::new(session))
430 .lock_owned()
431 .await;
432 let _ = self.close(session).await; return Err(err);
435 }
436
437 let now = insecure_posix_time();
441 let mut sessions = self.sessions.lock().await;
442 let maybe_removed_session = match sessions.remove_for(&rsp.node, now) {
443 Ok(maybe_removed_session) => maybe_removed_session,
444 Err(err) => {
445 drop(sessions); let session = Arc::new(tokio::sync::Mutex::new(session))
449 .lock_owned()
450 .await;
451 let _ = self.close(session).await; return Err(err.into());
454 }
455 };
456 let session = sessions
457 .add(session, now)
458 .expect("there should be space for the new session");
459
460 if let Some(removed_session) = maybe_removed_session {
461 drop(sessions); let _ = self.close(removed_session).await; }
466
467 Ok(session)
468 }
469
470 async fn secure_call_raw(
471 &self,
472 request: types::Request,
473 session: &mut OwnedMutexGuard<MultiplexedSession<signature::PublicKey>>,
474 ) -> Result<(u64, types::Response), RpcClientError> {
475 let method = request.method.clone();
476 let msg = types::Message::Request(request);
477 let session_id = *session.get_session_id();
478
479 let mut buffer = vec![];
481 session
482 .write_message(msg, &mut buffer)
483 .map_err(|_| RpcClientError::Transport)?;
484 let node = session.get_remote_node()?;
485
486 let request_id = self.next_request_id();
487 let result = async {
488 let rsp = self
490 .transport
491 .write_noise_session(request_id, session_id, buffer, method, vec![node])
492 .await
493 .map_err(|_| RpcClientError::Transport)?;
494
495 session.process_data(&rsp.data, vec![]).await
497 }
498 .await;
499
500 if result.is_err() {
503 let _ = self
504 .transport
505 .submit_peer_feedback(request_id, types::PeerFeedback::Failure)
506 .await; }
508
509 let maybe_msg = result?;
511
512 let msg = maybe_msg.expect("message must be decoded if there is no error");
514 let rsp = match msg {
515 types::Message::Response(rsp) => rsp,
516 msg => return Err(RpcClientError::ExpectedResponseMessage(msg)),
517 };
518
519 Ok((request_id, rsp))
520 }
521
522 async fn insecure_call_raw(
523 &self,
524 request: types::Request,
525 nodes: Vec<signature::PublicKey>,
526 ) -> Result<(u64, types::Response), RpcClientError> {
527 let request_id = self.next_request_id();
529 let result = self
530 .transport
531 .write_insecure_query(request_id, cbor::to_vec(request), nodes)
532 .await
533 .map_err(|_| RpcClientError::Transport);
534
535 if result.is_err() {
537 let _ = self
538 .transport
539 .submit_peer_feedback(request_id, types::PeerFeedback::Failure)
540 .await; }
542
543 let rsp = result?;
545
546 let rsp = cbor::from_slice(&rsp.data).map_err(RpcClientError::DecodeError)?;
548
549 Ok((request_id, rsp))
550 }
551
552 async fn close(
554 &self,
555 mut session: OwnedMutexGuard<MultiplexedSession<signature::PublicKey>>,
556 ) -> Result<(), RpcClientError> {
557 if !session.is_connected() && !session.is_unauthenticated() {
558 return Ok(());
559 }
560
561 let session_id = *session.get_session_id();
562 let node = session.get_remote_node()?;
563
564 let mut buffer = vec![];
566 session
567 .write_message(types::Message::Close, &mut buffer)
568 .map_err(|_| RpcClientError::Transport)?;
569
570 let request_id = self.next_request_id();
572 let rsp = self
573 .transport
574 .write_noise_session(request_id, session_id, buffer, String::new(), vec![node])
575 .await
576 .map_err(|_| RpcClientError::Transport)?;
577
578 let msg = session
583 .process_data(&rsp.data, vec![])
584 .await?
585 .expect("message must be decoded if there is no error");
586
587 session.close();
589
590 match msg {
591 types::Message::Close => Ok(()),
592 msg => Err(RpcClientError::ExpectedCloseMessage(msg)),
593 }
594 }
595
596 async fn close_all(&self, sessions: Vec<SharedSession<signature::PublicKey>>) {
598 let futures = FuturesUnordered::new();
599 for session in sessions {
600 let future = async {
601 let locked_session = session.lock_owned().await;
602 let _ = self.close(locked_session).await; };
604 futures.push(future);
605 }
606 futures.collect::<()>().await;
607 }
608
609 fn next_request_id(&self) -> u64 {
611 let next_request_id = self.next_request_id.fetch_add(1, Ordering::SeqCst); ((self.client_id as u64) << 32) + (next_request_id as u64)
613 }
614
615 fn random_client_id() -> u32 {
617 #[cfg(test)]
618 return 0;
619
620 #[cfg(not(test))]
621 OsRng.next_u32()
622 }
623}
624
625#[cfg(test)]
626mod test {
627 use std::sync::{
628 atomic::{AtomicBool, Ordering},
629 Arc, Mutex,
630 };
631
632 use anyhow::anyhow;
633 use async_trait::async_trait;
634
635 use crate::{
636 common::crypto::signature,
637 enclave_rpc::{demux::Demux, session, transport::EnclaveResponse, types},
638 };
639
640 use super::{super::transport::Transport, RpcClient};
641
642 #[derive(Clone)]
643 struct MockTransport {
644 demux: Arc<Demux>,
645 next_error: Arc<AtomicBool>,
646 peer_feedback_history: Arc<Mutex<Vec<(u64, types::PeerFeedback)>>>,
647 }
648
649 impl MockTransport {
650 fn new() -> Self {
651 Self {
652 demux: Arc::new(Demux::new(session::Builder::default(), 4, 4, 60)),
653 next_error: Arc::new(AtomicBool::new(false)),
654 peer_feedback_history: Arc::new(Mutex::new(Vec::new())),
655 }
656 }
657
658 fn reset(&self) {
659 self.demux.reset();
660 }
661
662 fn induce_transport_error(&self) {
663 self.next_error.store(true, Ordering::SeqCst);
664 }
665
666 fn take_peer_feedback_history(&self) -> Vec<(u64, types::PeerFeedback)> {
667 let mut pfh = self.peer_feedback_history.lock().unwrap();
668 pfh.drain(..).collect()
669 }
670 }
671
672 #[async_trait]
673 impl Transport for MockTransport {
674 async fn write_message_impl(
675 &self,
676 _request_id: u64,
677 request: Vec<u8>,
678 kind: types::Kind,
679 _nodes: Vec<signature::PublicKey>,
680 ) -> Result<EnclaveResponse, anyhow::Error> {
681 if self
683 .next_error
684 .compare_exchange(true, false, Ordering::SeqCst, Ordering::SeqCst)
685 .is_ok()
686 {
687 return Err(anyhow!("transport error"));
688 }
689
690 match kind {
691 types::Kind::NoiseSession => {
692 let mut buffer = Vec::new();
694 let (mut session, message) = self
695 .demux
696 .process_frame(vec![], request, &mut buffer)
697 .await?;
698
699 match message {
700 Some(message) => {
701 let mut buffer = Vec::new();
702
703 match message {
705 types::Message::Request(rq) => {
706 let response = types::Message::Response(types::Response {
708 body: types::Body::Success(rq.args),
709 });
710
711 session.write_message(response, &mut buffer)?;
712 }
713 types::Message::Close => {
714 self.demux.close(session, &mut buffer)?;
715 }
716 _ => panic!("unhandled message type"),
717 };
718
719 let rsp = EnclaveResponse {
720 data: buffer,
721 node: Default::default(),
722 };
723 Ok(rsp)
724 }
725 None => {
726 let rsp = EnclaveResponse {
728 data: buffer,
729 node: Default::default(),
730 };
731 Ok(rsp)
732 }
733 }
734 }
735 types::Kind::InsecureQuery => {
736 let rq: types::Request = cbor::from_slice(&request).unwrap();
738 let body = types::Body::Success(rq.args);
739 let response = types::Response { body };
740 let rsp = EnclaveResponse {
741 data: cbor::to_vec(response),
742 node: Default::default(),
743 };
744 return Ok(rsp);
745 }
746 types::Kind::LocalQuery => {
747 panic!("unhandled RPC kind")
748 }
749 }
750 }
751
752 async fn submit_peer_feedback(
753 &self,
754 request_id: u64,
755 peer_feedback: types::PeerFeedback,
756 ) -> Result<(), anyhow::Error> {
757 self.peer_feedback_history
758 .lock()
759 .unwrap()
760 .push((request_id, peer_feedback));
761
762 Ok(())
763 }
764 }
765
766 #[test]
767 fn test_rpc_client() {
768 let rt = tokio::runtime::Runtime::new().unwrap();
769 let _guard = rt.enter(); let transport = MockTransport::new();
771 let builder = session::Builder::default();
772 let client = RpcClient::new(Box::new(transport.clone()), builder, 8, 2, 60);
773
774 let result: u64 = rt
776 .block_on(async {
777 client
778 .secure_call("test", 42, vec![])
779 .await
780 .into_result_with_feedback()
781 .await
782 })
783 .unwrap();
784 assert_eq!(result, 42, "secure call should work");
785 assert_eq!(
786 transport.take_peer_feedback_history(),
787 vec![
788 (1, types::PeerFeedback::Success), (2, types::PeerFeedback::Success), (3, types::PeerFeedback::Success), ]
792 );
793
794 transport.reset();
796
797 let result: u64 = rt
798 .block_on(async {
799 client
800 .secure_call("test", 43, vec![])
801 .await
802 .into_result_with_feedback()
803 .await
804 })
805 .unwrap();
806 assert_eq!(result, 43, "secure call should work");
807 assert_eq!(
808 transport.take_peer_feedback_history(),
809 vec![
810 (4, types::PeerFeedback::Failure), (5, types::PeerFeedback::Success), (6, types::PeerFeedback::Success), (7, types::PeerFeedback::Success), ]
815 );
816
817 transport.induce_transport_error();
820
821 let result: u64 = rt
822 .block_on(async {
823 client
824 .secure_call("test", 44, vec![])
825 .await
826 .into_result_with_feedback()
827 .await
828 })
829 .unwrap();
830 assert_eq!(result, 44, "secure call should work");
831 assert_eq!(
832 transport.take_peer_feedback_history(),
833 vec![
834 (8, types::PeerFeedback::Failure), (9, types::PeerFeedback::Success), (10, types::PeerFeedback::Success), (11, types::PeerFeedback::Success), ]
840 );
841
842 let result: u64 = rt
844 .block_on(async {
845 client
846 .insecure_call("test", 45, vec![])
847 .await
848 .into_result_with_feedback()
849 .await
850 })
851 .unwrap();
852 assert_eq!(result, 45, "insecure call should work");
853 assert_eq!(
854 transport.take_peer_feedback_history(),
855 vec![
856 (12, types::PeerFeedback::Success), ]
858 );
859
860 transport.induce_transport_error();
862
863 let result: u64 = rt
864 .block_on(async {
865 client
866 .insecure_call("test", 46, vec![])
867 .await
868 .into_result_with_feedback()
869 .await
870 })
871 .unwrap();
872 assert_eq!(result, 46, "insecure call should work");
873 assert_eq!(
874 transport.take_peer_feedback_history(),
875 vec![
876 (13, types::PeerFeedback::Failure), (14, types::PeerFeedback::Success), ]
879 );
880 }
881}