From 0268714c06accac4830f4b09d901259d3788a29f Mon Sep 17 00:00:00 2001 From: coord_e Date: Mon, 24 Nov 2025 13:03:30 +0900 Subject: [PATCH] Implement #[thrust::extern_spec_fn] --- src/analyze/annot.rs | 4 +++ src/analyze/crate_.rs | 12 ++++++- src/analyze/local_def.rs | 54 ++++++++++++++++++++++++++++++- tests/ui/fail/extern_spec_take.rs | 14 ++++++++ tests/ui/pass/extern_spec_take.rs | 15 +++++++++ 5 files changed, 97 insertions(+), 2 deletions(-) create mode 100644 tests/ui/fail/extern_spec_take.rs create mode 100644 tests/ui/pass/extern_spec_take.rs diff --git a/src/analyze/annot.rs b/src/analyze/annot.rs index 91dd209..2dbb9ea 100644 --- a/src/analyze/annot.rs +++ b/src/analyze/annot.rs @@ -33,6 +33,10 @@ pub fn callable_path() -> [Symbol; 2] { [Symbol::intern("thrust"), Symbol::intern("callable")] } +pub fn extern_spec_fn_path() -> [Symbol; 2] { + [Symbol::intern("thrust"), Symbol::intern("extern_spec_fn")] +} + /// A [`annot::Resolver`] implementation for resolving function parameters. /// /// The parameter names and their sorts needs to be configured via diff --git a/src/analyze/crate_.rs b/src/analyze/crate_.rs index 2a17b11..7e3d420 100644 --- a/src/analyze/crate_.rs +++ b/src/analyze/crate_.rs @@ -48,12 +48,22 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { self.trusted.insert(local_def_id.to_def_id()); } + if analyzer.is_annotated_as_extern_spec_fn() { + assert!(analyzer.is_fully_annotated()); + self.trusted.insert(local_def_id.to_def_id()); + } + use mir_ty::TypeVisitableExt as _; if sig.has_param() && !analyzer.is_fully_annotated() { self.ctx.register_deferred_def(local_def_id.to_def_id()); } else { let expected = analyzer.expected_ty(); - self.ctx.register_def(local_def_id.to_def_id(), expected); + let target_def_id = if analyzer.is_annotated_as_extern_spec_fn() { + analyzer.extern_spec_fn_target_def_id() + } else { + local_def_id.to_def_id() + }; + self.ctx.register_def(target_def_id, expected); } } diff --git a/src/analyze/local_def.rs b/src/analyze/local_def.rs index 50a4397..d556ef0 100644 --- a/src/analyze/local_def.rs +++ b/src/analyze/local_def.rs @@ -6,7 +6,7 @@ use rustc_index::bit_set::BitSet; use rustc_index::IndexVec; use rustc_middle::mir::{self, BasicBlock, Body, Local}; use rustc_middle::ty::{self as mir_ty, TyCtxt, TypeAndMut}; -use rustc_span::def_id::LocalDefId; +use rustc_span::def_id::{DefId, LocalDefId}; use rustc_span::symbol::Ident; use crate::analyze; @@ -126,6 +126,16 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { .is_some() } + pub fn is_annotated_as_extern_spec_fn(&self) -> bool { + self.tcx + .get_attrs_by_path( + self.local_def_id.to_def_id(), + &analyze::annot::extern_spec_fn_path(), + ) + .next() + .is_some() + } + // TODO: unify this logic with extraction functions above pub fn is_fully_annotated(&self) -> bool { let has_require = self @@ -240,6 +250,48 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { rty::RefinedType::unrefined(builder.build().into()) } + /// Extract the target DefId from `#[thrust::extern_spec_fn]` function. + pub fn extern_spec_fn_target_def_id(&self) -> DefId { + struct ExtractDefId<'tcx> { + tcx: TyCtxt<'tcx>, + outer_def_id: LocalDefId, + inner_def_id: Option, + } + + impl<'tcx> rustc_hir::intravisit::Visitor<'tcx> for ExtractDefId<'tcx> { + type NestedFilter = rustc_middle::hir::nested_filter::OnlyBodies; + + fn nested_visit_map(&mut self) -> Self::Map { + self.tcx.hir() + } + + fn visit_qpath( + &mut self, + qpath: &rustc_hir::QPath<'tcx>, + hir_id: rustc_hir::HirId, + _span: rustc_span::Span, + ) { + let typeck_result = self.tcx.typeck(self.outer_def_id); + if let rustc_hir::def::Res::Def(_, def_id) = typeck_result.qpath_res(qpath, hir_id) + { + assert!(self.inner_def_id.is_none(), "invalid extern_spec_fn"); + self.inner_def_id = Some(def_id); + } + } + } + + use rustc_hir::intravisit::Visitor as _; + let mut visitor = ExtractDefId { + tcx: self.tcx, + outer_def_id: self.local_def_id, + inner_def_id: None, + }; + if let rustc_hir::Node::Item(item) = self.tcx.hir_node_by_def_id(self.local_def_id) { + visitor.visit_item(item); + } + visitor.inner_def_id.expect("invalid extern_spec_fn") + } + fn is_mut_param(&self, param_idx: rty::FunctionParamIdx) -> bool { let param_local = analyze::local_of_function_param(param_idx); self.body.local_decls[param_local].mutability.is_mut() diff --git a/tests/ui/fail/extern_spec_take.rs b/tests/ui/fail/extern_spec_take.rs new file mode 100644 index 0000000..5687edd --- /dev/null +++ b/tests/ui/fail/extern_spec_take.rs @@ -0,0 +1,14 @@ +//@error-in-other-file: Unsat + +#[thrust::extern_spec_fn] +#[thrust::requires(true)] +#[thrust::ensures(result == *dest && ^dest == 0)] +fn _extern_spec_take(dest: &mut i32) -> i32 { + std::mem::take(dest) +} + +fn main() { + let mut x = 42; + let old = std::mem::take(&mut x); + assert!(x == 42); +} diff --git a/tests/ui/pass/extern_spec_take.rs b/tests/ui/pass/extern_spec_take.rs new file mode 100644 index 0000000..a72df6d --- /dev/null +++ b/tests/ui/pass/extern_spec_take.rs @@ -0,0 +1,15 @@ +//@check-pass + +#[thrust::extern_spec_fn] +#[thrust::requires(true)] +#[thrust::ensures(result == *dest && ^dest == 0)] +fn _extern_spec_take(dest: &mut i32) -> i32 { + std::mem::take(dest) +} + +fn main() { + let mut x = 42; + let old = std::mem::take(&mut x); + assert!(old == 42); + assert!(x == 0); +}