1use std::{
3    collections::HashSet,
4    sync::{
5        atomic::{AtomicU32, Ordering},
6        Arc,
7    },
8};
9
10use futures::stream::{FuturesUnordered, StreamExt};
11use lazy_static::lazy_static;
12#[cfg(not(test))]
13use rand::{rngs::OsRng, RngCore};
14
15use thiserror::Error;
16use tokio::sync::OwnedMutexGuard;
17
18use crate::{
19    common::{
20        crypto::signature,
21        namespace::Namespace,
22        sgx::{EnclaveIdentity, QuotePolicy},
23        time::insecure_posix_time,
24    },
25    enclave_rpc::{session::Builder, types},
26    protocol::Protocol,
27};
28
29use super::{
30    sessions::{self, MultiplexedSession, Sessions, SharedSession},
31    transport::{RuntimeTransport, Transport},
32};
33
34const MAX_TRANSPORT_ERROR_RETRIES: usize = 3;
36
37lazy_static! {
38    static ref NEXT_CLIENT_ID: AtomicU32 = AtomicU32::new(RpcClient::random_client_id());
40}
41
42#[derive(Error, Debug)]
44pub enum RpcClientError {
45    #[error("call failed: {0}")]
46    CallFailed(String),
47    #[error("expected response message, received: {0:?}")]
48    ExpectedResponseMessage(types::Message),
49    #[error("expected close message, received: {0:?}")]
50    ExpectedCloseMessage(types::Message),
51    #[error("transport error")]
52    Transport,
53    #[error("unsupported RPC kind")]
54    UnsupportedRpcKind,
55    #[error("client dropped")]
56    Dropped,
57    #[error("decode error: {0}")]
58    DecodeError(#[from] cbor::DecodeError),
59    #[error("sessions error: {0}")]
60    SessionsError(#[from] sessions::Error),
61    #[error("unknown error: {0}")]
62    Unknown(#[from] anyhow::Error),
63}
64
65pub struct Response<'a, T> {
67    transport: &'a dyn Transport,
68    request_id: Option<u64>,
69    inner: Result<T, RpcClientError>,
70}
71
72impl<T> Response<'_, T> {
73    pub async fn into_result_with_feedback(mut self) -> Result<T, RpcClientError> {
76        match self.inner {
77            Ok(_) => self.success().await,
78            Err(_) => self.failure().await,
79        }
80
81        self.inner
82    }
83
84    pub fn result(&self) -> &Result<T, RpcClientError> {
86        &self.inner
87    }
88
89    pub fn into_result(self) -> Result<T, RpcClientError> {
91        self.inner
92    }
93
94    pub async fn success(&mut self) {
96        self.send_peer_feedback(types::PeerFeedback::Success).await;
97    }
98
99    pub async fn failure(&mut self) {
101        self.send_peer_feedback(types::PeerFeedback::Failure).await;
102    }
103
104    pub async fn bad_peer(&mut self) {
106        self.send_peer_feedback(types::PeerFeedback::BadPeer).await;
107    }
108
109    async fn send_peer_feedback(&mut self, feedback: types::PeerFeedback) {
111        if let Some(request_id) = self.request_id.take() {
112            let _ = self
114                .transport
115                .submit_peer_feedback(request_id, feedback)
116                .await; }
118    }
119}
120
121pub struct RpcClient {
123    transport: Box<dyn Transport>,
125    sessions: tokio::sync::Mutex<Sessions<signature::PublicKey>>,
127    client_id: u32,
129    next_request_id: AtomicU32,
131}
132
133impl RpcClient {
134    fn new(
135        transport: Box<dyn Transport>,
136        builder: Builder,
137        max_sessions: usize,
138        max_sessions_per_peer: usize,
139        stale_session_timeout: i64,
140    ) -> Self {
141        let client_id = NEXT_CLIENT_ID.fetch_add(1, Ordering::SeqCst); let next_request_id = AtomicU32::new(1);
144
145        let sessions = tokio::sync::Mutex::new(Sessions::new(
146            builder,
147            max_sessions,
148            max_sessions_per_peer,
149            stale_session_timeout,
150        ));
151
152        Self {
153            transport,
154            sessions,
155            client_id,
156            next_request_id,
157        }
158    }
159
160    pub fn new_runtime(
162        protocol: Arc<Protocol>,
163        endpoint: &str,
164        builder: Builder,
165        max_sessions: usize,
166        max_sessions_per_peer: usize,
167        stale_session_timeout: i64,
168    ) -> Self {
169        let transport = Box::new(RuntimeTransport::new(protocol, endpoint));
170
171        Self::new(
172            transport,
173            builder,
174            max_sessions,
175            max_sessions_per_peer,
176            stale_session_timeout,
177        )
178    }
179
180    pub async fn update_enclaves(&self, enclaves: Option<HashSet<EnclaveIdentity>>) {
182        let sessions = {
183            let mut sessions = self.sessions.lock().await;
184            sessions.update_enclaves(enclaves)
185        };
186        self.close_all(sessions).await;
187    }
188
189    pub async fn update_quote_policy(&self, policy: QuotePolicy) {
191        let sessions = {
192            let mut sessions = self.sessions.lock().await;
193            sessions.update_quote_policy(policy)
194        };
195        self.close_all(sessions).await;
196    }
197
198    pub async fn update_runtime_id(&self, id: Option<Namespace>) {
200        let sessions = {
201            let mut sessions = self.sessions.lock().await;
202            sessions.update_runtime_id(id)
203        };
204        self.close_all(sessions).await;
205    }
206
207    pub async fn secure_call<C, O>(
209        &self,
210        method: &'static str,
211        args: C,
212        nodes: Vec<signature::PublicKey>,
213    ) -> Response<O>
214    where
215        C: cbor::Encode,
216        O: cbor::Decode + Send + 'static,
217    {
218        self.call(method, args, types::Kind::NoiseSession, nodes)
219            .await
220    }
221
222    pub async fn insecure_call<C, O>(
224        &self,
225        method: &'static str,
226        args: C,
227        nodes: Vec<signature::PublicKey>,
228    ) -> Response<O>
229    where
230        C: cbor::Encode,
231        O: cbor::Decode + Send + 'static,
232    {
233        self.call(method, args, types::Kind::InsecureQuery, nodes)
234            .await
235    }
236
237    async fn call<C, O>(
238        &self,
239        method: &'static str,
240        args: C,
241        kind: types::Kind,
242        nodes: Vec<signature::PublicKey>,
243    ) -> Response<O>
244    where
245        C: cbor::Encode,
246        O: cbor::Decode + Send + 'static,
247    {
248        let request = types::Request {
249            method: method.to_owned(),
250            args: cbor::to_value(args),
251        };
252
253        let retry_strategy = tokio_retry::strategy::ExponentialBackoff::from_millis(2)
256            .factor(25)
257            .max_delay(std::time::Duration::from_millis(250))
258            .take(MAX_TRANSPORT_ERROR_RETRIES);
259
260        let result = tokio_retry::Retry::spawn(retry_strategy, || {
261            self.execute_call(request.clone(), kind, nodes.clone())
262        })
263        .await;
264
265        let (request_id, inner) = match result {
266            Ok((request_id, response)) => match response.body {
267                types::Body::Success(value) => (
268                    Some(request_id),
269                    cbor::from_value(value).map_err(Into::into),
270                ),
271                types::Body::Error(error) => {
272                    (Some(request_id), Err(RpcClientError::CallFailed(error)))
273                }
274            },
275            Err(err) => (None, Err(err)),
276        };
277
278        Response {
279            transport: &*self.transport,
280            request_id,
281            inner,
282        }
283    }
284
285    async fn execute_call(
286        &self,
287        request: types::Request,
288        kind: types::Kind,
289        nodes: Vec<signature::PublicKey>,
290    ) -> Result<(u64, types::Response), RpcClientError> {
291        match kind {
292            types::Kind::NoiseSession => {
293                let session = self.connect(nodes).await?;
296                let mut session = session.lock_owned().await;
297
298                let result = self.secure_call_raw(request, &mut session).await;
300
301                if result.is_err() {
305                    let mut sessions = self.sessions.lock().await;
306                    sessions.remove(&session);
307                }
308
309                result
310            }
311            types::Kind::InsecureQuery => {
312                self.insecure_call_raw(request, nodes).await
314            }
315            _ => Err(RpcClientError::UnsupportedRpcKind),
316        }
317    }
318
319    async fn connect(
320        &self,
321        nodes: Vec<signature::PublicKey>,
322    ) -> Result<SharedSession<signature::PublicKey>, RpcClientError> {
323        let mut session = {
325            let mut sessions = self.sessions.lock().await;
326
327            if let Some(session) = sessions.find(&nodes) {
329                return Ok(session);
330            }
331
332            let peer_id = Default::default();
334            sessions.create_initiator(peer_id)
335        };
336
337        let session_id = *session.get_session_id();
339
340        let mut buffer1 = vec![];
342        let mut buffer2 = vec![];
343
344        session
346            .process_data(&[], &mut buffer1)
347            .await
348            .expect("initiation must always succeed");
349
350        let request_id = self.next_request_id();
351        let result: Result<_, RpcClientError> = async {
352            let rsp = self
354                .transport
355                .write_noise_session(request_id, session_id, buffer1, String::new(), nodes)
356                .await
357                .map_err(|_| RpcClientError::Transport)?;
358
359            session.set_peer_id(rsp.node);
363            session
364                .set_remote_node(rsp.node)
365                .expect("remote node should not be set");
366
367            let _ = session
370                .process_data(&rsp.data, &mut buffer2)
371                .await
372                .map_err(|_| RpcClientError::Transport)?;
373
374            Ok(rsp)
375        }
376        .await;
377
378        let feedback = match result {
381            Ok(_) => types::PeerFeedback::Success,
382            Err(_) => types::PeerFeedback::Failure,
383        };
384        let _ = self
385            .transport
386            .submit_peer_feedback(request_id, feedback)
387            .await; let rsp = result?;
391
392        let request_id = self.next_request_id();
393        let result = async {
394            let rsp = self
396                .transport
397                .write_noise_session(
398                    request_id,
399                    session_id,
400                    buffer2,
401                    String::new(),
402                    vec![rsp.node],
403                )
404                .await
405                .map_err(|_| RpcClientError::Transport)?;
406
407            if session.is_unauthenticated() {
408                return Err(RpcClientError::Transport);
409            }
410
411            Ok(rsp)
412        }
413        .await;
414
415        let feedback = match result {
418            Ok(_) => types::PeerFeedback::Success,
419            Err(_) => types::PeerFeedback::Failure,
420        };
421        let _ = self
422            .transport
423            .submit_peer_feedback(request_id, feedback)
424            .await; if let Err(err) = result {
428            let session = Arc::new(tokio::sync::Mutex::new(session))
430                .lock_owned()
431                .await;
432            let _ = self.close(session).await; return Err(err);
435        }
436
437        let now = insecure_posix_time();
441        let mut sessions = self.sessions.lock().await;
442        let maybe_removed_session = match sessions.remove_for(&rsp.node, now) {
443            Ok(maybe_removed_session) => maybe_removed_session,
444            Err(err) => {
445                drop(sessions); let session = Arc::new(tokio::sync::Mutex::new(session))
449                    .lock_owned()
450                    .await;
451                let _ = self.close(session).await; return Err(err.into());
454            }
455        };
456        let session = sessions
457            .add(session, now)
458            .expect("there should be space for the new session");
459
460        if let Some(removed_session) = maybe_removed_session {
461            drop(sessions); let _ = self.close(removed_session).await; }
466
467        Ok(session)
468    }
469
470    async fn secure_call_raw(
471        &self,
472        request: types::Request,
473        session: &mut OwnedMutexGuard<MultiplexedSession<signature::PublicKey>>,
474    ) -> Result<(u64, types::Response), RpcClientError> {
475        let method = request.method.clone();
476        let msg = types::Message::Request(request);
477        let session_id = *session.get_session_id();
478
479        let mut buffer = vec![];
481        session
482            .write_message(msg, &mut buffer)
483            .map_err(|_| RpcClientError::Transport)?;
484        let node = session.get_remote_node()?;
485
486        let request_id = self.next_request_id();
487        let result = async {
488            let rsp = self
490                .transport
491                .write_noise_session(request_id, session_id, buffer, method, vec![node])
492                .await
493                .map_err(|_| RpcClientError::Transport)?;
494
495            session.process_data(&rsp.data, vec![]).await
497        }
498        .await;
499
500        if result.is_err() {
503            let _ = self
504                .transport
505                .submit_peer_feedback(request_id, types::PeerFeedback::Failure)
506                .await; }
508
509        let maybe_msg = result?;
511
512        let msg = maybe_msg.expect("message must be decoded if there is no error");
514        let rsp = match msg {
515            types::Message::Response(rsp) => rsp,
516            msg => return Err(RpcClientError::ExpectedResponseMessage(msg)),
517        };
518
519        Ok((request_id, rsp))
520    }
521
522    async fn insecure_call_raw(
523        &self,
524        request: types::Request,
525        nodes: Vec<signature::PublicKey>,
526    ) -> Result<(u64, types::Response), RpcClientError> {
527        let request_id = self.next_request_id();
529        let result = self
530            .transport
531            .write_insecure_query(request_id, cbor::to_vec(request), nodes)
532            .await
533            .map_err(|_| RpcClientError::Transport);
534
535        if result.is_err() {
537            let _ = self
538                .transport
539                .submit_peer_feedback(request_id, types::PeerFeedback::Failure)
540                .await; }
542
543        let rsp = result?;
545
546        let rsp = cbor::from_slice(&rsp.data).map_err(RpcClientError::DecodeError)?;
548
549        Ok((request_id, rsp))
550    }
551
552    async fn close(
554        &self,
555        mut session: OwnedMutexGuard<MultiplexedSession<signature::PublicKey>>,
556    ) -> Result<(), RpcClientError> {
557        if !session.is_connected() && !session.is_unauthenticated() {
558            return Ok(());
559        }
560
561        let session_id = *session.get_session_id();
562        let node = session.get_remote_node()?;
563
564        let mut buffer = vec![];
566        session
567            .write_message(types::Message::Close, &mut buffer)
568            .map_err(|_| RpcClientError::Transport)?;
569
570        let request_id = self.next_request_id();
572        let rsp = self
573            .transport
574            .write_noise_session(request_id, session_id, buffer, String::new(), vec![node])
575            .await
576            .map_err(|_| RpcClientError::Transport)?;
577
578        let msg = session
583            .process_data(&rsp.data, vec![])
584            .await?
585            .expect("message must be decoded if there is no error");
586
587        session.close();
589
590        match msg {
591            types::Message::Close => Ok(()),
592            msg => Err(RpcClientError::ExpectedCloseMessage(msg)),
593        }
594    }
595
596    async fn close_all(&self, sessions: Vec<SharedSession<signature::PublicKey>>) {
598        let futures = FuturesUnordered::new();
599        for session in sessions {
600            let future = async {
601                let locked_session = session.lock_owned().await;
602                let _ = self.close(locked_session).await; };
604            futures.push(future);
605        }
606        futures.collect::<()>().await;
607    }
608
609    fn next_request_id(&self) -> u64 {
611        let next_request_id = self.next_request_id.fetch_add(1, Ordering::SeqCst); ((self.client_id as u64) << 32) + (next_request_id as u64)
613    }
614
615    fn random_client_id() -> u32 {
617        #[cfg(test)]
618        return 0;
619
620        #[cfg(not(test))]
621        OsRng.next_u32()
622    }
623}
624
625#[cfg(test)]
626mod test {
627    use std::sync::{
628        atomic::{AtomicBool, Ordering},
629        Arc, Mutex,
630    };
631
632    use anyhow::anyhow;
633    use async_trait::async_trait;
634
635    use crate::{
636        common::crypto::signature,
637        enclave_rpc::{demux::Demux, session, transport::EnclaveResponse, types},
638    };
639
640    use super::{super::transport::Transport, RpcClient};
641
642    #[derive(Clone)]
643    struct MockTransport {
644        demux: Arc<Demux>,
645        next_error: Arc<AtomicBool>,
646        peer_feedback_history: Arc<Mutex<Vec<(u64, types::PeerFeedback)>>>,
647    }
648
649    impl MockTransport {
650        fn new() -> Self {
651            Self {
652                demux: Arc::new(Demux::new(session::Builder::default(), 4, 4, 60)),
653                next_error: Arc::new(AtomicBool::new(false)),
654                peer_feedback_history: Arc::new(Mutex::new(Vec::new())),
655            }
656        }
657
658        fn reset(&self) {
659            self.demux.reset();
660        }
661
662        fn induce_transport_error(&self) {
663            self.next_error.store(true, Ordering::SeqCst);
664        }
665
666        fn take_peer_feedback_history(&self) -> Vec<(u64, types::PeerFeedback)> {
667            let mut pfh = self.peer_feedback_history.lock().unwrap();
668            pfh.drain(..).collect()
669        }
670    }
671
672    #[async_trait]
673    impl Transport for MockTransport {
674        async fn write_message_impl(
675            &self,
676            _request_id: u64,
677            request: Vec<u8>,
678            kind: types::Kind,
679            _nodes: Vec<signature::PublicKey>,
680        ) -> Result<EnclaveResponse, anyhow::Error> {
681            if self
683                .next_error
684                .compare_exchange(true, false, Ordering::SeqCst, Ordering::SeqCst)
685                .is_ok()
686            {
687                return Err(anyhow!("transport error"));
688            }
689
690            match kind {
691                types::Kind::NoiseSession => {
692                    let mut buffer = Vec::new();
694                    let (mut session, message) = self
695                        .demux
696                        .process_frame(vec![], request, &mut buffer)
697                        .await?;
698
699                    match message {
700                        Some(message) => {
701                            let mut buffer = Vec::new();
702
703                            match message {
705                                types::Message::Request(rq) => {
706                                    let response = types::Message::Response(types::Response {
708                                        body: types::Body::Success(rq.args),
709                                    });
710
711                                    session.write_message(response, &mut buffer)?;
712                                }
713                                types::Message::Close => {
714                                    self.demux.close(session, &mut buffer)?;
715                                }
716                                _ => panic!("unhandled message type"),
717                            };
718
719                            let rsp = EnclaveResponse {
720                                data: buffer,
721                                node: Default::default(),
722                            };
723                            Ok(rsp)
724                        }
725                        None => {
726                            let rsp = EnclaveResponse {
728                                data: buffer,
729                                node: Default::default(),
730                            };
731                            Ok(rsp)
732                        }
733                    }
734                }
735                types::Kind::InsecureQuery => {
736                    let rq: types::Request = cbor::from_slice(&request).unwrap();
738                    let body = types::Body::Success(rq.args);
739                    let response = types::Response { body };
740                    let rsp = EnclaveResponse {
741                        data: cbor::to_vec(response),
742                        node: Default::default(),
743                    };
744                    return Ok(rsp);
745                }
746                types::Kind::LocalQuery => {
747                    panic!("unhandled RPC kind")
748                }
749            }
750        }
751
752        async fn submit_peer_feedback(
753            &self,
754            request_id: u64,
755            peer_feedback: types::PeerFeedback,
756        ) -> Result<(), anyhow::Error> {
757            self.peer_feedback_history
758                .lock()
759                .unwrap()
760                .push((request_id, peer_feedback));
761
762            Ok(())
763        }
764    }
765
766    #[test]
767    fn test_rpc_client() {
768        let rt = tokio::runtime::Runtime::new().unwrap();
769        let _guard = rt.enter(); let transport = MockTransport::new();
771        let builder = session::Builder::default();
772        let client = RpcClient::new(Box::new(transport.clone()), builder, 8, 2, 60);
773
774        let result: u64 = rt
776            .block_on(async {
777                client
778                    .secure_call("test", 42, vec![])
779                    .await
780                    .into_result_with_feedback()
781                    .await
782            })
783            .unwrap();
784        assert_eq!(result, 42, "secure call should work");
785        assert_eq!(
786            transport.take_peer_feedback_history(),
787            vec![
788                (1, types::PeerFeedback::Success), (2, types::PeerFeedback::Success), (3, types::PeerFeedback::Success), ]
792        );
793
794        transport.reset();
796
797        let result: u64 = rt
798            .block_on(async {
799                client
800                    .secure_call("test", 43, vec![])
801                    .await
802                    .into_result_with_feedback()
803                    .await
804            })
805            .unwrap();
806        assert_eq!(result, 43, "secure call should work");
807        assert_eq!(
808            transport.take_peer_feedback_history(),
809            vec![
810                (4, types::PeerFeedback::Failure), (5, types::PeerFeedback::Success), (6, types::PeerFeedback::Success), (7, types::PeerFeedback::Success), ]
815        );
816
817        transport.induce_transport_error();
820
821        let result: u64 = rt
822            .block_on(async {
823                client
824                    .secure_call("test", 44, vec![])
825                    .await
826                    .into_result_with_feedback()
827                    .await
828            })
829            .unwrap();
830        assert_eq!(result, 44, "secure call should work");
831        assert_eq!(
832            transport.take_peer_feedback_history(),
833            vec![
834                (8, types::PeerFeedback::Failure), (9, types::PeerFeedback::Success), (10, types::PeerFeedback::Success), (11, types::PeerFeedback::Success), ]
840        );
841
842        let result: u64 = rt
844            .block_on(async {
845                client
846                    .insecure_call("test", 45, vec![])
847                    .await
848                    .into_result_with_feedback()
849                    .await
850            })
851            .unwrap();
852        assert_eq!(result, 45, "insecure call should work");
853        assert_eq!(
854            transport.take_peer_feedback_history(),
855            vec![
856                (12, types::PeerFeedback::Success), ]
858        );
859
860        transport.induce_transport_error();
862
863        let result: u64 = rt
864            .block_on(async {
865                client
866                    .insecure_call("test", 46, vec![])
867                    .await
868                    .into_result_with_feedback()
869                    .await
870            })
871            .unwrap();
872        assert_eq!(result, 46, "insecure call should work");
873        assert_eq!(
874            transport.take_peer_feedback_history(),
875            vec![
876                (13, types::PeerFeedback::Failure), (14, types::PeerFeedback::Success), ]
879        );
880    }
881}