use std::collections::HashMap;
use anyhow::{bail, Result};
use thiserror::Error;
use crate::{common::sgx::QuotePolicy, consensus::state::keymanager::Status as KeyManagerStatus};
use super::{
context::Context,
types::{Body, Kind, Request, Response},
};
#[derive(Error, Debug)]
enum DispatchError {
#[error("method not found: {method:?}")]
MethodNotFound { method: String },
#[error("invalid RPC kind: {method:?} ({kind:?})")]
InvalidRpcKind { method: String, kind: Kind },
}
pub trait Handler {
fn methods(&'static self) -> Vec<Method>;
}
#[derive(Clone, Debug)]
pub struct MethodDescriptor {
pub name: String,
pub kind: Kind,
}
pub trait MethodHandler<Rq, Rsp> {
fn handle(&self, ctx: &Context, request: &Rq) -> Result<Rsp>;
}
impl<Rq, Rsp, F> MethodHandler<Rq, Rsp> for F
where
Rq: 'static,
Rsp: 'static,
F: Fn(&Context, &Rq) -> Result<Rsp> + 'static,
{
fn handle(&self, ctx: &Context, request: &Rq) -> Result<Rsp> {
(*self)(ctx, request)
}
}
pub trait MethodHandlerDispatch {
fn get_descriptor(&self) -> &MethodDescriptor;
fn dispatch(&self, ctx: &Context, request: Request) -> Result<Response>;
}
struct MethodHandlerDispatchImpl<Rq, Rsp> {
descriptor: MethodDescriptor,
handler: Box<dyn MethodHandler<Rq, Rsp> + Send + Sync>,
}
impl<Rq, Rsp> MethodHandlerDispatch for MethodHandlerDispatchImpl<Rq, Rsp>
where
Rq: cbor::Decode + 'static,
Rsp: cbor::Encode + 'static,
{
fn get_descriptor(&self) -> &MethodDescriptor {
&self.descriptor
}
fn dispatch(&self, ctx: &Context, request: Request) -> Result<Response> {
let request = cbor::from_value(request.args)?;
let response = self.handler.handle(ctx, &request)?;
Ok(Response {
body: Body::Success(cbor::to_value(response)),
})
}
}
pub struct Method {
dispatcher: Box<dyn MethodHandlerDispatch + Send + Sync>,
}
impl Method {
pub fn new<Rq, Rsp, Handler>(method: MethodDescriptor, handler: Handler) -> Self
where
Rq: cbor::Decode + 'static,
Rsp: cbor::Encode + 'static,
Handler: MethodHandler<Rq, Rsp> + Send + Sync + 'static,
{
Method {
dispatcher: Box::new(MethodHandlerDispatchImpl {
descriptor: method,
handler: Box::new(handler),
}),
}
}
fn get_name(&self) -> &String {
&self.dispatcher.get_descriptor().name
}
fn get_kind(&self) -> Kind {
self.dispatcher.get_descriptor().kind
}
fn dispatch(&self, ctx: &mut Context, request: Request) -> Result<Response> {
self.dispatcher.dispatch(ctx, request)
}
}
pub type KeyManagerStatusHandler = dyn Fn(KeyManagerStatus) + Send + Sync;
pub type KeyManagerQuotePolicyHandler = dyn Fn(QuotePolicy) + Send + Sync;
#[derive(Default)]
pub struct Dispatcher {
methods: HashMap<String, Method>,
km_status_handler: Option<Box<KeyManagerStatusHandler>>,
km_quote_policy_handler: Option<Box<KeyManagerQuotePolicyHandler>>,
}
impl Dispatcher {
pub fn add_method(&mut self, method: Method) {
self.methods.insert(method.get_name().clone(), method);
}
pub fn add_methods(&mut self, methods: Vec<Method>) {
for method in methods {
self.add_method(method);
}
}
pub fn dispatch(&self, mut ctx: Context, request: Request, kind: Kind) -> Response {
match self.dispatch_fallible(&mut ctx, request, kind) {
Ok(response) => response,
Err(error) => Response {
body: Body::Error(format!("{error}")),
},
}
}
fn dispatch_fallible(
&self,
ctx: &mut Context,
request: Request,
kind: Kind,
) -> Result<Response> {
let method = match self.methods.get(&request.method) {
Some(method) => method,
None => bail!(DispatchError::MethodNotFound {
method: request.method,
}),
};
if method.get_kind() != kind {
bail!(DispatchError::InvalidRpcKind {
method: request.method,
kind,
});
};
method.dispatch(ctx, request)
}
pub fn handle_km_status_update(&self, status: KeyManagerStatus) {
if let Some(handler) = self.km_status_handler.as_ref() {
handler(status)
}
}
pub fn handle_km_quote_policy_update(&self, policy: QuotePolicy) {
if let Some(handler) = self.km_quote_policy_handler.as_ref() {
handler(policy)
}
}
pub fn set_keymanager_status_update_handler(
&mut self,
f: Option<Box<KeyManagerStatusHandler>>,
) {
self.km_status_handler = f;
}
pub fn set_keymanager_quote_policy_update_handler(
&mut self,
f: Option<Box<KeyManagerQuotePolicyHandler>>,
) {
self.km_quote_policy_handler = f;
}
}