Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions tests/ui/fail/loop_invariant_generic.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
//@error-in-other-file: Unsat
//@compile-flags: -C debug-assertions=off

#[thrust_macros::requires(true)]
#[thrust_macros::ensures(true)]
#[thrust::trusted]
fn rand() -> i64 { unimplemented!() }

#[thrust_macros::context]
fn keep<T: Copy + PartialEq>(v: T) {
let mut x = v;
while rand() == 0 {
thrust_macros::invariant!(|v: T| v == v);
x = v;
}
assert!(x == v);
}

fn main() {
keep(0_i64);
keep(true);
}
29 changes: 29 additions & 0 deletions tests/ui/fail/loop_invariant_self.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
//@error-in-other-file: Unsat
//@compile-flags: -C debug-assertions=off

#[thrust_macros::requires(true)]
#[thrust_macros::ensures(true)]
#[thrust::trusted]
fn rand() -> i64 { unimplemented!() }

struct Counter(i64);
impl thrust_models::Model for Counter {
type Ty = (thrust_models::model::Int,);
}

#[thrust_macros::context]
impl Counter {
fn run(self) {
let mut c = self;
let mut x = 1_i64;
while rand() == 0 {
thrust_macros::invariant!(|x: i64, c: Self| x >= 2 && c == c);
x = x + 1;
c = Counter(0);
}
let _last = c;
assert!(x >= 1);
}
}

fn main() { Counter(0).run(); }
22 changes: 22 additions & 0 deletions tests/ui/pass/loop_invariant_generic.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
//@check-pass
//@compile-flags: -C debug-assertions=off

#[thrust_macros::requires(true)]
#[thrust_macros::ensures(true)]
#[thrust::trusted]
fn rand() -> i64 { unimplemented!() }

#[thrust_macros::context]
fn keep<T: Copy + PartialEq>(v: T) {
let mut x = v;
while rand() == 0 {
thrust_macros::invariant!(|x: T, v: T| x == v);
x = v;
}
assert!(x == v);
}

fn main() {
keep(0_i64);
keep(true);
}
29 changes: 29 additions & 0 deletions tests/ui/pass/loop_invariant_self.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
//@check-pass
//@compile-flags: -C debug-assertions=off

#[thrust_macros::requires(true)]
#[thrust_macros::ensures(true)]
#[thrust::trusted]
fn rand() -> i64 { unimplemented!() }

struct Counter(i64);
impl thrust_models::Model for Counter {
type Ty = (thrust_models::model::Int,);
}

#[thrust_macros::context]
impl Counter {
fn run(self) {
let mut c = self;
let mut x = 1_i64;
while rand() == 0 {
thrust_macros::invariant!(|x: i64, c: Self| x >= 1 && c == c);
x = x + 1;
c = Counter(0);
}
let _last = c;
assert!(x >= 1);
}
}

fn main() { Counter(0).run(); }
210 changes: 210 additions & 0 deletions thrust-macros/src/context.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
//! Expansion of `#[thrust_macros::context]`.
//!
//! Supplies the surrounding context that thrust annotations within an item
//! cannot see by themselves: stamps the enclosing `impl`/`trait` header onto
//! each method (so method-level `requires`/`ensures` recover the outer
//! generics) and threads that generic context into every `invariant!(...)`.

use proc_macro::TokenStream;
use proc_macro2::{TokenStream as TokenStream2, TokenTree};
use quote::{quote, ToTokens};
use syn::{parse_macro_input, GenericParam, Generics, WherePredicate};

use crate::FnOuterItem;

pub(super) fn expand(item: TokenStream) -> TokenStream {
let mut item = parse_macro_input!(item as syn::Item);
process_context_item(&mut item);
item.into_token_stream().into()
}

