From 1dd9a0e4ded67423c3c0ddb88b21eb5a6f1d45d6 Mon Sep 17 00:00:00 2001 From: Soispha Date: Sat, 14 Oct 2023 18:46:25 +0200 Subject: [PATCH] feat(command_interface): Add support for functions (and thus callbacks) --- src/app/command_interface/command_list/mod.rs | 22 +++- .../command_transfer_value/lua.rs | 18 ++- .../command_transfer_value/mod.rs | 9 +- .../command_transfer_value/support_types.rs | 68 ++++++++++ .../lua_command_manager/mod.rs | 118 +++++++++++++++--- .../event_types/event/handlers/command.rs | 4 + 6 files changed, 215 insertions(+), 24 deletions(-) create mode 100644 src/app/command_interface/command_transfer_value/support_types.rs diff --git a/src/app/command_interface/command_list/mod.rs b/src/app/command_interface/command_list/mod.rs index 8237e6c..15ca747 100644 --- a/src/app/command_interface/command_list/mod.rs +++ b/src/app/command_interface/command_list/mod.rs @@ -5,9 +5,11 @@ use language_macros::parse_command_enum; // TODO(@soispha): Should these paths be moved to the proc macro? // As they are not static, it could be easier for other people, // if they stay here. -use mlua::IntoLua; -use crate::app::command_interface::command_transfer_value::CommandTransferValue; +use crate::app::command_interface::command_transfer_value::{ + support_types::Function, CommandTransferValue, +}; use crate::app::Event; +use mlua::IntoLua; parse_command_enum! { commands { @@ -39,6 +41,10 @@ commands { /// the help pages at the start declare help: fn(Option), + // Register a function to be used with the Trinitrix api + // (This function is not actually implemented here) + /* declare register_function: false, */ + /// Function that change the UI, or UI state namespace ui { /// Shows the command line @@ -66,6 +72,18 @@ commands { /// This is mainly used to display the final /// output of evaluated lua commands. declare display_output: fn(String), + + /// This namespace is used to store some command specific data (like functions, as + /// ensuring memory locations stay allocated in garbage collected language is hard) + /// + /// Treat it as an implementation detail + namespace __private { + /// This command is a no-op, it's just here to ensure that the '__private' + /// namespace get actually created + // FIXME(@soispha): Add an attribute to namespaces to avoid having to use + // empty functions <2023-10-14> + declare __private_initializer: fn(), + }, }, }, }, diff --git a/src/app/command_interface/command_transfer_value/lua.rs b/src/app/command_interface/command_transfer_value/lua.rs index b0146dd..36f8f65 100644 --- a/src/app/command_interface/command_transfer_value/lua.rs +++ b/src/app/command_interface/command_transfer_value/lua.rs @@ -3,7 +3,23 @@ use std::collections::HashMap; use cli_log::info; use mlua::{ErrorContext, FromLua, IntoLua, LuaSerdeExt, Value}; -use super::{CommandTransferValue, Table}; +use super::{support_types::Function, CommandTransferValue, Table}; + +impl<'lua> FromLua<'lua> for Function { + fn from_lua(value: Value<'lua>, lua: &'lua mlua::Lua) -> mlua::Result { + match value { + Value::String(function_id) => { + return Ok(Function::new( + function_id + .to_str() + .context("Failed to convert lua string to rust string")? + .to_owned(), + )) + } + _ => unreachable!("The Function type can only take functions!"), + }; + } +} impl<'lua> IntoLua<'lua> for CommandTransferValue { fn into_lua(self, lua: &'lua mlua::Lua) -> mlua::Result> { diff --git a/src/app/command_interface/command_transfer_value/mod.rs b/src/app/command_interface/command_transfer_value/mod.rs index 44e16fc..b9aa4f1 100644 --- a/src/app/command_interface/command_transfer_value/mod.rs +++ b/src/app/command_interface/command_transfer_value/mod.rs @@ -2,6 +2,9 @@ use std::{collections::HashMap, fmt::Display}; use serde::{Deserialize, Serialize}; +use self::support_types::Function; + +pub mod support_types; pub mod type_conversions; // language support @@ -30,9 +33,10 @@ pub enum CommandTransferValue { /// A table, dictionary or HashMap Table(HashMap), - // Reference to a Lua function (or closure). - // /* TODO */ Function(Function), + /// Reference to a function (or closure). + /// This 'Function' value is obtained by registering a function with 'register_function()' + Function(Function), // Reference to a Lua thread (or coroutine). // /* TODO */ Thread(Thread<'lua>), @@ -55,6 +59,7 @@ impl Display for CommandTransferValue { // TODO(@Soispha): The following line should be a real display call, but how do you // format a HashMap? CommandTransferValue::Table(table) => f.write_str(&format!("{:#?}", table)), + CommandTransferValue::Function(function) => function.fmt(f), } } } diff --git a/src/app/command_interface/command_transfer_value/support_types.rs b/src/app/command_interface/command_transfer_value/support_types.rs new file mode 100644 index 0000000..d9ade5d --- /dev/null +++ b/src/app/command_interface/command_transfer_value/support_types.rs @@ -0,0 +1,68 @@ +use std::{ + fmt::Display, + sync::atomic::{AtomicUsize, Ordering}, +}; + +use mlua::{ErrorContext, IntoLua, Lua, Table}; +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Function { + id: String, +} + +impl Function { + pub fn new(function_uuid: String) -> Self { + Function { id: function_uuid } + } + pub fn from_lua_function(function: mlua::Function, lua: &Lua) -> mlua::Result { + // TODO(@soispha): Does this expose a vulnerability, as the ids are predictable? <2023-10-14> + static COUNTER: AtomicUsize = AtomicUsize::new(0); + let id = COUNTER.fetch_add(1, Ordering::Relaxed); + + let globals = lua.globals(); + // This is always initialized, as the namespaces are specified in the 'command_list' module + let private: Table = globals + .get::<&str, mlua::Table>("trinitrix") + .context("Failed to access 'trinitrix'")? + .get::<&str, mlua::Table>("api") + .context("Failed to access 'api'")? + .get::<&str, mlua::Table>("raw") + .context("Failed to access 'raw'")? + .get::<&str, mlua::Table>("__private") + .context("Failed to access '__private'")?; + private.set(id, function)?; + + Ok(Function::new(format!("{}", id))) + } + pub fn call(&self, lua: &Lua) -> mlua::Result<()> { + // This is always initialized, as the namespaces are specified in the 'command_list' module + let private: Table = lua + .globals() + .get::<&str, mlua::Table>("trinitrix") + .context("Failed to access 'trinitrix'")? + .get::<&str, mlua::Table>("api") + .context("Failed to access 'api'")? + .get::<&str, mlua::Table>("raw") + .context("Failed to access 'raw'")? + .get::<&str, mlua::Table>("__private") + .context("Failed to access '__private'")?; + let function: mlua::Function = private + .get(self.id.clone()) + .context("Failed to get function associated with callback!")?; + + function.call(()).context("Failed to call function") + } +} + +impl<'lua> IntoLua<'lua> for Function { + fn into_lua(self, lua: &'lua Lua) -> mlua::Result> { + Ok(self.id.into_lua(lua)?) + } +} + +impl Display for Function { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.id.fmt(f) + } +} diff --git a/src/app/command_interface/lua_command_manager/mod.rs b/src/app/command_interface/lua_command_manager/mod.rs index 520bf5c..f33d054 100644 --- a/src/app/command_interface/lua_command_manager/mod.rs +++ b/src/app/command_interface/lua_command_manager/mod.rs @@ -2,10 +2,11 @@ use std::thread; use anyhow::{Context, Result}; use cli_log::{debug, error, info}; -use mlua::{Function, Value}; +use mlua::{ErrorContext, Lua, Value}; use once_cell::sync::OnceCell; use tokio::{ runtime::Builder, + select, sync::{mpsc, Mutex}, task::{self, LocalSet}, }; @@ -21,12 +22,15 @@ use crate::app::{ events::event_types::Event, }; +use super::command_transfer_value::support_types::Function; + static LUA: OnceCell> = OnceCell::new(); /// This structure contains the necessary state for running an embedded Lua runtime (i.e. /// the tread, the Lua memory, etc.). pub struct LuaCommandManager { lua_command_tx: mpsc::Sender, + lua_function_tx: mpsc::Sender, } impl LuaCommandManager { @@ -36,10 +40,17 @@ impl LuaCommandManager { .await .expect("The receiver should not be dropped at this time"); } + pub async fn execute_function(&self, function: Function) { + self.lua_function_tx + .send(function) + .await + .expect("The receiver should not be dropped"); + } pub fn new(event_call_tx: mpsc::Sender) -> Self { info!("Spawning lua code execution thread..."); let (lua_command_tx, mut lua_command_rx) = mpsc::channel::(256); + let (lua_function_tx, mut lua_function_rx) = mpsc::channel::(256); thread::spawn(move || { let rt = Builder::new_current_thread().enable_all().build().expect( "Should always be able to build \ @@ -51,43 +62,112 @@ impl LuaCommandManager { "Lua command handling initialized, \ waiting for commands.." ); - while let Some(command) = lua_command_rx.recv().await { - debug!("Recieved lua code (in LuaCommandHandler): {}", &command); - let local_event_call_tx = event_call_tx.clone(); + let mut done = false; + let moved_event_call_tx = event_call_tx.clone(); + while !done { + select! { + command = lua_command_rx.recv() => { + if let Some(command) = command { + debug!("Received lua code (in LuaCommandHandler): {}", &command); + let local_event_call_tx = moved_event_call_tx.clone(); - task::spawn_local(async move { - exec_lua(&command, local_event_call_tx).await.expect( - "This should return all relevent errors \ - by other messages, \ - this should never error", - ); - }); + task::spawn_local(async move { + exec_lua(&command, local_event_call_tx).await.expect( + "This should return all relevent errors \ + by other messages, \ + this should never error", + ); + }); + } else { + done = true; + } + }, + function_id = lua_function_rx.recv() => { + if let Some(function_id) = function_id { + debug!("Received lua function (in LuaCommandHandler): {}", &function_id); + let local_event_call_tx = moved_event_call_tx.clone(); + + task::spawn_local(async move { + let lua = initialize_lua(local_event_call_tx.clone()).await; + let out = function_id.call(&lua).map_err(|err| async move { + error!("Lua function `{}` returned error: `{}`", function_id, err); + local_event_call_tx + .send(Event::CommandEvent( + Command::Trinitrix(Api(Raw(RaiseError(err.to_string())))), + None, + )) + .await.expect( + "This should return all relevent errors \ + by other messages, \ + this should never error", + ); + }); + if let Err(err) = out { + err.await; + } + }); + } else { + done = true; + } + }, + else => done = true, + } } }); rt.block_on(local); }); - LuaCommandManager { lua_command_tx } + LuaCommandManager { + lua_command_tx, + lua_function_tx, + } } } -async fn exec_lua(lua_code: &str, event_call_tx: mpsc::Sender) -> Result<()> { - let second_event_call_tx = event_call_tx.clone(); +async fn initialize_lua<'a>( + event_call_tx: mpsc::Sender, +) -> matrix_sdk::locks::MutexGuard<'a, Lua> { let lua = LUA .get_or_init(|| { - Mutex::new(add_lua_functions_to_globals( - mlua::Lua::new(), - second_event_call_tx, - )) + let lua: Lua = add_lua_functions_to_globals(mlua::Lua::new(), event_call_tx); + { + let wrapped_register_function = lua + .create_function( + |lua: &Lua, function: mlua::Function| -> mlua::Result { + ErrorContext::context( + Function::from_lua_function(function, lua), + "Failed to register function", + ) + }, + ) + .expect("This always works, as the function is static"); + let trinitrix_api: mlua::Table = lua + .globals() + .get::<&str, mlua::Table>("trinitrix") + .expect("This was set in the add_lua_functions_to_globals function") + .get::<&str, mlua::Table>("api") + .expect("Same reason"); + trinitrix_api + .set("register_function", wrapped_register_function) + .expect("This should work"); + } + + Mutex::new(lua) }) .lock() .await; + lua +} + +async fn exec_lua(lua_code: &str, event_call_tx: mpsc::Sender) -> Result<()> { + let lua = initialize_lua(event_call_tx.clone()).await; info!("Recieved code to execute: `{}`, executing...", &lua_code); let output = lua.load(lua_code).eval_async::().await; match output { Ok(out) => { - let to_string_fn: Function = lua.globals().get("tostring").expect("This always exists"); + let to_string_fn: mlua::Function = + lua.globals().get("tostring").expect("This always exists"); let output: String = to_string_fn.call(out).expect("tostring should not error"); info!("Lua code `{}` evaluated to: `{}`", lua_code, &output); diff --git a/src/app/events/event_types/event/handlers/command.rs b/src/app/events/event_types/event/handlers/command.rs index 4b655ad..f984892 100644 --- a/src/app/events/event_types/event/handlers/command.rs +++ b/src/app/events/event_types/event/handlers/command.rs @@ -160,6 +160,10 @@ pub async fn handle( send_status_output!(output); EventStatus::Ok } + Raw::Private(_) => { + // no-op, read the comment about it in the `command_list` + EventStatus::Ok + } }, }, },