forked from trinitrix/core
1
0
Fork 0

feat(command_interface): Add support for functions (and thus callbacks)

This commit is contained in:
Benedikt Peetz 2023-10-14 18:46:25 +02:00
parent d7b93178e8
commit 1dd9a0e4de
Signed by: bpeetz
GPG Key ID: A5E94010C3A642AD
6 changed files with 215 additions and 24 deletions

View File

@ -5,9 +5,11 @@ use language_macros::parse_command_enum;
// TODO(@soispha): Should these paths be moved to the proc macro? // TODO(@soispha): Should these paths be moved to the proc macro?
// As they are not static, it could be easier for other people, // As they are not static, it could be easier for other people,
// if they stay here. // if they stay here.
use mlua::IntoLua; use crate::app::command_interface::command_transfer_value::{
use crate::app::command_interface::command_transfer_value::CommandTransferValue; support_types::Function, CommandTransferValue,
};
use crate::app::Event; use crate::app::Event;
use mlua::IntoLua;
parse_command_enum! { parse_command_enum! {
commands { commands {
@ -39,6 +41,10 @@ commands {
/// the help pages at the start /// the help pages at the start
declare help: fn(Option<String>), declare help: fn(Option<String>),
// Register a function to be used with the Trinitrix api
// (This function is not actually implemented here)
/* declare register_function: false, */
/// Function that change the UI, or UI state /// Function that change the UI, or UI state
namespace ui { namespace ui {
/// Shows the command line /// Shows the command line
@ -66,6 +72,18 @@ commands {
/// This is mainly used to display the final /// This is mainly used to display the final
/// output of evaluated lua commands. /// output of evaluated lua commands.
declare display_output: fn(String), declare display_output: fn(String),
/// This namespace is used to store some command specific data (like functions, as
/// ensuring memory locations stay allocated in garbage collected language is hard)
///
/// Treat it as an implementation detail
namespace __private {
/// This command is a no-op, it's just here to ensure that the '__private'
/// namespace get actually created
// FIXME(@soispha): Add an attribute to namespaces to avoid having to use
// empty functions <2023-10-14>
declare __private_initializer: fn(),
},
}, },
}, },
}, },

View File

@ -3,7 +3,23 @@ use std::collections::HashMap;
use cli_log::info; use cli_log::info;
use mlua::{ErrorContext, FromLua, IntoLua, LuaSerdeExt, Value}; use mlua::{ErrorContext, FromLua, IntoLua, LuaSerdeExt, Value};
use super::{CommandTransferValue, Table}; use super::{support_types::Function, CommandTransferValue, Table};
impl<'lua> FromLua<'lua> for Function {
fn from_lua(value: Value<'lua>, lua: &'lua mlua::Lua) -> mlua::Result<Self> {
match value {
Value::String(function_id) => {
return Ok(Function::new(
function_id
.to_str()
.context("Failed to convert lua string to rust string")?
.to_owned(),
))
}
_ => unreachable!("The Function type can only take functions!"),
};
}
}
impl<'lua> IntoLua<'lua> for CommandTransferValue { impl<'lua> IntoLua<'lua> for CommandTransferValue {
fn into_lua(self, lua: &'lua mlua::Lua) -> mlua::Result<mlua::Value<'lua>> { fn into_lua(self, lua: &'lua mlua::Lua) -> mlua::Result<mlua::Value<'lua>> {

View File

@ -2,6 +2,9 @@ use std::{collections::HashMap, fmt::Display};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use self::support_types::Function;
pub mod support_types;
pub mod type_conversions; pub mod type_conversions;
// language support // language support
@ -30,9 +33,10 @@ pub enum CommandTransferValue {
/// A table, dictionary or HashMap /// A table, dictionary or HashMap
Table(HashMap<String, CommandTransferValue>), Table(HashMap<String, CommandTransferValue>),
// Reference to a Lua function (or closure).
// /* TODO */ Function(Function),
/// Reference to a function (or closure).
/// This 'Function' value is obtained by registering a function with 'register_function()'
Function(Function),
// Reference to a Lua thread (or coroutine). // Reference to a Lua thread (or coroutine).
// /* TODO */ Thread(Thread<'lua>), // /* TODO */ Thread(Thread<'lua>),
@ -55,6 +59,7 @@ impl Display for CommandTransferValue {
// TODO(@Soispha): The following line should be a real display call, but how do you // TODO(@Soispha): The following line should be a real display call, but how do you
// format a HashMap? // format a HashMap?
CommandTransferValue::Table(table) => f.write_str(&format!("{:#?}", table)), CommandTransferValue::Table(table) => f.write_str(&format!("{:#?}", table)),
CommandTransferValue::Function(function) => function.fmt(f),
} }
} }
} }

View File

@ -0,0 +1,68 @@
use std::{
fmt::Display,
sync::atomic::{AtomicUsize, Ordering},
};
use mlua::{ErrorContext, IntoLua, Lua, Table};
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct Function {
id: String,
}
impl Function {
pub fn new(function_uuid: String) -> Self {
Function { id: function_uuid }
}
pub fn from_lua_function(function: mlua::Function, lua: &Lua) -> mlua::Result<Self> {
// TODO(@soispha): Does this expose a vulnerability, as the ids are predictable? <2023-10-14>
static COUNTER: AtomicUsize = AtomicUsize::new(0);
let id = COUNTER.fetch_add(1, Ordering::Relaxed);
let globals = lua.globals();
// This is always initialized, as the namespaces are specified in the 'command_list' module
let private: Table = globals
.get::<&str, mlua::Table>("trinitrix")
.context("Failed to access 'trinitrix'")?
.get::<&str, mlua::Table>("api")
.context("Failed to access 'api'")?
.get::<&str, mlua::Table>("raw")
.context("Failed to access 'raw'")?
.get::<&str, mlua::Table>("__private")
.context("Failed to access '__private'")?;
private.set(id, function)?;
Ok(Function::new(format!("{}", id)))
}
pub fn call(&self, lua: &Lua) -> mlua::Result<()> {
// This is always initialized, as the namespaces are specified in the 'command_list' module
let private: Table = lua
.globals()
.get::<&str, mlua::Table>("trinitrix")
.context("Failed to access 'trinitrix'")?
.get::<&str, mlua::Table>("api")
.context("Failed to access 'api'")?
.get::<&str, mlua::Table>("raw")
.context("Failed to access 'raw'")?
.get::<&str, mlua::Table>("__private")
.context("Failed to access '__private'")?;
let function: mlua::Function = private
.get(self.id.clone())
.context("Failed to get function associated with callback!")?;
function.call(()).context("Failed to call function")
}
}
impl<'lua> IntoLua<'lua> for Function {
fn into_lua(self, lua: &'lua Lua) -> mlua::Result<mlua::Value<'lua>> {
Ok(self.id.into_lua(lua)?)
}
}
impl Display for Function {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.id.fmt(f)
}
}

View File

@ -2,10 +2,11 @@ use std::thread;
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use cli_log::{debug, error, info}; use cli_log::{debug, error, info};
use mlua::{Function, Value}; use mlua::{ErrorContext, Lua, Value};
use once_cell::sync::OnceCell; use once_cell::sync::OnceCell;
use tokio::{ use tokio::{
runtime::Builder, runtime::Builder,
select,
sync::{mpsc, Mutex}, sync::{mpsc, Mutex},
task::{self, LocalSet}, task::{self, LocalSet},
}; };
@ -21,12 +22,15 @@ use crate::app::{
events::event_types::Event, events::event_types::Event,
}; };
use super::command_transfer_value::support_types::Function;
static LUA: OnceCell<Mutex<mlua::Lua>> = OnceCell::new(); static LUA: OnceCell<Mutex<mlua::Lua>> = OnceCell::new();
/// This structure contains the necessary state for running an embedded Lua runtime (i.e. /// This structure contains the necessary state for running an embedded Lua runtime (i.e.
/// the tread, the Lua memory, etc.). /// the tread, the Lua memory, etc.).
pub struct LuaCommandManager { pub struct LuaCommandManager {
lua_command_tx: mpsc::Sender<String>, lua_command_tx: mpsc::Sender<String>,
lua_function_tx: mpsc::Sender<Function>,
} }
impl LuaCommandManager { impl LuaCommandManager {
@ -36,10 +40,17 @@ impl LuaCommandManager {
.await .await
.expect("The receiver should not be dropped at this time"); .expect("The receiver should not be dropped at this time");
} }
pub async fn execute_function(&self, function: Function) {
self.lua_function_tx
.send(function)
.await
.expect("The receiver should not be dropped");
}
pub fn new(event_call_tx: mpsc::Sender<Event>) -> Self { pub fn new(event_call_tx: mpsc::Sender<Event>) -> Self {
info!("Spawning lua code execution thread..."); info!("Spawning lua code execution thread...");
let (lua_command_tx, mut lua_command_rx) = mpsc::channel::<String>(256); let (lua_command_tx, mut lua_command_rx) = mpsc::channel::<String>(256);
let (lua_function_tx, mut lua_function_rx) = mpsc::channel::<Function>(256);
thread::spawn(move || { thread::spawn(move || {
let rt = Builder::new_current_thread().enable_all().build().expect( let rt = Builder::new_current_thread().enable_all().build().expect(
"Should always be able to build \ "Should always be able to build \
@ -51,9 +62,14 @@ impl LuaCommandManager {
"Lua command handling initialized, \ "Lua command handling initialized, \
waiting for commands.." waiting for commands.."
); );
while let Some(command) = lua_command_rx.recv().await { let mut done = false;
debug!("Recieved lua code (in LuaCommandHandler): {}", &command); let moved_event_call_tx = event_call_tx.clone();
let local_event_call_tx = event_call_tx.clone(); while !done {
select! {
command = lua_command_rx.recv() => {
if let Some(command) = command {
debug!("Received lua code (in LuaCommandHandler): {}", &command);
let local_event_call_tx = moved_event_call_tx.clone();
task::spawn_local(async move { task::spawn_local(async move {
exec_lua(&command, local_event_call_tx).await.expect( exec_lua(&command, local_event_call_tx).await.expect(
@ -62,32 +78,96 @@ impl LuaCommandManager {
this should never error", this should never error",
); );
}); });
} else {
done = true;
}
},
function_id = lua_function_rx.recv() => {
if let Some(function_id) = function_id {
debug!("Received lua function (in LuaCommandHandler): {}", &function_id);
let local_event_call_tx = moved_event_call_tx.clone();
task::spawn_local(async move {
let lua = initialize_lua(local_event_call_tx.clone()).await;
let out = function_id.call(&lua).map_err(|err| async move {
error!("Lua function `{}` returned error: `{}`", function_id, err);
local_event_call_tx
.send(Event::CommandEvent(
Command::Trinitrix(Api(Raw(RaiseError(err.to_string())))),
None,
))
.await.expect(
"This should return all relevent errors \
by other messages, \
this should never error",
);
});
if let Err(err) = out {
err.await;
}
});
} else {
done = true;
}
},
else => done = true,
}
} }
}); });
rt.block_on(local); rt.block_on(local);
}); });
LuaCommandManager { lua_command_tx } LuaCommandManager {
lua_command_tx,
lua_function_tx,
}
} }
} }
async fn exec_lua(lua_code: &str, event_call_tx: mpsc::Sender<Event>) -> Result<()> { async fn initialize_lua<'a>(
let second_event_call_tx = event_call_tx.clone(); event_call_tx: mpsc::Sender<Event>,
) -> matrix_sdk::locks::MutexGuard<'a, Lua> {
let lua = LUA let lua = LUA
.get_or_init(|| { .get_or_init(|| {
Mutex::new(add_lua_functions_to_globals( let lua: Lua = add_lua_functions_to_globals(mlua::Lua::new(), event_call_tx);
mlua::Lua::new(), {
second_event_call_tx, let wrapped_register_function = lua
)) .create_function(
|lua: &Lua, function: mlua::Function| -> mlua::Result<Function> {
ErrorContext::context(
Function::from_lua_function(function, lua),
"Failed to register function",
)
},
)
.expect("This always works, as the function is static");
let trinitrix_api: mlua::Table = lua
.globals()
.get::<&str, mlua::Table>("trinitrix")
.expect("This was set in the add_lua_functions_to_globals function")
.get::<&str, mlua::Table>("api")
.expect("Same reason");
trinitrix_api
.set("register_function", wrapped_register_function)
.expect("This should work");
}
Mutex::new(lua)
}) })
.lock() .lock()
.await; .await;
lua
}
async fn exec_lua(lua_code: &str, event_call_tx: mpsc::Sender<Event>) -> Result<()> {
let lua = initialize_lua(event_call_tx.clone()).await;
info!("Recieved code to execute: `{}`, executing...", &lua_code); info!("Recieved code to execute: `{}`, executing...", &lua_code);
let output = lua.load(lua_code).eval_async::<Value>().await; let output = lua.load(lua_code).eval_async::<Value>().await;
match output { match output {
Ok(out) => { Ok(out) => {
let to_string_fn: Function = lua.globals().get("tostring").expect("This always exists"); let to_string_fn: mlua::Function =
lua.globals().get("tostring").expect("This always exists");
let output: String = to_string_fn.call(out).expect("tostring should not error"); let output: String = to_string_fn.call(out).expect("tostring should not error");
info!("Lua code `{}` evaluated to: `{}`", lua_code, &output); info!("Lua code `{}` evaluated to: `{}`", lua_code, &output);

View File

@ -160,6 +160,10 @@ pub async fn handle(
send_status_output!(output); send_status_output!(output);
EventStatus::Ok EventStatus::Ok
} }
Raw::Private(_) => {
// no-op, read the comment about it in the `command_list`
EventStatus::Ok
}
}, },
}, },
}, },