1use std::{
3 collections::{BTreeMap, HashMap},
4 io::{BufReader, BufWriter, Read, Write},
5 sync::{
6 atomic::{AtomicUsize, Ordering},
7 Arc, Mutex,
8 },
9};
10
11use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
12use crossbeam::channel;
13use slog::{debug, error, info, warn, Logger};
14use thiserror::Error;
15use tokio::sync::oneshot;
16
17use crate::{
18 common::{logger::get_logger, namespace::Namespace, version::Version},
19 config::Config,
20 consensus::{tendermint, verifier::Verifier},
21 dispatcher::Dispatcher,
22 future::block_on,
23 identity::Identity,
24 storage::KeyValue,
25 types::{Body, Error, Message, MessageType, RuntimeInfoRequest, RuntimeInfoResponse},
26 TeeType, BUILD_INFO,
27};
28
29pub enum Stream {
31 #[cfg(not(target_env = "sgx"))]
32 Unix(std::os::unix::net::UnixStream),
33 Tcp(std::net::TcpStream),
34 #[cfg(feature = "tdx")]
35 Vsock(vsock::VsockStream),
36}
37
38impl Read for &Stream {
39 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
40 #[allow(clippy::borrow_deref_ref)]
41 match self {
42 #[cfg(not(target_env = "sgx"))]
43 Stream::Unix(stream) => (&*stream).read(buf),
44 Stream::Tcp(stream) => (&*stream).read(buf),
45 #[cfg(feature = "tdx")]
46 Stream::Vsock(stream) => (&*stream).read(buf),
47 }
48 }
49}
50
51impl Write for &Stream {
52 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
53 #[allow(clippy::borrow_deref_ref)]
54 match self {
55 #[cfg(not(target_env = "sgx"))]
56 Stream::Unix(stream) => (&*stream).write(buf),
57 Stream::Tcp(stream) => (&*stream).write(buf),
58 #[cfg(feature = "tdx")]
59 Stream::Vsock(stream) => (&*stream).write(buf),
60 }
61 }
62
63 fn flush(&mut self) -> std::io::Result<()> {
64 #[allow(clippy::borrow_deref_ref)]
65 match self {
66 #[cfg(not(target_env = "sgx"))]
67 Stream::Unix(stream) => (&*stream).flush(),
68 Stream::Tcp(stream) => (&*stream).flush(),
69 #[cfg(feature = "tdx")]
70 Stream::Vsock(stream) => (&*stream).flush(),
71 }
72 }
73}
74
75const MAX_MESSAGE_SIZE: usize = 16 * 1024 * 1024; #[derive(Error, Debug)]
79pub enum ProtocolError {
80 #[error("message too large")]
81 MessageTooLarge,
82 #[error("method not supported")]
83 MethodNotSupported,
84 #[error("invalid response")]
85 InvalidResponse,
86 #[error("attestation required")]
87 #[allow(unused)]
88 AttestationRequired,
89 #[error("host environment information not configured")]
90 HostInfoNotConfigured,
91 #[error("incompatible consensus backend")]
92 IncompatibleConsensusBackend,
93 #[error("invalid runtime id (expected: {0} got: {1})")]
94 InvalidRuntimeId(Namespace, Namespace),
95 #[error("already initialized")]
96 AlreadyInitialized,
97 #[error("channel closed")]
98 ChannelClosed,
99}
100
101impl From<ProtocolError> for Error {
102 fn from(err: ProtocolError) -> Self {
103 Self {
104 module: "protocol".to_string(),
105 code: 1,
106 message: err.to_string(),
107 }
108 }
109}
110
111#[derive(Debug, Clone)]
113pub struct HostInfo {
114 pub runtime_id: Namespace,
116 pub consensus_backend: String,
118 pub consensus_protocol_version: Version,
120 pub consensus_chain_context: String,
122 pub local_config: BTreeMap<String, cbor::Value>,
127}
128
129pub struct Protocol {
131 logger: Logger,
133 #[cfg_attr(
135 not(any(target_env = "sgx", feature = "debug-mock-sgx")),
136 allow(unused)
137 )]
138 identity: Arc<Identity>,
139 dispatcher: Arc<Dispatcher>,
141 outgoing_tx: channel::Sender<Message>,
143 outgoing_rx: channel::Receiver<Message>,
145 stream: Stream,
147 last_request_id: AtomicUsize,
149 pending_out_requests: Mutex<HashMap<u64, oneshot::Sender<Body>>>,
151 config: Config,
153 host_info: Mutex<Option<HostInfo>>,
155 tokio_runtime: tokio::runtime::Handle,
157}
158
159impl Protocol {
160 pub(crate) fn new(
162 tokio_runtime: tokio::runtime::Handle,
163 stream: Stream,
164 identity: Arc<Identity>,
165 dispatcher: Arc<Dispatcher>,
166 config: Config,
167 ) -> Self {
168 let logger = get_logger("runtime/protocol");
169
170 let (outgoing_tx, outgoing_rx) = channel::unbounded();
171
172 Self {
173 logger,
174 identity,
175 dispatcher,
176 outgoing_tx,
177 outgoing_rx,
178 stream,
179 last_request_id: AtomicUsize::new(0),
180 pending_out_requests: Mutex::new(HashMap::new()),
181 config,
182 host_info: Mutex::new(None),
183 tokio_runtime,
184 }
185 }
186
187 pub fn get_config(&self) -> &Config {
189 &self.config
190 }
191
192 pub fn get_identity(&self) -> Option<&Arc<Identity>> {
194 self.identity.quote()?;
195 Some(&self.identity)
196 }
197
198 pub fn get_runtime_id(&self) -> Namespace {
204 self.host_info
205 .lock()
206 .unwrap()
207 .as_ref()
208 .expect("host environment information should be set")
209 .runtime_id
210 }
211
212 pub fn get_host_info(&self) -> HostInfo {
218 self.host_info
219 .lock()
220 .unwrap()
221 .as_ref()
222 .expect("host environment information should be set")
223 .clone()
224 }
225
226 pub(crate) fn start(self: &Arc<Protocol>) {
228 let protocol = self.clone();
230 std::thread::spawn(move || protocol.io_write());
231
232 self.io_read();
234 }
235
236 fn io_read(self: &Arc<Protocol>) {
237 info!(self.logger, "Starting protocol reader thread");
238 let mut reader = BufReader::new(&self.stream);
239
240 loop {
241 if let Err(error) = self.handle_message(&mut reader) {
242 error!(self.logger, "Failed to handle message"; "err" => %error);
243 break;
244 }
245 }
246
247 info!(self.logger, "Protocol reader thread is terminating");
248 }
249
250 fn io_write(self: &Arc<Protocol>) {
251 info!(self.logger, "Starting protocol writer thread");
252
253 while let Ok(message) = self.outgoing_rx.recv() {
254 if let Err(error) = self.write_message(message) {
255 warn!(self.logger, "Failed to write message"; "err" => %error);
256 }
257 }
258
259 info!(self.logger, "Protocol writer thread is terminating");
260 }
261
262 pub fn call_host(&self, body: Body) -> Result<Body, Error> {
270 block_on(self.call_host_async(body))
271 }
272
273 pub async fn call_host_async(&self, body: Body) -> Result<Body, Error> {
275 let id = self.last_request_id.fetch_add(1, Ordering::SeqCst) as u64;
276 let message = Message {
277 id,
278 body,
279 message_type: MessageType::Request,
280 };
281
282 let (tx, rx) = oneshot::channel();
284 {
285 let mut pending_requests = self.pending_out_requests.lock().unwrap();
286 pending_requests.insert(id, tx);
287 }
288
289 self.send_message(message).map_err(Error::from)?;
291
292 let result = rx
293 .await
294 .map_err(|_| Error::from(ProtocolError::ChannelClosed))?;
295 match result {
296 Body::Error(err) => Err(err),
297 body => Ok(body),
298 }
299 }
300
301 pub fn send_response(&self, id: u64, body: Body) -> anyhow::Result<()> {
303 self.send_message(Message {
304 id,
305 body,
306 message_type: MessageType::Response,
307 })
308 }
309
310 fn send_message(&self, message: Message) -> anyhow::Result<()> {
311 self.outgoing_tx.send(message).map_err(|err| err.into())
312 }
313
314 fn decode_message<R: Read>(&self, mut reader: R) -> anyhow::Result<Message> {
315 let length = reader.read_u32::<BigEndian>()? as usize;
316 if length > MAX_MESSAGE_SIZE {
317 return Err(ProtocolError::MessageTooLarge.into());
318 }
319
320 let mut buffer = vec![0; length];
322 reader.read_exact(&mut buffer)?;
323
324 let message = cbor::from_slice(&buffer)
325 .map_err(|error| {
326 warn!(self.logger, "Failed to decode message"; "err" => %error);
327 debug!(self.logger, "Malformed message"; "bytes" => ?buffer);
328 error
329 })
330 .unwrap_or_default();
331
332 Ok(message)
333 }
334
335 fn write_message(&self, message: Message) -> anyhow::Result<()> {
336 let buffer = cbor::to_vec(message);
337 if buffer.len() > MAX_MESSAGE_SIZE {
338 return Err(ProtocolError::MessageTooLarge.into());
339 }
340
341 let mut writer = BufWriter::new(&self.stream);
342 writer.write_u32::<BigEndian>(buffer.len() as u32)?;
343 writer.write_all(&buffer)?;
344
345 Ok(())
346 }
347
348 fn handle_message<R: Read>(self: &Arc<Protocol>, reader: R) -> anyhow::Result<()> {
349 let message = self.decode_message(reader)?;
350
351 match message.message_type {
352 MessageType::Request => {
353 let id = message.id;
355
356 let body = match self.handle_request(id, message.body) {
357 Ok(Some(result)) => result,
358 Ok(None) => {
359 return Ok(());
362 }
363 Err(error) => Body::Error(Error::new("rhp/dispatcher", 1, &format!("{error}"))),
364 };
365
366 self.send_message(Message {
368 id,
369 message_type: MessageType::Response,
370 body,
371 })?;
372 }
373 MessageType::Response => {
374 let response_sender = {
376 let mut pending_requests = self.pending_out_requests.lock().unwrap();
377 pending_requests.remove(&message.id)
378 };
379
380 match response_sender {
381 Some(response_sender) => {
382 if response_sender.send(message.body).is_err() {
383 warn!(self.logger, "Unable to deliver response to local handler");
384 }
385 }
386 None => {
387 warn!(self.logger, "Received response message for unknown request"; "msg_id" => message.id);
388 }
389 }
390 }
391 _ => warn!(self.logger, "Received a malformed message"),
392 }
393
394 Ok(())
395 }
396
397 fn handle_request(
398 self: &Arc<Protocol>,
399 id: u64,
400 request: Body,
401 ) -> anyhow::Result<Option<Body>> {
402 match request {
403 Body::RuntimeInfoRequest(request) => Ok(Some(Body::RuntimeInfoResponse(
405 self.initialize_guest(request)?,
406 ))),
407 Body::RuntimePingRequest {} => Ok(Some(Body::Empty {})),
408 Body::RuntimeShutdownRequest {} => {
409 info!(self.logger, "Received worker shutdown request");
410 Err(ProtocolError::MethodNotSupported.into())
411 }
412 Body::RuntimeAbortRequest {} => {
413 info!(self.logger, "Received worker abort request");
414 Err(ProtocolError::MethodNotSupported.into())
415 }
416
417 Body::RuntimeCapabilityTEERakInitRequest { .. }
419 | Body::RuntimeCapabilityTEERakReportRequest {}
420 | Body::RuntimeCapabilityTEERakAvrRequest { .. }
421 | Body::RuntimeCapabilityTEERakQuoteRequest { .. }
422 | Body::RuntimeCapabilityTEEUpdateEndorsementRequest { .. } => {
423 self.dispatcher.queue_request(id, request)?;
424 Ok(None)
425 }
426
427 Body::RuntimeRPCCallRequest { .. }
429 | Body::RuntimeLocalRPCCallRequest { .. }
430 | Body::RuntimeCheckTxBatchRequest { .. }
431 | Body::RuntimeExecuteTxBatchRequest { .. }
432 | Body::RuntimeNotifyRequest { .. }
433 | Body::RuntimeKeyManagerStatusUpdateRequest { .. }
434 | Body::RuntimeKeyManagerQuotePolicyUpdateRequest { .. }
435 | Body::RuntimeQueryRequest { .. }
436 | Body::RuntimeConsensusSyncRequest { .. } => {
437 self.ensure_initialized()?;
438 self.dispatcher.queue_request(id, request)?;
439 Ok(None)
440 }
441
442 _ => {
443 warn!(self.logger, "Received unsupported request"; "req" => format!("{request:?}"));
444 Err(ProtocolError::MethodNotSupported.into())
445 }
446 }
447 }
448
449 fn initialize_guest(
450 self: &Arc<Protocol>,
451 host_info: RuntimeInfoRequest,
452 ) -> anyhow::Result<RuntimeInfoResponse> {
453 info!(self.logger, "Received host environment information";
454 "runtime_id" => ?host_info.runtime_id,
455 "consensus_backend" => &host_info.consensus_backend,
456 "consensus_protocol_version" => ?host_info.consensus_protocol_version,
457 "consensus_chain_context" => &host_info.consensus_chain_context,
458 "local_config" => ?host_info.local_config,
459 );
460
461 if tendermint::BACKEND_NAME != host_info.consensus_backend {
462 return Err(ProtocolError::IncompatibleConsensusBackend.into());
463 }
464 let mut local_host_info = self.host_info.lock().unwrap();
465 if local_host_info.is_some() {
466 return Err(ProtocolError::AlreadyInitialized.into());
467 }
468
469 let consensus_verifier: Box<dyn Verifier> =
471 if let Some(ref trust_root) = self.config.trust_root {
472 if host_info.runtime_id != trust_root.runtime_id {
474 return Err(ProtocolError::InvalidRuntimeId(
475 trust_root.runtime_id,
476 host_info.runtime_id,
477 )
478 .into());
479 }
480
481 let verifier = tendermint::verifier::Verifier::new(
483 self.clone(),
484 self.tokio_runtime.clone(),
485 trust_root.clone(),
486 host_info.runtime_id,
487 host_info.consensus_chain_context.clone(),
488 );
489 let handle = verifier.handle();
490 verifier.start();
491
492 Box::new(handle)
493 } else {
494 let verifier = tendermint::verifier::NopVerifier::new(self.clone());
496 verifier.start();
497
498 Box::new(verifier)
499 };
500
501 *local_host_info = Some(HostInfo {
503 runtime_id: host_info.runtime_id,
504 consensus_backend: host_info.consensus_backend,
505 consensus_protocol_version: host_info.consensus_protocol_version,
506 consensus_chain_context: host_info.consensus_chain_context,
507 local_config: host_info.local_config,
508 });
509
510 self.dispatcher.start(self.clone(), consensus_verifier);
512
513 Ok(RuntimeInfoResponse {
514 protocol_version: BUILD_INFO.protocol_version,
515 runtime_version: self.config.version,
516 features: self.config.features.clone(),
517 })
518 }
519
520 pub fn ensure_initialized(&self) -> anyhow::Result<()> {
522 self.host_info
523 .lock()
524 .unwrap()
525 .as_ref()
526 .ok_or(ProtocolError::HostInfoNotConfigured)?;
527
528 match BUILD_INFO.tee_type {
529 TeeType::Sgx | TeeType::Tdx => {
530 self.identity
531 .quote()
532 .ok_or(ProtocolError::AttestationRequired)?;
533 }
534 TeeType::None => {}
535 }
536
537 Ok(())
538 }
539}
540
541pub struct ProtocolUntrustedLocalStorage {
549 protocol: Arc<Protocol>,
550}
551
552impl ProtocolUntrustedLocalStorage {
553 pub fn new(protocol: Arc<Protocol>) -> Self {
554 Self { protocol }
555 }
556}
557
558impl KeyValue for ProtocolUntrustedLocalStorage {
559 fn get(&self, key: Vec<u8>) -> Result<Vec<u8>, Error> {
560 match self
561 .protocol
562 .call_host(Body::HostLocalStorageGetRequest { key })?
563 {
564 Body::HostLocalStorageGetResponse { value } => Ok(value),
565 _ => Err(ProtocolError::InvalidResponse.into()),
566 }
567 }
568
569 fn insert(&self, key: Vec<u8>, value: Vec<u8>) -> Result<(), Error> {
570 match self
571 .protocol
572 .call_host(Body::HostLocalStorageSetRequest { key, value })?
573 {
574 Body::HostLocalStorageSetResponse {} => Ok(()),
575 _ => Err(ProtocolError::InvalidResponse.into()),
576 }
577 }
578}