Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/analyze/annot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,14 @@ pub fn exists_path() -> [Symbol; 3] {
]
}

pub fn invariant_marker_path() -> [Symbol; 3] {
[
Symbol::intern("thrust"),
Symbol::intern("def"),
Symbol::intern("invariant_marker"),
]
}

/// A [`annot::Resolver`] implementation for resolving function parameters.
///
/// The parameter names and their sorts needs to be configured via
Expand Down
4 changes: 4 additions & 0 deletions src/analyze/annot_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ where
}

impl<'tcx> FormulaFn<'tcx> {
pub fn formula(&self) -> &chc::Formula<rty::FunctionParamIdx> {
&self.formula
}

pub fn to_require_annot(&self) -> AnnotFormula<rty::FunctionParamIdx> {
AnnotFormula::Formula(self.formula.clone())
}
Expand Down
20 changes: 20 additions & 0 deletions src/analyze/basic_block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1024,6 +1024,19 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
None
}

/// Whether this block's terminator is a loop-invariant marker call.
fn terminator_is_invariant_marker(&self) -> Option<BasicBlock> {
let term = &self.body.basic_blocks[self.basic_block].terminator().kind;
if let TerminatorKind::Call { func, target, .. } = term {
if let Some((def_id, _)) = func.const_fn_def() {
if Some(def_id) == self.ctx.def_ids().invariant_marker() {
return Some(target.expect("invariant marker call must have a target"));
}
}
}
None
}

fn analyze_statements(&mut self) {
for local in self.drop_points.before_statements.clone() {
tracing::info!(?local, "implicitly dropped before statements");
Expand Down Expand Up @@ -1065,6 +1078,13 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
source_info: term.source_info,
};
}
if let Some(target) = self.terminator_is_invariant_marker() {
tracing::debug!(?term, "skip invariant marker");
return mir::Terminator {
kind: TerminatorKind::Goto { target },
source_info: term.source_info,
};
}
self.rust_call_visitor().visit_terminator(&mut term);
self.reborrow_visitor().visit_terminator(&mut term);
tracing::debug!(term = ?term.kind);
Expand Down
8 changes: 8 additions & 0 deletions src/analyze/did_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ struct DefIds {
array_model_store: OnceCell<Option<DefId>>,

exists: OnceCell<Option<DefId>>,
invariant_marker: OnceCell<Option<DefId>>,
}

/// Retrieves and caches well-known [`DefId`]s.
Expand Down Expand Up @@ -176,4 +177,11 @@ impl<'tcx> DefIdCache<'tcx> {
.exists
.get_or_init(|| self.annotated_def(&crate::analyze::annot::exists_path()))
}

pub fn invariant_marker(&self) -> Option<DefId> {
*self
.def_ids
.invariant_marker
.get_or_init(|| self.annotated_def(&crate::analyze::annot::invariant_marker_path()))
}
}
145 changes: 144 additions & 1 deletion src/analyze/local_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -805,8 +805,135 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
}
}

/// Scans the body for loop-invariant marker calls and groups them by
/// enclosing loop header. Multiple invariants for the same header are kept
/// in source order; the caller is responsible for AND'ing them.
fn collect_loop_invariant_annotations(
&self,
) -> HashMap<BasicBlock, Vec<(LocalDefId, mir_ty::GenericArgsRef<'tcx>)>> {
let mut loop_invariants: HashMap<_, Vec<_>> = HashMap::new();
for (bb, data) in self.body.basic_blocks.iter_enumerated() {
let Some(term) = &data.terminator else {
continue;
};
let mir::TerminatorKind::Call { func, args, .. } = &term.kind else {
continue;
};
let Some((def_id, _)) = func.const_fn_def() else {
continue;
};
if Some(def_id) != self.ctx.def_ids().invariant_marker() {
continue;
}

let arg_ty = args[0].node.ty(&self.body.local_decls, self.tcx);
let mir_ty::TyKind::FnDef(formula_def_id, generic_args) = arg_ty.kind() else {
panic!("invariant marker argument must be a formula function item");
};
let formula_def_id = formula_def_id
.as_local()
.expect("invariant formula function must be local");
let header = Self::loop_header_of(&self.body, bb).unwrap_or_else(|| {
panic!("no enclosing loop header for invariant marker at {bb:?}")
});
loop_invariants
.entry(header)
.or_default()
.push((formula_def_id, *generic_args));
}
loop_invariants
}