/// Stamps outer context onto methods and threads the generic context into the
/// invariants of every function body in the item (recursing through modules).
fn process_context_item(item: &mut syn::Item) {
match item {
syn::Item::Fn(item_fn) => {
let generics = item_fn.sig.generics.clone();
let threaded = thread_invariants(&mut item_fn.block, &generics, None);
if threaded.found {
inject_model_bounds(&mut item_fn.sig.generics, None, false);
}
}
syn::Item::Impl(item_impl) => {
let outer = FnOuterItem::ItemImpl(item_impl.clone()).into_header_only();
for impl_item in &mut item_impl.items {
let syn::ImplItem::Fn(method) = impl_item else {
continue;
};
method
.attrs
.push(syn::parse_quote!(#[thrust::_outer_context(#outer)]));
let generics = method.sig.generics.clone();
let threaded = thread_invariants(&mut method.block, &generics, Some(&outer));
if threaded.found {
inject_model_bounds(&mut method.sig.generics, Some(&outer), threaded.self_used);
}
}
}
syn::Item::Trait(item_trait) => {
let outer = FnOuterItem::ItemTrait(item_trait.clone()).into_header_only();
for trait_item in &mut item_trait.items {
let syn::TraitItem::Fn(method) = trait_item else {
continue;
};
method
.attrs
.push(syn::parse_quote!(#[thrust::_outer_context(#outer)]));
if let Some(block) = &mut method.default {
let generics = method.sig.generics.clone();
let threaded = thread_invariants(block, &generics, Some(&outer));
if threaded.found {
inject_model_bounds(
&mut method.sig.generics,
Some(&outer),
threaded.self_used,
);
}
}
}
}
syn::Item::Mod(item_mod) => {
if let Some((_, items)) = &mut item_mod.content {
for inner in items {
process_context_item(inner);
}
}
}
_ => {}
}
}

struct Threaded {
found: bool,
self_used: bool,
}

/// Prepends the generic context to every `invariant!(...)` in a function body.
fn thread_invariants(
block: &mut syn::Block,
generics: &Generics,
outer: Option<&FnOuterItem>,
) -> Threaded {
use syn::visit_mut::VisitMut as _;

let context = invariant_context_tokens(generics, outer);
let mut threader = InvariantThreader {
context,
is_method: outer.is_some(),
found: false,
self_used: false,
};
threader.visit_block_mut(block);
Threaded {
found: threader.found,
self_used: threader.self_used,
}
}

/// Builds the `[generic-params] [where-predicates]` prefix that `invariant!`
/// consumes: every generic parameter in scope (the function's own and, for
/// methods, the outer ones), the existing where predicates, and the
/// `Model`/`PartialEq` bounds those parameters require.
fn invariant_context_tokens(generics: &Generics, outer: Option<&FnOuterItem>) -> TokenStream2 {
let mut params: Vec<GenericParam> = generics.params.iter().cloned().collect();
let mut preds: Vec<WherePredicate> = generics
.where_clause
.as_ref()
.map(|wc| wc.predicates.iter().cloned().collect())
.unwrap_or_default();
if let Some(outer) = outer {
params.extend(outer.generics().params.iter().cloned());
if let Some(wc) = &outer.generics().where_clause {
preds.extend(wc.predicates.iter().cloned());
}
}
for param in &params {
if let GenericParam::Type(tp) = param {
let ident = &tp.ident;
preds.push(syn::parse_quote!(#ident: thrust_models::Model));
preds.push(syn::parse_quote!(<#ident as thrust_models::Model>::Ty: PartialEq));
}
}
quote! { [#(#params),*] [#(#preds),*] }
}

struct InvariantThreader {
context: TokenStream2,
is_method: bool,
found: bool,
self_used: bool,
}

impl syn::visit_mut::VisitMut for InvariantThreader {
fn visit_macro_mut(&mut self, mac: &mut syn::Macro) {
syn::visit_mut::visit_macro_mut(self, mac);
if is_invariant_macro(&mac.path) {
self.found = true;
// An invariant in a method may name `Self` in its variable types.
// A nested item cannot, so signal `invariant!` to re-declare `Self`
// as a synthetic generic, but only when it is actually used (so we
// do not over-constrain the host method with `Self: Model`).
let uses_self = self.is_method && tokens_contain_self(&mac.tokens);
self.self_used |= uses_self;
let self_marker = if uses_self { quote!(Self) } else { quote!() };
let context = &self.context;
let original = &mac.tokens;
mac.tokens = quote! { [#self_marker] #context #original };
}
}
}

fn tokens_contain_self(tokens: &TokenStream2) -> bool {
tokens.clone().into_iter().any(|tt| match tt {
TokenTree::Ident(ident) => ident == "Self",
TokenTree::Group(group) => tokens_contain_self(&group.stream()),
_ => false,
})
}

/// Adds `T: Model` and `<T as Model>::Ty: PartialEq` bounds for every type
/// parameter in scope to a function's where clause. The marker call generated
/// for an invariant instantiates a `Model`-bounded formula function, so the
/// function hosting the call must itself satisfy those bounds. When an
/// invariant names `Self`, `invariant!` instantiates the formula function with
/// `Self`, so the same bounds are added for `Self` (`with_self`).
fn inject_model_bounds(generics: &mut Generics, outer: Option<&FnOuterItem>, with_self: bool) {
let mut tys: Vec<TokenStream2> = generics
.params
.iter()
.filter_map(|p| match p {
GenericParam::Type(tp) => Some(tp.ident.to_token_stream()),
_ => None,
})
.collect();
if let Some(outer) = outer {
for param in &outer.generics().params {
if let GenericParam::Type(tp) = param {
tys.push(tp.ident.to_token_stream());
}
}
}
if with_self {
tys.push(quote!(Self));
}
if tys.is_empty() {
return;
}
let where_clause = generics.make_where_clause();
for ty in tys {
where_clause
.predicates
.push(syn::parse_quote!(#ty: thrust_models::Model));
where_clause
.predicates
.push(syn::parse_quote!(<#ty as thrust_models::Model>::Ty: PartialEq));
}
}

fn is_invariant_macro(path: &syn::Path) -> bool {
path.segments.last().is_some_and(|s| s.ident == "invariant")
}
Loading