Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 85 additions & 23 deletions src/analyze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use std::collections::HashMap;
use std::rc::Rc;

use rustc_hir::lang_items::LangItem;
use rustc_index::IndexVec;
use rustc_middle::mir::{self, BasicBlock, Local};
use rustc_middle::ty::{self as mir_ty, TyCtxt};
use rustc_span::def_id::{DefId, LocalDefId};
Expand Down Expand Up @@ -114,6 +115,33 @@ enum DefTy<'tcx> {
Deferred(DeferredDefTy<'tcx>),
}

#[derive(Debug, Clone, Default)]
pub struct EnumDefs {
defs: HashMap<DefId, rty::EnumDatatypeDef>,
}

impl EnumDefs {
pub fn find_by_name(&self, name: &chc::DatatypeSymbol) -> Option<&rty::EnumDatatypeDef> {
self.defs.values().find(|def| &def.name == name)
}

pub fn get(&self, def_id: DefId) -> Option<&rty::EnumDatatypeDef> {
self.defs.get(&def_id)
}

pub fn insert(&mut self, def_id: DefId, def: rty::EnumDatatypeDef) {
self.defs.insert(def_id, def);
}
}

impl refine::EnumDefProvider for Rc<RefCell<EnumDefs>> {
fn enum_def(&self, name: &chc::DatatypeSymbol) -> rty::EnumDatatypeDef {
self.borrow().find_by_name(name).unwrap().clone()
}
}

pub type Env = refine::Env<Rc<RefCell<EnumDefs>>>;

#[derive(Clone)]
pub struct Analyzer<'tcx> {
tcx: TyCtxt<'tcx>,
Expand All @@ -131,7 +159,7 @@ pub struct Analyzer<'tcx> {
basic_blocks: HashMap<LocalDefId, HashMap<BasicBlock, BasicBlockType>>,
def_ids: did_cache::DefIdCache<'tcx>,

enum_defs: Rc<RefCell<HashMap<DefId, rty::EnumDatatypeDef>>>,
enum_defs: Rc<RefCell<EnumDefs>>,
}

impl<'tcx> crate::refine::TemplateRegistry for Analyzer<'tcx> {
Expand Down Expand Up @@ -174,7 +202,58 @@ impl<'tcx> Analyzer<'tcx> {
}
}

