forked from trinitrix/core
feat(command_interface): Add support for functions (and thus callbacks)
This commit is contained in:
parent
d7b93178e8
commit
1dd9a0e4de
|
@ -5,9 +5,11 @@ use language_macros::parse_command_enum;
|
|||
// TODO(@soispha): Should these paths be moved to the proc macro?
|
||||
// As they are not static, it could be easier for other people,
|
||||
// if they stay here.
|
||||
use mlua::IntoLua;
|
||||
use crate::app::command_interface::command_transfer_value::CommandTransferValue;
|
||||
use crate::app::command_interface::command_transfer_value::{
|
||||
support_types::Function, CommandTransferValue,
|
||||
};
|
||||
use crate::app::Event;
|
||||
use mlua::IntoLua;
|
||||
|
||||
parse_command_enum! {
|
||||
commands {
|
||||
|
@ -39,6 +41,10 @@ commands {
|
|||
/// the help pages at the start
|
||||
declare help: fn(Option<String>),
|
||||
|
||||
// Register a function to be used with the Trinitrix api
|
||||
// (This function is not actually implemented here)
|
||||
/* declare register_function: false, */
|
||||
|
||||
/// Function that change the UI, or UI state
|
||||
namespace ui {
|
||||
/// Shows the command line
|
||||
|
@ -66,6 +72,18 @@ commands {
|
|||
/// This is mainly used to display the final
|
||||
/// output of evaluated lua commands.
|
||||
declare display_output: fn(String),
|
||||
|
||||
/// This namespace is used to store some command specific data (like functions, as
|
||||
/// ensuring memory locations stay allocated in garbage collected language is hard)
|
||||
///
|
||||
/// Treat it as an implementation detail
|
||||
namespace __private {
|
||||
/// This command is a no-op, it's just here to ensure that the '__private'
|
||||
/// namespace get actually created
|
||||
// FIXME(@soispha): Add an attribute to namespaces to avoid having to use
|
||||
// empty functions <2023-10-14>
|
||||
declare __private_initializer: fn(),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
|
@ -3,7 +3,23 @@ use std::collections::HashMap;
|
|||
use cli_log::info;
|
||||
use mlua::{ErrorContext, FromLua, IntoLua, LuaSerdeExt, Value};
|
||||
|
||||
use super::{CommandTransferValue, Table};
|
||||
use super::{support_types::Function, CommandTransferValue, Table};
|
||||
|
||||
impl<'lua> FromLua<'lua> for Function {
|
||||
fn from_lua(value: Value<'lua>, lua: &'lua mlua::Lua) -> mlua::Result<Self> {
|
||||
match value {
|
||||
Value::String(function_id) => {
|
||||
return Ok(Function::new(
|
||||
function_id
|
||||
.to_str()
|
||||
.context("Failed to convert lua string to rust string")?
|
||||
.to_owned(),
|
||||
))
|
||||
}
|
||||
_ => unreachable!("The Function type can only take functions!"),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
impl<'lua> IntoLua<'lua> for CommandTransferValue {
|
||||
fn into_lua(self, lua: &'lua mlua::Lua) -> mlua::Result<mlua::Value<'lua>> {
|
||||
|
|
|
@ -2,6 +2,9 @@ use std::{collections::HashMap, fmt::Display};
|
|||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use self::support_types::Function;
|
||||
|
||||
pub mod support_types;
|
||||
pub mod type_conversions;
|
||||
|
||||
// language support
|
||||
|
@ -30,9 +33,10 @@ pub enum CommandTransferValue {
|
|||
|
||||
/// A table, dictionary or HashMap
|
||||
Table(HashMap<String, CommandTransferValue>),
|
||||
// Reference to a Lua function (or closure).
|
||||
// /* TODO */ Function(Function),
|
||||
|
||||
/// Reference to a function (or closure).
|
||||
/// This 'Function' value is obtained by registering a function with 'register_function()'
|
||||
Function(Function),
|
||||
// Reference to a Lua thread (or coroutine).
|
||||
// /* TODO */ Thread(Thread<'lua>),
|
||||
|
||||
|
@ -55,6 +59,7 @@ impl Display for CommandTransferValue {
|
|||
// TODO(@Soispha): The following line should be a real display call, but how do you
|
||||
// format a HashMap?
|
||||
CommandTransferValue::Table(table) => f.write_str(&format!("{:#?}", table)),
|
||||
CommandTransferValue::Function(function) => function.fmt(f),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
use std::{
|
||||
fmt::Display,
|
||||
sync::atomic::{AtomicUsize, Ordering},
|
||||
};
|
||||
|
||||
use mlua::{ErrorContext, IntoLua, Lua, Table};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct Function {
|
||||
id: String,
|
||||
}
|
||||
|
||||
impl Function {
|
||||
pub fn new(function_uuid: String) -> Self {
|
||||
Function { id: function_uuid }
|
||||
}
|
||||
pub fn from_lua_function(function: mlua::Function, lua: &Lua) -> mlua::Result<Self> {
|
||||
// TODO(@soispha): Does this expose a vulnerability, as the ids are predictable? <2023-10-14>
|
||||
static COUNTER: AtomicUsize = AtomicUsize::new(0);
|
||||
let id = COUNTER.fetch_add(1, Ordering::Relaxed);
|
||||
|
||||
let globals = lua.globals();
|
||||
// This is always initialized, as the namespaces are specified in the 'command_list' module
|
||||
let private: Table = globals
|
||||
.get::<&str, mlua::Table>("trinitrix")
|
||||
.context("Failed to access 'trinitrix'")?
|
||||
.get::<&str, mlua::Table>("api")
|
||||
.context("Failed to access 'api'")?
|
||||
.get::<&str, mlua::Table>("raw")
|
||||
.context("Failed to access 'raw'")?
|
||||
.get::<&str, mlua::Table>("__private")
|
||||
.context("Failed to access '__private'")?;
|
||||
private.set(id, function)?;
|
||||
|
||||
Ok(Function::new(format!("{}", id)))
|
||||
}
|
||||
pub fn call(&self, lua: &Lua) -> mlua::Result<()> {
|
||||
// This is always initialized, as the namespaces are specified in the 'command_list' module
|
||||
let private: Table = lua
|
||||
.globals()
|
||||
.get::<&str, mlua::Table>("trinitrix")
|
||||
.context("Failed to access 'trinitrix'")?
|
||||
.get::<&str, mlua::Table>("api")
|
||||
.context("Failed to access 'api'")?
|
||||
.get::<&str, mlua::Table>("raw")
|
||||
.context("Failed to access 'raw'")?
|
||||
.get::<&str, mlua::Table>("__private")
|
||||
.context("Failed to access '__private'")?;
|
||||
let function: mlua::Function = private
|
||||
.get(self.id.clone())
|
||||
.context("Failed to get function associated with callback!")?;
|
||||
|
||||
function.call(()).context("Failed to call function")
|
||||
}
|
||||
}
|
||||
|
||||
impl<'lua> IntoLua<'lua> for Function {
|
||||
fn into_lua(self, lua: &'lua Lua) -> mlua::Result<mlua::Value<'lua>> {
|
||||
Ok(self.id.into_lua(lua)?)
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Function {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
self.id.fmt(f)
|
||||
}
|
||||
}
|
|
@ -2,10 +2,11 @@ use std::thread;
|
|||
|
||||
use anyhow::{Context, Result};
|
||||
use cli_log::{debug, error, info};
|
||||
use mlua::{Function, Value};
|
||||
use mlua::{ErrorContext, Lua, Value};
|
||||
use once_cell::sync::OnceCell;
|
||||
use tokio::{
|
||||
runtime::Builder,
|
||||
select,
|
||||
sync::{mpsc, Mutex},
|
||||
task::{self, LocalSet},
|
||||
};
|
||||
|
@ -21,12 +22,15 @@ use crate::app::{
|
|||
events::event_types::Event,
|
||||
};
|
||||
|
||||
use super::command_transfer_value::support_types::Function;
|
||||
|
||||
static LUA: OnceCell<Mutex<mlua::Lua>> = OnceCell::new();
|
||||
|
||||
/// This structure contains the necessary state for running an embedded Lua runtime (i.e.
|
||||
/// the tread, the Lua memory, etc.).
|
||||
pub struct LuaCommandManager {
|
||||
lua_command_tx: mpsc::Sender<String>,
|
||||
lua_function_tx: mpsc::Sender<Function>,
|
||||
}
|
||||
|
||||
impl LuaCommandManager {
|
||||
|
@ -36,10 +40,17 @@ impl LuaCommandManager {
|
|||
.await
|
||||
.expect("The receiver should not be dropped at this time");
|
||||
}
|
||||
pub async fn execute_function(&self, function: Function) {
|
||||
self.lua_function_tx
|
||||
.send(function)
|
||||
.await
|
||||
.expect("The receiver should not be dropped");
|
||||
}
|
||||
|
||||
pub fn new(event_call_tx: mpsc::Sender<Event>) -> Self {
|
||||
info!("Spawning lua code execution thread...");
|
||||
let (lua_command_tx, mut lua_command_rx) = mpsc::channel::<String>(256);
|
||||
let (lua_function_tx, mut lua_function_rx) = mpsc::channel::<Function>(256);
|
||||
thread::spawn(move || {
|
||||
let rt = Builder::new_current_thread().enable_all().build().expect(
|
||||
"Should always be able to build \
|
||||
|
@ -51,43 +62,112 @@ impl LuaCommandManager {
|
|||
"Lua command handling initialized, \
|
||||
waiting for commands.."
|
||||
);
|
||||
while let Some(command) = lua_command_rx.recv().await {
|
||||
debug!("Recieved lua code (in LuaCommandHandler): {}", &command);
|
||||
let local_event_call_tx = event_call_tx.clone();
|
||||
let mut done = false;
|
||||
let moved_event_call_tx = event_call_tx.clone();
|
||||
while !done {
|
||||
select! {
|
||||
command = lua_command_rx.recv() => {
|
||||
if let Some(command) = command {
|
||||
debug!("Received lua code (in LuaCommandHandler): {}", &command);
|
||||
let local_event_call_tx = moved_event_call_tx.clone();
|
||||
|
||||
task::spawn_local(async move {
|
||||
exec_lua(&command, local_event_call_tx).await.expect(
|
||||
"This should return all relevent errors \
|
||||
by other messages, \
|
||||
this should never error",
|
||||
);
|
||||
});
|
||||
task::spawn_local(async move {
|
||||
exec_lua(&command, local_event_call_tx).await.expect(
|
||||
"This should return all relevent errors \
|
||||
by other messages, \
|
||||
this should never error",
|
||||
);
|
||||
});
|
||||
} else {
|
||||
done = true;
|
||||
}
|
||||
},
|
||||
function_id = lua_function_rx.recv() => {
|
||||
if let Some(function_id) = function_id {
|
||||
debug!("Received lua function (in LuaCommandHandler): {}", &function_id);
|
||||
let local_event_call_tx = moved_event_call_tx.clone();
|
||||
|
||||
task::spawn_local(async move {
|
||||
let lua = initialize_lua(local_event_call_tx.clone()).await;
|
||||
let out = function_id.call(&lua).map_err(|err| async move {
|
||||
error!("Lua function `{}` returned error: `{}`", function_id, err);
|
||||
local_event_call_tx
|
||||
.send(Event::CommandEvent(
|
||||
Command::Trinitrix(Api(Raw(RaiseError(err.to_string())))),
|
||||
None,
|
||||
))
|
||||
.await.expect(
|
||||
"This should return all relevent errors \
|
||||
by other messages, \
|
||||
this should never error",
|
||||
);
|
||||
});
|
||||
if let Err(err) = out {
|
||||
err.await;
|
||||
}
|
||||
});
|
||||
} else {
|
||||
done = true;
|
||||
}
|
||||
},
|
||||
else => done = true,
|
||||
}
|
||||
}
|
||||
});
|
||||
rt.block_on(local);
|
||||
});
|
||||
|
||||
LuaCommandManager { lua_command_tx }
|
||||
LuaCommandManager {
|
||||
lua_command_tx,
|
||||
lua_function_tx,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn exec_lua(lua_code: &str, event_call_tx: mpsc::Sender<Event>) -> Result<()> {
|
||||
let second_event_call_tx = event_call_tx.clone();
|
||||
async fn initialize_lua<'a>(
|
||||
event_call_tx: mpsc::Sender<Event>,
|
||||
) -> matrix_sdk::locks::MutexGuard<'a, Lua> {
|
||||
let lua = LUA
|
||||
.get_or_init(|| {
|
||||
Mutex::new(add_lua_functions_to_globals(
|
||||
mlua::Lua::new(),
|
||||
second_event_call_tx,
|
||||
))
|
||||
let lua: Lua = add_lua_functions_to_globals(mlua::Lua::new(), event_call_tx);
|
||||
{
|
||||
let wrapped_register_function = lua
|
||||
.create_function(
|
||||
|lua: &Lua, function: mlua::Function| -> mlua::Result<Function> {
|
||||
ErrorContext::context(
|
||||
Function::from_lua_function(function, lua),
|
||||
"Failed to register function",
|
||||
)
|
||||
},
|
||||
)
|
||||
.expect("This always works, as the function is static");
|
||||
let trinitrix_api: mlua::Table = lua
|
||||
.globals()
|
||||
.get::<&str, mlua::Table>("trinitrix")
|
||||
.expect("This was set in the add_lua_functions_to_globals function")
|
||||
.get::<&str, mlua::Table>("api")
|
||||
.expect("Same reason");
|
||||
trinitrix_api
|
||||
.set("register_function", wrapped_register_function)
|
||||
.expect("This should work");
|
||||
}
|
||||
|
||||
Mutex::new(lua)
|
||||
})
|
||||
.lock()
|
||||
.await;
|
||||
lua
|
||||
}
|
||||
|
||||
async fn exec_lua(lua_code: &str, event_call_tx: mpsc::Sender<Event>) -> Result<()> {
|
||||
let lua = initialize_lua(event_call_tx.clone()).await;
|
||||
|
||||
info!("Recieved code to execute: `{}`, executing...", &lua_code);
|
||||
let output = lua.load(lua_code).eval_async::<Value>().await;
|
||||
match output {
|
||||
Ok(out) => {
|
||||
let to_string_fn: Function = lua.globals().get("tostring").expect("This always exists");
|
||||
let to_string_fn: mlua::Function =
|
||||
lua.globals().get("tostring").expect("This always exists");
|
||||
let output: String = to_string_fn.call(out).expect("tostring should not error");
|
||||
info!("Lua code `{}` evaluated to: `{}`", lua_code, &output);
|
||||
|
||||
|
|
|
@ -160,6 +160,10 @@ pub async fn handle(
|
|||
send_status_output!(output);
|
||||
EventStatus::Ok
|
||||
}
|
||||
Raw::Private(_) => {
|
||||
// no-op, read the comment about it in the `command_list`
|
||||
EventStatus::Ok
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
Loading…
Reference in New Issue