Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/analyze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ impl<'tcx> Analyzer<'tcx> {

/// Computes the signature of the local function.
///
/// This is a drop-in replacement of `self.tcx.fn_sig(local_def_id).instantiate_identity().skip_binder()`,
/// This works like `self.tcx.fn_sig(local_def_id).instantiate_identity().skip_binder()`,
/// but extracts parameter and return types directly from the given `body` to obtain a signature that
/// reflects potential type instantiations happened after `optimized_mir`.
pub fn local_fn_sig_with_body(
Expand All @@ -364,4 +364,14 @@ impl<'tcx> Analyzer<'tcx> {
sig.abi,
)
}

/// Computes the signature of the local function.
///
/// This works like `self.tcx.fn_sig(local_def_id).instantiate_identity().skip_binder()`,
/// but extracts parameter and return types directly from [`mir::Body`] to obtain a signature that
/// reflects the actual type of lifted closure functions.
pub fn local_fn_sig(&self, local_def_id: LocalDefId) -> mir_ty::FnSig<'tcx> {
let body = self.tcx.optimized_mir(local_def_id);
self.local_fn_sig_with_body(local_def_id, body)
}
}
78 changes: 68 additions & 10 deletions src/analyze/basic_block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,16 +84,55 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
) -> Vec<chc::Clause> {
let mut clauses = Vec::new();

if expected_args.is_empty() {
// elaboration: we need at least one predicate variable in parameter (see mir_function_ty_impl)
expected_args.push(rty::RefinedType::unrefined(rty::Type::unit()).vacuous());
}
tracing::debug!(
got = %got.display(),
expected = %crate::pretty::FunctionType::new(&expected_args, &expected_ret).display(),
"fn_sub_type"
);

match got.abi {
rty::FunctionAbi::Rust => {
if expected_args.is_empty() {
// elaboration: we need at least one predicate variable in parameter (see mir_function_ty_impl)
expected_args.push(rty::RefinedType::unrefined(rty::Type::unit()).vacuous());
}
}
rty::FunctionAbi::RustCall => {
// &Closure, { v: (own i32, own bool) | v = (<0>, <false>) }
// =>
// &Closure, { v: i32 | (<v>, _) = (<0>, <false>) }, { v: bool | (_, <v>) = (<0>, <false>) }

let rty::RefinedType { ty, mut refinement } =
expected_args.pop().expect("rust-call last arg");
let ty = ty.into_tuple().expect("rust-call last arg is tuple");
let mut replacement_tuple = Vec::new(); // will be (<v>, _) or (_, <v>)
for elem in &ty.elems {
let existential = refinement.existentials.push(elem.ty.to_sort());
replacement_tuple.push(chc::Term::var(rty::RefinedTypeVar::Existential(
existential,
)));
}

for (i, elem) in ty.elems.into_iter().enumerate() {
// all tuple elements are boxed during the translation to rty::Type
let mut param_ty = elem.deref();
param_ty
.refinement
.push_conj(refinement.clone().subst_value_var(|| {
let mut value_elems = replacement_tuple.clone();
value_elems[i] = chc::Term::var(rty::RefinedTypeVar::Value).boxed();
chc::Term::tuple(value_elems)
}));
expected_args.push(param_ty);
}

tracing::info!(
expected = %crate::pretty::FunctionType::new(&expected_args, &expected_ret).display(),
"rust-call expanded",
);
}
}

// TODO: check sty and length is equal
let mut builder = self.env.build_clause();
for (param_idx, param_rty) in got.params.iter_enumerated() {
Expand Down Expand Up @@ -175,6 +214,12 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
chc::Term::bool(val.try_to_bool().unwrap()),
)
}
(mir_ty::TyKind::Tuple(tys), _) if tys.is_empty() => {
PlaceType::with_ty_and_term(rty::Type::unit(), chc::Term::tuple(vec![]))
}
(mir_ty::TyKind::Closure(_, args), _) if args.as_closure().upvar_tys().is_empty() => {
PlaceType::with_ty_and_term(rty::Type::unit(), chc::Term::tuple(vec![]))
}
(
mir_ty::TyKind::Ref(_, elem, Mutability::Not),
ConstValue::Scalar(Scalar::Ptr(ptr, _)),
Expand Down Expand Up @@ -568,12 +613,25 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
let ret = rty::RefinedType::new(rty::Type::unit(), ret_formula.into());
rty::FunctionType::new([param1, param2].into_iter().collect(), ret).into()
}
Some((def_id, args)) => self
.ctx
.def_ty_with_args(def_id, args)
.expect("unknown def")
.ty
.vacuous(),
Some((def_id, args)) => {
let param_env = self.tcx.param_env(self.local_def_id);
let instance =
mir_ty::Instance::resolve(self.tcx, param_env, def_id, args).unwrap();
let resolved_def_id = if let Some(instance) = instance {
instance.def_id()
} else {
def_id
};
if def_id != resolved_def_id {
tracing::info!(?def_id, ?resolved_def_id, "resolve",);
}

self.ctx
.def_ty_with_args(resolved_def_id, args)
.expect("unknown def")
.ty
.vacuous()
}
_ => self.operand_type(func.clone()).ty,
};
let expected_args: IndexVec<_, _> = args
Expand Down
7 changes: 2 additions & 5 deletions src/analyze/crate_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,15 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {

#[tracing::instrument(skip(self), fields(def_id = %self.tcx.def_path_str(local_def_id)))]
fn refine_fn_def(&mut self, local_def_id: LocalDefId) {
let sig = self.ctx.local_fn_sig(local_def_id);

let mut analyzer = self.ctx.local_def_analyzer(local_def_id);

if analyzer.is_annotated_as_trusted() {
assert!(analyzer.is_fully_annotated());
self.trusted.insert(local_def_id.to_def_id());
}

let sig = self
.tcx
.fn_sig(local_def_id)
.instantiate_identity()
.skip_binder();
use mir_ty::TypeVisitableExt as _;
if sig.has_param() && !analyzer.is_fully_annotated() {
self.ctx.register_deferred_def(local_def_id.to_def_id());
Expand Down
4 changes: 4 additions & 0 deletions src/chc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,10 @@ impl<V> Term<V> {
Term::Mut(Box::new(t1), Box::new(t2))
}

pub fn boxed(self) -> Self {
Term::Box(Box::new(self))
}

pub fn box_current(self) -> Self {
Term::BoxCurrent(Box::new(self))
}
Expand Down
17 changes: 15 additions & 2 deletions src/refine/template.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ impl<'tcx> TypeBuilder<'tcx> {
unimplemented!("unsupported ADT: {:?}", ty);
}
}
mir_ty::TyKind::Closure(_, args) => self.build(args.as_closure().tupled_upvars_ty()),
kind => unimplemented!("unrefined_ty: {:?}", kind),
}
}
Expand All @@ -183,6 +184,11 @@ impl<'tcx> TypeBuilder<'tcx> {
registry: &'a mut R,
sig: mir_ty::FnSig<'tcx>,
) -> FunctionTemplateTypeBuilder<'tcx, 'a, R> {
let abi = match sig.abi {
rustc_target::spec::abi::Abi::Rust => rty::FunctionAbi::Rust,
rustc_target::spec::abi::Abi::RustCall => rty::FunctionAbi::RustCall,
_ => unimplemented!("unsupported function ABI: {:?}", sig.abi),
};
FunctionTemplateTypeBuilder {
inner: self.clone(),
registry,
Expand All @@ -198,6 +204,7 @@ impl<'tcx> TypeBuilder<'tcx> {
param_rtys: Default::default(),
param_refinement: None,
ret_rty: None,
abi,
}
}
}
Expand Down Expand Up @@ -282,6 +289,7 @@ where
unimplemented!("unsupported ADT: {:?}", ty);
}
}
mir_ty::TyKind::Closure(_, args) => self.build(args.as_closure().tupled_upvars_ty()),
kind => unimplemented!("ty: {:?}", kind),
}
}
Expand All @@ -301,9 +309,12 @@ where
where
I: IntoIterator<Item = (Local, mir_ty::TypeAndMut<'tcx>)>,
{
// this is necessary for local_def::Analyzer::elaborate_unused_args
let mut live_locals: Vec<_> = live_locals.into_iter().collect();
live_locals.sort_by_key(|(local, _)| *local);

let mut locals = IndexVec::<rty::FunctionParamIdx, _>::new();
let mut tys = Vec::new();
// TODO: avoid two iteration and assumption of FunctionParamIdx match between locals and ty
for (local, ty) in live_locals {
locals.push((local, ty.mutbl));
tys.push(ty);
Expand All @@ -316,6 +327,7 @@ where
param_rtys: Default::default(),
param_refinement: None,
ret_rty: None,
abi: Default::default(),
}
.build();
BasicBlockType { ty, locals }
Expand All @@ -331,6 +343,7 @@ pub struct FunctionTemplateTypeBuilder<'tcx, 'a, R> {
param_refinement: Option<rty::Refinement<rty::FunctionParamIdx>>,
param_rtys: HashMap<rty::FunctionParamIdx, rty::RefinedType<rty::FunctionParamIdx>>,
ret_rty: Option<rty::RefinedType<rty::FunctionParamIdx>>,
abi: rty::FunctionAbi,
}

impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> {
Expand Down Expand Up @@ -439,6 +452,6 @@ where
.with_scope(&builder)
.build_refined(self.ret_ty)
});
rty::FunctionType::new(param_rtys, ret_rty)
rty::FunctionType::new(param_rtys, ret_rty).with_abi(self.abi)
}
}
89 changes: 81 additions & 8 deletions src/rty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,36 @@ where
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum FunctionAbi {
#[default]
Rust,
RustCall,
}

