1use std::{
3 collections::{BTreeMap, HashMap},
4 io::{BufReader, BufWriter, Read, Write},
5 sync::{
6 atomic::{AtomicUsize, Ordering},
7 Arc, Mutex,
8 },
9};
10
11use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
12use crossbeam::channel;
13use slog::{debug, error, info, warn, Logger};
14use thiserror::Error;
15use tokio::sync::oneshot;
16
17use crate::{
18 common::{logger::get_logger, namespace::Namespace, version::Version},
19 config::Config,
20 consensus::{tendermint, verifier::Verifier},
21 dispatcher::Dispatcher,
22 future::block_on,
23 identity::Identity,
24 storage::KeyValue,
25 types::{Body, Error, Message, MessageType, RuntimeInfoRequest, RuntimeInfoResponse},
26 TeeType, BUILD_INFO,
27};
28
29pub enum Stream {
31 #[cfg(not(target_env = "sgx"))]
32 Unix(std::os::unix::net::UnixStream),
33 Tcp(std::net::TcpStream),
34 #[cfg(feature = "tdx")]
35 Vsock(vsock::VsockStream),
36}
37
38impl Read for &Stream {
39 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
40 #[allow(clippy::borrow_deref_ref)]
41 match self {
42 #[cfg(not(target_env = "sgx"))]
43 Stream::Unix(stream) => (&*stream).read(buf),
44 Stream::Tcp(stream) => (&*stream).read(buf),
45 #[cfg(feature = "tdx")]
46 Stream::Vsock(stream) => (&*stream).read(buf),
47 }
48 }
49}
50
51impl Write for &Stream {
52 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
53 #[allow(clippy::borrow_deref_ref)]
54 match self {
55 #[cfg(not(target_env = "sgx"))]
56 Stream::Unix(stream) => (&*stream).write(buf),
57 Stream::Tcp(stream) => (&*stream).write(buf),
58 #[cfg(feature = "tdx")]
59 Stream::Vsock(stream) => (&*stream).write(buf),
60 }
61 }
62
63 fn flush(&mut self) -> std::io::Result<()> {
64 #[allow(clippy::borrow_deref_ref)]
65 match self {
66 #[cfg(not(target_env = "sgx"))]
67 Stream::Unix(stream) => (&*stream).flush(),
68 Stream::Tcp(stream) => (&*stream).flush(),
69 #[cfg(feature = "tdx")]
70 Stream::Vsock(stream) => (&*stream).flush(),
71 }
72 }
73}
74
75const MAX_MESSAGE_SIZE: usize = 32 * 1024 * 1024; #[derive(Error, Debug)]
79pub enum ProtocolError {
80 #[error("message too large")]
81 MessageTooLarge,
82 #[error("method not supported")]
83 MethodNotSupported,
84 #[error("invalid response")]
85 InvalidResponse,
86 #[error("attestation required")]
87 #[allow(unused)]
88 AttestationRequired,
89 #[error("host environment information not configured")]
90 HostInfoNotConfigured,
91 #[error("incompatible consensus backend")]
92 IncompatibleConsensusBackend,
93 #[error("invalid runtime id (expected: {0} got: {1})")]
94 InvalidRuntimeId(Namespace, Namespace),
95 #[error("already initialized")]
96 AlreadyInitialized,
97 #[error("channel closed")]
98 ChannelClosed,
99}
100
101impl From<ProtocolError> for Error {
102 fn from(err: ProtocolError) -> Self {
103 Self {
104 module: "protocol".to_string(),
105 code: 1,
106 message: err.to_string(),
107 }
108 }
109}
110
111#[derive(Debug, Clone)]
113pub struct HostInfo {
114 pub runtime_id: Namespace,
116 pub consensus_backend: String,
118 pub consensus_protocol_version: Version,
120 pub consensus_chain_context: String,
122 pub local_config: BTreeMap<String, cbor::Value>,
127}
128
129pub struct Protocol {
131 logger: Logger,
133 #[cfg_attr(
135 not(any(target_env = "sgx", feature = "debug-mock-sgx")),
136 allow(unused)
137 )]
138 identity: Arc<Identity>,
139 dispatcher: Arc<Dispatcher>,
141 outgoing_tx: channel::Sender<Message>,
143 outgoing_rx: channel::Receiver<Message>,
145 stream: Stream,
147 last_request_id: AtomicUsize,
149 pending_out_requests: Mutex<HashMap<u64, oneshot::Sender<Body>>>,
151 config: Config,
153 host_info: Mutex<Option<HostInfo>>,
155 tokio_runtime: tokio::runtime::Handle,
157}
158
159impl Protocol {
160 pub(crate) fn new(
162 tokio_runtime: tokio::runtime::Handle,
163 stream: Stream,
164 identity: Arc<Identity>,
165 dispatcher: Arc<Dispatcher>,
166 config: Config,
167 ) -> Self {
168 let logger = get_logger("runtime/protocol");
169
170 let (outgoing_tx, outgoing_rx) = channel::unbounded();
171
172 Self {
173 logger,
174 identity,
175 dispatcher,
176 outgoing_tx,
177 outgoing_rx,
178 stream,
179 last_request_id: AtomicUsize::new(0),
180 pending_out_requests: Mutex::new(HashMap::new()),
181 config,
182 host_info: Mutex::new(None),
183 tokio_runtime,
184 }
185 }
186
187 pub fn get_config(&self) -> &Config {
189 &self.config
190 }
191
192 pub fn get_identity(&self) -> Option<&Arc<Identity>> {
194 self.identity.quote()?;
195 Some(&self.identity)
196 }
197
198 pub fn get_runtime_id(&self) -> Namespace {
204 self.host_info
205 .lock()
206 .unwrap()
207 .as_ref()
208 .expect("host environment information should be set")
209 .runtime_id
210 }
211
212 pub fn get_host_info(&self) -> HostInfo {
218 self.host_info
219 .lock()
220 .unwrap()
221 .as_ref()
222 .expect("host environment information should be set")
223 .clone()
224 }
225
226 pub(crate) fn start(self: &Arc<Protocol>) {
228 let protocol = self.clone();
230 std::thread::spawn(move || protocol.io_write());
231
232 self.io_read();
234 }
235
236 fn io_read(self: &Arc<Protocol>) {
237 info!(self.logger, "Starting protocol reader thread");
238 let mut reader = BufReader::new(&self.stream);
239
240 loop {
241 if let Err(error) = self.handle_message(&mut reader) {
242 error!(self.logger, "Failed to handle message"; "err" => %error);
243 break;
244 }
245 }
246
247 info!(self.logger, "Protocol reader thread is terminating");
248 }
249
250 fn io_write(self: &Arc<Protocol>) {
251 info!(self.logger, "Starting protocol writer thread");
252
253 while let Ok(
254 message @ Message {
255 id, message_type, ..
256 },
257 ) = self.outgoing_rx.recv()
258 {
259 if let Err(error) = self.write_message(message) {
260 warn!(self.logger, "Failed to write message"; "err" => %error);
261 self.handle_write_failure(id, message_type, error);
262 }
263 }
264
265 info!(self.logger, "Protocol writer thread is terminating");
266 }
267
268 pub fn call_host(&self, body: Body) -> Result<Body, Error> {
276 block_on(self.call_host_async(body))
277 }
278
279 pub async fn call_host_async(&self, body: Body) -> Result<Body, Error> {
281 let id = self.last_request_id.fetch_add(1, Ordering::SeqCst) as u64;
282 let message = Message {
283 id,
284 body,
285 message_type: MessageType::Request,
286 };
287
288 let (tx, rx) = oneshot::channel();
290 {
291 let mut pending_requests = self.pending_out_requests.lock().unwrap();
292 pending_requests.insert(id, tx);
293 }
294
295 self.send_message(message).map_err(Error::from)?;
297
298 let result = rx
299 .await
300 .map_err(|_| Error::from(ProtocolError::ChannelClosed))?;
301 match result {
302 Body::Error(err) => Err(err),
303 body => Ok(body),
304 }
305 }
306
307 pub fn send_response(&self, id: u64, body: Body) -> anyhow::Result<()> {
309 self.send_message(Message {
310 id,
311 body,
312 message_type: MessageType::Response,
313 })
314 }
315
316 fn send_message(&self, message: Message) -> anyhow::Result<()> {
317 self.outgoing_tx.send(message).map_err(|err| err.into())
318 }
319
320 fn decode_message<R: Read>(&self, mut reader: R) -> anyhow::Result<Message> {
321 let length = reader.read_u32::<BigEndian>()? as usize;
322 if length > MAX_MESSAGE_SIZE {
323 return Err(ProtocolError::MessageTooLarge.into());
324 }
325
326 let mut buffer = vec![0; length];
328 reader.read_exact(&mut buffer)?;
329
330 let message = cbor::from_slice(&buffer)
331 .map_err(|error| {
332 warn!(self.logger, "Failed to decode message"; "err" => %error);
333 debug!(self.logger, "Malformed message"; "bytes" => ?buffer);
334 error
335 })
336 .unwrap_or_default();
337
338 Ok(message)
339 }
340
341 fn write_message(&self, message: Message) -> anyhow::Result<()> {
342 let buffer = cbor::to_vec(message);
343 if buffer.len() > MAX_MESSAGE_SIZE {
344 return Err(ProtocolError::MessageTooLarge.into());
345 }
346
347 let mut writer = BufWriter::new(&self.stream);
348 writer.write_u32::<BigEndian>(buffer.len() as u32)?;
349 writer.write_all(&buffer)?;
350
351 Ok(())
352 }
353
354 fn handle_write_failure(
355 &self,
356 message_id: u64,
357 message_type: MessageType,
358 error: anyhow::Error,
359 ) {
360 match message_type {
361 MessageType::Request => {
362 let response_sender = {
364 let mut pending_requests = self.pending_out_requests.lock().unwrap();
365 pending_requests.remove(&message_id)
366 };
367
368 if let Some(response_sender) = response_sender {
369 let error_body = Body::Error(Error::new(
370 "rhp/write",
371 1,
372 &format!("Failed to write request: {error}"),
373 ));
374
375 if response_sender.send(error_body).is_err() {
376 warn!(
377 self.logger,
378 "Failed to deliver error response to local handler"
379 );
380 }
381 }
382 }
383 MessageType::Response => {
384 let error_response = Message {
386 id: message_id,
387 message_type: MessageType::Response,
388 body: Body::Error(Error::new(
389 "rhp/write",
390 1,
391 &format!("Failed to write response: {error}"),
392 )),
393 };
394
395 if self.write_message(error_response).is_err() {
396 warn!(self.logger, "Failed to write error message"; "err" => %error);
397 }
398 }
399 _ => {
400 warn!(self.logger, "Write failure for invalid message type"; "err" => %error)
401 }
402 }
403 }
404
405 fn handle_message<R: Read>(self: &Arc<Protocol>, reader: R) -> anyhow::Result<()> {
406 let message = self.decode_message(reader)?;
407
408 match message.message_type {
409 MessageType::Request => {
410 let id = message.id;
412
413 let body = match self.handle_request(id, message.body) {
414 Ok(Some(result)) => result,
415 Ok(None) => {
416 return Ok(());
419 }
420 Err(error) => Body::Error(Error::new("rhp/dispatcher", 1, &format!("{error}"))),
421 };
422
423 self.send_message(Message {
425 id,
426 message_type: MessageType::Response,
427 body,
428 })?;
429 }
430 MessageType::Response => {
431 let response_sender = {
433 let mut pending_requests = self.pending_out_requests.lock().unwrap();
434 pending_requests.remove(&message.id)
435 };
436
437 match response_sender {
438 Some(response_sender) => {
439 if response_sender.send(message.body).is_err() {
440 warn!(self.logger, "Failed to deliver response to local handler");
441 }
442 }
443 None => {
444 warn!(self.logger, "Received response message for unknown request"; "msg_id" => message.id);
445 }
446 }
447 }
448 _ => warn!(self.logger, "Received a malformed message"),
449 }
450
451 Ok(())
452 }
453
454 fn handle_request(
455 self: &Arc<Protocol>,
456 id: u64,
457 request: Body,
458 ) -> anyhow::Result<Option<Body>> {
459 match request {
460 Body::RuntimeInfoRequest(request) => Ok(Some(Body::RuntimeInfoResponse(
462 self.initialize_guest(request)?,
463 ))),
464 Body::RuntimePingRequest {} => Ok(Some(Body::Empty {})),
465 Body::RuntimeShutdownRequest {} => {
466 info!(self.logger, "Received worker shutdown request");
467 Err(ProtocolError::MethodNotSupported.into())
468 }
469 Body::RuntimeAbortRequest {} => {
470 info!(self.logger, "Received worker abort request");
471 Err(ProtocolError::MethodNotSupported.into())
472 }
473
474 Body::RuntimeCapabilityTEERakInitRequest { .. }
476 | Body::RuntimeCapabilityTEERakReportRequest {}
477 | Body::RuntimeCapabilityTEERakAvrRequest { .. }
478 | Body::RuntimeCapabilityTEERakQuoteRequest { .. }
479 | Body::RuntimeCapabilityTEEUpdateEndorsementRequest { .. } => {
480 self.dispatcher.queue_request(id, request)?;
481 Ok(None)
482 }
483
484 Body::RuntimeRPCCallRequest { .. }
486 | Body::RuntimeLocalRPCCallRequest { .. }
487 | Body::RuntimeCheckTxBatchRequest { .. }
488 | Body::RuntimeExecuteTxBatchRequest { .. }
489 | Body::RuntimeNotifyRequest { .. }
490 | Body::RuntimeKeyManagerStatusUpdateRequest { .. }
491 | Body::RuntimeKeyManagerQuotePolicyUpdateRequest { .. }
492 | Body::RuntimeQueryRequest { .. }
493 | Body::RuntimeConsensusSyncRequest { .. } => {
494 self.ensure_initialized()?;
495 self.dispatcher.queue_request(id, request)?;
496 Ok(None)
497 }
498
499 _ => {
500 warn!(self.logger, "Received unsupported request"; "req" => format!("{request:?}"));
501 Err(ProtocolError::MethodNotSupported.into())
502 }
503 }
504 }
505
506 fn initialize_guest(
507 self: &Arc<Protocol>,
508 host_info: RuntimeInfoRequest,
509 ) -> anyhow::Result<RuntimeInfoResponse> {
510 info!(self.logger, "Received host environment information";
511 "runtime_id" => ?host_info.runtime_id,
512 "consensus_backend" => &host_info.consensus_backend,
513 "consensus_protocol_version" => ?host_info.consensus_protocol_version,
514 "consensus_chain_context" => &host_info.consensus_chain_context,
515 "local_config" => ?host_info.local_config,
516 );
517
518 if tendermint::BACKEND_NAME != host_info.consensus_backend {
519 return Err(ProtocolError::IncompatibleConsensusBackend.into());
520 }
521 let mut local_host_info = self.host_info.lock().unwrap();
522 if local_host_info.is_some() {
523 return Err(ProtocolError::AlreadyInitialized.into());
524 }
525
526 let consensus_verifier: Box<dyn Verifier> =
528 if let Some(ref trust_root) = self.config.trust_root {
529 if host_info.runtime_id != trust_root.runtime_id {
531 return Err(ProtocolError::InvalidRuntimeId(
532 trust_root.runtime_id,
533 host_info.runtime_id,
534 )
535 .into());
536 }
537
538 let verifier = tendermint::verifier::Verifier::new(
540 self.clone(),
541 self.tokio_runtime.clone(),
542 trust_root.clone(),
543 host_info.runtime_id,
544 host_info.consensus_chain_context.clone(),
545 );
546 let handle = verifier.handle();
547 verifier.start();
548
549 Box::new(handle)
550 } else {
551 let verifier = tendermint::verifier::NopVerifier::new(self.clone());
553 verifier.start();
554
555 Box::new(verifier)
556 };
557
558 *local_host_info = Some(HostInfo {
560 runtime_id: host_info.runtime_id,
561 consensus_backend: host_info.consensus_backend,
562 consensus_protocol_version: host_info.consensus_protocol_version,
563 consensus_chain_context: host_info.consensus_chain_context,
564 local_config: host_info.local_config,
565 });
566
567 self.dispatcher.start(self.clone(), consensus_verifier);
569
570 Ok(RuntimeInfoResponse {
571 protocol_version: BUILD_INFO.protocol_version,
572 runtime_version: self.config.version,
573 features: self.config.features.clone(),
574 })
575 }
576
577 pub fn ensure_initialized(&self) -> anyhow::Result<()> {
579 self.host_info
580 .lock()
581 .unwrap()
582 .as_ref()
583 .ok_or(ProtocolError::HostInfoNotConfigured)?;
584
585 match BUILD_INFO.tee_type {
586 TeeType::Sgx | TeeType::Tdx => {
587 self.identity
588 .quote()
589 .ok_or(ProtocolError::AttestationRequired)?;
590 }
591 TeeType::None => {}
592 }
593
594 Ok(())
595 }
596}
597
598pub struct ProtocolUntrustedLocalStorage {
606 protocol: Arc<Protocol>,
607}
608
609impl ProtocolUntrustedLocalStorage {
610 pub fn new(protocol: Arc<Protocol>) -> Self {
611 Self { protocol }
612 }
613}
614
615impl KeyValue for ProtocolUntrustedLocalStorage {
616 fn get(&self, key: Vec<u8>) -> Result<Vec<u8>, Error> {
617 match self
618 .protocol
619 .call_host(Body::HostLocalStorageGetRequest { key })?
620 {
621 Body::HostLocalStorageGetResponse { value } => Ok(value),
622 _ => Err(ProtocolError::InvalidResponse.into()),
623 }
624 }
625
626 fn insert(&self, key: Vec<u8>, value: Vec<u8>) -> Result<(), Error> {
627 match self
628 .protocol
629 .call_host(Body::HostLocalStorageSetRequest { key, value })?
630 {
631 Body::HostLocalStorageSetResponse {} => Ok(()),
632 _ => Err(ProtocolError::InvalidResponse.into()),
633 }
634 }
635}