From 3e8722433d348bb18ae5cd2ad08fc88ce5bb0303 Mon Sep 17 00:00:00 2001 From: Soispha Date: Sun, 16 Jul 2023 14:01:48 +0200 Subject: [PATCH] Fix(lua_macros): Expand to generate the required types and functions --- lua_macros/Cargo.toml | 5 +- lua_macros/src/generate_noop_lua_function.rs | 11 ++ lua_macros/src/lib.rs | 99 ++++++------ lua_macros/src/mark_as_ci_command.rs | 161 +++++++++++++++++++ lua_macros/src/struct_to_ci_enum.rs | 140 ++++++++++++++++ 5 files changed, 369 insertions(+), 47 deletions(-) create mode 100644 lua_macros/src/generate_noop_lua_function.rs create mode 100644 lua_macros/src/mark_as_ci_command.rs create mode 100644 lua_macros/src/struct_to_ci_enum.rs diff --git a/lua_macros/Cargo.toml b/lua_macros/Cargo.toml index bf1583b..cb754e2 100644 --- a/lua_macros/Cargo.toml +++ b/lua_macros/Cargo.toml @@ -4,9 +4,10 @@ version = "0.1.0" edition = "2021" [lib] -crate_type = ["proc-macro"] +proc-macro = true [dependencies] +convert_case = "0.6.0" proc-macro2 = "1.0.64" quote = "1.0.29" -syn = "2.0.25" +syn = { version = "2.0.25", features = ["extra-traits", "full", "parsing"] } diff --git a/lua_macros/src/generate_noop_lua_function.rs b/lua_macros/src/generate_noop_lua_function.rs new file mode 100644 index 0000000..eac9e0a --- /dev/null +++ b/lua_macros/src/generate_noop_lua_function.rs @@ -0,0 +1,11 @@ +use proc_macro2::TokenStream as TokenStream2; +use quote::ToTokens; + +// TODO: Do we need this noop? +pub fn generate_default_lua_function(input: &syn::Field) -> TokenStream2 { + let output: TokenStream2 = syn::parse(input.into_token_stream().into()) + .expect("This is generated from valid rust code, it should stay that way."); + + output +} + diff --git a/lua_macros/src/lib.rs b/lua_macros/src/lib.rs index 777665e..12fe422 100644 --- a/lua_macros/src/lib.rs +++ b/lua_macros/src/lib.rs @@ -1,59 +1,68 @@ +mod generate_noop_lua_function; +mod mark_as_ci_command; +mod struct_to_ci_enum; + +use generate_noop_lua_function::generate_default_lua_function; +use mark_as_ci_command::generate_final_function; use proc_macro::TokenStream; use proc_macro2::TokenStream as TokenStream2; -use quote::{format_ident, quote}; -use syn; +use quote::quote; +use struct_to_ci_enum::{generate_command_enum, generate_generate_ci_function}; +use syn::{self, ItemFn, Field, parse::Parser}; #[proc_macro_attribute] -pub fn generate_ci_functions(_: TokenStream, input: TokenStream) -> TokenStream { +pub fn turn_struct_to_ci_commands(_attrs: TokenStream, input: TokenStream) -> TokenStream { // Construct a representation of Rust code as a syntax tree // that we can manipulate let input = syn::parse(input) .expect("This should always be valid rust code, as it's extracted from direct code"); // Build the trait implementation - generate_generate_ci_functions(&input) + let generate_ci_function: TokenStream2 = generate_generate_ci_function(&input); + + let command_enum = generate_command_enum(&input); + + quote! { + #command_enum + #generate_ci_function + } + .into() } -fn generate_generate_ci_functions(input: &syn::DeriveInput) -> TokenStream { - let input_tokens: TokenStream2 = match &input.data { - syn::Data::Struct(input) => match &input.fields { - syn::Fields::Named(named_fields) => named_fields - .named - .iter() - .map(|field| -> TokenStream2 { - let field_ident = field.ident.as_ref().expect( - "These are only the named field, thus they all should have a name.", - ); - let function_name_ident = format_ident!("fun_{}", field_ident); - let function_name = format!("{}", field_ident); - quote! { - let #function_name_ident = context.create_function(#field_ident).expect( - &format!( - "The function: `{}` should be defined", - #function_name - ) - ); - globals.set(#function_name, #function_name_ident).expect( - &format!( - "Setting a static global value ({}, fun_{}) should work", - #function_name, - #function_name - ) - ); - } - .into() - }) - .collect(), - _ => unimplemented!("Only implemented for named fileds"), - }, - _ => unimplemented!("Only for implemented for structs"), - }; +/// Generate a default lua function implementation. +#[proc_macro_attribute] +pub fn gen_lua_function(_attrs: TokenStream, input: TokenStream) -> TokenStream { + // Construct a representation of Rust code as a syntax tree + // that we can manipulate + // + let parser = Field::parse_named; + let input = parser.parse(input) + .expect("This is only defined for named fileds."); - let gen = quote! { - pub fn generate_ci_functions(context: &mut Context) { - let globals = context.globals(); - #input_tokens - } - }; - gen.into() + + // Build the trait implementation + let default_lua_function: TokenStream2 = generate_default_lua_function(&input); + + quote! { + #default_lua_function + } + .into() +} + +/// Turn a function into a valid ci command function +#[proc_macro_attribute] +pub fn ci_command(_attrs: TokenStream, input: TokenStream) -> TokenStream { + // Construct a representation of Rust code as a syntax tree + // that we can manipulate + let mut input: ItemFn = syn::parse(input) + .expect("This should always be valid rust code, as it's extracted from direct code"); + + // Build the trait implementation + let output_function: TokenStream2 = generate_final_function(&mut input); + + //panic!("{:#?}", output_function); + quote! { + #output_function + } + .into() } diff --git a/lua_macros/src/mark_as_ci_command.rs b/lua_macros/src/mark_as_ci_command.rs new file mode 100644 index 0000000..08b715a --- /dev/null +++ b/lua_macros/src/mark_as_ci_command.rs @@ -0,0 +1,161 @@ +use convert_case::{Case, Casing}; +use proc_macro2::TokenStream as TokenStream2; +use quote::{format_ident, quote, ToTokens}; +use syn::{Block, Expr, ExprBlock, GenericArgument, ReturnType, Stmt, Type}; + +pub fn generate_final_function(input: &mut syn::ItemFn) -> TokenStream2 { + append_tx_send_code(input); + + let output: TokenStream2 = syn::parse(input.into_token_stream().into()) + .expect("This is generated from valid rust code, it should stay that way."); + + output +} + +fn append_tx_send_code(input: &mut syn::ItemFn) -> &mut syn::ItemFn { + let function_name_pascal = format_ident!( + "{}", + input + .sig + .ident + .clone() + .to_string() + .from_case(Case::Snake) + .to_case(Case::Pascal) + ); + + let tx_send = match &input.sig.output { + syn::ReturnType::Default => { + todo!( + "Does this case even trigger? All functions should have a output of (Result<$type, rlua::Error>)" + ); + quote! { + { + let tx: std::sync::mpsc::Sender = + context + .named_registry_value("sender_for_ci_commands") + .expect("This exists, it was set before"); + + tx + .send(Event::CommandEvent(Command::#function_name_pascal)) + .expect("This should work, as the reciever is not dropped"); + } + } + } + syn::ReturnType::Type(_, ret_type) => { + let return_type = match *(ret_type.clone()) { + syn::Type::Path(path) => { + match path + .path + .segments + .first() + .expect("This is expected to be only one path segment") + .arguments + .to_owned() + { + syn::PathArguments::AngleBracketed(angled_path) => { + let angled_path = angled_path.args.to_owned(); + let filtered_paths: Vec<_> = angled_path + .into_iter() + .filter(|generic_arg| { + if let GenericArgument::Type(generic_type) = generic_arg { + if let Type::Path(_) = generic_type { + true + } else { + false + } + } else { + false + } + }) + .collect(); + + // There should only be two segments (the type is ) + if filtered_paths.len() > 2 { + unreachable!( + "There should be no more than two filtered_output, but got: {:#?}", + filtered_paths + ) + } else if filtered_paths.len() <= 0 { + unreachable!( + "There should be more than zero filtered_output, but got: {:#?}", + filtered_paths + ) + } + if filtered_paths.len() == 2 { + // There is something else than rlua + let gen_type = if let GenericArgument::Type(ret_type) = + filtered_paths + .first() + .expect("One path segment should exists") + .to_owned() + { + ret_type + } else { + unreachable!("These were filtered above."); + }; + let return_type_as_type_prepared = quote! {-> #gen_type}; + + let return_type_as_return_type: ReturnType = syn::parse( + return_type_as_type_prepared.to_token_stream().into(), + ) + .expect("This is valid."); + return_type_as_return_type + } else { + // There is only rlua + ReturnType::Default + } + } + _ => unimplemented!("Only for angled paths"), + } + } + _ => unimplemented!("Only for path types"), + }; + match return_type { + ReturnType::Default => { + quote! { + { + let tx: std::sync::mpsc::Sender = + context + .named_registry_value("sender_for_ci_commands") + .expect("This exists, it was set before"); + + tx + .send(Event::CommandEvent(Command::#function_name_pascal)) + .expect("This should work, as the reciever is not dropped"); + } + } + } + ReturnType::Type(_, _) => { + quote! { + { + let tx: std::sync::mpsc::Sender = + context + .named_registry_value("sender_for_ci_commands") + .expect("This exists, it was set before"); + + tx + .send(Event::CommandEvent(Command::#function_name_pascal(input_str))) + .expect("This should work, as the reciever is not dropped"); + } + } + } + } + } + }; + + let tx_send_block: Block = + syn::parse(tx_send.into()).expect("This is a static string, it will always parse"); + let tx_send_expr_block = ExprBlock { + attrs: vec![], + label: None, + block: tx_send_block, + }; + let mut tx_send_stmt = vec![Stmt::Expr(Expr::Block(tx_send_expr_block), None)]; + + let mut new_stmts: Vec = Vec::with_capacity(input.block.stmts.len() + 1); + new_stmts.append(&mut tx_send_stmt); + new_stmts.append(&mut input.block.stmts); + input.block.stmts = new_stmts; + input +} diff --git a/lua_macros/src/struct_to_ci_enum.rs b/lua_macros/src/struct_to_ci_enum.rs new file mode 100644 index 0000000..f231249 --- /dev/null +++ b/lua_macros/src/struct_to_ci_enum.rs @@ -0,0 +1,140 @@ +use convert_case::{Case, Casing}; +use proc_macro2::TokenStream as TokenStream2; +use quote::{format_ident, quote}; +use syn::{self, ReturnType}; + +pub fn generate_generate_ci_function(input: &syn::DeriveInput) -> TokenStream2 { + let mut functions_to_generate: Vec = vec![]; + let input_tokens: TokenStream2 = match &input.data { + syn::Data::Struct(input) => match &input.fields { + syn::Fields::Named(named_fields) => named_fields + .named + .iter() + .map(|field| -> TokenStream2 { + if field.attrs.iter().any(|attribute| { + attribute.path() + == &syn::parse_str::("gen_default_lua_function") + .expect("This is valid rust code") + }) { + let function_name = field + .ident + .as_ref() + .expect("These are only the named field, thus they all should have a name."); + functions_to_generate.push(quote! { + #[ci_command] + fn #function_name(context: Context, input_str: String) -> Result<(), rlua::Error> { + Ok(()) + } + }); + generate_ci_part(field) + } else { + generate_ci_part(field) + } + }) + .collect(), + + _ => unimplemented!("Only implemented for named fileds"), + }, + _ => unimplemented!("Only implemented for structs"), + }; + + let functions_to_generate: TokenStream2 = functions_to_generate.into_iter().collect(); + let gen = quote! { + pub fn generate_ci_functions( + context: &mut rlua::Context, + tx: std::sync::mpsc::Sender) + { + context.set_named_registry_value("sender_for_ci_commands", tx).expect("This should always work, as the value is added before all else"); + let globals = context.globals(); + #input_tokens + } + #functions_to_generate + }; + gen.into() +} + +fn generate_ci_part(field: &syn::Field) -> TokenStream2 { + let field_ident = field + .ident + .as_ref() + .expect("These are only the named field, thus they all should have a name."); + let function_name_ident = format_ident!("fun_{}", field_ident); + let function_name = format!("{}", field_ident); + quote! { + let #function_name_ident = context.create_function(#field_ident).expect( + &format!( + "The function: `{}` should be defined", + #function_name + ) + ); + + globals.set(#function_name, #function_name_ident).expect( + &format!( + "Setting a static global value ({}, fun_{}) should work", + #function_name, + #function_name + ) + ); + } + .into() +} + +pub fn generate_command_enum(input: &syn::DeriveInput) -> TokenStream2 { + let input_tokens: TokenStream2 = match &input.data { + syn::Data::Struct(input) => match &input.fields { + syn::Fields::Named(named_fields) => named_fields + .named + .iter() + .map(|field| -> TokenStream2 { + let field_ident = field + .ident + .as_ref() + .expect("These are only the named field, thus they all should have a name."); + + let enum_variant_type = match &field.ty { + syn::Type::BareFn(function) => { + let return_path: &ReturnType = &function.output; + match return_path { + ReturnType::Default => None, + ReturnType::Type(_, return_type) => Some(match *(return_type.to_owned()) { + syn::Type::Path(path_type) => path_type + .path + .get_ident() + .expect("A path should either be complete, or only conain one segment") + .to_owned(), + _ => unimplemented!("This is only implemented for path types"), + }), + } + } + _ => unimplemented!("This is only implemented for bare function types"), + }; + + let enum_variant_name = format_ident!( + "{}", + field_ident.to_string().from_case(Case::Snake).to_case(Case::Pascal) + ); + if enum_variant_type.is_some() { + quote! { + #enum_variant_name (#enum_variant_type), + } + .into() + } else { + quote! { + #enum_variant_name, + } + } + }) + .collect(), + _ => unimplemented!("Only implemented for named fileds"), + }, + _ => unimplemented!("Only implemented for structs"), + }; + + let gen = quote! { + #[derive(Debug)] + pub enum Command { + #input_tokens + } + }; + gen.into() +}