diff --git a/tests/ui/fail/loop_invariant_generic.rs b/tests/ui/fail/loop_invariant_generic.rs new file mode 100644 index 0000000..05c390e --- /dev/null +++ b/tests/ui/fail/loop_invariant_generic.rs @@ -0,0 +1,22 @@ +//@error-in-other-file: Unsat +//@compile-flags: -C debug-assertions=off + +#[thrust_macros::requires(true)] +#[thrust_macros::ensures(true)] +#[thrust::trusted] +fn rand() -> i64 { unimplemented!() } + +#[thrust_macros::context] +fn keep(v: T) { + let mut x = v; + while rand() == 0 { + thrust_macros::invariant!(|v: T| v == v); + x = v; + } + assert!(x == v); +} + +fn main() { + keep(0_i64); + keep(true); +} diff --git a/tests/ui/fail/loop_invariant_self.rs b/tests/ui/fail/loop_invariant_self.rs new file mode 100644 index 0000000..4fb50ba --- /dev/null +++ b/tests/ui/fail/loop_invariant_self.rs @@ -0,0 +1,29 @@ +//@error-in-other-file: Unsat +//@compile-flags: -C debug-assertions=off + +#[thrust_macros::requires(true)] +#[thrust_macros::ensures(true)] +#[thrust::trusted] +fn rand() -> i64 { unimplemented!() } + +struct Counter(i64); +impl thrust_models::Model for Counter { + type Ty = (thrust_models::model::Int,); +} + +#[thrust_macros::context] +impl Counter { + fn run(self) { + let mut c = self; + let mut x = 1_i64; + while rand() == 0 { + thrust_macros::invariant!(|x: i64, c: Self| x >= 2 && c == c); + x = x + 1; + c = Counter(0); + } + let _last = c; + assert!(x >= 1); + } +} + +fn main() { Counter(0).run(); } diff --git a/tests/ui/pass/loop_invariant_generic.rs b/tests/ui/pass/loop_invariant_generic.rs new file mode 100644 index 0000000..1cf831a --- /dev/null +++ b/tests/ui/pass/loop_invariant_generic.rs @@ -0,0 +1,22 @@ +//@check-pass +//@compile-flags: -C debug-assertions=off + +#[thrust_macros::requires(true)] +#[thrust_macros::ensures(true)] +#[thrust::trusted] +fn rand() -> i64 { unimplemented!() } + +#[thrust_macros::context] +fn keep(v: T) { + let mut x = v; + while rand() == 0 { + thrust_macros::invariant!(|x: T, v: T| x == v); + x = v; + } + assert!(x == v); +} + +fn main() { + keep(0_i64); + keep(true); +} diff --git a/tests/ui/pass/loop_invariant_self.rs b/tests/ui/pass/loop_invariant_self.rs new file mode 100644 index 0000000..acef38f --- /dev/null +++ b/tests/ui/pass/loop_invariant_self.rs @@ -0,0 +1,29 @@ +//@check-pass +//@compile-flags: -C debug-assertions=off + +#[thrust_macros::requires(true)] +#[thrust_macros::ensures(true)] +#[thrust::trusted] +fn rand() -> i64 { unimplemented!() } + +struct Counter(i64); +impl thrust_models::Model for Counter { + type Ty = (thrust_models::model::Int,); +} + +#[thrust_macros::context] +impl Counter { + fn run(self) { + let mut c = self; + let mut x = 1_i64; + while rand() == 0 { + thrust_macros::invariant!(|x: i64, c: Self| x >= 1 && c == c); + x = x + 1; + c = Counter(0); + } + let _last = c; + assert!(x >= 1); + } +} + +fn main() { Counter(0).run(); } diff --git a/thrust-macros/src/context.rs b/thrust-macros/src/context.rs new file mode 100644 index 0000000..7e1e17e --- /dev/null +++ b/thrust-macros/src/context.rs @@ -0,0 +1,210 @@ +//! Expansion of `#[thrust_macros::context]`. +//! +//! Supplies the surrounding context that thrust annotations within an item +//! cannot see by themselves: stamps the enclosing `impl`/`trait` header onto +//! each method (so method-level `requires`/`ensures` recover the outer +//! generics) and threads that generic context into every `invariant!(...)`. + +use proc_macro::TokenStream; +use proc_macro2::{TokenStream as TokenStream2, TokenTree}; +use quote::{quote, ToTokens}; +use syn::{parse_macro_input, GenericParam, Generics, WherePredicate}; + +use crate::FnOuterItem; + +pub(super) fn expand(item: TokenStream) -> TokenStream { + let mut item = parse_macro_input!(item as syn::Item); + process_context_item(&mut item); + item.into_token_stream().into() +} + +/// Stamps outer context onto methods and threads the generic context into the +/// invariants of every function body in the item (recursing through modules). +fn process_context_item(item: &mut syn::Item) { + match item { + syn::Item::Fn(item_fn) => { + let generics = item_fn.sig.generics.clone(); + let threaded = thread_invariants(&mut item_fn.block, &generics, None); + if threaded.found { + inject_model_bounds(&mut item_fn.sig.generics, None, false); + } + } + syn::Item::Impl(item_impl) => { + let outer = FnOuterItem::ItemImpl(item_impl.clone()).into_header_only(); + for impl_item in &mut item_impl.items { + let syn::ImplItem::Fn(method) = impl_item else { + continue; + }; + method + .attrs + .push(syn::parse_quote!(#[thrust::_outer_context(#outer)])); + let generics = method.sig.generics.clone(); + let threaded = thread_invariants(&mut method.block, &generics, Some(&outer)); + if threaded.found { + inject_model_bounds(&mut method.sig.generics, Some(&outer), threaded.self_used); + } + } + } + syn::Item::Trait(item_trait) => { + let outer = FnOuterItem::ItemTrait(item_trait.clone()).into_header_only(); + for trait_item in &mut item_trait.items { + let syn::TraitItem::Fn(method) = trait_item else { + continue; + }; + method + .attrs + .push(syn::parse_quote!(#[thrust::_outer_context(#outer)])); + if let Some(block) = &mut method.default { + let generics = method.sig.generics.clone(); + let threaded = thread_invariants(block, &generics, Some(&outer)); + if threaded.found { + inject_model_bounds( + &mut method.sig.generics, + Some(&outer), + threaded.self_used, + ); + } + } + } + } + syn::Item::Mod(item_mod) => { + if let Some((_, items)) = &mut item_mod.content { + for inner in items { + process_context_item(inner); + } + } + } + _ => {} + } +} + +struct Threaded { + found: bool, + self_used: bool, +} + +/// Prepends the generic context to every `invariant!(...)` in a function body. +fn thread_invariants( + block: &mut syn::Block, + generics: &Generics, + outer: Option<&FnOuterItem>, +) -> Threaded { + use syn::visit_mut::VisitMut as _; + + let context = invariant_context_tokens(generics, outer); + let mut threader = InvariantThreader { + context, + is_method: outer.is_some(), + found: false, + self_used: false, + }; + threader.visit_block_mut(block); + Threaded { + found: threader.found, + self_used: threader.self_used, + } +} + +/// Builds the `[generic-params] [where-predicates]` prefix that `invariant!` +/// consumes: every generic parameter in scope (the function's own and, for +/// methods, the outer ones), the existing where predicates, and the +/// `Model`/`PartialEq` bounds those parameters require. +fn invariant_context_tokens(generics: &Generics, outer: Option<&FnOuterItem>) -> TokenStream2 { + let mut params: Vec = generics.params.iter().cloned().collect(); + let mut preds: Vec = generics + .where_clause + .as_ref() + .map(|wc| wc.predicates.iter().cloned().collect()) + .unwrap_or_default(); + if let Some(outer) = outer { + params.extend(outer.generics().params.iter().cloned()); + if let Some(wc) = &outer.generics().where_clause { + preds.extend(wc.predicates.iter().cloned()); + } + } + for param in ¶ms { + if let GenericParam::Type(tp) = param { + let ident = &tp.ident; + preds.push(syn::parse_quote!(#ident: thrust_models::Model)); + preds.push(syn::parse_quote!(<#ident as thrust_models::Model>::Ty: PartialEq)); + } + } + quote! { [#(#params),*] [#(#preds),*] } +} + +struct InvariantThreader { + context: TokenStream2, + is_method: bool, + found: bool, + self_used: bool, +} + +impl syn::visit_mut::VisitMut for InvariantThreader { + fn visit_macro_mut(&mut self, mac: &mut syn::Macro) { + syn::visit_mut::visit_macro_mut(self, mac); + if is_invariant_macro(&mac.path) { + self.found = true; + // An invariant in a method may name `Self` in its variable types. + // A nested item cannot, so signal `invariant!` to re-declare `Self` + // as a synthetic generic, but only when it is actually used (so we + // do not over-constrain the host method with `Self: Model`). + let uses_self = self.is_method && tokens_contain_self(&mac.tokens); + self.self_used |= uses_self; + let self_marker = if uses_self { quote!(Self) } else { quote!() }; + let context = &self.context; + let original = &mac.tokens; + mac.tokens = quote! { [#self_marker] #context #original }; + } + } +} + +fn tokens_contain_self(tokens: &TokenStream2) -> bool { + tokens.clone().into_iter().any(|tt| match tt { + TokenTree::Ident(ident) => ident == "Self", + TokenTree::Group(group) => tokens_contain_self(&group.stream()), + _ => false, + }) +} + +/// Adds `T: Model` and `::Ty: PartialEq` bounds for every type +/// parameter in scope to a function's where clause. The marker call generated +/// for an invariant instantiates a `Model`-bounded formula function, so the +/// function hosting the call must itself satisfy those bounds. When an +/// invariant names `Self`, `invariant!` instantiates the formula function with +/// `Self`, so the same bounds are added for `Self` (`with_self`). +fn inject_model_bounds(generics: &mut Generics, outer: Option<&FnOuterItem>, with_self: bool) { + let mut tys: Vec = generics + .params + .iter() + .filter_map(|p| match p { + GenericParam::Type(tp) => Some(tp.ident.to_token_stream()), + _ => None, + }) + .collect(); + if let Some(outer) = outer { + for param in &outer.generics().params { + if let GenericParam::Type(tp) = param { + tys.push(tp.ident.to_token_stream()); + } + } + } + if with_self { + tys.push(quote!(Self)); + } + if tys.is_empty() { + return; + } + let where_clause = generics.make_where_clause(); + for ty in tys { + where_clause + .predicates + .push(syn::parse_quote!(#ty: thrust_models::Model)); + where_clause + .predicates + .push(syn::parse_quote!(<#ty as thrust_models::Model>::Ty: PartialEq)); + } +} + +fn is_invariant_macro(path: &syn::Path) -> bool { + path.segments.last().is_some_and(|s| s.ident == "invariant") +} diff --git a/thrust-macros/src/invariant.rs b/thrust-macros/src/invariant.rs index 8edf62e..a0d7526 100644 --- a/thrust-macros/src/invariant.rs +++ b/thrust-macros/src/invariant.rs @@ -1,21 +1,35 @@ //! Expansion of `thrust_macros::invariant!`. //! -//! Expands a closure with concrete parameter types into a -//! `#[thrust::formula_fn]` over `Model::Ty` parameters and a marker call -//! referencing it. +//! Expands into a `#[thrust::formula_fn]` over `Model::Ty` parameters plus a +//! trusted marker call referencing it. The formula function re-declares the +//! threaded generics (shadowing the enclosing ones) and is instantiated via +//! turbofish, which is what lets an in-body macro support generic-typed +//! variables. When `#[thrust_macros::context]` has threaded in a `Self` +//! marker, `Self` is re-declared the same way as a synthetic type parameter +//! and instantiated with the real `Self` via turbofish (legal in expression +//! position). use std::sync::atomic::{AtomicUsize, Ordering}; use proc_macro::TokenStream; -use quote::{format_ident, quote}; -use syn::{parse_macro_input, FnArg}; +use proc_macro2::TokenStream as TokenStream2; +use quote::{format_ident, quote, ToTokens}; +use syn::{ + parse_macro_input, punctuated::Punctuated, visit_mut::VisitMut, FnArg, GenericParam, + WherePredicate, +}; use crate::fn_params_with_model_ty; static COUNTER: AtomicUsize = AtomicUsize::new(0); pub fn expand(input: TokenStream) -> TokenStream { - let closure = parse_macro_input!(input as syn::ExprClosure); + let InvariantInput { + self_ty, + params, + wheres, + closure, + } = parse_macro_input!(input as InvariantInput); let mut fn_params: Vec = Vec::new(); for param in &closure.inputs { @@ -32,23 +46,131 @@ pub fn expand(input: TokenStream) -> TokenStream { fn_params.push(syn::parse_quote!(#pat: #ty)); } + let mut def_params: Vec = Vec::new(); + let mut def_wheres: Vec = wheres.iter().map(|w| w.to_token_stream()).collect(); + let mut turbofish_args: Vec = Vec::new(); + + // `Self` cannot appear in a nested item, so rewrite it to a synthetic type + // parameter and pass the real `Self` (legal here, in expression position) + // through the turbofish. + if let Some(self_ty) = &self_ty { + let synth: syn::Ident = format_ident!("__ThrustSelf"); + for param in &mut fn_params { + SelfRewriter { synth: &synth }.visit_fn_arg_mut(param); + } + def_params.push(quote!(#synth)); + def_wheres.push(quote!(#synth: thrust_models::Model)); + def_wheres.push(quote!(<#synth as thrust_models::Model>::Ty: PartialEq)); + turbofish_args.push(self_ty.to_token_stream()); + } + + for param in ¶ms { + def_params.push(param.to_token_stream()); + match param { + GenericParam::Type(tp) => turbofish_args.push(tp.ident.to_token_stream()), + GenericParam::Const(cp) => turbofish_args.push(cp.ident.to_token_stream()), + GenericParam::Lifetime(_) => {} + } + } + let model_ty_params = fn_params_with_model_ty(&fn_params); let body = &closure.body; let id = COUNTER.fetch_add(1, Ordering::Relaxed); let name = format_ident!("_thrust_invariant_{}", id); + let def_generics = if def_params.is_empty() { + quote!() + } else { + quote!(<#(#def_params),*>) + }; + let where_clause = if def_wheres.is_empty() { + quote!() + } else { + quote!(where #(#def_wheres),*) + }; + let turbofish = if turbofish_args.is_empty() { + quote!() + } else { + quote!(::<#(#turbofish_args),*>) + }; + quote! { { #[allow(unused_variables)] #[allow(non_snake_case)] #[thrust::formula_fn] - fn #name(#model_ty_params) -> bool { + fn #name #def_generics(#model_ty_params) -> bool #where_clause { #body } - thrust_models::__invariant_marker(#name) + thrust_models::__invariant_marker(#name #turbofish) } } .into() } + +/// Rewrites a bare `Self` type path to a synthetic type parameter, so the type +/// can be named inside a nested formula function. Qualified paths such as +/// `Self::Assoc` are left untouched (and are not supported in invariants). +struct SelfRewriter<'a> { + synth: &'a syn::Ident, +} + +impl VisitMut for SelfRewriter<'_> { + fn visit_path_mut(&mut self, path: &mut syn::Path) { + syn::visit_mut::visit_path_mut(self, path); + if path.leading_colon.is_none() + && path.segments.len() == 1 + && path.segments[0].ident == "Self" + { + path.segments[0].ident = self.synth.clone(); + } + } +} + +/// The input to `invariant!`. When `#[thrust_macros::context]` has run it is +/// the threaded form `[self] [generic-params] [where-predicates] ` +/// (the `[self]` group holds `Self` when an invariant in a method names it, +/// and is empty otherwise). Without `context` the input is just ``. +struct InvariantInput { + self_ty: Option, + params: Punctuated, + wheres: Punctuated, + closure: syn::ExprClosure, +} + +impl syn::parse::Parse for InvariantInput { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + if !input.peek(syn::token::Bracket) { + // No context threaded in: a bare `invariant!(|...| ...)`. + let closure = input.parse()?; + return Ok(Self { + self_ty: None, + params: Punctuated::new(), + wheres: Punctuated::new(), + closure, + }); + } + let self_group; + syn::bracketed!(self_group in input); + let self_ty = if self_group.is_empty() { + None + } else { + Some(self_group.parse()?) + }; + let params_group; + syn::bracketed!(params_group in input); + let params = params_group.parse_terminated(GenericParam::parse, syn::Token![,])?; + let wheres_group; + syn::bracketed!(wheres_group in input); + let wheres = wheres_group.parse_terminated(WherePredicate::parse, syn::Token![,])?; + let closure = input.parse()?; + Ok(Self { + self_ty, + params, + wheres, + closure, + }) + } +} diff --git a/thrust-macros/src/lib.rs b/thrust-macros/src/lib.rs index 99df73d..98d973e 100644 --- a/thrust-macros/src/lib.rs +++ b/thrust-macros/src/lib.rs @@ -1,67 +1,20 @@ use proc_macro::TokenStream; use proc_macro2::TokenStream as TokenStream2; -use quote::{format_ident, quote, ToTokens}; -use syn::{ - parse_macro_input, punctuated::Punctuated, FnArg, GenericParam, Generics, TypeParamBound, - WherePredicate, -}; +use quote::quote; +use syn::{FnArg, Generics}; +mod context; mod invariant; +mod spec; -#[derive(Debug, Clone)] -enum FnOuterItem { - ItemImpl(syn::ItemImpl), - ItemTrait(syn::ItemTrait), -} +// ===== proc-macro entry points ===== +// +// A proc-macro crate must declare these at the crate root; each delegates to +// the relevant module. -impl syn::parse::Parse for FnOuterItem { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - use syn::parse::discouraged::Speculative as _; - - let fork = input.fork(); - if let Ok(item_impl) = fork.parse::() { - input.advance_to(&fork); - return Ok(Self::ItemImpl(item_impl)); - } - - let fork = input.fork(); - if let Ok(item_trait) = fork.parse::() { - input.advance_to(&fork); - return Ok(Self::ItemTrait(item_trait)); - } - - Err(input.error("expected an impl block or a trait definition")) - } -} - -impl quote::ToTokens for FnOuterItem { - fn to_tokens(&self, tokens: &mut TokenStream2) { - match self { - FnOuterItem::ItemImpl(item_impl) => item_impl.to_tokens(tokens), - FnOuterItem::ItemTrait(item_trait) => item_trait.to_tokens(tokens), - } - } -} - -impl FnOuterItem { - fn into_header_only(mut self) -> Self { - match &mut self { - FnOuterItem::ItemImpl(item_impl) => { - item_impl.items.clear(); - } - FnOuterItem::ItemTrait(item_trait) => { - item_trait.items.clear(); - } - } - self - } - - fn generics(&self) -> &Generics { - match self { - FnOuterItem::ItemImpl(item_impl) => &item_impl.generics, - FnOuterItem::ItemTrait(item_trait) => &item_trait.generics, - } - } +#[proc_macro_attribute] +pub fn context(_attr: TokenStream, item: TokenStream) -> TokenStream { + context::expand(item) } /// Declares a loop invariant inside a loop body: @@ -83,555 +36,88 @@ pub fn invariant(input: TokenStream) -> TokenStream { invariant::expand(input) } -#[proc_macro_attribute] -pub fn context(_attr: TokenStream, item: TokenStream) -> TokenStream { - let mut outer_item = syn::parse_macro_input!(item as FnOuterItem); - let outer_header = outer_item.clone().into_header_only(); - match &mut outer_item { - FnOuterItem::ItemImpl(item_impl) => { - for item in &mut item_impl.items { - let syn::ImplItem::Fn(item) = item else { - continue; - }; - item.attrs - .push(syn::parse_quote!(#[thrust::_outer_context(#outer_header)])); - } - } - FnOuterItem::ItemTrait(item_trait) => { - for item in &mut item_trait.items { - let syn::TraitItem::Fn(item) = item else { - continue; - }; - item.attrs - .push(syn::parse_quote!(#[thrust::_outer_context(#outer_header)])); - } - } - } - - outer_item.into_token_stream().into() -} - -#[allow(clippy::enum_variant_names)] -#[derive(Debug, Clone)] -enum FnItemWithSignature { - ItemFn(syn::ItemFn), - ImplItemFn(syn::ImplItemFn), - TraitItemFn(syn::TraitItemFn), -} - -impl syn::parse::Parse for FnItemWithSignature { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - use syn::parse::discouraged::Speculative as _; - - let fork = input.fork(); - if let Ok(item_fn) = fork.parse::() { - input.advance_to(&fork); - return Ok(Self::ItemFn(item_fn)); - } - - let fork = input.fork(); - if let Ok(impl_item_fn) = fork.parse::() { - input.advance_to(&fork); - return Ok(Self::ImplItemFn(impl_item_fn)); - } - - let fork = input.fork(); - if let Ok(trait_item_fn) = fork.parse::() { - input.advance_to(&fork); - return Ok(Self::TraitItemFn(trait_item_fn)); - } - - Err(input.error("expected a free function, an impl method, or a trait method")) - } -} - -impl quote::ToTokens for FnItemWithSignature { - fn to_tokens(&self, tokens: &mut TokenStream2) { - match self { - FnItemWithSignature::ItemFn(item_fn) => item_fn.to_tokens(tokens), - FnItemWithSignature::ImplItemFn(impl_item_fn) => impl_item_fn.to_tokens(tokens), - FnItemWithSignature::TraitItemFn(trait_item_fn) => trait_item_fn.to_tokens(tokens), - } - } -} - -impl FnItemWithSignature { - fn block(&self) -> Option<&syn::Block> { - match self { - FnItemWithSignature::ItemFn(item_fn) => Some(&item_fn.block), - FnItemWithSignature::ImplItemFn(impl_item_fn) => Some(&impl_item_fn.block), - FnItemWithSignature::TraitItemFn(_) => None, - } - } - - fn block_mut(&mut self) -> Option<&mut syn::Block> { - match self { - FnItemWithSignature::ItemFn(item_fn) => Some(&mut item_fn.block), - FnItemWithSignature::ImplItemFn(impl_item_fn) => Some(&mut impl_item_fn.block), - FnItemWithSignature::TraitItemFn(_) => None, - } - } - - fn attrs(&self) -> &[syn::Attribute] { - match self { - FnItemWithSignature::ItemFn(item_fn) => &item_fn.attrs, - FnItemWithSignature::ImplItemFn(impl_item_fn) => &impl_item_fn.attrs, - FnItemWithSignature::TraitItemFn(trait_item_fn) => &trait_item_fn.attrs, - } - } - - fn attrs_mut(&mut self) -> &mut Vec { - match self { - FnItemWithSignature::ItemFn(item_fn) => &mut item_fn.attrs, - FnItemWithSignature::ImplItemFn(impl_item_fn) => &mut impl_item_fn.attrs, - FnItemWithSignature::TraitItemFn(trait_item_fn) => &mut trait_item_fn.attrs, - } - } - - fn sig(&self) -> &syn::Signature { - match self { - FnItemWithSignature::ItemFn(item_fn) => &item_fn.sig, - FnItemWithSignature::ImplItemFn(impl_item_fn) => &impl_item_fn.sig, - FnItemWithSignature::TraitItemFn(trait_item_fn) => &trait_item_fn.sig, - } - } -} - #[proc_macro_attribute] pub fn predicate(_attr: TokenStream, item: TokenStream) -> TokenStream { - let func = parse_macro_input!(item as FnItemWithSignature); - let outer_context = match extract_outer_context(&func) { - Ok(ctx) => ctx, - Err(e) => { - let err = e.to_compile_error(); - return quote! { #err #func }.into(); - } - }; - - let name = &func.sig().ident; - let def_generics = generic_params_tokens(&func.sig().generics); - let model_ty_params = fn_params_with_model_ty(&func.sig().inputs); - let model_ret = fn_return_ty_with_model_ty(&func.sig().output); - - let model_preds = model_where_predicates(&func, outer_context.as_ref()); - let extended_where = extended_where_clause(&func, &model_preds); - - let sig = quote! { - #[allow(dead_code)] - #[thrust::predicate] - fn #name #def_generics(#model_ty_params) -> #model_ret #extended_where - }; - if let Some(block) = func.block() { - quote! { #sig #block }.into() - } else { - quote! { #sig; }.into() - } + spec::expand_predicate(item) } #[proc_macro_attribute] pub fn requires(attr: TokenStream, item: TokenStream) -> TokenStream { - let expr = TokenStream2::from(attr); - let mut func = parse_macro_input!(item as FnItemWithSignature); - - let (req_expr, ens_expr) = match extract_requires_ensures(&mut func) { - Ok((req, ens)) => (req, ens), - Err(e) => return e.to_compile_error().into(), - }; - func.attrs_mut().push(syn::parse_quote!( - #[::thrust_macros::_requires_ensures((#req_expr) && (#expr), #ens_expr)] - )); - - func.into_token_stream().into() + spec::expand_requires(attr, item) } #[proc_macro_attribute] pub fn ensures(attr: TokenStream, item: TokenStream) -> TokenStream { - let expr = TokenStream2::from(attr); - let mut func = parse_macro_input!(item as FnItemWithSignature); - - let (req_expr, ens_expr) = match extract_requires_ensures(&mut func) { - Ok((req, ens)) => (req, ens), - Err(e) => return e.to_compile_error().into(), - }; - func.attrs_mut().push(syn::parse_quote!( - #[::thrust_macros::_requires_ensures(#req_expr, (#ens_expr) && (#expr))] - )); - - func.into_token_stream().into() -} - -fn extract_requires_ensures(func: &mut FnItemWithSignature) -> syn::Result<(syn::Expr, syn::Expr)> { - let mut result = None; - - let requires_ensures_path: syn::Path = syn::parse_quote!(::thrust_macros::_requires_ensures); - - for attr in func.attrs() { - if attr.path() == &requires_ensures_path { - if result.is_some() { - return Err(syn::Error::new_spanned( - attr, - "multiple _requires_ensures attributes found; expected at most one", - )); - } - - let parser = Punctuated::::parse_separated_nonempty; - let mut exprs = attr.parse_args_with(parser)?; - if exprs.len() != 2 { - return Err(syn::Error::new_spanned( - attr, - "expected exactly two comma-separated expressions in _requires_ensures attribute", - )); - } - let ens_expr = exprs.pop().unwrap().into_value(); - let req_expr = exprs.pop().unwrap().into_value(); - result = Some((req_expr, ens_expr)); - } - } - - func.attrs_mut() - .retain(|attr| attr.path() != &requires_ensures_path); - - if let Some((req_expr, ens_expr)) = result { - Ok((req_expr, ens_expr)) - } else { - Ok((syn::parse_quote!(true), syn::parse_quote!(true))) - } + spec::expand_ensures(attr, item) } #[proc_macro_attribute] pub fn _requires_ensures(attr: TokenStream, item: TokenStream) -> TokenStream { - use syn::parse::Parser as _; - let parser = Punctuated::::parse_separated_nonempty; - let mut exprs = match parser.parse(attr.clone()) { - Ok(exprs) => exprs, - Err(e) => return e.to_compile_error().into(), - }; - if exprs.len() != 2 { - return syn::Error::new_spanned( - TokenStream2::from(attr), - "expected exactly two comma-separated expressions in _requires_ensures attribute", - ) - .to_compile_error() - .into(); - } - - let ens_expr = exprs.pop().unwrap().into_value(); - let req_expr = exprs.pop().unwrap().into_value(); - - let func = parse_macro_input!(item as FnItemWithSignature); - let outer_context = match extract_outer_context(&func) { - Ok(ctx) => ctx, - Err(e) => { - let err = e.to_compile_error(); - return quote! { #err #func }.into(); - } - }; - let mut tokens = ExpandedTokens::new(func, req_expr, ens_expr); - if let Some(ctx) = outer_context { - tokens = tokens.with_outer_context(ctx); - } - tokens.into_token_stream().into() -} - -fn extract_outer_context(func: &FnItemWithSignature) -> syn::Result> { - let outer_context_path: syn::Path = syn::parse_quote!(thrust::_outer_context); - let mut outer_context = None; - for attr in func.attrs() { - if attr.path() != &outer_context_path { - continue; - } - - let item = attr.parse_args()?; - if outer_context.is_some() { - return Err(syn::Error::new_spanned( - attr, - "multiple _outer_context attributes found; expected at most one", - )); - } - outer_context = Some(item); - } - if mentions_self(func.sig()) && outer_context.is_none() { - return Err(syn::Error::new_spanned( - func.sig().ident.clone(), - "Wrap the surrounding impl block or trait definition with #[thrust_macros::context] to annotate methods", - )); - } - Ok(outer_context) + spec::expand_requires_ensures(attr, item) } -struct ExpandedTokens { - func: FnItemWithSignature, +// ===== shared helpers ===== +// +// Used by more than one of the modules above. They live in the crate root and +// stay private: a private root item is visible to every descendant module, so +// no `pub(crate)` is required. - requires_name: syn::Ident, - ensures_name: syn::Ident, - req_expr: syn::Expr, - ens_expr: syn::Expr, - - def_generics: TokenStream2, - turbofish: TokenStream2, - - model_ty_params: TokenStream2, - ret_model_ty: syn::Type, - - outer_context: Option, -} - -impl quote::ToTokens for ExpandedTokens { - fn to_tokens(&self, tokens: &mut TokenStream2) { - if self.is_extern_spec_fn() { - self.expand_extern_spec_fn().to_tokens(tokens); - } else { - self.expand().to_tokens(tokens); - } - } +/// An `impl` or `trait` header carried by the `#[thrust::_outer_context(..)]` +/// attribute so a method can recover its enclosing generics. +#[derive(Debug, Clone)] +enum FnOuterItem { + ItemImpl(syn::ItemImpl), + ItemTrait(syn::ItemTrait), } -impl ExpandedTokens { - pub fn new( - func: FnItemWithSignature, - mut req_expr: syn::Expr, - mut ens_expr: syn::Expr, - ) -> Self { - let name = &func.sig().ident; - let requires_name = format_ident!("_thrust_requires_{}", name); - let ensures_name = format_ident!("_thrust_ensures_{}", name); - - let generics = &func.sig().generics; - let def_generics = generic_params_tokens(generics); - let turbofish = generic_turbofish(generics); - - let model_ty_params = fn_params_with_model_ty(&func.sig().inputs); - let ret_model_ty = fn_return_ty_with_model_ty(&func.sig().output); - - if func.sig().receiver().is_some() { - rewrite_self_in_expr(&mut req_expr); - rewrite_self_in_expr(&mut ens_expr); - } - - Self { - func, - req_expr, - ens_expr, - requires_name, - ensures_name, - def_generics, - turbofish, - model_ty_params, - ret_model_ty, - outer_context: None, - } - } - - pub fn with_outer_context(mut self, outer_item: FnOuterItem) -> Self { - self.outer_context = Some(outer_item); - self - } - - fn extended_where_clause(&self) -> TokenStream2 { - let model_preds = model_where_predicates(&self.func, self.outer_context.as_ref()); - extended_where_clause(&self.func, &model_preds) - } - - fn is_extern_spec_fn(&self) -> bool { - let extern_spec_fn_path: syn::Path = syn::parse_quote!(thrust::extern_spec_fn); - self.func - .attrs() - .iter() - .any(|a| a.path() == &extern_spec_fn_path) - } - - fn requires_fn(&self) -> TokenStream2 { - let requires_name = &self.requires_name; - let def_generics = &self.def_generics; - let model_ty_params = &self.model_ty_params; - let extended_where = self.extended_where_clause(); - let req_expr = &self.req_expr; - - quote! { - #[allow(unused_variables)] - #[allow(non_snake_case)] - #[thrust::formula_fn] - fn #requires_name #def_generics(#model_ty_params) -> bool #extended_where { - #req_expr - } - } - } - - fn ensures_fn(&self) -> TokenStream2 { - let ensures_name = &self.ensures_name; - let def_generics = &self.def_generics; - let model_ty_params = &self.model_ty_params; - let extended_where = self.extended_where_clause(); - let ret_model_ty = &self.ret_model_ty; - let ens_expr = &self.ens_expr; - - quote! { - #[allow(unused_variables)] - #[allow(non_snake_case)] - #[thrust::formula_fn] - fn #ensures_name #def_generics(result: #ret_model_ty, #model_ty_params) -> bool #extended_where { - #ens_expr - } - } - } - - fn path_prefix(&self) -> Option { - self.outer_context.as_ref()?; - Some(quote!(Self::)) - } +impl syn::parse::Parse for FnOuterItem { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + use syn::parse::discouraged::Speculative as _; - fn expand(&self) -> TokenStream2 { - let mut func = self.func.clone(); - let trusted_path: syn::Path = syn::parse_quote!(thrust::trusted); - for attr in func.attrs_mut() { - if attr.path() == &trusted_path { - *attr = syn::parse_quote!(#[thrust::ignored]); - } + let fork = input.fork(); + if let Ok(item_impl) = fork.parse::() { + input.advance_to(&fork); + return Ok(Self::ItemImpl(item_impl)); } - let requires_fn = self.requires_fn(); - let ensures_fn = self.ensures_fn(); - - let extern_spec_name = format_ident!("_thrust_extern_spec_{}", self.func.sig().ident); - let def_generics = &self.def_generics; - let orig_output = &self.func.sig().output; - let extended_where = self.extended_where_clause(); - - let requires_name = &self.requires_name; - let ensures_name = &self.ensures_name; - let turbofish = &self.turbofish; - let path_prefix = self.path_prefix(); - - let name = &self.func.sig().ident; - let (extern_spec_inputs, call_args) = rewrite_inputs_for_call(&self.func.sig().inputs); - - quote! { - #func - - #requires_fn - #ensures_fn - - #[thrust::extern_spec_fn] - #[allow(path_statements)] - fn #extern_spec_name #def_generics(#extern_spec_inputs) #orig_output #extended_where { - #[thrust::requires_path] - #path_prefix #requires_name #turbofish; - - #[thrust::ensures_path] - #path_prefix #ensures_name #turbofish; - - #path_prefix #name #turbofish(#call_args) - } + let fork = input.fork(); + if let Ok(item_trait) = fork.parse::() { + input.advance_to(&fork); + return Ok(Self::ItemTrait(item_trait)); } - } - - fn expand_extern_spec_fn(&self) -> TokenStream2 { - let requires_name = &self.requires_name; - let ensures_name = &self.ensures_name; - let turbofish = &self.turbofish; - let path_prefix = self.path_prefix(); - - let mut func = self.func.clone(); - let func_tokens = if let Some(block) = func.block_mut() { - let orig_stmts = block.stmts.drain(..).collect::>(); - *block = syn::parse_quote!({ - #[thrust::requires_path] - #path_prefix #requires_name #turbofish; - - #[thrust::ensures_path] - #path_prefix #ensures_name #turbofish; - - #(#orig_stmts)* - }); - quote! { - #[allow(path_statements)] - #func - } - } else { - let error = syn::Error::new_spanned( - func.sig().ident.clone(), - "extern_spec_fn must have a function body", - ) - .into_compile_error(); - quote! { - #error - #func - } - }; - - let requires_fn = self.requires_fn(); - let ensures_fn = self.ensures_fn(); - quote! { - #requires_fn - #ensures_fn - - #func_tokens - } + Err(input.error("expected an impl block or a trait definition")) } } -fn mentions_self(sig: &syn::Signature) -> bool { - struct Visitor { - mentions_self: bool, - } - - impl syn::visit::Visit<'_> for Visitor { - fn visit_ident(&mut self, i: &syn::Ident) { - if i == "self" || i == "Self" { - self.mentions_self = true; - } +impl quote::ToTokens for FnOuterItem { + fn to_tokens(&self, tokens: &mut TokenStream2) { + match self { + FnOuterItem::ItemImpl(item_impl) => item_impl.to_tokens(tokens), + FnOuterItem::ItemTrait(item_trait) => item_trait.to_tokens(tokens), } } - - let mut visitor = Visitor { - mentions_self: false, - }; - use syn::visit::Visit as _; - visitor.visit_signature(sig); - visitor.mentions_self } -fn rewrite_self_in_expr(expr: &mut syn::Expr) { - struct Visitor; - - impl syn::visit_mut::VisitMut for Visitor { - fn visit_ident_mut(&mut self, ident: &mut syn::Ident) { - if ident == "self" { - *ident = format_ident!("self_"); +impl FnOuterItem { + fn into_header_only(mut self) -> Self { + match &mut self { + FnOuterItem::ItemImpl(item_impl) => { + item_impl.items.clear(); + } + FnOuterItem::ItemTrait(item_trait) => { + item_trait.items.clear(); } } + self } - use syn::visit_mut::VisitMut as _; - Visitor.visit_expr_mut(expr); -} - -/// Returns `` — the generic param list for function definitions, -/// without a where clause. -fn generic_params_tokens(generics: &Generics) -> TokenStream2 { - if generics.params.is_empty() { - return quote!(); - } - let params = &generics.params; - quote!(<#params>) -} - -/// Returns `::` for turbofish use, or nothing if no generic params. -fn generic_turbofish(generics: &Generics) -> TokenStream2 { - let args: Vec = generics - .params - .iter() - .flat_map(|p| match p { - GenericParam::Type(tp) => Some(tp.ident.to_token_stream()), - GenericParam::Lifetime(_) => None, - GenericParam::Const(cp) => Some(cp.ident.to_token_stream()), - }) - .collect(); - if args.is_empty() { - return quote!(); + fn generics(&self) -> &Generics { + match self { + FnOuterItem::ItemImpl(item_impl) => &item_impl.generics, + FnOuterItem::ItemTrait(item_trait) => &item_trait.generics, + } } - quote!(::<#(#args),*>) } /// Maps each function parameter `x: T` to `x: ::Ty`. @@ -655,163 +141,3 @@ where } quote!(#(#model_inputs),*) } - -/// For the extern_spec wrapper: replaces every typed parameter with a fresh `_arg_N` ident, -/// returning `(rewritten_inputs_tokens, call_args_tokens)`. -fn rewrite_inputs_for_call( - inputs: &syn::punctuated::Punctuated, -) -> (TokenStream2, TokenStream2) { - let mut rewritten: Vec = Vec::new(); - let mut call_args: Vec = Vec::new(); - - for (i, arg) in inputs.iter().enumerate() { - match arg { - FnArg::Typed(pt) => { - let fresh = format_ident!("_arg_{}", i); - let ty = &pt.ty; - rewritten.push(quote!(#fresh: #ty)); - call_args.push(fresh.to_token_stream()); - } - FnArg::Receiver(_) => { - rewritten.push(arg.to_token_stream()); - call_args.push(quote!(self)); - } - } - } - - (quote!(#(#rewritten),*), quote!(#(#call_args),*)) -} - -/// Returns `T: thrust_models::Model` predicates for every type param that does not -/// already carry an `Fn`, `FnOnce`, or `FnMut` bound. -fn model_where_predicates( - func: &FnItemWithSignature, - outer_context: Option<&FnOuterItem>, -) -> Vec { - struct GenericTypeParam { - ident: syn::Ident, - bounds: Vec, - } - - impl From for GenericTypeParam { - fn from(tp: syn::TypeParam) -> Self { - Self { - ident: tp.ident, - bounds: tp.bounds.into_iter().collect(), - } - } - } - - impl GenericTypeParam { - fn has_fn_bound(&self) -> bool { - self.bounds.iter().any(|b| { - let TypeParamBound::Trait(tb) = b else { - return false; - }; - tb.path.segments.last().is_some_and(|s| { - matches!(s.ident.to_string().as_str(), "Fn" | "FnOnce" | "FnMut") - }) - }) - } - } - - let mut generic_type_params: Vec = Vec::new(); - for param in &func.sig().generics.params { - let GenericParam::Type(tp) = param else { - continue; - }; - generic_type_params.push(tp.clone().into()); - } - if let Some(outer_item) = outer_context { - for param in &outer_item.generics().params { - let GenericParam::Type(tp) = param else { - continue; - }; - generic_type_params.push(tp.clone().into()); - } - if let FnOuterItem::ItemTrait(outer_item) = &outer_item { - generic_type_params.push(GenericTypeParam { - ident: format_ident!("Self"), - bounds: outer_item.supertraits.iter().cloned().collect(), - }); - } - } - generic_type_params.retain(|p| !p.has_fn_bound()); - - let mut predicates: Vec = Vec::new(); - for param in &generic_type_params { - let ident = ¶m.ident; - predicates.push(syn::parse_quote!(#ident: thrust_models::Model)); - predicates.push(syn::parse_quote!(<#ident as thrust_models::Model>::Ty: PartialEq)); - } - - struct Visitor { - generic_type_params: Vec, - generic_paths: Vec, - } - - impl syn::visit::Visit<'_> for Visitor { - fn visit_type_path(&mut self, tp: &syn::TypePath) { - for param in &self.generic_type_params { - if let Some(qself) = &tp.qself { - let param = ¶m.ident; - let param_ty: syn::Type = syn::parse_quote!(#param); - if *qself.ty == param_ty { - self.generic_paths.push(tp.clone()); - } - } - if tp.path.segments.len() > 1 - && tp.path.segments.first().unwrap().ident == param.ident - && tp.qself.is_none() - { - self.generic_paths.push(tp.clone()); - } - } - syn::visit::visit_type_path(self, tp); - } - } - - let mut visitor = Visitor { - generic_type_params, - generic_paths: Vec::new(), - }; - use syn::visit::Visit as _; - for arg in &func.sig().inputs { - visitor.visit_fn_arg(arg); - } - visitor.visit_return_type(&func.sig().output); - for tp in visitor.generic_paths { - predicates.push(syn::parse_quote!(#tp: thrust_models::Model)); - predicates.push(syn::parse_quote!(<#tp as thrust_models::Model>::Ty: PartialEq)); - } - - predicates -} - -/// Builds `where , `. -/// Returns an empty token stream when both sets are empty. -fn extended_where_clause( - func: &FnItemWithSignature, - model_preds: &Vec, -) -> TokenStream2 { - let existing: Vec<&WherePredicate> = func - .sig() - .generics - .where_clause - .as_ref() - .map(|wc| wc.predicates.iter().collect()) - .unwrap_or_default(); - - if existing.is_empty() && model_preds.is_empty() { - return quote!(); - } - - quote! { where #(#existing,)* #(#model_preds),* } -} - -fn fn_return_ty_with_model_ty(ret: &syn::ReturnType) -> syn::Type { - match ret { - syn::ReturnType::Default => syn::parse_quote!(<() as thrust_models::Model>::Ty), - syn::ReturnType::Type(_, ty) => syn::parse_quote!(<#ty as thrust_models::Model>::Ty), - } -} diff --git a/thrust-macros/src/spec.rs b/thrust-macros/src/spec.rs new file mode 100644 index 0000000..2758291 --- /dev/null +++ b/thrust-macros/src/spec.rs @@ -0,0 +1,692 @@ +//! Expansion of `#[thrust_macros::requires]`, `#[thrust_macros::ensures]`, +//! `#[thrust_macros::predicate]`, and the internal `_requires_ensures` glue. +//! +//! `requires`/`ensures` accumulate their predicates into a single +//! `_requires_ensures` attribute, which expands into `#[thrust::formula_fn]` +//! companions (over `Model::Ty` parameters) plus an extern-spec wrapper that +//! references them. + +use proc_macro::TokenStream; +use proc_macro2::TokenStream as TokenStream2; +use quote::{format_ident, quote, ToTokens}; +use syn::{ + parse_macro_input, punctuated::Punctuated, FnArg, GenericParam, Generics, TypeParamBound, + WherePredicate, +}; + +use crate::{fn_params_with_model_ty, FnOuterItem}; + +pub(super) fn expand_predicate(item: TokenStream) -> TokenStream { + let func = parse_macro_input!(item as FnItemWithSignature); + let outer_context = match extract_outer_context(&func) { + Ok(ctx) => ctx, + Err(e) => { + let err = e.to_compile_error(); + return quote! { #err #func }.into(); + } + }; + + let name = &func.sig().ident; + let def_generics = generic_params_tokens(&func.sig().generics); + let model_ty_params = fn_params_with_model_ty(&func.sig().inputs); + let model_ret = fn_return_ty_with_model_ty(&func.sig().output); + + let model_preds = model_where_predicates(&func, outer_context.as_ref()); + let extended_where = extended_where_clause(&func, &model_preds); + + let sig = quote! { + #[allow(dead_code)] + #[thrust::predicate] + fn #name #def_generics(#model_ty_params) -> #model_ret #extended_where + }; + if let Some(block) = func.block() { + quote! { #sig #block }.into() + } else { + quote! { #sig; }.into() + } +} + +pub(super) fn expand_requires(attr: TokenStream, item: TokenStream) -> TokenStream { + let expr = TokenStream2::from(attr); + let mut func = parse_macro_input!(item as FnItemWithSignature); + + let (req_expr, ens_expr) = match extract_requires_ensures(&mut func) { + Ok((req, ens)) => (req, ens), + Err(e) => return e.to_compile_error().into(), + }; + func.attrs_mut().push(syn::parse_quote!( + #[::thrust_macros::_requires_ensures((#req_expr) && (#expr), #ens_expr)] + )); + + func.into_token_stream().into() +} + +pub(super) fn expand_ensures(attr: TokenStream, item: TokenStream) -> TokenStream { + let expr = TokenStream2::from(attr); + let mut func = parse_macro_input!(item as FnItemWithSignature); + + let (req_expr, ens_expr) = match extract_requires_ensures(&mut func) { + Ok((req, ens)) => (req, ens), + Err(e) => return e.to_compile_error().into(), + }; + func.attrs_mut().push(syn::parse_quote!( + #[::thrust_macros::_requires_ensures(#req_expr, (#ens_expr) && (#expr))] + )); + + func.into_token_stream().into() +} + +pub(super) fn expand_requires_ensures(attr: TokenStream, item: TokenStream) -> TokenStream { + use syn::parse::Parser as _; + let parser = Punctuated::::parse_separated_nonempty; + let mut exprs = match parser.parse(attr.clone()) { + Ok(exprs) => exprs, + Err(e) => return e.to_compile_error().into(), + }; + if exprs.len() != 2 { + return syn::Error::new_spanned( + TokenStream2::from(attr), + "expected exactly two comma-separated expressions in _requires_ensures attribute", + ) + .to_compile_error() + .into(); + } + + let ens_expr = exprs.pop().unwrap().into_value(); + let req_expr = exprs.pop().unwrap().into_value(); + + let func = parse_macro_input!(item as FnItemWithSignature); + let outer_context = match extract_outer_context(&func) { + Ok(ctx) => ctx, + Err(e) => { + let err = e.to_compile_error(); + return quote! { #err #func }.into(); + } + }; + let mut tokens = ExpandedTokens::new(func, req_expr, ens_expr); + if let Some(ctx) = outer_context { + tokens = tokens.with_outer_context(ctx); + } + tokens.into_token_stream().into() +} + +#[allow(clippy::enum_variant_names)] +#[derive(Debug, Clone)] +enum FnItemWithSignature { + ItemFn(syn::ItemFn), + ImplItemFn(syn::ImplItemFn), + TraitItemFn(syn::TraitItemFn), +} + +impl syn::parse::Parse for FnItemWithSignature { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + use syn::parse::discouraged::Speculative as _; + + let fork = input.fork(); + if let Ok(item_fn) = fork.parse::() { + input.advance_to(&fork); + return Ok(Self::ItemFn(item_fn)); + } + + let fork = input.fork(); + if let Ok(impl_item_fn) = fork.parse::() { + input.advance_to(&fork); + return Ok(Self::ImplItemFn(impl_item_fn)); + } + + let fork = input.fork(); + if let Ok(trait_item_fn) = fork.parse::() { + input.advance_to(&fork); + return Ok(Self::TraitItemFn(trait_item_fn)); + } + + Err(input.error("expected a free function, an impl method, or a trait method")) + } +} + +impl quote::ToTokens for FnItemWithSignature { + fn to_tokens(&self, tokens: &mut TokenStream2) { + match self { + FnItemWithSignature::ItemFn(item_fn) => item_fn.to_tokens(tokens), + FnItemWithSignature::ImplItemFn(impl_item_fn) => impl_item_fn.to_tokens(tokens), + FnItemWithSignature::TraitItemFn(trait_item_fn) => trait_item_fn.to_tokens(tokens), + } + } +} + +impl FnItemWithSignature { + fn block(&self) -> Option<&syn::Block> { + match self { + FnItemWithSignature::ItemFn(item_fn) => Some(&item_fn.block), + FnItemWithSignature::ImplItemFn(impl_item_fn) => Some(&impl_item_fn.block), + FnItemWithSignature::TraitItemFn(_) => None, + } + } + + fn block_mut(&mut self) -> Option<&mut syn::Block> { + match self { + FnItemWithSignature::ItemFn(item_fn) => Some(&mut item_fn.block), + FnItemWithSignature::ImplItemFn(impl_item_fn) => Some(&mut impl_item_fn.block), + FnItemWithSignature::TraitItemFn(_) => None, + } + } + + fn attrs(&self) -> &[syn::Attribute] { + match self { + FnItemWithSignature::ItemFn(item_fn) => &item_fn.attrs, + FnItemWithSignature::ImplItemFn(impl_item_fn) => &impl_item_fn.attrs, + FnItemWithSignature::TraitItemFn(trait_item_fn) => &trait_item_fn.attrs, + } + } + + fn attrs_mut(&mut self) -> &mut Vec { + match self { + FnItemWithSignature::ItemFn(item_fn) => &mut item_fn.attrs, + FnItemWithSignature::ImplItemFn(impl_item_fn) => &mut impl_item_fn.attrs, + FnItemWithSignature::TraitItemFn(trait_item_fn) => &mut trait_item_fn.attrs, + } + } + + fn sig(&self) -> &syn::Signature { + match self { + FnItemWithSignature::ItemFn(item_fn) => &item_fn.sig, + FnItemWithSignature::ImplItemFn(impl_item_fn) => &impl_item_fn.sig, + FnItemWithSignature::TraitItemFn(trait_item_fn) => &trait_item_fn.sig, + } + } +} + +fn extract_requires_ensures(func: &mut FnItemWithSignature) -> syn::Result<(syn::Expr, syn::Expr)> { + let mut result = None; + + let requires_ensures_path: syn::Path = syn::parse_quote!(::thrust_macros::_requires_ensures); + + for attr in func.attrs() { + if attr.path() == &requires_ensures_path { + if result.is_some() { + return Err(syn::Error::new_spanned( + attr, + "multiple _requires_ensures attributes found; expected at most one", + )); + } + + let parser = Punctuated::::parse_separated_nonempty; + let mut exprs = attr.parse_args_with(parser)?; + if exprs.len() != 2 { + return Err(syn::Error::new_spanned( + attr, + "expected exactly two comma-separated expressions in _requires_ensures attribute", + )); + } + let ens_expr = exprs.pop().unwrap().into_value(); + let req_expr = exprs.pop().unwrap().into_value(); + result = Some((req_expr, ens_expr)); + } + } + + func.attrs_mut() + .retain(|attr| attr.path() != &requires_ensures_path); + + if let Some((req_expr, ens_expr)) = result { + Ok((req_expr, ens_expr)) + } else { + Ok((syn::parse_quote!(true), syn::parse_quote!(true))) + } +} + +fn extract_outer_context(func: &FnItemWithSignature) -> syn::Result> { + let outer_context_path: syn::Path = syn::parse_quote!(thrust::_outer_context); + let mut outer_context = None; + for attr in func.attrs() { + if attr.path() != &outer_context_path { + continue; + } + + let item = attr.parse_args()?; + if outer_context.is_some() { + return Err(syn::Error::new_spanned( + attr, + "multiple _outer_context attributes found; expected at most one", + )); + } + outer_context = Some(item); + } + if mentions_self(func.sig()) && outer_context.is_none() { + return Err(syn::Error::new_spanned( + func.sig().ident.clone(), + "Wrap the surrounding impl block or trait definition with #[thrust_macros::context] to annotate methods", + )); + } + Ok(outer_context) +} + +struct ExpandedTokens { + func: FnItemWithSignature, + + requires_name: syn::Ident, + ensures_name: syn::Ident, + req_expr: syn::Expr, + ens_expr: syn::Expr, + + def_generics: TokenStream2, + turbofish: TokenStream2, + + model_ty_params: TokenStream2, + ret_model_ty: syn::Type, + + outer_context: Option, +} + +impl quote::ToTokens for ExpandedTokens { + fn to_tokens(&self, tokens: &mut TokenStream2) { + if self.is_extern_spec_fn() { + self.expand_extern_spec_fn().to_tokens(tokens); + } else { + self.expand().to_tokens(tokens); + } + } +} + +impl ExpandedTokens { + fn new(func: FnItemWithSignature, mut req_expr: syn::Expr, mut ens_expr: syn::Expr) -> Self { + let name = &func.sig().ident; + let requires_name = format_ident!("_thrust_requires_{}", name); + let ensures_name = format_ident!("_thrust_ensures_{}", name); + + let generics = &func.sig().generics; + let def_generics = generic_params_tokens(generics); + let turbofish = generic_turbofish(generics); + + let model_ty_params = fn_params_with_model_ty(&func.sig().inputs); + let ret_model_ty = fn_return_ty_with_model_ty(&func.sig().output); + + if func.sig().receiver().is_some() { + rewrite_self_in_expr(&mut req_expr); + rewrite_self_in_expr(&mut ens_expr); + } + + Self { + func, + req_expr, + ens_expr, + requires_name, + ensures_name, + def_generics, + turbofish, + model_ty_params, + ret_model_ty, + outer_context: None, + } + } + + fn with_outer_context(mut self, outer_item: FnOuterItem) -> Self { + self.outer_context = Some(outer_item); + self + } + + fn extended_where_clause(&self) -> TokenStream2 { + let model_preds = model_where_predicates(&self.func, self.outer_context.as_ref()); + extended_where_clause(&self.func, &model_preds) + } + + fn is_extern_spec_fn(&self) -> bool { + let extern_spec_fn_path: syn::Path = syn::parse_quote!(thrust::extern_spec_fn); + self.func + .attrs() + .iter() + .any(|a| a.path() == &extern_spec_fn_path) + } + + fn requires_fn(&self) -> TokenStream2 { + let requires_name = &self.requires_name; + let def_generics = &self.def_generics; + let model_ty_params = &self.model_ty_params; + let extended_where = self.extended_where_clause(); + let req_expr = &self.req_expr; + + quote! { + #[allow(unused_variables)] + #[allow(non_snake_case)] + #[thrust::formula_fn] + fn #requires_name #def_generics(#model_ty_params) -> bool #extended_where { + #req_expr + } + } + } + + fn ensures_fn(&self) -> TokenStream2 { + let ensures_name = &self.ensures_name; + let def_generics = &self.def_generics; + let model_ty_params = &self.model_ty_params; + let extended_where = self.extended_where_clause(); + let ret_model_ty = &self.ret_model_ty; + let ens_expr = &self.ens_expr; + + quote! { + #[allow(unused_variables)] + #[allow(non_snake_case)] + #[thrust::formula_fn] + fn #ensures_name #def_generics(result: #ret_model_ty, #model_ty_params) -> bool #extended_where { + #ens_expr + } + } + } + + fn path_prefix(&self) -> Option { + self.outer_context.as_ref()?; + Some(quote!(Self::)) + } + + fn expand(&self) -> TokenStream2 { + let mut func = self.func.clone(); + let trusted_path: syn::Path = syn::parse_quote!(thrust::trusted); + for attr in func.attrs_mut() { + if attr.path() == &trusted_path { + *attr = syn::parse_quote!(#[thrust::ignored]); + } + } + + let requires_fn = self.requires_fn(); + let ensures_fn = self.ensures_fn(); + + let extern_spec_name = format_ident!("_thrust_extern_spec_{}", self.func.sig().ident); + let def_generics = &self.def_generics; + let orig_output = &self.func.sig().output; + let extended_where = self.extended_where_clause(); + + let requires_name = &self.requires_name; + let ensures_name = &self.ensures_name; + let turbofish = &self.turbofish; + let path_prefix = self.path_prefix(); + + let name = &self.func.sig().ident; + let (extern_spec_inputs, call_args) = rewrite_inputs_for_call(&self.func.sig().inputs); + + quote! { + #func + + #requires_fn + #ensures_fn + + #[thrust::extern_spec_fn] + #[allow(path_statements)] + fn #extern_spec_name #def_generics(#extern_spec_inputs) #orig_output #extended_where { + #[thrust::requires_path] + #path_prefix #requires_name #turbofish; + + #[thrust::ensures_path] + #path_prefix #ensures_name #turbofish; + + #path_prefix #name #turbofish(#call_args) + } + } + } + + fn expand_extern_spec_fn(&self) -> TokenStream2 { + let requires_name = &self.requires_name; + let ensures_name = &self.ensures_name; + let turbofish = &self.turbofish; + let path_prefix = self.path_prefix(); + + let mut func = self.func.clone(); + let func_tokens = if let Some(block) = func.block_mut() { + let orig_stmts = block.stmts.drain(..).collect::>(); + *block = syn::parse_quote!({ + #[thrust::requires_path] + #path_prefix #requires_name #turbofish; + + #[thrust::ensures_path] + #path_prefix #ensures_name #turbofish; + + #(#orig_stmts)* + }); + quote! { + #[allow(path_statements)] + #func + } + } else { + let error = syn::Error::new_spanned( + func.sig().ident.clone(), + "extern_spec_fn must have a function body", + ) + .into_compile_error(); + quote! { + #error + #func + } + }; + + let requires_fn = self.requires_fn(); + let ensures_fn = self.ensures_fn(); + + quote! { + #requires_fn + #ensures_fn + + #func_tokens + } + } +} + +fn mentions_self(sig: &syn::Signature) -> bool { + struct Visitor { + mentions_self: bool, + } + + impl syn::visit::Visit<'_> for Visitor { + fn visit_ident(&mut self, i: &syn::Ident) { + if i == "self" || i == "Self" { + self.mentions_self = true; + } + } + } + + let mut visitor = Visitor { + mentions_self: false, + }; + use syn::visit::Visit as _; + visitor.visit_signature(sig); + visitor.mentions_self +} + +fn rewrite_self_in_expr(expr: &mut syn::Expr) { + struct Visitor; + + impl syn::visit_mut::VisitMut for Visitor { + fn visit_ident_mut(&mut self, ident: &mut syn::Ident) { + if ident == "self" { + *ident = format_ident!("self_"); + } + } + } + + use syn::visit_mut::VisitMut as _; + Visitor.visit_expr_mut(expr); +} + +/// Returns `` — the generic param list for function definitions, +/// without a where clause. +fn generic_params_tokens(generics: &Generics) -> TokenStream2 { + if generics.params.is_empty() { + return quote!(); + } + let params = &generics.params; + quote!(<#params>) +} + +/// Returns `::` for turbofish use, or nothing if no generic params. +fn generic_turbofish(generics: &Generics) -> TokenStream2 { + let args: Vec = generics + .params + .iter() + .flat_map(|p| match p { + GenericParam::Type(tp) => Some(tp.ident.to_token_stream()), + GenericParam::Lifetime(_) => None, + GenericParam::Const(cp) => Some(cp.ident.to_token_stream()), + }) + .collect(); + if args.is_empty() { + return quote!(); + } + quote!(::<#(#args),*>) +} + +/// For the extern_spec wrapper: replaces every typed parameter with a fresh `_arg_N` ident, +/// returning `(rewritten_inputs_tokens, call_args_tokens)`. +fn rewrite_inputs_for_call( + inputs: &syn::punctuated::Punctuated, +) -> (TokenStream2, TokenStream2) { + let mut rewritten: Vec = Vec::new(); + let mut call_args: Vec = Vec::new(); + + for (i, arg) in inputs.iter().enumerate() { + match arg { + FnArg::Typed(pt) => { + let fresh = format_ident!("_arg_{}", i); + let ty = &pt.ty; + rewritten.push(quote!(#fresh: #ty)); + call_args.push(fresh.to_token_stream()); + } + FnArg::Receiver(_) => { + rewritten.push(arg.to_token_stream()); + call_args.push(quote!(self)); + } + } + } + + (quote!(#(#rewritten),*), quote!(#(#call_args),*)) +} + +/// Returns `T: thrust_models::Model` predicates for every type param that does not +/// already carry an `Fn`, `FnOnce`, or `FnMut` bound. +fn model_where_predicates( + func: &FnItemWithSignature, + outer_context: Option<&FnOuterItem>, +) -> Vec { + struct GenericTypeParam { + ident: syn::Ident, + bounds: Vec, + } + + impl From for GenericTypeParam { + fn from(tp: syn::TypeParam) -> Self { + Self { + ident: tp.ident, + bounds: tp.bounds.into_iter().collect(), + } + } + } + + impl GenericTypeParam { + fn has_fn_bound(&self) -> bool { + self.bounds.iter().any(|b| { + let TypeParamBound::Trait(tb) = b else { + return false; + }; + tb.path.segments.last().is_some_and(|s| { + matches!(s.ident.to_string().as_str(), "Fn" | "FnOnce" | "FnMut") + }) + }) + } + } + + let mut generic_type_params: Vec = Vec::new(); + for param in &func.sig().generics.params { + let GenericParam::Type(tp) = param else { + continue; + }; + generic_type_params.push(tp.clone().into()); + } + if let Some(outer_item) = outer_context { + for param in &outer_item.generics().params { + let GenericParam::Type(tp) = param else { + continue; + }; + generic_type_params.push(tp.clone().into()); + } + if let FnOuterItem::ItemTrait(outer_item) = &outer_item { + generic_type_params.push(GenericTypeParam { + ident: format_ident!("Self"), + bounds: outer_item.supertraits.iter().cloned().collect(), + }); + } + } + generic_type_params.retain(|p| !p.has_fn_bound()); + + let mut predicates: Vec = Vec::new(); + for param in &generic_type_params { + let ident = ¶m.ident; + predicates.push(syn::parse_quote!(#ident: thrust_models::Model)); + predicates.push(syn::parse_quote!(<#ident as thrust_models::Model>::Ty: PartialEq)); + } + + struct Visitor { + generic_type_params: Vec, + generic_paths: Vec, + } + + impl syn::visit::Visit<'_> for Visitor { + fn visit_type_path(&mut self, tp: &syn::TypePath) { + for param in &self.generic_type_params { + if let Some(qself) = &tp.qself { + let param = ¶m.ident; + let param_ty: syn::Type = syn::parse_quote!(#param); + if *qself.ty == param_ty { + self.generic_paths.push(tp.clone()); + } + } + if tp.path.segments.len() > 1 + && tp.path.segments.first().unwrap().ident == param.ident + && tp.qself.is_none() + { + self.generic_paths.push(tp.clone()); + } + } + syn::visit::visit_type_path(self, tp); + } + } + + let mut visitor = Visitor { + generic_type_params, + generic_paths: Vec::new(), + }; + use syn::visit::Visit as _; + for arg in &func.sig().inputs { + visitor.visit_fn_arg(arg); + } + visitor.visit_return_type(&func.sig().output); + for tp in visitor.generic_paths { + predicates.push(syn::parse_quote!(#tp: thrust_models::Model)); + predicates.push(syn::parse_quote!(<#tp as thrust_models::Model>::Ty: PartialEq)); + } + + predicates +} + +/// Builds `where , `. +/// Returns an empty token stream when both sets are empty. +fn extended_where_clause( + func: &FnItemWithSignature, + model_preds: &Vec, +) -> TokenStream2 { + let existing: Vec<&WherePredicate> = func + .sig() + .generics + .where_clause + .as_ref() + .map(|wc| wc.predicates.iter().collect()) + .unwrap_or_default(); + + if existing.is_empty() && model_preds.is_empty() { + return quote!(); + } + + quote! { where #(#existing,)* #(#model_preds),* } +} + +fn fn_return_ty_with_model_ty(ret: &syn::ReturnType) -> syn::Type { + match ret { + syn::ReturnType::Default => syn::parse_quote!(<() as thrust_models::Model>::Ty), + syn::ReturnType::Type(_, ty) => syn::parse_quote!(<#ty as thrust_models::Model>::Ty), + } +}