pub fn register_enum_def(&mut self, def_id: DefId, enum_def: rty::EnumDatatypeDef) {
fn build_enum_def(&self, def_id: DefId) -> rty::EnumDatatypeDef {
let adt = self.tcx.adt_def(def_id);

let name = refine::datatype_symbol(self.tcx, def_id);
let variants: IndexVec<_, _> = adt
.variants()
.iter()
.map(|variant| {
let name = refine::datatype_symbol(self.tcx, variant.def_id);
// TODO: consider using TyCtxt::tag_for_variant
let discr = resolve_discr(self.tcx, variant.discr);
let field_tys = variant
.fields
.iter()
.map(|field| {
let field_ty = self.tcx.type_of(field.did).instantiate_identity();
TypeBuilder::new(self.tcx, def_id).build(field_ty)
})
.collect();
rty::EnumVariantDef {
name,
discr,
field_tys,
}
})
.collect();

let generics = self.tcx.generics_of(def_id);
let ty_params = (0..generics.count())
.filter(|idx| {
matches!(
generics.param_at(*idx, self.tcx).kind,
mir_ty::GenericParamDefKind::Type { .. }
)
})
.count();
tracing::debug!(?def_id, ?name, ?ty_params, "ty_params count");

rty::EnumDatatypeDef {
name,
ty_params,
variants,
}
}

pub fn get_or_register_enum_def(&self, def_id: DefId) -> rty::EnumDatatypeDef {
let mut enum_defs = self.enum_defs.borrow_mut();
if let Some(enum_def) = enum_defs.get(def_id) {
return enum_def.clone();
}

let enum_def = self.build_enum_def(def_id);
tracing::debug!(def_id = ?def_id, enum_def = ?enum_def, "register_enum_def");
let ctors = enum_def
.variants
Expand All @@ -199,21 +278,10 @@ impl<'tcx> Analyzer<'tcx> {
params: enum_def.ty_params,
ctors,
};
self.enum_defs.borrow_mut().insert(def_id, enum_def);
enum_defs.insert(def_id, enum_def.clone());
self.system.borrow_mut().datatypes.push(datatype);
}

pub fn find_enum_variant(
&self,
ty_sym: &chc::DatatypeSymbol,
v_sym: &chc::DatatypeSymbol,
) -> Option<rty::EnumVariantDef> {
self.enum_defs
.borrow()
.iter()
.find(|(_, d)| &d.name == ty_sym)
.and_then(|(_, d)| d.variants.iter().find(|v| &v.name == v_sym))
.cloned()
enum_def
}

pub fn register_def(&mut self, def_id: DefId, rty: rty::RefinedType) {
Expand Down Expand Up @@ -304,14 +372,8 @@ impl<'tcx> Analyzer<'tcx> {
self.register_def(panic_def_id, rty::RefinedType::unrefined(panic_ty.into()));
}

pub fn new_env(&self) -> refine::Env {
let defs = self
.enum_defs
.borrow()
.values()
.map(|def| (def.name.clone(), def.clone()))
.collect();
refine::Env::new(defs)
pub fn new_env(&self) -> Env {
refine::Env::new(Rc::clone(&self.enum_defs))
}

pub fn crate_analyzer(&mut self) -> crate_::Analyzer<'tcx, '_> {
Expand Down
60 changes: 41 additions & 19 deletions src/analyze/basic_block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ use crate::analyze;
use crate::chc;
use crate::pretty::PrettyDisplayExt as _;
use crate::refine::{
self, Assumption, BasicBlockType, Env, PlaceType, PlaceTypeBuilder, PlaceTypeVar, TempVarIdx,
TypeBuilder, Var,
Assumption, BasicBlockType, PlaceType, PlaceTypeBuilder, PlaceTypeVar, TempVarIdx, TypeBuilder,
Var,
};
use crate::rty::{
self, ClauseBuilderExt as _, ClauseScope as _, ShiftExistential as _, Subtyping as _,
Expand All @@ -34,7 +34,7 @@ pub struct Analyzer<'tcx, 'ctx> {
body: Cow<'tcx, Body<'tcx>>,

type_builder: TypeBuilder<'tcx>,
env: Env,
env: analyze::Env,
local_decls: IndexVec<Local, mir::LocalDecl<'tcx>>,
// TODO: remove this
prophecy_vars: HashMap<usize, TempVarIdx>,
Expand Down Expand Up @@ -350,16 +350,12 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
.map(|operand| self.operand_type(operand).boxed())
.collect();
match *kind {
mir::AggregateKind::Adt(did, variant_id, args, _, _)
mir::AggregateKind::Adt(did, variant_idx, args, _, _)
if self.tcx.def_kind(did) == DefKind::Enum =>
{
let adt = self.tcx.adt_def(did);
let ty_sym = refine::datatype_symbol(self.tcx, did);
let variant = adt.variant(variant_id);
let v_sym = refine::datatype_symbol(self.tcx, variant.def_id);

let enum_variant_def = self.ctx.find_enum_variant(&ty_sym, &v_sym).unwrap();
let variant_rtys = enum_variant_def
let enum_def = self.ctx.get_or_register_enum_def(did);
let variant_def = &enum_def.variants[variant_idx];
let variant_rtys = variant_def
.field_tys
.clone()
.into_iter()
Expand All @@ -386,7 +382,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {

let sort_args: Vec<_> =
rty_args.iter().map(|rty| rty.ty.to_sort()).collect();
let ty = rty::EnumType::new(ty_sym.clone(), rty_args).into();
let ty = rty::EnumType::new(enum_def.name.clone(), rty_args).into();

let mut builder = PlaceTypeBuilder::default();
let mut field_terms = Vec::new();
Expand All @@ -396,7 +392,12 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
}
builder.build(
ty,
chc::Term::datatype_ctor(ty_sym, sort_args, v_sym, field_terms),
chc::Term::datatype_ctor(
enum_def.name,
sort_args,
variant_def.name.clone(),
field_terms,
),
)
}
_ => PlaceType::tuple(field_tys),
Expand Down Expand Up @@ -924,6 +925,31 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
}
}
}

fn register_enum_defs(&mut self) {
for local_decl in &self.local_decls {
use mir_ty::{TypeSuperVisitable as _, TypeVisitable as _};
#[derive(Default)]
struct EnumCollector {
enums: std::collections::HashSet<DefId>,
}
impl<'tcx> mir_ty::TypeVisitor<mir_ty::TyCtxt<'tcx>> for EnumCollector {
fn visit_ty(&mut self, ty: mir_ty::Ty<'tcx>) {
if let mir_ty::TyKind::Adt(adt_def, _) = ty.kind() {
if adt_def.is_enum() {
self.enums.insert(adt_def.did());
}
}
ty.super_visit_with(self);
}
}
let mut visitor = EnumCollector::default();
local_decl.ty.visit_with(&mut visitor);
for def_id in visitor.enums {
self.ctx.get_or_register_enum_def(def_id);
}
}
}
}

/// Turns [`rty::RefinedType<Var>`] into [`rty::RefinedType<T>`].
Expand Down Expand Up @@ -967,7 +993,7 @@ impl<T> UnbindAtoms<T> {
self.existentials.extend(var_ty.existentials);
}

pub fn unbind(mut self, env: &Env, ty: rty::RefinedType<Var>) -> rty::RefinedType<T> {
pub fn unbind(mut self, env: &analyze::Env, ty: rty::RefinedType<Var>) -> rty::RefinedType<T> {
let rty::RefinedType {
ty: src_ty,
refinement,
Expand Down Expand Up @@ -1136,14 +1162,10 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
self
}

pub fn env(&mut self, env: Env) -> &mut Self {
self.env = env;
self
}

pub fn run(&mut self, expected: &BasicBlockType) {
let span = tracing::info_span!("bb", bb = ?self.basic_block);
let _guard = span.enter();
self.register_enum_defs();

let params = expected.as_ref().params.clone();
self.bind_locals(&params);
Expand Down
55 changes: 0 additions & 55 deletions src/analyze/crate_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,11 @@

use std::collections::HashSet;

use rustc_hir::def::DefKind;
use rustc_index::IndexVec;
use rustc_middle::ty::{self as mir_ty, TyCtxt};
use rustc_span::def_id::{DefId, LocalDefId};

use crate::analyze;
use crate::chc;
use crate::refine::{self, TypeBuilder};
use crate::rty::{self, ClauseBuilderExt as _};

/// An implementation of local crate analysis.
Expand Down Expand Up @@ -167,57 +164,6 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
}
}
}

fn register_enum_defs(&mut self) {
for local_def_id in self.tcx.iter_local_def_id() {
let DefKind::Enum = self.tcx.def_kind(local_def_id) else {
continue;
};
let adt = self.tcx.adt_def(local_def_id);

let name = refine::datatype_symbol(self.tcx, local_def_id.to_def_id());
let variants: IndexVec<_, _> = adt
.variants()
.iter()
.map(|variant| {
let name = refine::datatype_symbol(self.tcx, variant.def_id);
// TODO: consider using TyCtxt::tag_for_variant
let discr = analyze::resolve_discr(self.tcx, variant.discr);
let field_tys = variant
.fields
.iter()
.map(|field| {
let field_ty = self.tcx.type_of(field.did).instantiate_identity();
TypeBuilder::new(self.tcx, local_def_id.to_def_id()).build(field_ty)
})
.collect();
rty::EnumVariantDef {
name,
discr,
field_tys,
}
})
.collect();

let generics = self.tcx.generics_of(local_def_id);
let ty_params = (0..generics.count())
.filter(|idx| {
matches!(
generics.param_at(*idx, self.tcx).kind,
mir_ty::GenericParamDefKind::Type { .. }
)
})
.count();
tracing::debug!(?local_def_id, ?name, ?ty_params, "ty_params count");

let def = rty::EnumDatatypeDef {
name,
ty_params,
variants,
};
self.ctx.register_enum_def(local_def_id.to_def_id(), def);
}
}
}

impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
Expand All @@ -231,7 +177,6 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
let span = tracing::debug_span!("crate", krate = %self.tcx.crate_name(rustc_span::def_id::LOCAL_CRATE));
let _guard = span.enter();

self.register_enum_defs();
self.refine_local_defs();
self.analyze_local_defs();
self.assert_callable_entry();
Expand Down
29 changes: 26 additions & 3 deletions src/chc/unbox.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,42 @@ fn unbox_term(term: Term) -> Term {
Term::App(fun, args) => Term::App(fun, args.into_iter().map(unbox_term).collect()),
Term::Tuple(ts) => Term::Tuple(ts.into_iter().map(unbox_term).collect()),
Term::TupleProj(t, i) => Term::TupleProj(Box::new(unbox_term(*t)), i),
Term::DatatypeCtor(s1, s2, args) => {
Term::DatatypeCtor(s1, s2, args.into_iter().map(unbox_term).collect())
}
Term::DatatypeCtor(s1, s2, args) => Term::DatatypeCtor(
unbox_datatype_sort(s1),
s2,
args.into_iter().map(unbox_term).collect(),
),
Term::DatatypeDiscr(sym, arg) => Term::DatatypeDiscr(sym, Box::new(unbox_term(*arg))),
Term::FormulaExistentialVar(sort, name) => {
Term::FormulaExistentialVar(unbox_sort(sort), name)
}
}
}

fn unbox_matcher_pred(pred: MatcherPred) -> Pred {
let MatcherPred {
datatype_symbol,
datatype_args,
} = pred;
let datatype_args = datatype_args.into_iter().map(unbox_sort).collect();
Pred::Matcher(MatcherPred {
datatype_symbol,
datatype_args,
})
}

fn unbox_pred(pred: Pred) -> Pred {
match pred {
Pred::Known(pred) => Pred::Known(pred),
Pred::Var(pred) => Pred::Var(pred),
Pred::Matcher(pred) => unbox_matcher_pred(pred),
}
}

fn unbox_atom(atom: Atom) -> Atom {
let Atom { guard, pred, args } = atom;
let guard = guard.map(|fo| Box::new(unbox_formula(*fo)));
let pred = unbox_pred(pred);
let args = args.into_iter().map(unbox_term).collect();
Atom { guard, pred, args }
}
Expand Down
Loading