diff options
author | Gary Guo <gary@garyguo.net> | 2022-11-10 17:41:18 +0100 |
---|---|---|
committer | Miguel Ojeda <ojeda@kernel.org> | 2022-12-04 01:59:15 +0100 |
commit | b44becc5ee808e02bbda0f90ee0584f206693a33 (patch) | |
tree | 88c3ec578019f9a208d9a6251066d136eae97fe6 /rust/macros | |
parent | 60f18c225f5f6939e348c4b067711d10afd33151 (diff) |
rust: macros: add `#[vtable]` proc macro
This procedural macro attribute provides a simple way to declare
a trait with a set of operations that later users can partially
implement, providing compile-time `HAS_*` boolean associated
constants that indicate whether a particular operation was overridden.
This is useful as the Rust counterpart to structs like
`file_operations` where some pointers may be `NULL`, indicating
an operation is not provided.
For instance:
#[vtable]
trait Operations {
fn read(...) -> Result<usize> {
Err(EINVAL)
}
fn write(...) -> Result<usize> {
Err(EINVAL)
}
}
#[vtable]
impl Operations for S {
fn read(...) -> Result<usize> {
...
}
}
assert_eq!(<S as Operations>::HAS_READ, true);
assert_eq!(<S as Operations>::HAS_WRITE, false);
Signed-off-by: Gary Guo <gary@garyguo.net>
Reviewed-by: Sergio González Collado <sergio.collado@gmail.com>
[Reworded, adapted for upstream and applied latest changes]
Signed-off-by: Miguel Ojeda <ojeda@kernel.org>
Diffstat (limited to 'rust/macros')
-rw-r--r-- | rust/macros/lib.rs | 52 | ||||
-rw-r--r-- | rust/macros/vtable.rs | 95 |
2 files changed, 147 insertions, 0 deletions
diff --git a/rust/macros/lib.rs b/rust/macros/lib.rs index 15555e7ff487..e40caaf0a656 100644 --- a/rust/macros/lib.rs +++ b/rust/macros/lib.rs @@ -5,6 +5,7 @@ mod concat_idents; mod helpers; mod module; +mod vtable; use proc_macro::TokenStream; @@ -72,6 +73,57 @@ pub fn module(ts: TokenStream) -> TokenStream { module::module(ts) } +/// Declares or implements a vtable trait. +/// +/// Linux's use of pure vtables is very close to Rust traits, but they differ +/// in how unimplemented functions are represented. In Rust, traits can provide +/// default implementation for all non-required methods (and the default +/// implementation could just return `Error::EINVAL`); Linux typically use C +/// `NULL` pointers to represent these functions. +/// +/// This attribute is intended to close the gap. Traits can be declared and +/// implemented with the `#[vtable]` attribute, and a `HAS_*` associated constant +/// will be generated for each method in the trait, indicating if the implementor +/// has overridden a method. +/// +/// This attribute is not needed if all methods are required. +/// +/// # Examples +/// +/// ```ignore +/// use kernel::prelude::*; +/// +/// // Declares a `#[vtable]` trait +/// #[vtable] +/// pub trait Operations: Send + Sync + Sized { +/// fn foo(&self) -> Result<()> { +/// Err(EINVAL) +/// } +/// +/// fn bar(&self) -> Result<()> { +/// Err(EINVAL) +/// } +/// } +/// +/// struct Foo; +/// +/// // Implements the `#[vtable]` trait +/// #[vtable] +/// impl Operations for Foo { +/// fn foo(&self) -> Result<()> { +/// # Err(EINVAL) +/// // ... +/// } +/// } +/// +/// assert_eq!(<Foo as Operations>::HAS_FOO, true); +/// assert_eq!(<Foo as Operations>::HAS_BAR, false); +/// ``` +#[proc_macro_attribute] +pub fn vtable(attr: TokenStream, ts: TokenStream) -> TokenStream { + vtable::vtable(attr, ts) +} + /// Concatenate two identifiers. /// /// This is useful in macros that need to declare or reference items with names diff --git a/rust/macros/vtable.rs b/rust/macros/vtable.rs new file mode 100644 index 000000000000..34d5e7fb5768 --- /dev/null +++ b/rust/macros/vtable.rs @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: GPL-2.0 + +use proc_macro::{Delimiter, Group, TokenStream, TokenTree}; +use std::collections::HashSet; +use std::fmt::Write; + +pub(crate) fn vtable(_attr: TokenStream, ts: TokenStream) -> TokenStream { + let mut tokens: Vec<_> = ts.into_iter().collect(); + + // Scan for the `trait` or `impl` keyword. + let is_trait = tokens + .iter() + .find_map(|token| match token { + TokenTree::Ident(ident) => match ident.to_string().as_str() { + "trait" => Some(true), + "impl" => Some(false), + _ => None, + }, + _ => None, + }) + .expect("#[vtable] attribute should only be applied to trait or impl block"); + + // Retrieve the main body. The main body should be the last token tree. + let body = match tokens.pop() { + Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace => group, + _ => panic!("cannot locate main body of trait or impl block"), + }; + + let mut body_it = body.stream().into_iter(); + let mut functions = Vec::new(); + let mut consts = HashSet::new(); + while let Some(token) = body_it.next() { + match token { + TokenTree::Ident(ident) if ident.to_string() == "fn" => { + let fn_name = match body_it.next() { + Some(TokenTree::Ident(ident)) => ident.to_string(), + // Possibly we've encountered a fn pointer type instead. + _ => continue, + }; + functions.push(fn_name); + } + TokenTree::Ident(ident) if ident.to_string() == "const" => { + let const_name = match body_it.next() { + Some(TokenTree::Ident(ident)) => ident.to_string(), + // Possibly we've encountered an inline const block instead. + _ => continue, + }; + consts.insert(const_name); + } + _ => (), + } + } + + let mut const_items; + if is_trait { + const_items = " + /// A marker to prevent implementors from forgetting to use [`#[vtable]`](vtable) + /// attribute when implementing this trait. + const USE_VTABLE_ATTR: (); + " + .to_owned(); + + for f in functions { + let gen_const_name = format!("HAS_{}", f.to_uppercase()); + // Skip if it's declared already -- this allows user override. + if consts.contains(&gen_const_name) { + continue; + } + // We don't know on the implementation-site whether a method is required or provided + // so we have to generate a const for all methods. + write!( + const_items, + "/// Indicates if the `{f}` method is overridden by the implementor. + const {gen_const_name}: bool = false;", + ) + .unwrap(); + } + } else { + const_items = "const USE_VTABLE_ATTR: () = ();".to_owned(); + + for f in functions { + let gen_const_name = format!("HAS_{}", f.to_uppercase()); + if consts.contains(&gen_const_name) { + continue; + } + write!(const_items, "const {gen_const_name}: bool = true;").unwrap(); + } + } + + let new_body = vec![const_items.parse().unwrap(), body.stream()] + .into_iter() + .collect(); + tokens.push(TokenTree::Group(Group::new(Delimiter::Brace, new_body))); + tokens.into_iter().collect() +} |