diff --git a/src/cbs/connection.rs b/src/cbs/connection.rs index a89c277..4397a48 100644 --- a/src/cbs/connection.rs +++ b/src/cbs/connection.rs @@ -56,14 +56,21 @@ impl UnstableConnection { let packet = Packet::recv(&mut sock_rx, &cipher, &nonce).await?; match packet { Packet::Request { id, body } => { - Packet::response(id_pool.acquire(), id, Response::Success) - .send(&mut sock_tx, &cipher, &nonce) - .await?; + match body { + Request::HandshakeUpgradeConnection => { + Packet::response(id_pool.acquire(), id, Response::Success) + .send(&mut sock_tx, &cipher, &nonce) + .await?; - cli_log::info!( - "CBS {id}: upgraded connection to encrypted messagepack", - id = self.id - ); + cli_log::info!( + "CBS {id}: upgraded connection to encrypted messagepack", + id = self.id + ); + } + req => return Err(anyhow!( + "expected cbs to send: Request::HandshakeUpgradeConnection, but got: Request::{req}" + )) + } } body => { return Err(anyhow!( @@ -72,6 +79,10 @@ impl UnstableConnection { } } + Packet::request(id_pool.acquire(), Request::HandshakeSuccess) + .send(&mut sock_tx, &cipher, &nonce) + .await?; + // Poll packets from socket { let cipher = cipher.clone();