1
0
mirror of https://github.com/danog/strum.git synced 2024-11-26 12:04:38 +01:00

EnumDiscriminant inherits the repr and discriminant values (#288)

* moved repr extraction to helpers

* repr pass-through added

* add discriminant pass through

* remove dev artifact

---------

Co-authored-by: Jason Scatena <jscatena@amazon.com>
This commit is contained in:
Jason Scatena 2023-10-28 21:57:36 -04:00 committed by GitHub
parent d32af44f2e
commit e8b2ff1b48
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 98 additions and 51 deletions

View File

@ -21,6 +21,7 @@ pub struct StrumTypeProperties {
pub discriminant_others: Vec<TokenStream>,
pub discriminant_vis: Option<Visibility>,
pub use_phf: bool,
pub enum_repr: Option<TokenStream>,
}
impl HasTypeProperties for DeriveInput {
@ -103,6 +104,17 @@ impl HasTypeProperties for DeriveInput {
}
}
let attrs = &self.attrs;
for attr in attrs {
if let Ok(list) = attr.meta.require_list() {
if let Some(ident) = list.path.get_ident() {
if ident == "repr" {
output.enum_repr = Some(list.tokens.clone())
}
}
}
}
Ok(output)
}
}

View File

@ -40,10 +40,16 @@ pub fn enum_discriminants_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
// Pass through all other attributes
let pass_though_attributes = type_properties.discriminant_others;
let repr = type_properties.enum_repr.map(|repr| quote!(#[repr(#repr)]));
// Add the variants without fields, but exclude the `strum` meta item
let mut discriminants = Vec::new();
for variant in variants {
let ident = &variant.ident;
let discriminant = variant
.discriminant
.as_ref()
.map(|(_, expr)| quote!( = #expr));
// Don't copy across the "strum" meta attribute. Only passthrough the whitelisted
// attributes and proxy `#[strum_discriminants(...)]` attributes
@ -81,7 +87,7 @@ pub fn enum_discriminants_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
})
.collect::<Result<Vec<_>, _>>()?;
discriminants.push(quote! { #(#attrs)* #ident });
discriminants.push(quote! { #(#attrs)* #ident #discriminant});
}
// Ideally:
@ -153,6 +159,7 @@ pub fn enum_discriminants_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
Ok(quote! {
/// Auto-generated discriminant enum variants
#derives
#repr
#(#[ #pass_though_attributes ])*
#discriminants_vis enum #discriminants_name {
#(#discriminants),*

View File

@ -1,62 +1,32 @@
use heck::ToShoutySnakeCase;
use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote, ToTokens};
use syn::{Data, DeriveInput, Fields, PathArguments, Type, TypeParen};
use quote::{format_ident, quote};
use syn::{Data, DeriveInput, Fields, Type};
use crate::helpers::{non_enum_error, HasStrumVariantProperties};
use crate::helpers::{non_enum_error, HasStrumVariantProperties, HasTypeProperties};
pub fn from_repr_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
let name = &ast.ident;
let gen = &ast.generics;
let (impl_generics, ty_generics, where_clause) = gen.split_for_impl();
let vis = &ast.vis;
let attrs = &ast.attrs;
let mut discriminant_type: Type = syn::parse("usize".parse().unwrap()).unwrap();
for attr in attrs {
let path = attr.path();
let mut ts = if let Ok(ts) = attr
.meta
.require_list()
.map(|metas| metas.to_token_stream().into_iter())
{
ts
} else {
continue;
};
// Discard the path
let _ = ts.next();
let tokens: TokenStream = ts.collect();
if path.leading_colon.is_some() {
continue;
}
if path.segments.len() != 1 {
continue;
}
let segment = path.segments.first().unwrap();
if segment.ident != "repr" {
continue;
}
if segment.arguments != PathArguments::None {
continue;
}
let typ_paren = match syn::parse2::<Type>(tokens.clone()) {
Ok(Type::Paren(TypeParen { elem, .. })) => *elem,
_ => continue,
};
let inner_path = match &typ_paren {
Type::Path(t) => t,
_ => continue,
};
if let Some(seg) = inner_path.path.segments.last() {
for t in &[
"u8", "u16", "u32", "u64", "usize", "i8", "i16", "i32", "i64", "isize",
] {
if seg.ident == t {
discriminant_type = typ_paren;
break;
if let Some(type_path) = ast
.get_type_properties()
.ok()
.and_then(|tp| tp.enum_repr)
.and_then(|repr_ts| syn::parse2::<Type>(repr_ts).ok())
{
if let Type::Path(path) = type_path.clone() {
if let Some(seg) = path.path.segments.last() {
for t in &[
"u8", "u16", "u32", "u64", "usize", "i8", "i16", "i32", "i64", "isize",
] {
if seg.ident == t {
discriminant_type = type_path;
break;
}
}
}
}

View File

@ -1,6 +1,9 @@
use enum_variant_type::EnumVariantType;
use strum::{Display, EnumDiscriminants, EnumIter, EnumMessage, EnumString, IntoEnumIterator};
use std::mem::{align_of, size_of};
use enum_variant_type::EnumVariantType;
use strum::{
Display, EnumDiscriminants, EnumIter, EnumMessage, EnumString, FromRepr, IntoEnumIterator,
};
mod core {} // ensure macros call `::core`
@ -305,3 +308,58 @@ fn crate_module_path_test() {
assert_eq!(expected, discriminants);
}
#[allow(dead_code)]
#[derive(EnumDiscriminants)]
#[repr(u16)]
enum WithReprUInt {
Variant0,
Variant1,
}
#[test]
fn with_repr_uint() {
// These tests would not be proof of proper functioning on a 16 bit system
assert_eq!(size_of::<u16>(), size_of::<WithReprUIntDiscriminants>());
assert_eq!(
size_of::<WithReprUInt>(),
size_of::<WithReprUIntDiscriminants>()
)
}
#[allow(dead_code)]
#[derive(EnumDiscriminants)]
#[repr(align(16), u8)]
enum WithReprAlign {
Variant0,
Variant1,
}
#[test]
fn with_repr_align() {
assert_eq!(
align_of::<WithReprAlign>(),
align_of::<WithReprAlignDiscriminants>()
);
assert_eq!(16, align_of::<WithReprAlignDiscriminants>());
}
#[allow(dead_code)]
#[derive(EnumDiscriminants)]
#[strum_discriminants(derive(FromRepr))]
enum WithExplicitDicriminantValue {
Variant0 = 42 + 100,
Variant1 = 11,
}
#[test]
fn with_explicit_discriminant_value() {
assert_eq!(
WithExplicitDicriminantValueDiscriminants::from_repr(11),
Some(WithExplicitDicriminantValueDiscriminants::Variant1)
);
assert_eq!(
142,
WithExplicitDicriminantValueDiscriminants::Variant0 as u8
);
}