impl std::fmt::Display for FunctionAbi {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.write_str(self.name())
}
}

impl FunctionAbi {
pub fn name(&self) -> &'static str {
match self {
FunctionAbi::Rust => "rust",
FunctionAbi::RustCall => "rust-call",
}
}

pub fn is_rust(&self) -> bool {
matches!(self, FunctionAbi::Rust)
}

pub fn is_rust_call(&self) -> bool {
matches!(self, FunctionAbi::RustCall)
}
}

/// A function type.
///
/// In Thrust, function types are closed. Because of that, function types, thus its parameters and
Expand All @@ -92,6 +122,7 @@ where
pub struct FunctionType {
pub params: IndexVec<FunctionParamIdx, RefinedType<FunctionParamIdx>>,
pub ret: Box<RefinedType<FunctionParamIdx>>,
pub abi: FunctionAbi,
}

impl<'a, 'b, D> Pretty<'a, D, termcolor::ColorSpec> for &'b FunctionType
Expand All @@ -100,15 +131,25 @@ where
D::Doc: Clone,
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, termcolor::ColorSpec> {
let abi = match self.abi {
FunctionAbi::Rust => allocator.nil(),
abi => allocator
.text("extern")
.append(allocator.space())
.append(allocator.as_string(abi))
.append(allocator.space()),
};
let separator = allocator.text(",").append(allocator.line());
allocator
.intersperse(self.params.iter().map(|ty| ty.pretty(allocator)), separator)
.parens()
.append(allocator.space())
.append(allocator.text("→"))
.append(allocator.line())
.append(self.ret.pretty(allocator))
.group()
abi.append(
allocator
.intersperse(self.params.iter().map(|ty| ty.pretty(allocator)), separator)
.parens(),
)
.append(allocator.space())
.append(allocator.text("→"))
.append(allocator.line())
.append(self.ret.pretty(allocator))
.group()
}
}