/// Walks up the dominator tree from the marker block to the innermost
/// enclosing loop header: the first dominator that needs its own
/// precondition (in-degree >= 2) and has a back edge.
fn loop_header_of(body: &Body<'_>, marker_bb: BasicBlock) -> Option<BasicBlock> {
let doms = body.basic_blocks.dominators();
let preds = body.basic_blocks.predecessors();
let mut cur = Some(marker_bb);
while let Some(bb) = cur {
if analyze::basic_block::needs_own_precondition(body, bb)
&& preds[bb].iter().any(|&p| doms.dominates(bb, p))
{
return Some(bb);
}
cur = doms.immediate_dominator(bb);
}
None
}

/// Resolves the live local matching a source variable name at the given
/// basic block, among the locals that are parameters of `bty`.
///
/// When several distinct live locals share the name (e.g. two shadowed
/// variables that are both loop-carried), the mapping is ambiguous; rather
/// than silently pick one, this raises a fatal error. Disambiguating which
/// shadow an invariant refers to is left as future work.
fn local_of_name_in_bb(&self, name: rustc_span::Symbol, bty: &BasicBlockType) -> Option<Local> {
let mut found: Option<Local> = None;
for vdi in &self.body.var_debug_info {
if vdi.name != name {
continue;
}
let mir::VarDebugInfoContents::Place(place) = vdi.value else {
continue;
};
if !place.projection.is_empty() {
continue;
}
if bty.param_of_local(place.local).is_none() {
continue;
}
match found {
None => found = Some(place.local),
Some(prev) if prev == place.local => {}
Some(_) => self.tcx.dcx().fatal(format!(
"loop invariant refers to `{name}`, which is ambiguous at the loop header: \
multiple live variables share this name (e.g. through shadowing). \
Rename the variables to disambiguate."
)),
}
}
found
}

/// Translates a user-provided loop invariant (a formula function over named
/// live variables) into a precondition refinement over `bty`'s parameters.
/// Each formula parameter names a live variable at the loop header and is
/// mapped to the corresponding basic-block parameter.
fn build_invariant_precondition(
&self,
formula_def_id: LocalDefId,
generic_args: mir_ty::GenericArgsRef<'tcx>,
bty: &BasicBlockType,
) -> rty::Refinement<rty::FunctionParamIdx> {
let formula_fn = self
.ctx
.formula_fn_with_args(formula_def_id, generic_args)
.expect("invariant formula function is not registered");
let idents = self.tcx.fn_arg_idents(formula_def_id.to_def_id());

let mut mapping: Vec<rty::FunctionParamIdx> = Vec::with_capacity(idents.len());
for ident in idents {
let name = ident.expect("invariant parameters must be named").name;
let local = self.local_of_name_in_bb(name, bty).unwrap_or_else(|| {
self.tcx.dcx().fatal(format!(
"loop invariant refers to `{name}`, which is not a live variable at the loop header"
))
});
mapping.push(bty.param_of_local(local).unwrap());
}

formula_fn
.formula()
.clone()
.subst_var(|idx| chc::Term::var(rty::RefinedTypeVar::Free(mapping[idx.index()])))
.into()
}

