Skip to content

Commit 018674c

Browse files
authored
Merge pull request #21333 from hvitved/rust/type-inference-restrict-receiver-type-propagation
Rust: Restrict type propagation into receivers
2 parents 266130b + f9869da commit 018674c

File tree

5 files changed

+569
-403
lines changed

5 files changed

+569
-403
lines changed

rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll

Lines changed: 78 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -778,13 +778,6 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
778778
prefix1 = TypePath::singleton(getArrayTypeParameter()) and
779779
prefix2.isEmpty()
780780
or
781-
exists(Struct s |
782-
n2 = [n1.(RangeExpr).getStart(), n1.(RangeExpr).getEnd()] and
783-
prefix1 = TypePath::singleton(TTypeParamTypeParameter(s.getGenericParamList().getATypeParam())) and
784-
prefix2.isEmpty() and
785-
s = getRangeType(n1)
786-
)
787-
or
788781
exists(ClosureExpr ce, int index |
789782
n1 = ce and
790783
n2 = ce.getParam(index).getPat() and
@@ -829,6 +822,12 @@ private predicate lubCoercion(AstNode parent, AstNode child, TypePath prefix) {
829822
bodyReturns(parent, child) and
830823
strictcount(Expr e | bodyReturns(parent, e)) > 1 and
831824
prefix.isEmpty()
825+
or
826+
exists(Struct s |
827+
child = [parent.(RangeExpr).getStart(), parent.(RangeExpr).getEnd()] and
828+
prefix = TypePath::singleton(TTypeParamTypeParameter(s.getGenericParamList().getATypeParam())) and
829+
s = getRangeType(parent)
830+
)
832831
}
833832

834833
/**
@@ -1031,10 +1030,10 @@ private module StructExprMatchingInput implements MatchingInputSig {
10311030
private module StructExprMatching = Matching<StructExprMatchingInput>;
10321031

10331032
pragma[nomagic]
1034-
private Type inferStructExprType0(AstNode n, boolean isReturn, TypePath path) {
1033+
private Type inferStructExprType0(AstNode n, FunctionPosition pos, TypePath path) {
10351034
exists(StructExprMatchingInput::Access a, StructExprMatchingInput::AccessPosition apos |
10361035
n = a.getNodeAt(apos) and
1037-
if apos.isStructPos() then isReturn = true else isReturn = false
1036+
if apos.isStructPos() then pos.isReturn() else pos.asPosition() = 0 // the actual position doesn't matter, as long as it is positional
10381037
|
10391038
result = StructExprMatching::inferAccessType(a, apos, path)
10401039
or
@@ -1113,6 +1112,25 @@ private Trait getCallExprTraitQualifier(CallExpr ce) {
11131112
* Provides functionality related to context-based typing of calls.
11141113
*/
11151114
private module ContextTyping {
1115+
/**
1116+
* Holds if `f` mentions type parameter `tp` at some non-return position,
1117+
* possibly via a constraint on another mentioned type parameter.
1118+
*/
1119+
pragma[nomagic]
1120+
private predicate assocFunctionMentionsTypeParameterAtNonRetPos(
1121+
ImplOrTraitItemNode i, Function f, TypeParameter tp
1122+
) {
1123+
exists(FunctionPosition nonRetPos |
1124+
not nonRetPos.isReturn() and
1125+
tp = getAssocFunctionTypeAt(f, i, nonRetPos, _)
1126+
)
1127+
or
1128+
exists(TypeParameter mid |
1129+
assocFunctionMentionsTypeParameterAtNonRetPos(i, f, mid) and
1130+
tp = getATypeParameterConstraint(mid, _)
1131+
)
1132+
}
1133+
11161134
/**
11171135
* Holds if the return type of the function `f` inside `i` at `path` is type
11181136
* parameter `tp`, and `tp` does not appear in the type of any parameter of
@@ -1129,12 +1147,7 @@ private module ContextTyping {
11291147
) {
11301148
pos.isReturn() and
11311149
tp = getAssocFunctionTypeAt(f, i, pos, path) and
1132-
not exists(FunctionPosition nonResPos | not nonResPos.isReturn() |
1133-
tp = getAssocFunctionTypeAt(f, i, nonResPos, _)
1134-
or
1135-
// `Self` types in traits implicitly mention all type parameters of the trait
1136-
getAssocFunctionTypeAt(f, i, nonResPos, _) = TSelfTypeParameter(i)
1137-
)
1150+
not assocFunctionMentionsTypeParameterAtNonRetPos(i, f, tp)
11381151
}
11391152

11401153
/**
@@ -1184,7 +1197,7 @@ private module ContextTyping {
11841197
pragma[nomagic]
11851198
private predicate hasUnknownType(AstNode n) { hasUnknownTypeAt(n, _) }
11861199

1187-
signature Type inferCallTypeSig(AstNode n, boolean isReturn, TypePath path);
1200+
signature Type inferCallTypeSig(AstNode n, FunctionPosition pos, TypePath path);
11881201

11891202
/**
11901203
* Given a predicate `inferCallType` for inferring the type of a call at a given
@@ -1194,19 +1207,31 @@ private module ContextTyping {
11941207
*/
11951208
module CheckContextTyping<inferCallTypeSig/3 inferCallType> {
11961209
pragma[nomagic]
1197-
private Type inferCallTypeFromContextCand(AstNode n, TypePath prefix, TypePath path) {
1198-
result = inferCallType(n, false, path) and
1210+
private Type inferCallNonReturnType(AstNode n, FunctionPosition pos, TypePath path) {
1211+
result = inferCallType(n, pos, path) and
1212+
not pos.isReturn()
1213+
}
1214+
1215+
pragma[nomagic]
1216+
private Type inferCallNonReturnType(
1217+
AstNode n, FunctionPosition pos, TypePath prefix, TypePath path
1218+
) {
1219+
result = inferCallNonReturnType(n, pos, path) and
11991220
hasUnknownType(n) and
12001221
prefix = path.getAPrefix()
12011222
}
12021223

12031224
pragma[nomagic]
12041225
Type check(AstNode n, TypePath path) {
1205-
result = inferCallType(n, true, path)
1226+
result = inferCallType(n, any(FunctionPosition pos | pos.isReturn()), path)
12061227
or
1207-
exists(TypePath prefix |
1208-
result = inferCallTypeFromContextCand(n, prefix, path) and
1228+
exists(FunctionPosition pos, TypePath prefix |
1229+
result = inferCallNonReturnType(n, pos, prefix, path) and
12091230
hasUnknownTypeAt(n, prefix)
1231+
|
1232+
// Never propagate type information directly into the receiver, since its type
1233+
// must already have been known in order to resolve the call
1234+
if pos.isSelf() then not prefix.isEmpty() else any()
12101235
)
12111236
}
12121237
}
@@ -2607,12 +2632,9 @@ private Type inferMethodCallType0(
26072632
}
26082633

26092634
pragma[nomagic]
2610-
private Type inferMethodCallTypeNonSelf(AstNode n, boolean isReturn, TypePath path) {
2611-
exists(MethodCallMatchingInput::AccessPosition apos |
2612-
result = inferMethodCallType0(_, apos, n, _, path) and
2613-
not apos.isSelf() and
2614-
if apos.isReturn() then isReturn = true else isReturn = false
2615-
)
2635+
private Type inferMethodCallTypeNonSelf(AstNode n, FunctionPosition pos, TypePath path) {
2636+
result = inferMethodCallType0(_, pos, n, _, path) and
2637+
not pos.isSelf()
26162638
}
26172639

26182640
/**
@@ -2623,12 +2645,12 @@ private Type inferMethodCallTypeNonSelf(AstNode n, boolean isReturn, TypePath pa
26232645
* empty, at which point the inferred type can be applied back to `n`.
26242646
*/
26252647
pragma[nomagic]
2626-
private Type inferMethodCallTypeSelf(AstNode n, DerefChain derefChain, TypePath path) {
2648+
private Type inferMethodCallTypeSelf(MethodCall mc, AstNode n, DerefChain derefChain, TypePath path) {
26272649
exists(
26282650
MethodCallMatchingInput::AccessPosition apos, string derefChainBorrow, BorrowKind borrow,
26292651
TypePath path0
26302652
|
2631-
result = inferMethodCallType0(_, apos, n, derefChainBorrow, path0) and
2653+
result = inferMethodCallType0(mc, apos, n, derefChainBorrow, path0) and
26322654
apos.isSelf() and
26332655
MethodCallMatchingInput::decodeDerefChainBorrow(derefChainBorrow, derefChain, borrow)
26342656
|
@@ -2647,7 +2669,7 @@ private Type inferMethodCallTypeSelf(AstNode n, DerefChain derefChain, TypePath
26472669
DerefChain derefChain0, Type t0, TypePath path0, DerefImplItemNode impl, Type selfParamType,
26482670
TypePath selfPath
26492671
|
2650-
t0 = inferMethodCallTypeSelf(n, derefChain0, path0) and
2672+
t0 = inferMethodCallTypeSelf(mc, n, derefChain0, path0) and
26512673
derefChain0.isCons(impl, derefChain) and
26522674
selfParamType = impl.resolveSelfTypeAt(selfPath)
26532675
|
@@ -2664,11 +2686,13 @@ private Type inferMethodCallTypeSelf(AstNode n, DerefChain derefChain, TypePath
26642686
)
26652687
}
26662688

2667-
private Type inferMethodCallTypePreCheck(AstNode n, boolean isReturn, TypePath path) {
2668-
result = inferMethodCallTypeNonSelf(n, isReturn, path)
2689+
private Type inferMethodCallTypePreCheck(AstNode n, FunctionPosition pos, TypePath path) {
2690+
result = inferMethodCallTypeNonSelf(n, pos, path)
26692691
or
2670-
result = inferMethodCallTypeSelf(n, DerefChain::nil(), path) and
2671-
isReturn = false
2692+
exists(MethodCall mc |
2693+
result = inferMethodCallTypeSelf(mc, n, DerefChain::nil(), path) and
2694+
if mc instanceof CallExpr then pos.asPosition() = 0 else pos.isSelf()
2695+
)
26722696
}
26732697

26742698
/**
@@ -3301,14 +3325,11 @@ private module NonMethodCallMatchingInput implements MatchingInputSig {
33013325
private module NonMethodCallMatching = Matching<NonMethodCallMatchingInput>;
33023326

33033327
pragma[nomagic]
3304-
private Type inferNonMethodCallType0(AstNode n, boolean isReturn, TypePath path) {
3305-
exists(NonMethodCallMatchingInput::Access a, NonMethodCallMatchingInput::AccessPosition apos |
3306-
n = a.getNodeAt(apos) and
3307-
if apos.isReturn() then isReturn = true else isReturn = false
3308-
|
3309-
result = NonMethodCallMatching::inferAccessType(a, apos, path)
3328+
private Type inferNonMethodCallType0(AstNode n, FunctionPosition pos, TypePath path) {
3329+
exists(NonMethodCallMatchingInput::Access a | n = a.getNodeAt(pos) |
3330+
result = NonMethodCallMatching::inferAccessType(a, pos, path)
33103331
or
3311-
a.hasUnknownTypeAt(apos, path) and
3332+
a.hasUnknownTypeAt(pos, path) and
33123333
result = TUnknownType()
33133334
)
33143335
}
@@ -3379,11 +3400,10 @@ private module OperationMatchingInput implements MatchingInputSig {
33793400
private module OperationMatching = Matching<OperationMatchingInput>;
33803401

33813402
pragma[nomagic]
3382-
private Type inferOperationType0(AstNode n, boolean isReturn, TypePath path) {
3383-
exists(OperationMatchingInput::Access a, OperationMatchingInput::AccessPosition apos |
3384-
n = a.getNodeAt(apos) and
3385-
result = OperationMatching::inferAccessType(a, apos, path) and
3386-
if apos.isReturn() then isReturn = true else isReturn = false
3403+
private Type inferOperationType0(AstNode n, FunctionPosition pos, TypePath path) {
3404+
exists(OperationMatchingInput::Access a |
3405+
n = a.getNodeAt(pos) and
3406+
result = OperationMatching::inferAccessType(a, pos, path)
33873407
)
33883408
}
33893409

@@ -3716,11 +3736,13 @@ private module AwaitSatisfiesConstraintInput implements SatisfiesConstraintInput
37163736
}
37173737
}
37183738

3739+
private module AwaitSatisfiesConstraint =
3740+
SatisfiesConstraint<AwaitTarget, AwaitSatisfiesConstraintInput>;
3741+
37193742
pragma[nomagic]
37203743
private Type inferAwaitExprType(AstNode n, TypePath path) {
37213744
exists(TypePath exprPath |
3722-
SatisfiesConstraint<AwaitTarget, AwaitSatisfiesConstraintInput>::satisfiesConstraintType(n.(AwaitExpr)
3723-
.getExpr(), _, exprPath, result) and
3745+
AwaitSatisfiesConstraint::satisfiesConstraintType(n.(AwaitExpr).getExpr(), _, exprPath, result) and
37243746
exprPath.isCons(getFutureOutputTypeParameter(), path)
37253747
)
37263748
}
@@ -3922,13 +3944,15 @@ private AssociatedTypeTypeParameter getIntoIteratorItemTypeParameter() {
39223944
result = getAssociatedTypeTypeParameter(any(IntoIteratorTrait t).getItemType())
39233945
}
39243946

3947+
private module ForIterableSatisfiesConstraint =
3948+
SatisfiesConstraint<ForIterableExpr, ForIterableSatisfiesConstraintInput>;
3949+
39253950
pragma[nomagic]
39263951
private Type inferForLoopExprType(AstNode n, TypePath path) {
39273952
// type of iterable -> type of pattern (loop variable)
39283953
exists(ForExpr fe, TypePath exprPath, AssociatedTypeTypeParameter tp |
39293954
n = fe.getPat() and
3930-
SatisfiesConstraint<ForIterableExpr, ForIterableSatisfiesConstraintInput>::satisfiesConstraintType(fe.getIterable(),
3931-
_, exprPath, result) and
3955+
ForIterableSatisfiesConstraint::satisfiesConstraintType(fe.getIterable(), _, exprPath, result) and
39323956
exprPath.isCons(tp, path)
39333957
|
39343958
tp = getIntoIteratorItemTypeParameter()
@@ -3963,10 +3987,12 @@ private module InvokedClosureSatisfiesConstraintInput implements
39633987
}
39643988
}
39653989

3990+
private module InvokedClosureSatisfiesConstraint =
3991+
SatisfiesConstraint<InvokedClosureExpr, InvokedClosureSatisfiesConstraintInput>;
3992+
39663993
/** Gets the type of `ce` when viewed as an implementation of `FnOnce`. */
39673994
private Type invokedClosureFnTypeAt(InvokedClosureExpr ce, TypePath path) {
3968-
SatisfiesConstraint<InvokedClosureExpr, InvokedClosureSatisfiesConstraintInput>::satisfiesConstraintType(ce,
3969-
_, path, result)
3995+
InvokedClosureSatisfiesConstraint::satisfiesConstraintType(ce, _, path, result)
39703996
}
39713997

39723998
/**
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
multipleResolvedTargets
22
| main.rs:2223:9:2223:31 | ... .my_add(...) |
33
| main.rs:2225:9:2225:29 | ... .my_add(...) |
4-
| main.rs:2723:13:2723:17 | x.f() |
4+
| main.rs:2733:13:2733:17 | x.f() |

rust/ql/test/library-tests/type-inference/main.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2636,6 +2636,13 @@ mod block_types {
26362636
}
26372637

26382638
mod context_typed {
2639+
#[derive(Default)]
2640+
struct S;
2641+
2642+
impl S {
2643+
fn f(self) {}
2644+
}
2645+
26392646
pub fn f() {
26402647
let x = None; // $ type=x:T.i32
26412648
let x: Option<i32> = x;
@@ -2683,6 +2690,9 @@ mod context_typed {
26832690

26842691
let y = Default::default(); // $ type=y:i32 target=default
26852692
x.push(y); // $ target=push
2693+
2694+
let s = Default::default(); // $ target=default type=s:S
2695+
S::f(s); // $ target=f
26862696
}
26872697
}
26882698

@@ -2740,6 +2750,7 @@ mod blanket_impl;
27402750
mod closure;
27412751
mod dereference;
27422752
mod dyn_type;
2753+
mod regressions;
27432754

27442755
fn main() {
27452756
field_access::f(); // $ target=f
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
mod regression1 {
2+
3+
pub struct S<T>(T);
4+
5+
pub enum E {
6+
V { vec: Vec<E> },
7+
}
8+
9+
impl<T> From<S<T>> for Option<T> {
10+
fn from(s: S<T>) -> Self {
11+
Some(s.0) // $ fieldof=S
12+
}
13+
}
14+
15+
pub fn f() -> E {
16+
let mut vec_e = Vec::new(); // $ target=new
17+
let mut opt_e = None;
18+
19+
let e = E::V { vec: Vec::new() }; // $ target=new
20+
21+
if let Some(e) = opt_e {
22+
vec_e.push(e); // $ target=push
23+
}
24+
opt_e = e.into(); // $ target=into
25+
26+
#[rustfmt::skip]
27+
let _ = if let Some(last) = vec_e.pop() // $ target=pop
28+
{
29+
opt_e = last.into(); // $ target=into
30+
};
31+
32+
opt_e.unwrap() // $ target=unwrap
33+
}
34+
}

0 commit comments

Comments
 (0)