diff --git a/src/analyze/annot.rs b/src/analyze/annot.rs index ec8465c9..6a40943e 100644 --- a/src/analyze/annot.rs +++ b/src/analyze/annot.rs @@ -141,6 +141,14 @@ pub fn exists_path() -> [Symbol; 3] { ] } +pub fn invariant_marker_path() -> [Symbol; 3] { + [ + Symbol::intern("thrust"), + Symbol::intern("def"), + Symbol::intern("invariant_marker"), + ] +} + /// A [`annot::Resolver`] implementation for resolving function parameters. /// /// The parameter names and their sorts needs to be configured via diff --git a/src/analyze/annot_fn.rs b/src/analyze/annot_fn.rs index aae82be4..5503382f 100644 --- a/src/analyze/annot_fn.rs +++ b/src/analyze/annot_fn.rs @@ -40,6 +40,10 @@ where } impl<'tcx> FormulaFn<'tcx> { + pub fn formula(&self) -> &chc::Formula { + &self.formula + } + pub fn to_require_annot(&self) -> AnnotFormula { AnnotFormula::Formula(self.formula.clone()) } diff --git a/src/analyze/basic_block.rs b/src/analyze/basic_block.rs index 2ec663d7..5aa060d1 100644 --- a/src/analyze/basic_block.rs +++ b/src/analyze/basic_block.rs @@ -1024,6 +1024,19 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { None } + /// Whether this block's terminator is a loop-invariant marker call. + fn terminator_is_invariant_marker(&self) -> Option { + let term = &self.body.basic_blocks[self.basic_block].terminator().kind; + if let TerminatorKind::Call { func, target, .. } = term { + if let Some((def_id, _)) = func.const_fn_def() { + if Some(def_id) == self.ctx.def_ids().invariant_marker() { + return Some(target.expect("invariant marker call must have a target")); + } + } + } + None + } + fn analyze_statements(&mut self) { for local in self.drop_points.before_statements.clone() { tracing::info!(?local, "implicitly dropped before statements"); @@ -1065,6 +1078,13 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { source_info: term.source_info, }; } + if let Some(target) = self.terminator_is_invariant_marker() { + tracing::debug!(?term, "skip invariant marker"); + return mir::Terminator { + kind: TerminatorKind::Goto { target }, + source_info: term.source_info, + }; + } self.rust_call_visitor().visit_terminator(&mut term); self.reborrow_visitor().visit_terminator(&mut term); tracing::debug!(term = ?term.kind); diff --git a/src/analyze/did_cache.rs b/src/analyze/did_cache.rs index c110adfc..98cc382e 100644 --- a/src/analyze/did_cache.rs +++ b/src/analyze/did_cache.rs @@ -25,6 +25,7 @@ struct DefIds { array_model_store: OnceCell>, exists: OnceCell>, + invariant_marker: OnceCell>, } /// Retrieves and caches well-known [`DefId`]s. @@ -176,4 +177,11 @@ impl<'tcx> DefIdCache<'tcx> { .exists .get_or_init(|| self.annotated_def(&crate::analyze::annot::exists_path())) } + + pub fn invariant_marker(&self) -> Option { + *self + .def_ids + .invariant_marker + .get_or_init(|| self.annotated_def(&crate::analyze::annot::invariant_marker_path())) + } } diff --git a/src/analyze/local_def.rs b/src/analyze/local_def.rs index 9f7f57ee..64d4100d 100644 --- a/src/analyze/local_def.rs +++ b/src/analyze/local_def.rs @@ -805,8 +805,135 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { } } + /// Scans the body for loop-invariant marker calls and groups them by + /// enclosing loop header. Multiple invariants for the same header are kept + /// in source order; the caller is responsible for AND'ing them. + fn collect_loop_invariant_annotations( + &self, + ) -> HashMap)>> { + let mut loop_invariants: HashMap<_, Vec<_>> = HashMap::new(); + for (bb, data) in self.body.basic_blocks.iter_enumerated() { + let Some(term) = &data.terminator else { + continue; + }; + let mir::TerminatorKind::Call { func, args, .. } = &term.kind else { + continue; + }; + let Some((def_id, _)) = func.const_fn_def() else { + continue; + }; + if Some(def_id) != self.ctx.def_ids().invariant_marker() { + continue; + } + + let arg_ty = args[0].node.ty(&self.body.local_decls, self.tcx); + let mir_ty::TyKind::FnDef(formula_def_id, generic_args) = arg_ty.kind() else { + panic!("invariant marker argument must be a formula function item"); + }; + let formula_def_id = formula_def_id + .as_local() + .expect("invariant formula function must be local"); + let header = Self::loop_header_of(&self.body, bb).unwrap_or_else(|| { + panic!("no enclosing loop header for invariant marker at {bb:?}") + }); + loop_invariants + .entry(header) + .or_default() + .push((formula_def_id, *generic_args)); + } + loop_invariants + } + + /// Walks up the dominator tree from the marker block to the innermost + /// enclosing loop header: the first dominator that needs its own + /// precondition (in-degree >= 2) and has a back edge. + fn loop_header_of(body: &Body<'_>, marker_bb: BasicBlock) -> Option { + let doms = body.basic_blocks.dominators(); + let preds = body.basic_blocks.predecessors(); + let mut cur = Some(marker_bb); + while let Some(bb) = cur { + if analyze::basic_block::needs_own_precondition(body, bb) + && preds[bb].iter().any(|&p| doms.dominates(bb, p)) + { + return Some(bb); + } + cur = doms.immediate_dominator(bb); + } + None + } + + /// Resolves the live local matching a source variable name at the given + /// basic block, among the locals that are parameters of `bty`. + /// + /// When several distinct live locals share the name (e.g. two shadowed + /// variables that are both loop-carried), the mapping is ambiguous; rather + /// than silently pick one, this raises a fatal error. Disambiguating which + /// shadow an invariant refers to is left as future work. + fn local_of_name_in_bb(&self, name: rustc_span::Symbol, bty: &BasicBlockType) -> Option { + let mut found: Option = None; + for vdi in &self.body.var_debug_info { + if vdi.name != name { + continue; + } + let mir::VarDebugInfoContents::Place(place) = vdi.value else { + continue; + }; + if !place.projection.is_empty() { + continue; + } + if bty.param_of_local(place.local).is_none() { + continue; + } + match found { + None => found = Some(place.local), + Some(prev) if prev == place.local => {} + Some(_) => self.tcx.dcx().fatal(format!( + "loop invariant refers to `{name}`, which is ambiguous at the loop header: \ + multiple live variables share this name (e.g. through shadowing). \ + Rename the variables to disambiguate." + )), + } + } + found + } + + /// Translates a user-provided loop invariant (a formula function over named + /// live variables) into a precondition refinement over `bty`'s parameters. + /// Each formula parameter names a live variable at the loop header and is + /// mapped to the corresponding basic-block parameter. + fn build_invariant_precondition( + &self, + formula_def_id: LocalDefId, + generic_args: mir_ty::GenericArgsRef<'tcx>, + bty: &BasicBlockType, + ) -> rty::Refinement { + let formula_fn = self + .ctx + .formula_fn_with_args(formula_def_id, generic_args) + .expect("invariant formula function is not registered"); + let idents = self.tcx.fn_arg_idents(formula_def_id.to_def_id()); + + let mut mapping: Vec = Vec::with_capacity(idents.len()); + for ident in idents { + let name = ident.expect("invariant parameters must be named").name; + let local = self.local_of_name_in_bb(name, bty).unwrap_or_else(|| { + self.tcx.dcx().fatal(format!( + "loop invariant refers to `{name}`, which is not a live variable at the loop header" + )) + }); + mapping.push(bty.param_of_local(local).unwrap()); + } + + formula_fn + .formula() + .clone() + .subst_var(|idx| chc::Term::var(rty::RefinedTypeVar::Free(mapping[idx.index()]))) + .into() + } + fn refine_basic_blocks(&mut self) { use rustc_mir_dataflow::Analysis as _; + let loop_invariants = self.collect_loop_invariant_annotations(); let mut results = rustc_mir_dataflow::impls::MaybeLiveLocals .iterate_to_fixpoint(self.tcx, &self.body, None) .into_results_cursor(&self.body); @@ -851,7 +978,23 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { } // function return type is basic block return type let ret_ty = self.body.local_decls[mir::RETURN_PLACE].ty; - if analyze::basic_block::needs_own_precondition(&self.body, bb) { + if let Some(invariants) = loop_invariants.get(&bb) { + // A user-supplied loop invariant fully replaces inference: build + // the block type without a precondition pvar and install the + // invariant as its precondition. Multiple `invariant!` calls at + // the same header are AND'd in source order. + let mut bty = self + .type_builder + .build_basic_block(&self.body, live_locals, ret_ty); + let mut inv = rty::Refinement::top(); + for &(formula_def_id, generic_args) in invariants { + let one = self.build_invariant_precondition(formula_def_id, generic_args, &bty); + inv.push_conj(one); + } + bty.set_precondition(inv); + self.ctx + .register_basic_block_ty_with_precondition(self.local_def_id, bb, bty); + } else if analyze::basic_block::needs_own_precondition(&self.body, bb) { let bty = self .type_builder .for_template(&mut self.ctx) diff --git a/std.rs b/std.rs index dc30924e..1b0ad409 100644 --- a/std.rs +++ b/std.rs @@ -302,6 +302,13 @@ mod thrust_models { pub fn exists(_x: T) -> bool { unimplemented!() } + + #[thrust::def::invariant_marker] + #[thrust::ignored] + #[inline(never)] + pub fn __invariant_marker(_f: F) { + unimplemented!() + } } #[thrust::extern_spec_fn] diff --git a/tests/ui/fail/loop_invariant.rs b/tests/ui/fail/loop_invariant.rs new file mode 100644 index 00000000..ba72e936 --- /dev/null +++ b/tests/ui/fail/loop_invariant.rs @@ -0,0 +1,20 @@ +//@error-in-other-file: Unsat +//@compile-flags: -C debug-assertions=off + +#[thrust_macros::requires(true)] +#[thrust_macros::ensures(true)] +#[thrust::trusted] +fn rand() -> i64 { unimplemented!() } + +fn main() { + let mut x = 1_i64; + let mut y = 1_i64; + while rand() == 0 { + thrust_macros::invariant!(|x: i64| x >= 1); + let t1 = x; + let t2 = y; + x = t1 + t2; + y = t1 + t2; + } + assert!(y >= 1); +} diff --git a/tests/ui/fail/loop_invariant_method.rs b/tests/ui/fail/loop_invariant_method.rs new file mode 100644 index 00000000..fec3edaf --- /dev/null +++ b/tests/ui/fail/loop_invariant_method.rs @@ -0,0 +1,24 @@ +//@error-in-other-file: Unsat +//@compile-flags: -C debug-assertions=off + +#[thrust_macros::requires(true)] +#[thrust_macros::ensures(true)] +#[thrust::trusted] +fn rand() -> i64 { unimplemented!() } + +struct Counter; + +impl Counter { + fn run(&self) { + let mut x = 5_i64; + while rand() == 0 { + thrust_macros::invariant!(|x: i64| x >= 1); + x = x - 1; + } + assert!(x >= 1); + } +} + +fn main() { + Counter.run(); +} diff --git a/tests/ui/fail/loop_invariant_multi.rs b/tests/ui/fail/loop_invariant_multi.rs new file mode 100644 index 00000000..bb040f14 --- /dev/null +++ b/tests/ui/fail/loop_invariant_multi.rs @@ -0,0 +1,24 @@ +//@error-in-other-file: Unsat +//@compile-flags: -C debug-assertions=off + +// Both invariants are AND'd. If either is too weak to be inductive, the +// verification fails — here `y >= 2` does not hold initially. + +#[thrust_macros::requires(true)] +#[thrust_macros::ensures(true)] +#[thrust::trusted] +fn rand() -> i64 { unimplemented!() } + +fn main() { + let mut x = 1_i64; + let mut y = 1_i64; + while rand() == 0 { + thrust_macros::invariant!(|x: i64| x >= 1); + thrust_macros::invariant!(|y: i64| y >= 2); + let t1 = x; + let t2 = y; + x = t1 + t2; + y = t1 + t2; + } + assert!(x >= 1 && y >= 2); +} diff --git a/tests/ui/fail/loop_invariant_mut.rs b/tests/ui/fail/loop_invariant_mut.rs new file mode 100644 index 00000000..3074aae8 --- /dev/null +++ b/tests/ui/fail/loop_invariant_mut.rs @@ -0,0 +1,17 @@ +//@error-in-other-file: Unsat +//@compile-flags: -C debug-assertions=off + +#[thrust_macros::requires(true)] +#[thrust_macros::ensures(true)] +#[thrust::trusted] +fn rand() -> i64 { unimplemented!() } + +fn main() { + let mut x = 5_i64; + let p = &mut x; + while rand() == 0 { + thrust_macros::invariant!(|p: &mut i64| *p >= 1); + *p = *p - 1; + } + assert!(*p >= 1); +} diff --git a/tests/ui/fail/loop_invariant_nested.rs b/tests/ui/fail/loop_invariant_nested.rs new file mode 100644 index 00000000..1d2486d1 --- /dev/null +++ b/tests/ui/fail/loop_invariant_nested.rs @@ -0,0 +1,20 @@ +//@error-in-other-file: Unsat +//@compile-flags: -C debug-assertions=off + +#[thrust_macros::requires(true)] +#[thrust_macros::ensures(true)] +#[thrust::trusted] +fn rand() -> i64 { unimplemented!() } + +fn main() { + let mut x = 1_i64; + while rand() == 0 { + let mut y = 1_i64; + while rand() == 0 { + thrust_macros::invariant!(|x: i64| x >= 1); + y = x + y; + } + x = x + y; + } + assert!(x >= 1); +} diff --git a/tests/ui/pass/loop_invariant.rs b/tests/ui/pass/loop_invariant.rs new file mode 100644 index 00000000..c64b0c84 --- /dev/null +++ b/tests/ui/pass/loop_invariant.rs @@ -0,0 +1,20 @@ +//@check-pass +//@compile-flags: -C debug-assertions=off + +#[thrust_macros::requires(true)] +#[thrust_macros::ensures(true)] +#[thrust::trusted] +fn rand() -> i64 { unimplemented!() } + +fn main() { + let mut x = 1_i64; + let mut y = 1_i64; + while rand() == 0 { + thrust_macros::invariant!(|x: i64, y: i64| x >= 1 && y >= 1); + let t1 = x; + let t2 = y; + x = t1 + t2; + y = t1 + t2; + } + assert!(y >= 1); +} diff --git a/tests/ui/pass/loop_invariant_method.rs b/tests/ui/pass/loop_invariant_method.rs new file mode 100644 index 00000000..198195cd --- /dev/null +++ b/tests/ui/pass/loop_invariant_method.rs @@ -0,0 +1,24 @@ +//@check-pass +//@compile-flags: -C debug-assertions=off + +#[thrust_macros::requires(true)] +#[thrust_macros::ensures(true)] +#[thrust::trusted] +fn rand() -> i64 { unimplemented!() } + +struct Counter; + +impl Counter { + fn run(&self) { + let mut x = 1_i64; + while rand() == 0 { + thrust_macros::invariant!(|x: i64| x >= 1); + x = x + 1; + } + assert!(x >= 1); + } +} + +fn main() { + Counter.run(); +} diff --git a/tests/ui/pass/loop_invariant_multi.rs b/tests/ui/pass/loop_invariant_multi.rs new file mode 100644 index 00000000..058b2d1b --- /dev/null +++ b/tests/ui/pass/loop_invariant_multi.rs @@ -0,0 +1,24 @@ +//@check-pass +//@compile-flags: -C debug-assertions=off + +// Multiple `invariant!` calls at the same loop header are AND'd: the proof +// below needs both `x >= 1` and `y >= 1` to be carried across the back edge. + +#[thrust_macros::requires(true)] +#[thrust_macros::ensures(true)] +#[thrust::trusted] +fn rand() -> i64 { unimplemented!() } + +fn main() { + let mut x = 1_i64; + let mut y = 1_i64; + while rand() == 0 { + thrust_macros::invariant!(|x: i64| x >= 1); + thrust_macros::invariant!(|y: i64| y >= 1); + let t1 = x; + let t2 = y; + x = t1 + t2; + y = t1 + t2; + } + assert!(x >= 1 && y >= 1); +} diff --git a/tests/ui/pass/loop_invariant_mut.rs b/tests/ui/pass/loop_invariant_mut.rs new file mode 100644 index 00000000..ca9f46e6 --- /dev/null +++ b/tests/ui/pass/loop_invariant_mut.rs @@ -0,0 +1,17 @@ +//@check-pass +//@compile-flags: -C debug-assertions=off + +#[thrust_macros::requires(true)] +#[thrust_macros::ensures(true)] +#[thrust::trusted] +fn rand() -> i64 { unimplemented!() } + +fn main() { + let mut x = 1_i64; + let p = &mut x; + while rand() == 0 { + thrust_macros::invariant!(|p: &mut i64| *p >= 1); + *p = *p + 1; + } + assert!(*p >= 1); +} diff --git a/tests/ui/pass/loop_invariant_nested.rs b/tests/ui/pass/loop_invariant_nested.rs new file mode 100644 index 00000000..2618af18 --- /dev/null +++ b/tests/ui/pass/loop_invariant_nested.rs @@ -0,0 +1,20 @@ +//@check-pass +//@compile-flags: -C debug-assertions=off + +#[thrust_macros::requires(true)] +#[thrust_macros::ensures(true)] +#[thrust::trusted] +fn rand() -> i64 { unimplemented!() } + +fn main() { + let mut x = 1_i64; + while rand() == 0 { + let mut y = 1_i64; + while rand() == 0 { + thrust_macros::invariant!(|x: i64, y: i64| x >= 1 && y >= 1); + y = x + y; + } + x = x + y; + } + assert!(x >= 1); +} diff --git a/thrust-macros/src/invariant.rs b/thrust-macros/src/invariant.rs new file mode 100644 index 00000000..8edf62eb --- /dev/null +++ b/thrust-macros/src/invariant.rs @@ -0,0 +1,54 @@ +//! Expansion of `thrust_macros::invariant!`. +//! +//! Expands a closure with concrete parameter types into a +//! `#[thrust::formula_fn]` over `Model::Ty` parameters and a marker call +//! referencing it. + +use std::sync::atomic::{AtomicUsize, Ordering}; + +use proc_macro::TokenStream; +use quote::{format_ident, quote}; +use syn::{parse_macro_input, FnArg}; + +use crate::fn_params_with_model_ty; + +static COUNTER: AtomicUsize = AtomicUsize::new(0); + +pub fn expand(input: TokenStream) -> TokenStream { + let closure = parse_macro_input!(input as syn::ExprClosure); + + let mut fn_params: Vec = Vec::new(); + for param in &closure.inputs { + let syn::Pat::Type(pt) = param else { + return syn::Error::new_spanned( + param, + "invariant closure parameters must have explicit types, e.g. `|x: i64| ...`", + ) + .to_compile_error() + .into(); + }; + let pat = &pt.pat; + let ty = &pt.ty; + fn_params.push(syn::parse_quote!(#pat: #ty)); + } + + let model_ty_params = fn_params_with_model_ty(&fn_params); + let body = &closure.body; + + let id = COUNTER.fetch_add(1, Ordering::Relaxed); + let name = format_ident!("_thrust_invariant_{}", id); + + quote! { + { + #[allow(unused_variables)] + #[allow(non_snake_case)] + #[thrust::formula_fn] + fn #name(#model_ty_params) -> bool { + #body + } + + thrust_models::__invariant_marker(#name) + } + } + .into() +} diff --git a/thrust-macros/src/lib.rs b/thrust-macros/src/lib.rs index 7bc0c042..99df73dc 100644 --- a/thrust-macros/src/lib.rs +++ b/thrust-macros/src/lib.rs @@ -6,6 +6,8 @@ use syn::{ WherePredicate, }; +mod invariant; + #[derive(Debug, Clone)] enum FnOuterItem { ItemImpl(syn::ItemImpl), @@ -62,6 +64,25 @@ impl FnOuterItem { } } +/// Declares a loop invariant inside a loop body: +/// +/// ```ignore +/// fn f() { +/// while cond { +/// thrust_macros::invariant!(|x: i64| x >= 1); +/// ... +/// } +/// } +/// ``` +/// +/// The argument is a closure whose parameters name the live variables the +/// invariant refers to (with their types) and whose body is the invariant +/// predicate. +#[proc_macro] +pub fn invariant(input: TokenStream) -> TokenStream { + invariant::expand(input) +} + #[proc_macro_attribute] pub fn context(_attr: TokenStream, item: TokenStream) -> TokenStream { let mut outer_item = syn::parse_macro_input!(item as FnOuterItem);