diff --git a/Cargo.toml b/Cargo.toml index 132617d..3423f0a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,9 @@ categories = ["no-std"] edition = "2021" exclude = ["/tests", "/.github"] +[features] +serde = ["dep:serde", "bitflags-derive-macros/serde"] + [dependencies.bitflags-derive-macros] path = "macros" version = "0.0.3" @@ -18,6 +21,10 @@ version = "0.0.3" [dependencies.bitflags] version = "2" +[dependencies.serde] +version = "1" +optional = true + # Adding new library support to `bitflags-derive`: # # 1. Add an optional dependency here diff --git a/macros/Cargo.toml b/macros/Cargo.toml index 984883c..093941c 100644 --- a/macros/Cargo.toml +++ b/macros/Cargo.toml @@ -13,6 +13,9 @@ edition = "2021" [lib] proc-macro = true +[features] +serde = [] + [dependencies.proc-macro2] version = "1" diff --git a/macros/src/lib.rs b/macros/src/lib.rs index b4a1389..7549860 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -15,6 +15,9 @@ mod debug; mod display; mod from_str; +#[cfg(feature = "serde")] +mod serde; + /** Derive [`Debug`](https://doc.rust-lang.org/std/fmt/trait.Debug.html). @@ -48,6 +51,26 @@ pub fn derive_bitflags_from_str(item: proc_macro::TokenStream) -> proc_macro::To from_str::expand(syn::parse_macro_input!(item as syn::DeriveInput)).unwrap_or_compile_error() } +/** +Derive [`Serialize`](https://docs.rs/serde/latest/serde/trait.Serialize.html). +*/ +#[cfg(feature = "serde")] +#[proc_macro_derive(FlagsSerialize)] +pub fn derive_bitflags_serialize(item: proc_macro::TokenStream) -> proc_macro::TokenStream { + serde::serialize::expand(syn::parse_macro_input!(item as syn::DeriveInput)) + .unwrap_or_compile_error() +} + +/** +Derive [`Deserialize`](https://docs.rs/serde/latest/serde/trait.Deserialize.html). +*/ +#[cfg(feature = "serde")] +#[proc_macro_derive(FlagsDeserialize)] +pub fn derive_bitflags_deserialize(item: proc_macro::TokenStream) -> proc_macro::TokenStream { + serde::deserialize::expand(syn::parse_macro_input!(item as syn::DeriveInput)) + .unwrap_or_compile_error() +} + trait ResultExt { fn unwrap_or_compile_error(self) -> proc_macro::TokenStream; } diff --git a/macros/src/serde.rs b/macros/src/serde.rs new file mode 100644 index 0000000..cf5bf2d --- /dev/null +++ b/macros/src/serde.rs @@ -0,0 +1,44 @@ +pub(crate) mod serialize { + use proc_macro2::TokenStream; + + pub(crate) fn expand(item: syn::DeriveInput) -> Result { + let ident = item.ident; + let (impl_generics, ty_generics, where_clause) = item.generics.split_for_impl(); + + Ok( + quote!(impl #impl_generics bitflags_derive::__private::serde::Serialize for #ident #ty_generics #where_clause { + fn serialize(&self, serializer: S) -> bitflags_derive::__private::core::result::Result { + bitflags_derive::__private::serde::serialize(self, serializer) + } + }), + ) + } +} + +pub(crate) mod deserialize { + use proc_macro2::{Span, TokenStream}; + + pub(crate) fn expand(item: syn::DeriveInput) -> Result { + let ident = item.ident; + + let de_lt = syn::Lifetime::new("'bitflags_derive_de", Span::call_site()); + + let mut de_generics = item.generics.clone(); + de_generics + .params + .push_value(syn::GenericParam::Lifetime(syn::LifetimeParam::new( + de_lt.clone(), + ))); + + let (impl_generics, _, where_clause) = de_generics.split_for_impl(); + let (_, ty_generics, _) = item.generics.split_for_impl(); + + Ok( + quote!(impl #impl_generics bitflags_derive::__private::serde::Deserialize<#de_lt> for #ident #ty_generics #where_clause { + fn deserialize>(deserializer: D) -> bitflags_derive::__private::core::result::Result { + bitflags_derive::__private::serde::deserialize(deserializer) + } + }), + ) + } +} diff --git a/src/__private/serde.rs b/src/__private/serde.rs new file mode 100644 index 0000000..4159857 --- /dev/null +++ b/src/__private/serde.rs @@ -0,0 +1,80 @@ +extern crate serde; + +use bitflags::{ + parser::{self, ParseHex, WriteHex}, + Flags, +}; + +use core::{fmt, str}; + +pub use serde::{ + de::{Error, Visitor}, + Deserialize, Deserializer, Serialize, Serializer, +}; + +struct AsDisplay<'a, B>(pub(crate) &'a B); + +impl<'a, B: Flags> fmt::Display for AsDisplay<'a, B> +where + B::Bits: WriteHex, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + parser::to_writer(self.0, f) + } +} + +/** +Serialize a set of flags as a human-readable string or their underlying bits. + +Any unknown bits will be retained. +*/ +pub fn serialize(flags: &B, serializer: S) -> Result +where + B::Bits: WriteHex + Serialize, +{ + // Serialize human-readable flags as a string like `"A | B"` + if serializer.is_human_readable() { + serializer.collect_str(&AsDisplay(flags)) + } + // Serialize non-human-readable flags directly as the underlying bits + else { + flags.bits().serialize(serializer) + } +} + +/** +Deserialize a set of flags from a human-readable string or their underlying bits. + +Any unknown bits will be retained. +*/ +pub fn deserialize<'de, B: Flags, D: Deserializer<'de>>(deserializer: D) -> Result +where + B::Bits: ParseHex + Deserialize<'de>, +{ + if deserializer.is_human_readable() { + // Deserialize human-readable flags by parsing them from strings like `"A | B"` + struct FlagsVisitor(core::marker::PhantomData); + + impl<'de, B: Flags> Visitor<'de> for FlagsVisitor + where + B::Bits: ParseHex, + { + type Value = B; + + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + formatter.write_str("a string value of `|` separated flags") + } + + fn visit_str(self, flags: &str) -> Result { + parser::from_str(flags).map_err(|e| E::custom(e)) + } + } + + deserializer.deserialize_str(FlagsVisitor(Default::default())) + } else { + // Deserialize non-human-readable flags directly from the underlying bits + let bits = B::Bits::deserialize(deserializer)?; + + Ok(B::from_bits_retain(bits)) + } +} diff --git a/src/lib.rs b/src/lib.rs index 906f51b..9d47e75 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -50,6 +50,9 @@ pub use bitflags_derive_macros::*; #[doc(hidden)] pub mod __private { + #[cfg(feature = "serde")] + pub mod serde; + pub use bitflags; pub use core; } diff --git a/tests/ui/Cargo.toml b/tests/ui/Cargo.toml index 57e712b..43379ff 100644 --- a/tests/ui/Cargo.toml +++ b/tests/ui/Cargo.toml @@ -4,11 +4,18 @@ version = "0.0.0" edition = "2021" publish = false +[features] +serde = ["bitflags-derive/serde", "dep:serde_test"] + [dependencies.bitflags-derive] path = "../../" [dependencies.bitflags] version = "2" +[dependencies.serde_test] +version = "1" +optional = true + [dependencies.trybuild] version = "1" diff --git a/tests/ui/src/lib.rs b/tests/ui/src/lib.rs index 78dce11..33e5079 100644 --- a/tests/ui/src/lib.rs +++ b/tests/ui/src/lib.rs @@ -9,3 +9,6 @@ extern crate bitflags_derive; mod debug; mod display; mod from_str; + +#[cfg(feature = "serde")] +mod serde; diff --git a/tests/ui/src/serde.rs b/tests/ui/src/serde.rs new file mode 100644 index 0000000..5d30418 --- /dev/null +++ b/tests/ui/src/serde.rs @@ -0,0 +1,21 @@ +use serde_test::{assert_tokens, Configure, Token::*}; + +#[test] +fn derive_serialize_deserialize() { + bitflags! { + #[derive(FlagsSerialize, FlagsDeserialize, PartialEq, Eq, Debug)] + struct Flags: u8 { + const A = 1; + const B = 1 << 1; + const C = 1 << 2; + } + } + + assert_tokens(&Flags::empty().readable(), &[Str("")]); + + assert_tokens(&Flags::empty().compact(), &[U8(0)]); + + assert_tokens(&(Flags::A | Flags::B).readable(), &[Str("A | B")]); + + assert_tokens(&(Flags::A | Flags::B).compact(), &[U8(1 | 2)]); +}