Skip to content

Commit

Permalink
Move custom underlying type specification to strong_type attribute (#56)
Browse files Browse the repository at this point in the history
  • Loading branch information
yunjhongwu authored Feb 16, 2024
1 parent 116d5ee commit ba5de87
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 77 deletions.
6 changes: 2 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,10 @@ println!("{:?}", Second::new(std::f64::consts::E)); // "Second { value: 2.718281
struct Dollar(i32);

#[derive(StrongType)]
#[strong_type(auto_operators)]
#[custom_underlying(i32)]
#[strong_type(auto_operators, underlying = i32)]
struct Cash(Dollar);

#[derive(StrongType)]
#[strong_type(auto_operators)]
#[custom_underlying(i32)]
#[strong_type(underlying = i32)]
struct Coin(Cash);
```
8 changes: 7 additions & 1 deletion strong-type-derive/src/detail/basic.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
use proc_macro2::TokenStream;
use quote::quote;

pub(crate) fn implement_basic(name: &syn::Ident, value_type: &syn::Ident) -> TokenStream {
pub(crate) fn implement_basic(
name: &syn::Ident,
value_type: &syn::Ident,
primitive_type: &syn::Ident,
) -> TokenStream {
quote! {
impl #name {
pub fn new(value: impl Into<#value_type>) -> Self {
Expand All @@ -11,6 +15,8 @@ pub(crate) fn implement_basic(name: &syn::Ident, value_type: &syn::Ident) -> Tok

impl StrongType for #name {
type UnderlyingType = #value_type;
type PrimitiveType = #primitive_type;

}

impl std::fmt::Debug for #name {
Expand Down
86 changes: 30 additions & 56 deletions strong-type-derive/src/detail/underlying_type_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,68 +16,21 @@ pub(crate) enum ValueTypeGroup {
pub(crate) struct TypeInfo {
pub primitive_type: syn::Ident,
pub value_type: syn::Ident,
pub type_group: ValueTypeGroup,
}

fn get_type_ident(input: &DeriveInput) -> Option<syn::Ident> {
if let Data::Struct(ref data_struct) = input.data {
if let Type::Path(ref path) = &data_struct.fields.iter().next().unwrap().ty {
return Some(path.path.segments.last().unwrap().ident.clone());
}
}
None
}

fn get_primitive_from_custom_underlying(input: &DeriveInput) -> Option<syn::Ident> {
for attr in input.attrs.iter() {
if attr.path().is_ident("custom_underlying") {
let mut primitive = None;
attr.parse_nested_meta(|meta| match meta.path.get_ident() {
Some(ident) => {
primitive = Some(ident.clone());
Ok(())
}
None => Err(meta.error("Unsupported attribute")),
})
.ok()?;
return primitive;
}
}

None
}

pub(crate) fn get_type(input: &DeriveInput) -> TypeInfo {
if let Some(value_type) = get_type_ident(input) {
match get_primitive_from_custom_underlying(input) {
Some(primitive_type) => TypeInfo {
primitive_type: primitive_type.clone(),
value_type: value_type.clone(),
type_group: get_type_group(&primitive_type, UnderlyingType::Derived),
},
None => TypeInfo {
primitive_type: value_type.clone(),
value_type: value_type.clone(),
type_group: get_type_group(&value_type, UnderlyingType::Primitive),
},
}
} else {
panic!("Unsupported input")
}
pub type_group: Option<ValueTypeGroup>,
}

pub(crate) fn get_type_group(
value_type: &syn::Ident,
underlying_type: UnderlyingType,
) -> ValueTypeGroup {
) -> Option<ValueTypeGroup> {
if value_type == "i8"
|| value_type == "i16"
|| value_type == "i32"
|| value_type == "i64"
|| value_type == "i128"
|| value_type == "isize"
{
return ValueTypeGroup::Int(underlying_type);
return Some(ValueTypeGroup::Int(underlying_type));
}
if value_type == "u8"
|| value_type == "u16"
Expand All @@ -86,19 +39,40 @@ pub(crate) fn get_type_group(
|| value_type == "u128"
|| value_type == "usize"
{
return ValueTypeGroup::UInt(underlying_type);
return Some(ValueTypeGroup::UInt(underlying_type));
}
if value_type == "f32" || value_type == "f64" {
return ValueTypeGroup::Float(underlying_type);
return Some(ValueTypeGroup::Float(underlying_type));
}
if value_type == "bool" {
return ValueTypeGroup::Bool(underlying_type);
return Some(ValueTypeGroup::Bool(underlying_type));
}
if value_type == "char" {
return ValueTypeGroup::Char(underlying_type);
return Some(ValueTypeGroup::Char(underlying_type));
}
if value_type == "String" {
return ValueTypeGroup::String(underlying_type);
return Some(ValueTypeGroup::String(underlying_type));
}

None
}

fn get_type_ident(input: &DeriveInput) -> Option<syn::Ident> {
if let Data::Struct(ref data_struct) = input.data {
if let Type::Path(ref path) = &data_struct.fields.iter().next().unwrap().ty {
return Some(path.path.segments.last().unwrap().ident.clone());
}
}
None
}

pub(crate) fn get_type(input: &DeriveInput) -> TypeInfo {
if let Some(value_type) = get_type_ident(input) {
return TypeInfo {
primitive_type: value_type.clone(),
value_type: value_type.clone(),
type_group: get_type_group(&value_type, UnderlyingType::Primitive),
};
}
panic!("Unsupported type: {}", value_type);
panic!("Unable to find underlying value type");
}
23 changes: 19 additions & 4 deletions strong-type-derive/src/detail/utils.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,41 @@
use crate::detail::underlying_type_utils::get_type_group;
use crate::detail::{get_type, TypeInfo, UnderlyingType};
use syn::{Data, DeriveInput, Fields, Visibility};

#[derive(Default)]
pub(crate) struct StrongTypeAttributes {
pub has_auto_operators: bool,
pub has_custom_display: bool,
pub type_info: TypeInfo,
}

pub(crate) fn get_attributes(input: &DeriveInput) -> StrongTypeAttributes {
let mut attributes = StrongTypeAttributes::default();
let mut attributes = StrongTypeAttributes {
has_auto_operators: false,
has_custom_display: false,
type_info: get_type(input),
};

for attr in input.attrs.iter() {
if attr.path().is_ident("strong_type") {
if let Err(message) = attr.parse_nested_meta(|meta| {

if meta.path.is_ident("auto_operators") {
attributes.has_auto_operators = true;
Ok(())
} else if meta.path.is_ident("custom_display") {
attributes.has_custom_display = true;
Ok(())
} else if meta.path.is_ident("underlying") {
if let Ok(strm) = meta.value() {
if let Ok(primitive_type) = strm.parse::<syn::Ident>() {
attributes.type_info.type_group = get_type_group(&primitive_type, UnderlyingType::Derived);
attributes.type_info.primitive_type = primitive_type;
} else {
panic!("Failed to parse custom underlying {}", strm);
}
}
Ok(())
} else {
Err(meta.error(format!("Invalid strong_type attribute {}, should be one of {{auto_operators, custom_display}}",
Err(meta.error(format!("Invalid strong_type attribute {}, should be one of {{auto_operators, custom_display, underlying=type}}",
meta.path.get_ident().expect("Failed to parse strong_type attributes."))))
}
}) {
Expand Down
17 changes: 10 additions & 7 deletions strong-type-derive/src/strong_type.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::detail::{
get_attributes, get_type, implement_arithmetic, implement_basic, implement_basic_primitive,
get_attributes, implement_arithmetic, implement_basic, implement_basic_primitive,
implement_basic_string, implement_bit_shift, implement_bool_ops, implement_constants,
implement_constants_derived, implement_display, implement_hash, implement_infinity,
implement_limit, implement_nan, implement_negate, implement_primitive_accessor,
Expand All @@ -17,18 +17,21 @@ pub(super) fn expand_strong_type(input: DeriveInput) -> TokenStream {
}

let name = &input.ident;
let TypeInfo {
primitive_type,
value_type,
type_group,
} = get_type(&input);
let StrongTypeAttributes {
has_auto_operators,
has_custom_display,
type_info:
TypeInfo {
primitive_type,
value_type,
type_group,
},
} = get_attributes(&input);
let type_group = type_group
.unwrap_or_else(|| panic!("Unable to determine the primitive type of {}", value_type));

let mut ast = quote!();
ast.extend(implement_basic(name, &value_type));
ast.extend(implement_basic(name, &value_type, &primitive_type));

if !has_custom_display {
ast.extend(implement_display(name));
Expand Down
9 changes: 4 additions & 5 deletions strong-type-tests/tests/custom_underlying.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ mod tests {
struct Dollar(i32);

#[derive(StrongType)]
#[strong_type(auto_operators)]
#[custom_underlying(i32)]
#[strong_type(auto_operators, underlying=i32)]
struct Cash(Dollar);
test_type::<Cash>();
assert_eq!(mem::size_of::<Cash>(), mem::size_of::<i32>());
Expand All @@ -30,7 +29,7 @@ mod tests {
);

#[derive(StrongType)]
#[custom_underlying(i32)]
#[strong_type(underlying=i32)]
struct Coin(Cash);
test_type::<Coin>();
assert_eq!(mem::size_of::<Coin>(), mem::size_of::<i32>());
Expand All @@ -51,7 +50,7 @@ mod tests {
struct Tag(String);

#[derive(StrongType)]
#[custom_underlying(String)]
#[strong_type(underlying=String)]
struct Name(Tag);

test_type::<Name>();
Expand All @@ -62,7 +61,7 @@ mod tests {
);

#[derive(StrongType)]
#[custom_underlying(String)]
#[strong_type(underlying=String)]
struct Surname(Name);
assert_eq!(mem::size_of::<Surname>(), mem::size_of::<String>());
assert_eq!(
Expand Down
1 change: 1 addition & 0 deletions strong-type/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ pub use strong_type_derive::StrongType;

pub trait StrongType: Debug + PartialEq + PartialOrd + Clone + Default + Send + Sync {
type UnderlyingType: Default;
type PrimitiveType;
}

0 comments on commit ba5de87

Please sign in to comment.