fn refine_basic_blocks(&mut self) {
use rustc_mir_dataflow::Analysis as _;
let loop_invariants = self.collect_loop_invariant_annotations();
let mut results = rustc_mir_dataflow::impls::MaybeLiveLocals
.iterate_to_fixpoint(self.tcx, &self.body, None)
.into_results_cursor(&self.body);
Expand Down Expand Up @@ -851,7 +978,23 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
}
// function return type is basic block return type
let ret_ty = self.body.local_decls[mir::RETURN_PLACE].ty;
if analyze::basic_block::needs_own_precondition(&self.body, bb) {
if let Some(invariants) = loop_invariants.get(&bb) {
// A user-supplied loop invariant fully replaces inference: build
// the block type without a precondition pvar and install the
// invariant as its precondition. Multiple `invariant!` calls at
// the same header are AND'd in source order.
let mut bty = self
.type_builder
.build_basic_block(&self.body, live_locals, ret_ty);
let mut inv = rty::Refinement::top();
for &(formula_def_id, generic_args) in invariants {
let one = self.build_invariant_precondition(formula_def_id, generic_args, &bty);
inv.push_conj(one);
}
bty.set_precondition(inv);
self.ctx
.register_basic_block_ty_with_precondition(self.local_def_id, bb, bty);
} else if analyze::basic_block::needs_own_precondition(&self.body, bb) {
let bty = self
.type_builder
.for_template(&mut self.ctx)
Expand Down
7 changes: 7 additions & 0 deletions std.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,13 @@ mod thrust_models {
pub fn exists<T>(_x: T) -> bool {
unimplemented!()
}

#[thrust::def::invariant_marker]
#[thrust::ignored]
#[inline(never)]
pub fn __invariant_marker<F>(_f: F) {
unimplemented!()
}
}

#[thrust::extern_spec_fn]
Expand Down
20 changes: 20 additions & 0 deletions tests/ui/fail/loop_invariant.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
//@error-in-other-file: Unsat
//@compile-flags: -C debug-assertions=off

#[thrust_macros::requires(true)]
#[thrust_macros::ensures(true)]
#[thrust::trusted]
fn rand() -> i64 { unimplemented!() }

fn main() {
let mut x = 1_i64;
let mut y = 1_i64;
while rand() == 0 {
thrust_macros::invariant!(|x: i64| x >= 1);
let t1 = x;
let t2 = y;
x = t1 + t2;
y = t1 + t2;
}
assert!(y >= 1);
}
24 changes: 24 additions & 0 deletions tests/ui/fail/loop_invariant_method.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
//@error-in-other-file: Unsat
//@compile-flags: -C debug-assertions=off

#[thrust_macros::requires(true)]
#[thrust_macros::ensures(true)]
#[thrust::trusted]
fn rand() -> i64 { unimplemented!() }

struct Counter;

impl Counter {
fn run(&self) {
let mut x = 5_i64;
while rand() == 0 {
thrust_macros::invariant!(|x: i64| x >= 1);
x = x - 1;
}
assert!(x >= 1);
}
}

fn main() {
Counter.run();
}
24 changes: 24 additions & 0 deletions tests/ui/fail/loop_invariant_multi.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
//@error-in-other-file: Unsat
//@compile-flags: -C debug-assertions=off

// Both invariants are AND'd. If either is too weak to be inductive, the
// verification fails — here `y >= 2` does not hold initially.

#[thrust_macros::requires(true)]
#[thrust_macros::ensures(true)]
#[thrust::trusted]
fn rand() -> i64 { unimplemented!() }

fn main() {
let mut x = 1_i64;
let mut y = 1_i64;
while rand() == 0 {
thrust_macros::invariant!(|x: i64| x >= 1);
thrust_macros::invariant!(|y: i64| y >= 2);
let t1 = x;
let t2 = y;
x = t1 + t2;
y = t1 + t2;
}
assert!(x >= 1 && y >= 2);
}
17 changes: 17 additions & 0 deletions tests/ui/fail/loop_invariant_mut.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
//@error-in-other-file: Unsat
//@compile-flags: -C debug-assertions=off

#[thrust_macros::requires(true)]
#[thrust_macros::ensures(true)]
#[thrust::trusted]
fn rand() -> i64 { unimplemented!() }

fn main() {
let mut x = 5_i64;
let p = &mut x;
while rand() == 0 {
thrust_macros::invariant!(|p: &mut i64| *p >= 1);
*p = *p - 1;
}
assert!(*p >= 1);
}
20 changes: 20 additions & 0 deletions tests/ui/fail/loop_invariant_nested.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
//@error-in-other-file: Unsat
//@compile-flags: -C debug-assertions=off

#[thrust_macros::requires(true)]
#[thrust_macros::ensures(true)]
#[thrust::trusted]
fn rand() -> i64 { unimplemented!() }

fn main() {
let mut x = 1_i64;
while rand() == 0 {
let mut y = 1_i64;
while rand() == 0 {
thrust_macros::invariant!(|x: i64| x >= 1);
y = x + y;
}
x = x + y;
}
assert!(x >= 1);
}
20 changes: 20 additions & 0 deletions tests/ui/pass/loop_invariant.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
//@check-pass
//@compile-flags: -C debug-assertions=off

#[thrust_macros::requires(true)]
#[thrust_macros::ensures(true)]
#[thrust::trusted]
fn rand() -> i64 { unimplemented!() }

fn main() {
let mut x = 1_i64;
let mut y = 1_i64;
while rand() == 0 {
thrust_macros::invariant!(|x: i64, y: i64| x >= 1 && y >= 1);
let t1 = x;
let t2 = y;
x = t1 + t2;
y = t1 + t2;
}
assert!(y >= 1);
}
Loading