Expand All @@ -120,9 +161,15 @@ impl FunctionType {
FunctionType {
params,
ret: Box::new(ret),
abi: FunctionAbi::Rust,
}
}

pub fn with_abi(mut self, abi: FunctionAbi) -> Self {
self.abi = abi;
self
}

/// Because function types are always closed in Thrust, we can convert this into
/// [`Type<Closed>`].
pub fn into_closed_ty(self) -> Type<Closed> {
Expand Down Expand Up @@ -1304,6 +1351,32 @@ impl<FV> RefinedType<FV> {
RefinedType { ty, refinement }
}

/// Returns a dereferenced type of the immutable reference or owned pointer.
///
/// e.g. `{ v: Box<T> | φ } --> { v: T | φ[box v/v] }`
pub fn deref(self) -> Self {
let RefinedType {
ty,
refinement: outer_refinement,
} = self;
let inner_ty = ty.into_pointer().expect("invalid deref");
if inner_ty.is_mut() {
// losing info about proph
panic!("invalid deref");
}
let RefinedType {
ty: inner_ty,
refinement: mut inner_refinement,
} = *inner_ty.elem;
inner_refinement.push_conj(
outer_refinement.subst_value_var(|| chc::Term::var(RefinedTypeVar::Value).boxed()),
);
RefinedType {
ty: inner_ty,
refinement: inner_refinement,
}
}

pub fn subst_var<F, W>(self, mut f: F) -> RefinedType<W>
where
F: FnMut(FV) -> chc::Term<W>,
Expand Down
12 changes: 12 additions & 0 deletions tests/ui/fail/closure_mut.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
//@error-in-other-file: Unsat
//@compile-flags: -C debug-assertions=off

fn main() {
let mut x = 1;
let mut incr = |by: i32| {
x += by;
};
incr(5);
incr(5);
assert!(x == 10);
}
12 changes: 12 additions & 0 deletions tests/ui/fail/closure_mut_0.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
//@error-in-other-file: Unsat
//@compile-flags: -C debug-assertions=off

fn main() {
let mut x = 1;
x += 1;
let mut incr = || {
x += 1;
};
incr();
assert!(x == 2);
}
9 changes: 9 additions & 0 deletions tests/ui/fail/closure_no_capture.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
//@error-in-other-file: Unsat
//@compile-flags: -C debug-assertions=off

fn main() {
let incr = |x| {
x + 1
};
assert!(incr(2) == 2);
}
Loading