Skip to content

Rust: Support blanket implementations #20133

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions rust/ql/lib/codeql/rust/internal/PathResolution.qll
Original file line number Diff line number Diff line change
Expand Up @@ -933,15 +933,36 @@ class TypeParamItemNode extends TypeItemNode instanceof TypeParam {
}

pragma[nomagic]
Path getABoundPath() {
exists(TypeBoundList tbl | result = tbl.getABound().getTypeRepr().(PathTypeRepr).getPath() |
tbl = super.getTypeBoundList()
TypeBound getBound(int index) {
result = super.getTypeBoundList().getBound(index)
or
exists(int offset |
offset = super.getTypeBoundList().getNumberOfBounds()
or
tbl = this.getAWherePred().getTypeBoundList()
not super.hasTypeBoundList() and
offset = 0
|
result = this.getAWherePred().getTypeBoundList().getBound(index - offset)
)
}

pragma[nomagic]
Path getBoundPath(int index) {
result = this.getBound(index).getTypeRepr().(PathTypeRepr).getPath()
}

Path getABoundPath() { result = this.getBoundPath(_) }

pragma[nomagic]
ItemNode resolveBound(int index) {
result =
rank[index + 1](int i, ItemNode item |
item = resolvePath(this.getBoundPath(i))
|
item order by i
)
}

ItemNode resolveABound() { result = resolvePath(this.getABoundPath()) }

/**
Expand Down
142 changes: 142 additions & 0 deletions rust/ql/lib/codeql/rust/internal/TypeInference.qll
Original file line number Diff line number Diff line change
Expand Up @@ -1518,10 +1518,10 @@
// Propagate the function's return type to the call expression
exists(TypePath path0 | result = invokedClosureFnTypeAt(ce, path0) |
n = ce.getCall() and
path = path0.stripPrefix(fnReturnPath())
or
// Propagate the function's parameter type to the arguments
exists(int index |

Check warning

Code scanning / CodeQL

Misspelling Warning

This comment contains the common misspelling 'reciever', which should instead be 'receiver'.
n = ce.getCall().getArgList().getArg(index) and
path = path0.stripPrefix(fnParameterPath(ce.getCall().getNumberOfArgs(), index))
)
Expand Down Expand Up @@ -1643,6 +1643,10 @@
methodCandidate(type, name, arity, impl)
}

/**
* Holds if `mc` has `rootType` as the root type of the reciever and the target
* method is named `name` and has arity `arity`
*/
pragma[nomagic]
private predicate isMethodCall(MethodCall mc, Type rootType, string name, int arity) {
rootType = mc.getTypeAt(TypePath::nil()) and
Expand Down Expand Up @@ -1841,6 +1845,142 @@
else any()
}

private module BlanketImplementation {
/**
* Holds if `impl` is a blanket implementation, that is, an implementation of a
* trait for a type parameter.
*/
private TypeParamItemNode getBlanketImplementationTypeParam(Impl impl) {
result = impl.(ImplItemNode).resolveSelfTy() and
result = impl.getGenericParamList().getAGenericParam() and
not exists(impl.getAttributeMacroExpansion())
}

predicate isBlanketImplementation(Impl impl) { exists(getBlanketImplementationTypeParam(impl)) }

private Impl getPotentialDuplicated(string fileName, string traitName, int arity, string tpName) {
tpName = getBlanketImplementationTypeParam(result).getName() and
fileName = result.getLocation().getFile().getBaseName() and
traitName = result.(ImplItemNode).resolveTraitTy().getName() and
arity = result.(ImplItemNode).resolveTraitTy().(Trait).getNumberOfGenericParams()
}

/**
* Holds if `impl1` and `impl2` are duplicates and `impl2` is more "canonical"
* than `impl1`.
*/
predicate duplicatedImpl(Impl impl1, Impl impl2) {
exists(string fileName, string traitName, int arity, string tpName |
impl1 = getPotentialDuplicated(fileName, traitName, arity, tpName) and
impl2 = getPotentialDuplicated(fileName, traitName, arity, tpName) and
impl1.getLocation().getFile().getAbsolutePath() <
impl2.getLocation().getFile().getAbsolutePath()
)
}

predicate hasNoDuplicates(Impl impl) {
not duplicatedImpl(impl, _) and isBlanketImplementation(impl)
}

/**
* We currently consider blanket implementations to be in scope "globally",
* even though they actually need to be imported to be used. One downside of
* this is that the libraries included in the database can often occur several
* times for different library versions. This causes the same blanket
* implementations to exist multiple times, and these add no useful
* information.
*
* We detect these duplicates based on some files heuristic (same trait name,
* file name, etc.). For these duplicates we select the one with the greatest
* file name (which usually is also the one with the greatest library version
* in the path)
*/
Impl getCanonicalImpl(Impl impl) {
result =
max(Impl impl0, Location l |
duplicatedImpl(impl, impl0) and l = impl0.getLocation()
|
impl0 order by l.getFile().getAbsolutePath(), l.getStartLine()
)
or
hasNoDuplicates(impl) and result = impl
}

predicate isCanonicalBlanketImplementation(Impl impl) { impl = getCanonicalImpl(impl) }

/**
* Holds if `impl` is a blanket implementation for a type parameter and the type
* parameter must implement `trait`.
*/
private predicate blanketImplementationTraitBound(Impl impl, Trait t) {
t =
min(Trait trait, int i |
trait = getBlanketImplementationTypeParam(impl).resolveBound(i) and
// Exclude traits that are "trivial" in the sense that they are known to
// not narrow things down very much.
not trait.getName().getText() =
[
"Sized", "Clone", "Fn", "FnOnce", "FnMut",
// The auto traits
"Send", "Sync", "Unpin", "UnwindSafe", "RefUnwindSafe"
]
|
trait order by i
)
}

private predicate blanketImplementationMethod(
Impl impl, Trait trait, string name, int arity, Function f
) {
isCanonicalBlanketImplementation(impl) and
blanketImplementationTraitBound(impl, trait) and
f.getParamList().hasSelfParam() and
arity = f.getParamList().getNumberOfParams() and
(
f = impl.(ImplItemNode).getAssocItem(name)
or
// If the the trait has a method with a default implementation, then that

Check warning

Code scanning / CodeQL

Comment has repeated word Warning

The comment repeats the.
// target is interesting as well.
not exists(impl.(ImplItemNode).getAssocItem(name)) and
f = impl.(ImplItemNode).resolveTraitTy().getAssocItem(name)
) and
// If the method is already available through one of the trait bounds on the
// type parameter (because they share a common trait ancestor) then ignore
// it.
not getBlanketImplementationTypeParam(impl).resolveABound().(TraitItemNode).getASuccessor(name) =
f
}

predicate methodCallMatchesBlanketImpl(MethodCall mc, Type t, Impl impl, Trait trait, Function f) {
// Only check method calls where we have ruled out inherent method targets.
// Ideally we would also check if non-blanket method targets have been ruled
// out.
methodCallHasNoInherentTarget(mc) and
exists(string name, int arity |
isMethodCall(mc, t, name, arity) and
blanketImplementationMethod(impl, trait, name, arity, f)
)
}

module SatisfiesConstraintInput implements SatisfiesConstraintInputSig<MethodCall> {
pragma[nomagic]
predicate relevantConstraint(MethodCall mc, Type constraint) {
methodCallMatchesBlanketImpl(mc, _, _, constraint.(TraitType).getTrait(), _)
}

predicate useUniversalConditions() { none() }
}

predicate hasBlanketImpl(MethodCall mc, Type t, Impl impl, Trait trait, Function f) {
SatisfiesConstraint<MethodCall, SatisfiesConstraintInput>::satisfiesConstraintType(mc,
TTrait(trait), _, _) and
methodCallMatchesBlanketImpl(mc, t, impl, trait, f)
}

pragma[nomagic]
Function getMethodFromBlanketImpl(MethodCall mc) { hasBlanketImpl(mc, _, _, _, result) }
}

/** Gets a method from an `impl` block that matches the method call `mc`. */
pragma[nomagic]
private Function getMethodFromImpl(MethodCall mc) {
Expand Down Expand Up @@ -1876,6 +2016,8 @@
// The method comes from an `impl` block targeting the type of the receiver.
result = getMethodFromImpl(mc)
or
result = BlanketImplementation::getMethodFromBlanketImpl(mc)
or
// The type of the receiver is a type parameter and the method comes from a
// trait bound on the type parameter.
result = getTypeParameterMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
Expand Down
123 changes: 123 additions & 0 deletions rust/ql/test/library-tests/type-inference/blanket_impl.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
// Tests for method resolution targeting blanket trait implementations

mod basic_blanket_impl {
#[derive(Debug, Copy, Clone)]
struct S1;

trait Clone1 {
fn clone1(&self) -> Self;
}

trait Duplicatable {
fn duplicate(&self) -> Self
where
Self: Sized;
}

impl Clone1 for S1 {
// S1::clone1
fn clone1(&self) -> Self {
*self // $ target=deref
}
}

// Blanket implementation for all types that implement Display and Clone
impl<T: Clone1> Duplicatable for T {
// Clone1duplicate
fn duplicate(&self) -> Self {
self.clone1() // $ target=clone1
}
}

pub fn test_basic_blanket() {
let x = S1.clone1(); // $ target=S1::clone1
println!("{x:?}");
let y = S1.duplicate(); // $ target=Clone1duplicate
println!("{y:?}");
}
}

mod extension_trait_blanket_impl {
// 1. Elements a trait that is implemented for a type parameter
// 2. An extension trait
// 3. A blanket implementation of the extension trait for a type parameter

trait Flag {
fn read_flag(&self) -> bool;
}

trait TryFlag {
fn try_read_flag(&self) -> Option<bool>;
}

impl<Fl> TryFlag for Fl
where
Fl: Flag,
{
fn try_read_flag(&self) -> Option<bool> {
Some(self.read_flag()) // $ target=read_flag
}
}

trait TryFlagExt: TryFlag {
// TryFlagExt::try_read_flag_twice
fn try_read_flag_twice(&self) -> Option<bool> {
self.try_read_flag() // $ target=try_read_flag
}
}

impl<T: TryFlag> TryFlagExt for T {}

trait AnotherTryFlag {
// AnotherTryFlag::try_read_flag_twice
fn try_read_flag_twice(&self) -> Option<bool>;
}

struct MyTryFlag {
flag: bool,
}

impl TryFlag for MyTryFlag {
// MyTryFlag::try_read_flag
fn try_read_flag(&self) -> Option<bool> {
Some(self.flag) // $ fieldof=MyTryFlag
}
}

struct MyFlag {
flag: bool,
}

impl Flag for MyFlag {
// MyFlag::read_flag
fn read_flag(&self) -> bool {
self.flag // $ fieldof=MyFlag
}
}

struct MyOtherFlag {
flag: bool,
}

impl AnotherTryFlag for MyOtherFlag {
// MyOtherFlag::try_read_flag_twice
fn try_read_flag_twice(&self) -> Option<bool> {
Some(self.flag) // $ fieldof=MyOtherFlag
}
}

fn test() {
let my_try_flag = MyTryFlag { flag: true };
let result = my_try_flag.try_read_flag_twice(); // $ target=TryFlagExt::try_read_flag_twice

let my_flag = MyFlag { flag: true };
// Here `TryFlagExt::try_read_flag_twice` is since there is a blanket
// implementaton of `TryFlag` for `Flag`.
let result = my_flag.try_read_flag_twice(); // $ MISSING: target=TryFlagExt::try_read_flag_twice

let my_other_flag = MyOtherFlag { flag: true };
// Here `TryFlagExt::try_read_flag_twice` is _not_ a target since
// `MyOtherFlag` does not implement `TryFlag`.
let result = my_other_flag.try_read_flag_twice(); // $ target=MyOtherFlag::try_read_flag_twice
}
}
2 changes: 1 addition & 1 deletion rust/ql/test/library-tests/type-inference/dyn_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ fn test_assoc_type(obj: &dyn AssocTrait<i64, AP = bool>) {
pub fn test() {
test_basic_dyn_trait(&MyStruct { value: 42 }); // $ target=test_basic_dyn_trait
test_generic_dyn_trait(&GenStruct {
value: "".to_string(),
value: "".to_string(), // $ target=to_string
}); // $ target=test_generic_dyn_trait
test_poly_dyn_trait(); // $ target=test_poly_dyn_trait
test_assoc_type(&GenStruct { value: 100 }); // $ target=test_assoc_type
Expand Down
7 changes: 4 additions & 3 deletions rust/ql/test/library-tests/type-inference/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ mod method_non_parametric_trait_impl {

fn type_bound_type_parameter_impl<TP: MyTrait<S1>>(thing: TP) -> S1 {
// The trait bound on `TP` makes the implementation of `ConvertTo` valid
thing.convert_to() // $ MISSING: target=T::convert_to
thing.convert_to() // $ target=T::convert_to
}

pub fn f() {
Expand Down Expand Up @@ -388,7 +388,7 @@ mod method_non_parametric_trait_impl {
let x = get_snd_fst(c); // $ type=x:S1 target=get_snd_fst

let thing = MyThing { a: S1 };
let i = thing.convert_to(); // $ MISSING: type=i:S1 target=T::convert_to
let i = thing.convert_to(); // $ type=i:S1 target=T::convert_to
let j = convert_to(thing); // $ type=j:S1 target=convert_to
}
}
Expand Down Expand Up @@ -1292,7 +1292,7 @@ mod method_call_type_conversion {
let t = x7.m1(); // $ target=m1 type=t:& type=t:&T.S2
println!("{:?}", x7);

let x9: String = "Hello".to_string(); // $ type=x9:String
let x9: String = "Hello".to_string(); // $ type=x9:String target=to_string

// Implicit `String` -> `str` conversion happens via the `Deref` trait:
// https://doc.rust-lang.org/std/string/struct.String.html#deref.
Expand Down Expand Up @@ -2487,6 +2487,7 @@ pub mod pattern_matching_experimental {
mod closure;
mod dereference;
mod dyn_type;
mod blanket_impl;

fn main() {
field_access::f(); // $ target=f
Expand Down
Loading
Loading