Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disallow duplicated enum variant indices #628

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ The derive implementation supports the following attributes:
- `codec(encoded_as = "OtherType")`: Needs to be placed above a field and makes the field being
encoded by using `OtherType`.
- `codec(index = 0)`: Needs to be placed above an enum variant to make the variant use the given
index when encoded. By default the index is determined by counting from `0` beginning wth the
index when encoded. By default the index is determined by counting from `0` beginning with the
first variant.
- `codec(encode_bound)`, `codec(decode_bound)` and `codec(mel_bound)`: All 3 attributes take
in a `where` clause for the `Encode`, `Decode` and `MaxEncodedLen` trait implementation for
Expand Down
23 changes: 16 additions & 7 deletions derive/src/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,18 @@ pub fn quote(
Ok(variants) => variants,
Err(e) => return e.to_compile_error(),
};

let recurse = variants.iter().enumerate().map(|(i, v)| {
match utils::check_indexes(variants.iter()).map_err(|e| e.to_compile_error()) {
Ok(()) => (),
Err(e) => return e,
};
let mut items = vec![];
for (index, v) in variants.iter().enumerate() {
let name = &v.ident;
let index = utils::variant_index(v, i);
let index = match utils::variant_index(v, index).map_err(|e| e.into_compile_error())
{
Ok(i) => i,
Err(e) => return e,
};

let create = create_instance(
quote! { #type_name #type_generics :: #name },
Expand All @@ -57,7 +65,7 @@ pub fn quote(
crate_path,
);

quote_spanned! { v.span() =>
let item = quote_spanned! { v.span() =>
#[allow(clippy::unnecessary_cast)]
__codec_x_edqy if __codec_x_edqy == #index as ::core::primitive::u8 => {
// NOTE: This lambda is necessary to work around an upstream bug
Expand All @@ -68,8 +76,9 @@ pub fn quote(
#create
})();
},
}
});
};
items.push(item);
}

