1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
//! RPC dispatcher.
use std::collections::HashMap;

use anyhow::{bail, Result};
use thiserror::Error;

use crate::{common::sgx::QuotePolicy, consensus::state::keymanager::Status as KeyManagerStatus};

use super::{
    context::Context,
    types::{Body, Kind, Request, Response},
};

/// Dispatch error.
#[derive(Error, Debug)]
enum DispatchError {
    #[error("method not found: {method:?}")]
    MethodNotFound { method: String },
    #[error("invalid RPC kind: {method:?} ({kind:?})")]
    InvalidRpcKind { method: String, kind: Kind },
}

/// RPC handler.
pub trait Handler {
    /// Returns the list of RPC methods supported by this handler.
    fn methods(&'static self) -> Vec<Method>;
}

/// Descriptor of a RPC API method.
#[derive(Clone, Debug)]
pub struct MethodDescriptor {
    /// Method name.
    pub name: String,
    /// Specifies which kind of RPC is allowed to call the method.
    pub kind: Kind,
}

/// Handler for a RPC method.
pub trait MethodHandler<Rq, Rsp> {
    /// Invoke the method implementation and return a response.
    fn handle(&self, ctx: &Context, request: &Rq) -> Result<Rsp>;
}

impl<Rq, Rsp, F> MethodHandler<Rq, Rsp> for F
where
    Rq: 'static,
    Rsp: 'static,
    F: Fn(&Context, &Rq) -> Result<Rsp> + 'static,
{
    fn handle(&self, ctx: &Context, request: &Rq) -> Result<Rsp> {
        (*self)(ctx, request)
    }
}

/// Dispatcher for a RPC method.
pub trait MethodHandlerDispatch {
    /// Get method descriptor.
    fn get_descriptor(&self) -> &MethodDescriptor;

    /// Dispatch request.
    fn dispatch(&self, ctx: &Context, request: Request) -> Result<Response>;
}

struct MethodHandlerDispatchImpl<Rq, Rsp> {
    /// Method descriptor.
    descriptor: MethodDescriptor,
    /// Method handler.
    handler: Box<dyn MethodHandler<Rq, Rsp> + Send + Sync>,
}

impl<Rq, Rsp> MethodHandlerDispatch for MethodHandlerDispatchImpl<Rq, Rsp>
where
    Rq: cbor::Decode + 'static,
    Rsp: cbor::Encode + 'static,
{
    fn get_descriptor(&self) -> &MethodDescriptor {
        &self.descriptor
    }

    fn dispatch(&self, ctx: &Context, request: Request) -> Result<Response> {
        let request = cbor::from_value(request.args)?;
        let response = self.handler.handle(ctx, &request)?;

        Ok(Response {
            body: Body::Success(cbor::to_value(response)),
        })
    }
}

/// RPC method dispatcher implementation.
pub struct Method {
    /// Method dispatcher.
    dispatcher: Box<dyn MethodHandlerDispatch + Send + Sync>,
}

impl Method {
    /// Create a new enclave method descriptor.
    pub fn new<Rq, Rsp, Handler>(method: MethodDescriptor, handler: Handler) -> Self
    where
        Rq: cbor::Decode + 'static,
        Rsp: cbor::Encode + 'static,
        Handler: MethodHandler<Rq, Rsp> + Send + Sync + 'static,
    {
        Method {
            dispatcher: Box::new(MethodHandlerDispatchImpl {
                descriptor: method,
                handler: Box::new(handler),
            }),
        }
    }

    /// Return method name.
    fn get_name(&self) -> &String {
        &self.dispatcher.get_descriptor().name
    }

    /// Return RPC call kind.
    fn get_kind(&self) -> Kind {
        self.dispatcher.get_descriptor().kind
    }

    /// Dispatch a request.
    fn dispatch(&self, ctx: &mut Context, request: Request) -> Result<Response> {
        self.dispatcher.dispatch(ctx, request)
    }
}

/// Key manager status update handler callback.
pub type KeyManagerStatusHandler = dyn Fn(KeyManagerStatus) + Send + Sync;
/// Key manager quote policy update handler callback.
pub type KeyManagerQuotePolicyHandler = dyn Fn(QuotePolicy) + Send + Sync;

/// RPC call dispatcher.
#[derive(Default)]
pub struct Dispatcher {
    /// Registered RPC methods.
    methods: HashMap<String, Method>,
    /// Registered key manager status handler.
    km_status_handler: Option<Box<KeyManagerStatusHandler>>,
    /// Registered key manager quote policy handler.
    km_quote_policy_handler: Option<Box<KeyManagerQuotePolicyHandler>>,
}

impl Dispatcher {
    /// Register a new method in the dispatcher.
    pub fn add_method(&mut self, method: Method) {
        self.methods.insert(method.get_name().clone(), method);
    }

    /// Register new methods in the dispatcher.
    pub fn add_methods(&mut self, methods: Vec<Method>) {
        for method in methods {
            self.add_method(method);
        }
    }

    /// Dispatch request.
    pub fn dispatch(&self, mut ctx: Context, request: Request, kind: Kind) -> Response {
        match self.dispatch_fallible(&mut ctx, request, kind) {
            Ok(response) => response,
            Err(error) => Response {
                body: Body::Error(format!("{error}")),
            },
        }
    }

    fn dispatch_fallible(
        &self,
        ctx: &mut Context,
        request: Request,
        kind: Kind,
    ) -> Result<Response> {
        let method = match self.methods.get(&request.method) {
            Some(method) => method,
            None => bail!(DispatchError::MethodNotFound {
                method: request.method,
            }),
        };

        if method.get_kind() != kind {
            bail!(DispatchError::InvalidRpcKind {
                method: request.method,
                kind,
            });
        };

        method.dispatch(ctx, request)
    }

    /// Handle key manager status update.
    pub fn handle_km_status_update(&self, status: KeyManagerStatus) {
        if let Some(handler) = self.km_status_handler.as_ref() {
            handler(status)
        }
    }

    /// Handle key manager quote policy update.
    pub fn handle_km_quote_policy_update(&self, policy: QuotePolicy) {
        if let Some(handler) = self.km_quote_policy_handler.as_ref() {
            handler(policy)
        }
    }

    /// Update key manager status update handler.
    pub fn set_keymanager_status_update_handler(
        &mut self,
        f: Option<Box<KeyManagerStatusHandler>>,
    ) {
        self.km_status_handler = f;
    }

    /// Update key manager quote policy update handler.
    pub fn set_keymanager_quote_policy_update_handler(
        &mut self,
        f: Option<Box<KeyManagerQuotePolicyHandler>>,
    ) {
        self.km_quote_policy_handler = f;
    }
}