diff --git a/src/app/events/mod.rs b/src/app/events/mod.rs index 7e3ab6a..522ddf1 100644 --- a/src/app/events/mod.rs +++ b/src/app/events/mod.rs @@ -22,7 +22,7 @@ mod handlers; pub mod listeners; -use anyhow::{Context, Result}; +use anyhow::{Context, Error, Result}; use crate::{ app::{command_interface::Commands, App}, @@ -33,10 +33,29 @@ use crossterm::event::Event as CrosstermEvent; use handlers::{command, input}; use uuid::Uuid; +#[derive(Debug)] +pub enum ErrorEvent { + CBSCrash(Uuid, Error), +} + +impl ErrorEvent { + pub async fn handle(self, app: &mut App) -> Result { + match self { + Self::CBSCrash(cbs, err) => { + // TODO: Kill CBS process + // TODO: Kill CBS connection threads + cli_log::error!("The CBS handler for {cbs} crashed: {err}"); + Ok(EventStatus::Ok) + } + } + } +} + #[derive(Debug)] pub enum Event { InputEvent(CrosstermEvent), CBSPacket(Uuid, triba_packet::Packet), + Error(ErrorEvent), // FIXME(@soispha): The `String` here is just wrong <2024-05-03> CommandEvent(Commands, Option>), @@ -56,6 +75,8 @@ impl Event { Ok(EventStatus::Ok) } + Event::Error(err) => err.handle(app).await, + Event::LuaCommand(lua_code) => { warn!( "Got lua code to execute, but no exectuter is available:\n{}", diff --git a/src/cbs/connection.rs b/src/cbs/connection.rs index 7a56b96..b2df663 100644 --- a/src/cbs/connection.rs +++ b/src/cbs/connection.rs @@ -1,36 +1,35 @@ -use crate::app::events::Event as AppEvent; +use crate::app::events::{ErrorEvent as AppErrorEvent, Event as AppEvent}; use aes_gcm_siv::{Aes256GcmSiv, Nonce}; -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, Context, Result}; use interprocess::local_socket::tokio::{RecvHalf, SendHalf}; use tokio::sync::mpsc; use tokio_util::sync::CancellationToken; use triba_packet::{IdPool, Packet, Request, Response}; use uuid::Uuid; -pub struct Connection { - req_tx: mpsc::UnboundedSender, - resp_tx: mpsc::UnboundedSender<(Response, u64)>, -} - -impl Connection { - pub async fn send_request(&self, body: Request) -> Result<()> { - self.req_tx.send(body)?; - Ok(()) - } - - pub async fn send_response(&self, body: Response, receiver: u64) -> Result<()> { - self.resp_tx.send((body, receiver))?; - Ok(()) - } -} - enum Event { ToCBSReq(Request), ToCBSResp(Response, u64), FromCBS(Packet), } +pub struct Connection { + tx: mpsc::UnboundedSender, +} + +impl Connection { + pub async fn send_request(&self, body: Request) -> Result<()> { + self.tx.send(Event::ToCBSReq(body))?; + Ok(()) + } + + pub async fn send_response(&self, body: Response, receiver: u64) -> Result<()> { + self.tx.send(Event::ToCBSResp(body, receiver))?; + Ok(()) + } +} + pub struct UnstableConnection { kill_token: CancellationToken, id: Uuid, @@ -109,37 +108,20 @@ impl UnstableConnection { let nonce = nonce.clone(); let tx = tx.clone(); + let id = self.id.clone(); + let main_tx = self.main_tx.clone(); + tokio::spawn(async move { - loop { - let packet = Packet::recv(&mut sock_rx, &cipher, &nonce).await.unwrap(); - tx.send(Event::FromCBS(packet)).unwrap(); + match poll_from_socket(&mut sock_rx, &cipher, &nonce, tx).await { + Err(e) => main_tx + .send(AppEvent::Error(AppErrorEvent::CBSCrash(id, e))) + .await + .expect("Failed to propagate error back to main queue."), + Ok(_) => (), } }); } - let (core_req_tx, mut core_req_rx) = mpsc::unbounded_channel(); - - // Poll requests from core - { - let tx = tx.clone(); - tokio::spawn(async move { - loop { - let body = core_req_rx.recv().await.unwrap(); - tx.send(Event::ToCBSReq(body)).unwrap(); - } - }); - } - - let (core_resp_tx, mut core_resp_rx) = mpsc::unbounded_channel(); - - // Poll responses from core - tokio::spawn(async move { - loop { - let (body, req) = core_resp_rx.recv().await.unwrap(); - tx.send(Event::ToCBSResp(body, req)).unwrap(); - } - }); - // Handle and route all packets { let cipher = cipher.clone(); @@ -149,37 +131,76 @@ impl UnstableConnection { let id = self.id.clone(); tokio::spawn(async move { - loop { - let event = tokio::select! { - event = rx.recv() => event.unwrap(), - _ = kill_token.cancelled() => break, - }; - - match event { - Event::ToCBSReq(req) => { - Packet::request(id_pool.acquire(), req) - .send(&mut sock_tx, &cipher, &nonce) - .await - .unwrap(); - } - Event::ToCBSResp(resp, req) => { - Packet::response(id_pool.acquire(), req, resp) - .send(&mut sock_tx, &cipher, &nonce) - .await - .unwrap(); - } - Event::FromCBS(packet) => main_tx - .send(AppEvent::CBSPacket(id.clone(), packet)) - .await - .unwrap(), - } + match route_packets( + rx, + kill_token, + id_pool, + &mut sock_tx, + &cipher, + &nonce, + main_tx.clone(), + id, + ) + .await + { + Err(e) => main_tx + .send(AppEvent::Error(AppErrorEvent::CBSCrash(id, e))) + .await + .expect("Failed to propagate error back to main queue."), + Ok(_) => (), } }); } - Ok(Connection { - req_tx: core_req_tx, - resp_tx: core_resp_tx, - }) + Ok(Connection { tx }) } } + +async fn poll_from_socket( + rx: &mut RecvHalf, + cipher: &Aes256GcmSiv, + nonce: &Nonce, + tx: mpsc::UnboundedSender, +) -> Result<()> { + loop { + let packet = Packet::recv(rx, cipher, nonce).await?; + tx.send(Event::FromCBS(packet))?; + } +} + +async fn route_packets( + mut rx: mpsc::UnboundedReceiver, + kill_token: CancellationToken, + mut id_pool: IdPool, + sock_tx: &mut SendHalf, + cipher: &Aes256GcmSiv, + nonce: &Nonce, + main_tx: mpsc::Sender, + id: Uuid, +) -> Result<()> { + loop { + let event = tokio::select! { + event = rx.recv() => event.context("The cbs event queue was closed unexpectedly.")?, + _ = kill_token.cancelled() => break, + }; + + match event { + Event::ToCBSReq(req) => { + Packet::request(id_pool.acquire(), req) + .send(sock_tx, cipher, nonce) + .await?; + } + Event::ToCBSResp(resp, req) => { + Packet::response(id_pool.acquire(), req, resp) + .send(sock_tx, cipher, nonce) + .await?; + } + Event::FromCBS(packet) => { + main_tx + .send(AppEvent::CBSPacket(id.clone(), packet)) + .await? + } + } + } + Ok(()) +} diff --git a/src/cbs/dummy.rs b/src/cbs/dummy.rs index e5cb73d..aeb93bd 100644 --- a/src/cbs/dummy.rs +++ b/src/cbs/dummy.rs @@ -4,5 +4,7 @@ use uuid::Uuid; pub async fn cbs(sock_name: Name<'_>, id: Uuid) { let (session, rx) = triba::Session::new(id, sock_name).await.unwrap(); - loop {} + loop { + tokio::task::yield_now().await; + } }