let read_byte_err_msg =
format!("Could not decode `{type_name}`, failed to read variant byte");
Expand All @@ -79,7 +88,7 @@ pub fn quote(
match #input.read_byte()
.map_err(|e| e.chain(#read_byte_err_msg))?
{
#( #recurse )*
#( #items )*
_ => {
#[allow(clippy::redundant_closure_call)]
return (move || {
Expand Down
26 changes: 17 additions & 9 deletions derive/src/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -306,12 +306,19 @@ fn impl_encode(data: &Data, type_name: &Ident, crate_path: &syn::Path) -> TokenS
if variants.is_empty() {
return quote!();
}

let recurse = variants.iter().enumerate().map(|(i, f)| {
match utils::check_indexes(variants.iter()).map_err(|e| e.to_compile_error()) {
Ok(()) => (),
Err(e) => return e,
};
let mut items = vec![];
for (index, f) in variants.iter().enumerate() {
let name = &f.ident;
let index = utils::variant_index(f, i);

match f.fields {
let index = match utils::variant_index(f, index).map_err(|e| e.into_compile_error())
{
Ok(i) => i,
Err(e) => return e,
};
let item = match f.fields {
Fields::Named(ref fields) => {
let fields = &fields.named;
let field_name = |_, ident: &Option<Ident>| quote!(#ident);
Expand Down Expand Up @@ -389,11 +396,12 @@ fn impl_encode(data: &Data, type_name: &Ident, crate_path: &syn::Path) -> TokenS

[hinting, encoding]
},
}
});
};
items.push(item)
}

let recurse_hinting = recurse.clone().map(|[hinting, _]| hinting);
let recurse_encoding = recurse.clone().map(|[_, encoding]| encoding);
let recurse_hinting = items.iter().map(|[hinting, _]| hinting);
let recurse_encoding = items.iter().map(|[_, encoding]| encoding);

let hinting = quote! {
// The variant index uses 1 byte.
Expand Down
4 changes: 2 additions & 2 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ fn wrap_with_dummy_const(
/// * if variant has attribute: `#[codec(index = "$n")]` then n
/// * else if variant has discriminant (like 3 in `enum T { A = 3 }`) then the discriminant.
/// * else its position in the variant set, excluding skipped variants, but including variant with
/// discriminant or attribute. Warning this position does collision with discriminant or attribute
/// index.
/// discriminant or attribute. Warning this position does collision with discriminant or attribute
/// index.
///
/// variant attributes:
/// * `#[codec(skip)]`: the variant is not encoded.
Expand Down
79 changes: 65 additions & 14 deletions derive/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
//! NOTE: attributes finder must be checked using check_attribute first,
//! otherwise the macro can panic.

use std::str::FromStr;
use std::{collections::HashMap, str::FromStr};

use proc_macro2::TokenStream;
use proc_macro2::{Span, TokenStream};
use quote::{quote, ToTokens};
use syn::{
parse::Parse, punctuated::Punctuated, spanned::Spanned, token, Attribute, Data, DataEnum,
DeriveInput, Field, Fields, FieldsNamed, FieldsUnnamed, Lit, Meta, MetaNameValue, NestedMeta,
Path, Variant,
DeriveInput, ExprLit, Field, Fields, FieldsNamed, FieldsUnnamed, Lit, Meta, MetaNameValue,
NestedMeta, Path, Variant,
};

fn find_meta_item<'a, F, R, I, M>(mut itr: I, mut pred: F) -> Option<R>
Expand All @@ -38,11 +38,60 @@ where
})
}

/// check usage of variant indexes with #[scale(index = $int)] attribute or
/// explicit discriminant on the variant
pub fn check_indexes<'a, I: Iterator<Item = &'a &'a Variant>>(values: I) -> syn::Result<()> {
let mut map: HashMap<u8, Span> = HashMap::new();
for (i, v) in values.enumerate() {
if let Some(index) = find_meta_item(v.attrs.iter(), |meta| {
if let NestedMeta::Meta(Meta::NameValue(ref nv)) = meta {
if nv.path.is_ident("index") {
if let Lit::Int(ref v) = nv.lit {
let byte = v
.base10_parse::<u8>()
.expect("Internal error, index attribute must have been checked");
return Some(byte);
}
}
}
None
}) {
if let Some(span) = map.insert(index, v.span()) {
let mut error = syn::Error::new(v.span(), "Duplicate variant index. qed");
error.combine(syn::Error::new(span, "Variant index already defined here."));
return Err(error)
}
} else {
match v.discriminant.as_ref() {
Some((_, syn::Expr::Lit(ExprLit { lit: syn::Lit::Int(lit_int), .. }))) => {
let index = lit_int
.base10_parse::<u8>()
.expect("Internal error, index attribute must have been checked");
if let Some(span) = map.insert(index, v.span()) {
let mut error = syn::Error::new(v.span(), "Duplicate variant index. qed");
error.combine(syn::Error::new(span, "Variant index already defined here."));
return Err(error)
}
},
Some((_, _)) => return Err(syn::Error::new(v.span(), "Invalid discriminant. qed")),
None =>
if let Some(span) = map.insert(i.try_into().unwrap(), v.span()) {
let mut error =
syn::Error::new(span, "Custom variant index is duplicated later. qed");
error.combine(syn::Error::new(v.span(), "Variant index derived here."));
return Err(error)
},
}
}
}
pkhry marked this conversation as resolved.
Show resolved Hide resolved
Ok(())
}

/// Look for a `#[scale(index = $int)]` attribute on a variant. If no attribute
/// is found, fall back to the discriminant or just the variant index.
pub fn variant_index(v: &Variant, i: usize) -> TokenStream {
pub fn variant_index(v: &Variant, index: usize) -> syn::Result<TokenStream> {
// first look for an attribute
let index = find_meta_item(v.attrs.iter(), |meta| {
let codec_index = find_meta_item(v.attrs.iter(), |meta| {
if let NestedMeta::Meta(Meta::NameValue(ref nv)) = meta {
if nv.path.is_ident("index") {
if let Lit::Int(ref v) = nv.lit {
Expand All @@ -56,14 +105,16 @@ pub fn variant_index(v: &Variant, i: usize) -> TokenStream {

None
});

// then fallback to discriminant or just index
index.map(|i| quote! { #i }).unwrap_or_else(|| {
v.discriminant
.as_ref()
.map(|(_, expr)| quote! { #expr })
.unwrap_or_else(|| quote! { #i })
})
if let Some(index) = codec_index {
Ok(quote! { #index })
} else {
match v.discriminant.as_ref() {
Some((_, expr @ syn::Expr::Lit(ExprLit { lit: syn::Lit::Int(_), .. }))) =>
Ok(quote! { #expr }),
Some((_, expr)) => Err(syn::Error::new(expr.span(), "Invalid discriminant. qed")),
None => Ok(quote! { #index }),
}
}
}

/// Look for a `#[codec(encoded_as = "SomeType")]` outer attribute on the given
Expand Down
1 change: 1 addition & 0 deletions tests/scale_codec_ui.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ fn scale_codec_ui_tests() {
let t = trybuild::TestCases::new();
t.compile_fail("tests/scale_codec_ui/*.rs");
t.pass("tests/scale_codec_ui/pass/*.rs");
t.compile_fail("tests/scale_codec_ui/fail/*.rs");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not put those tests at "tests/scale_codec_ui/" ?
or move all failing test in the fail dir.

(anyway I don't mind)

}
17 changes: 17 additions & 0 deletions tests/scale_codec_ui/fail/codec_duplicate_index.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#[derive(::parity_scale_codec::Decode, ::parity_scale_codec::Encode)]
#[codec(crate = ::parity_scale_codec)]
enum T {
A = 3,
#[codec(index = 3)]
B,
}

#[derive(::parity_scale_codec::Decode, ::parity_scale_codec::Encode)]
#[codec(crate = ::parity_scale_codec)]
enum T1 {
A,
#[codec(index = 0)]
B,
}

fn main() {}
23 changes: 23 additions & 0 deletions tests/scale_codec_ui/fail/codec_duplicate_index.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
error: Duplicate variant index. qed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QED means: quod erat demonstrandum = which was to be demonstrated.

I think it doesn't fit in an error message.

Maybe writing: scale codec error: Invalid variant index, the variant index is duplicated.

--> tests/scale_codec_ui/fail/codec_duplicate_index.rs:5:2
|
5 | #[codec(index = 3)]
| ^

error: Variant index already defined here.
--> tests/scale_codec_ui/fail/codec_duplicate_index.rs:4:2
|
4 | A = 3,
| ^

error: Duplicate variant index. qed
--> tests/scale_codec_ui/fail/codec_duplicate_index.rs:13:2
|
13 | #[codec(index = 0)]
| ^

error: Variant index already defined here.
--> tests/scale_codec_ui/fail/codec_duplicate_index.rs:12:2
|
12 | A,
| ^
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#[derive(::parity_scale_codec::Decode, ::parity_scale_codec::Encode)]
#[codec(crate = ::parity_scale_codec)]
enum T {
A = 1,
B,
}

#[derive(::parity_scale_codec::Decode, ::parity_scale_codec::Encode)]
#[codec(crate = ::parity_scale_codec)]
enum T2 {
#[codec(index = 1)]
A,
B,
}

fn main() {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
error: Custom variant index is duplicated later. qed
--> tests/scale_codec_ui/fail/discriminant_variant_counted_in_default_index.rs:4:2
|
4 | A = 1,
| ^

error: Variant index derived here.
--> tests/scale_codec_ui/fail/discriminant_variant_counted_in_default_index.rs:5:2
|
5 | B,
| ^

error: Custom variant index is duplicated later. qed
--> tests/scale_codec_ui/fail/discriminant_variant_counted_in_default_index.rs:11:2
|
11 | #[codec(index = 1)]
| ^

error: Variant index derived here.
--> tests/scale_codec_ui/fail/discriminant_variant_counted_in_default_index.rs:13:2
|
13 | B,
| ^
22 changes: 6 additions & 16 deletions tests/variant_number.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,6 @@
use parity_scale_codec::Encode;
use parity_scale_codec_derive::Encode as DeriveEncode;

#[test]
fn discriminant_variant_counted_in_default_index() {
#[derive(DeriveEncode)]
enum T {
A = 1,
B,
}

assert_eq!(T::A.encode(), vec![1]);
assert_eq!(T::B.encode(), vec![1]);
}

#[test]
fn skipped_variant_not_counted_in_default_index() {
#[derive(DeriveEncode)]
Expand All @@ -27,14 +15,16 @@ fn skipped_variant_not_counted_in_default_index() {
}

#[test]
fn index_attr_variant_counted_and_reused_in_default_index() {
fn index_attr_variant_duplicates_indices() {
// Tests codec index overriding and that variant indexes are without duplicates
#[derive(DeriveEncode)]
enum T {
#[codec(index = 0)]
A = 1,
#[codec(index = 1)]
A,
B,
B = 0,
}

assert_eq!(T::A.encode(), vec![1]);
assert_eq!(T::A.encode(), vec![0]);
assert_eq!(T::B.encode(), vec![1]);
}