1
0
mirror of https://github.com/danog/psalm.git synced 2024-11-30 04:39:00 +01:00

Merge pull request #7417 from klimick/partially-applied-closure-inference

Contextual type inference for high order function arg
This commit is contained in:
orklah 2022-01-20 21:03:46 +01:00 committed by GitHub
commit 6f1a5e8a47
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 581 additions and 27 deletions

View File

@ -19,6 +19,7 @@ use Psalm\Internal\Codebase\TaintFlowGraph;
use Psalm\Internal\DataFlow\TaintSink;
use Psalm\Internal\MethodIdentifier;
use Psalm\Internal\Stubs\Generator\StubsGenerator;
use Psalm\Internal\Type\Comparator\CallableTypeComparator;
use Psalm\Internal\Type\Comparator\UnionTypeComparator;
use Psalm\Internal\Type\TemplateInferredTypeReplacer;
use Psalm\Internal\Type\TemplateResult;
@ -196,7 +197,21 @@ class ArgumentsAnalyzer
$toggled_class_exists = true;
}
if (($arg->value instanceof PhpParser\Node\Expr\Closure
$high_order_template_result = null;
if (($arg->value instanceof PhpParser\Node\Expr\FuncCall
|| $arg->value instanceof PhpParser\Node\Expr\MethodCall
|| $arg->value instanceof PhpParser\Node\Expr\StaticCall)
&& $param
&& $function_storage = self::getHighOrderFuncStorage($context, $statements_analyzer, $arg->value)
) {
$high_order_template_result = self::handleHighOrderFuncCallArg(
$statements_analyzer,
$template_result ?? new TemplateResult([], []),
$function_storage,
$param
);
} elseif (($arg->value instanceof PhpParser\Node\Expr\Closure
|| $arg->value instanceof PhpParser\Node\Expr\ArrowFunction)
&& $param
&& !$arg->value->getDocComment()
@ -217,7 +232,15 @@ class ArgumentsAnalyzer
$context->inside_call = true;
if (ExpressionAnalyzer::analyze($statements_analyzer, $arg->value, $context) === false) {
if (ExpressionAnalyzer::analyze(
$statements_analyzer,
$arg->value,
$context,
false,
null,
false,
$high_order_template_result
) === false) {
$context->inside_call = $was_inside_call;
return false;
@ -315,6 +338,172 @@ class ArgumentsAnalyzer
}
}
private static function getHighOrderFuncStorage(
Context $context,
StatementsAnalyzer $statements_analyzer,
PhpParser\Node\Expr\CallLike $function_like_call
): ?FunctionLikeStorage {
$codebase = $statements_analyzer->getCodebase();
try {
if ($function_like_call instanceof PhpParser\Node\Expr\FuncCall) {
$function_id = strtolower((string) $function_like_call->name->getAttribute('resolvedName'));
if (empty($function_id)) {
return null;
}
return $codebase->functions->getStorage($statements_analyzer, $function_id);
}
if ($function_like_call instanceof PhpParser\Node\Expr\MethodCall &&
$function_like_call->var instanceof PhpParser\Node\Expr\Variable &&
$function_like_call->name instanceof PhpParser\Node\Identifier &&
is_string($function_like_call->var->name) &&
isset($context->vars_in_scope['$' . $function_like_call->var->name])
) {
$lhs_type = $context->vars_in_scope['$' . $function_like_call->var->name]->getSingleAtomic();
if (!$lhs_type instanceof Type\Atomic\TNamedObject) {
return null;
}
$method_id = new MethodIdentifier(
$lhs_type->value,
strtolower((string)$function_like_call->name)
);
return $codebase->methods->getStorage($method_id);
}
if ($function_like_call instanceof PhpParser\Node\Expr\StaticCall &&
$function_like_call->name instanceof PhpParser\Node\Identifier
) {
$method_id = new MethodIdentifier(
(string)$function_like_call->class->getAttribute('resolvedName'),
strtolower($function_like_call->name->name)
);
return $codebase->methods->getStorage($method_id);
}
} catch (UnexpectedValueException $e) {
return null;
}
return null;
}
/**
* Compiles TemplateResult for high-order functions ($func_call)
* by previous template args ($inferred_template_result).
*
* It's need for proper template replacement:
*
* ```
* * template T
* * return Closure(T): T
* function id(): Closure { ... }
*
* * template A
* * template B
* *
* * param list<A> $_items
* * param callable(A): B $_ab
* * return list<B>
* function map(array $items, callable $ab): array { ... }
*
* // list<int>
* $numbers = [1, 2, 3];
*
* $result = map($numbers, id());
* // $result is list<int> because template T of id() was inferred by previous arg.
* ```
*/
private static function handleHighOrderFuncCallArg(
StatementsAnalyzer $statements_analyzer,
TemplateResult $inferred_template_result,
FunctionLikeStorage $storage,
FunctionLikeParameter $actual_func_param
): ?TemplateResult {
$codebase = $statements_analyzer->getCodebase();
$input_hof_atomic = $storage->return_type && $storage->return_type->isSingle()
? $storage->return_type->getSingleAtomic()
: null;
// Try upcast invokable to callable type.
if ($input_hof_atomic instanceof Type\Atomic\TNamedObject &&
$input_hof_atomic->value !== 'Closure' &&
$codebase->classExists($input_hof_atomic->value)
) {
$callable_from_invokable = CallableTypeComparator::getCallableFromAtomic(
$codebase,
$input_hof_atomic
);
if ($callable_from_invokable) {
$invoke_id = new MethodIdentifier($input_hof_atomic->value, '__invoke');
$declaring_invoke_id = $codebase->methods->getDeclaringMethodId($invoke_id);
$storage = $codebase->methods->getStorage($declaring_invoke_id ?? $invoke_id);
$input_hof_atomic = $callable_from_invokable;
}
}
if (!$input_hof_atomic instanceof TClosure && !$input_hof_atomic instanceof TCallable) {
return null;
}
$container_hof_atomic = $actual_func_param->type && $actual_func_param->type->isSingle()
? $actual_func_param->type->getSingleAtomic()
: null;
if (!$container_hof_atomic instanceof TClosure && !$container_hof_atomic instanceof TCallable) {
return null;
}
$replaced_container_hof_atomic = new Union([clone $container_hof_atomic]);
// Replaces all input args in container function.
//
// For example:
// The map function expects callable(A):B as second param
// We know that previous arg type is list<int> where the int is the A template.
// Then we can replace callable(A): B to callable(int):B using $inferred_template_result.
TemplateInferredTypeReplacer::replace(
$replaced_container_hof_atomic,
$inferred_template_result,
$codebase
);
/** @var TClosure|TCallable $container_hof_atomic */
$container_hof_atomic = $replaced_container_hof_atomic->getSingleAtomic();
$high_order_template_result = new TemplateResult($storage->template_types ?: [], []);
// We can replace each templated param for the input function.
// Example:
// map($numbers, id());
// We know that map expects callable(int):B because the $numbers is list<int>.
// We know that id() returns callable(T):T.
// Then we can replace templated params sequentially using the expected callable(int):B.
foreach ($input_hof_atomic->params ?? [] as $offset => $actual_func_param) {
if ($actual_func_param->type &&
$actual_func_param->type->getTemplateTypes() &&
isset($container_hof_atomic->params[$offset])
) {
TemplateStandinTypeReplacer::replace(
clone $actual_func_param->type,
$high_order_template_result,
$codebase,
null,
$container_hof_atomic->params[$offset]->type
);
}
}
return $high_order_template_result;
}
/**
* @param array<int, PhpParser\Node\Arg> $args
*/

View File

@ -82,7 +82,8 @@ class FunctionCallAnalyzer extends CallAnalyzer
public static function analyze(
StatementsAnalyzer $statements_analyzer,
PhpParser\Node\Expr\FuncCall $stmt,
Context $context
Context $context,
?TemplateResult $template_result = null
): bool {
$function_name = $stmt->name;
@ -166,10 +167,12 @@ class FunctionCallAnalyzer extends CallAnalyzer
}
if (!$is_first_class_callable) {
$template_result = null;
if (isset($function_call_info->function_storage->template_types)) {
$template_result = new TemplateResult($function_call_info->function_storage->template_types ?: [], []);
if (!$template_result) {
$template_result = new TemplateResult([], []);
}
$template_result->template_types += $function_call_info->function_storage->template_types ?: [];
}
ArgumentsAnalyzer::analyze(
@ -205,6 +208,10 @@ class FunctionCallAnalyzer extends CallAnalyzer
}
}
$already_inferred_lower_bounds = $template_result
? $template_result->lower_bounds
: [];
$template_result = new TemplateResult([], []);
// do this here to allow closure param checks
@ -229,6 +236,8 @@ class FunctionCallAnalyzer extends CallAnalyzer
$function_call_info->function_id
);
$template_result->lower_bounds += $already_inferred_lower_bounds;
if ($function_name instanceof PhpParser\Node\Name && $function_call_info->function_id) {
$stmt_type = FunctionCallReturnTypeFetcher::fetch(
$statements_analyzer,

View File

@ -76,7 +76,8 @@ class AtomicMethodCallAnalyzer extends CallAnalyzer
?Atomic $static_type,
bool $is_intersection,
?string $lhs_var_id,
AtomicMethodCallAnalysisResult $result
AtomicMethodCallAnalysisResult $result,
?TemplateResult $inferred_template_result = null
): void {
if ($lhs_type_part instanceof TTemplateParam
&& !$lhs_type_part->as->isMixed()
@ -440,7 +441,8 @@ class AtomicMethodCallAnalyzer extends CallAnalyzer
$static_type,
$lhs_var_id,
$method_id,
$result
$result,
$inferred_template_result
);
$statements_analyzer->node_data = $old_node_data;

View File

@ -68,7 +68,8 @@ class ExistingAtomicMethodCallAnalyzer extends CallAnalyzer
?Atomic $static_type,
?string $lhs_var_id,
MethodIdentifier $method_id,
AtomicMethodCallAnalysisResult $result
AtomicMethodCallAnalysisResult $result,
?TemplateResult $inferred_template_result = null
): Union {
$config = $codebase->config;
@ -220,6 +221,10 @@ class ExistingAtomicMethodCallAnalyzer extends CallAnalyzer
$template_result = new TemplateResult([], $class_template_params ?: []);
$template_result->lower_bounds += $method_template_params;
if ($inferred_template_result) {
$template_result->lower_bounds += $inferred_template_result->lower_bounds;
}
if ($codebase->store_node_types
&& !$context->collect_initializations
&& !$context->collect_mutations

View File

@ -11,6 +11,7 @@ use Psalm\Internal\Analyzer\Statements\Expression\CallAnalyzer;
use Psalm\Internal\Analyzer\Statements\Expression\ExpressionIdentifier;
use Psalm\Internal\Analyzer\Statements\ExpressionAnalyzer;
use Psalm\Internal\Analyzer\StatementsAnalyzer;
use Psalm\Internal\Type\TemplateResult;
use Psalm\Issue\InvalidMethodCall;
use Psalm\Issue\InvalidScope;
use Psalm\Issue\NullReference;
@ -43,7 +44,8 @@ class MethodCallAnalyzer extends CallAnalyzer
StatementsAnalyzer $statements_analyzer,
PhpParser\Node\Expr\MethodCall $stmt,
Context $context,
bool $real_method_call = true
bool $real_method_call = true,
?TemplateResult $template_result = null
): bool {
$was_inside_call = $context->inside_call;
@ -194,7 +196,8 @@ class MethodCallAnalyzer extends CallAnalyzer
: null,
false,
$lhs_var_id,
$result
$result,
$template_result
);
if (isset($context->vars_in_scope[$lhs_var_id])
&& ($possible_new_class_type = $context->vars_in_scope[$lhs_var_id]) instanceof Union

View File

@ -41,7 +41,8 @@ class StaticCallAnalyzer extends CallAnalyzer
public static function analyze(
StatementsAnalyzer $statements_analyzer,
PhpParser\Node\Expr\StaticCall $stmt,
Context $context
Context $context,
?TemplateResult $template_result = null
): bool {
$method_id = null;
@ -219,7 +220,8 @@ class StaticCallAnalyzer extends CallAnalyzer
$lhs_type->ignore_nullable_issues,
$moved_call,
$has_mock,
$has_existing_method
$has_existing_method,
$template_result
);
}

View File

@ -22,6 +22,7 @@ use Psalm\Internal\Analyzer\Statements\Expression\Fetch\AtomicPropertyFetchAnaly
use Psalm\Internal\Analyzer\Statements\ExpressionAnalyzer;
use Psalm\Internal\Analyzer\StatementsAnalyzer;
use Psalm\Internal\MethodIdentifier;
use Psalm\Internal\Type\TemplateResult;
use Psalm\Internal\Type\TypeExpander;
use Psalm\Issue\DeprecatedClass;
use Psalm\Issue\ImpureMethodCall;
@ -74,7 +75,8 @@ class AtomicStaticCallAnalyzer
bool $ignore_nullable_issues,
bool &$moved_call,
bool &$has_mock,
bool &$has_existing_method
bool &$has_existing_method,
?TemplateResult $inferred_template_result = null
): void {
$intersection_types = [];
@ -209,7 +211,8 @@ class AtomicStaticCallAnalyzer
$intersection_types ?: [],
$fq_class_name,
$moved_call,
$has_existing_method
$has_existing_method,
$inferred_template_result
);
} else {
if ($stmt->name instanceof PhpParser\Node\Expr) {
@ -271,7 +274,8 @@ class AtomicStaticCallAnalyzer
array $intersection_types,
string $fq_class_name,
bool &$moved_call,
bool &$has_existing_method
bool &$has_existing_method,
?TemplateResult $inferred_template_result = null
): bool {
$codebase = $statements_analyzer->getCodebase();
@ -829,7 +833,8 @@ class AtomicStaticCallAnalyzer
$method_id,
$cased_method_id,
$class_storage,
$moved_call
$moved_call,
$inferred_template_result
);
return true;

View File

@ -62,7 +62,8 @@ class ExistingAtomicStaticCallAnalyzer
MethodIdentifier $method_id,
string $cased_method_id,
ClassLikeStorage $class_storage,
bool &$moved_call
bool &$moved_call,
?TemplateResult $inferred_template_result = null
): void {
$fq_class_name = $method_id->fq_class_name;
$method_name_lc = $method_id->method_name;
@ -185,6 +186,10 @@ class ExistingAtomicStaticCallAnalyzer
$template_result = new TemplateResult([], $found_generic_params ?: []);
if ($inferred_template_result) {
$template_result->lower_bounds += $inferred_template_result->lower_bounds;
}
if (CallAnalyzer::checkMethodArgs(
$method_id,
$args,

View File

@ -47,6 +47,7 @@ use Psalm\Internal\Codebase\TaintFlowGraph;
use Psalm\Internal\DataFlow\DataFlowNode;
use Psalm\Internal\DataFlow\TaintSink;
use Psalm\Internal\FileManipulation\FileManipulationBuffer;
use Psalm\Internal\Type\TemplateResult;
use Psalm\Issue\ForbiddenCode;
use Psalm\Issue\UnrecognizedExpression;
use Psalm\IssueBuffer;
@ -70,7 +71,8 @@ class ExpressionAnalyzer
Context $context,
bool $array_assignment = false,
?Context $global_context = null,
bool $from_stmt = false
bool $from_stmt = false,
?TemplateResult $template_result = null
): bool {
$codebase = $statements_analyzer->getCodebase();
@ -80,9 +82,9 @@ class ExpressionAnalyzer
$context,
$array_assignment,
$global_context,
$from_stmt
) === false
) {
$from_stmt,
$template_result
) === false) {
return false;
}
@ -144,7 +146,8 @@ class ExpressionAnalyzer
Context $context,
bool $array_assignment,
?Context $global_context,
bool $from_stmt
bool $from_stmt,
?TemplateResult $template_result = null
): bool {
if ($stmt instanceof PhpParser\Node\Expr\Variable) {
return VariableFetchAnalyzer::analyze(
@ -183,11 +186,11 @@ class ExpressionAnalyzer
}
if ($stmt instanceof PhpParser\Node\Expr\MethodCall) {
return MethodCallAnalyzer::analyze($statements_analyzer, $stmt, $context);
return MethodCallAnalyzer::analyze($statements_analyzer, $stmt, $context, true, $template_result);
}
if ($stmt instanceof PhpParser\Node\Expr\StaticCall) {
return StaticCallAnalyzer::analyze($statements_analyzer, $stmt, $context);
return StaticCallAnalyzer::analyze($statements_analyzer, $stmt, $context, $template_result);
}
if ($stmt instanceof PhpParser\Node\Expr\ConstFetch) {
@ -295,7 +298,8 @@ class ExpressionAnalyzer
return FunctionCallAnalyzer::analyze(
$statements_analyzer,
$stmt,
$context
$context,
$template_result
);
}

View File

@ -10,6 +10,9 @@ use Psalm\Internal\Analyzer\StatementsAnalyzer;
use Psalm\Internal\Codebase\InternalCallMapHandler;
use Psalm\Internal\MethodIdentifier;
use Psalm\Internal\Provider\NodeDataProvider;
use Psalm\Internal\Type\TemplateInferredTypeReplacer;
use Psalm\Internal\Type\TemplateResult;
use Psalm\Internal\Type\TemplateStandinTypeReplacer;
use Psalm\Internal\Type\TypeExpander;
use Psalm\Type;
use Psalm\Type\Atomic;
@ -385,6 +388,35 @@ class CallableTypeComparator
if ($codebase->methods->methodExists($invoke_id)) {
$declaring_method_id = $codebase->methods->getDeclaringMethodId($invoke_id);
$template_result = null;
if ($input_type_part instanceof Atomic\TGenericObject) {
$invokable_storage = $codebase->methods->getClassLikeStorageForMethod(
$declaring_method_id ?? $invoke_id
);
$type_params = [];
foreach ($invokable_storage->template_types ?? [] as $template => $for_class) {
foreach ($for_class as $type) {
$type_params[] = new Type\Union([
new TTemplateParam($template, $type, $input_type_part->value)
]);
}
}
if (!empty($type_params)) {
$input_with_templates = new Atomic\TGenericObject($input_type_part->value, $type_params);
$template_result = new TemplateResult($invokable_storage->template_types ?? [], []);
TemplateStandinTypeReplacer::replace(
new Type\Union([$input_with_templates]),
$template_result,
$codebase,
null,
new Type\Union([$input_type_part])
);
}
}
if ($declaring_method_id) {
$method_storage = $codebase->methods->getStorage($declaring_method_id);
@ -400,12 +432,26 @@ class CallableTypeComparator
);
}
return new TCallable(
$callable = new TCallable(
'callable',
$method_storage->params,
$converted_return_type,
$method_storage->pure
);
if ($template_result) {
$replaced_callable = clone $callable;
TemplateInferredTypeReplacer::replace(
new Type\Union([$replaced_callable]),
$template_result,
$codebase
);
$callable = $replaced_callable;
}
return $callable;
}
}
}

View File

@ -233,6 +233,290 @@ class CallableTest extends TestCase
'$inferred' => 'list<Foo>',
],
],
'inferTemplateOfHighOrderFunctionArgByPreviousArg' => [
'<?php
/**
* @return list<int>
*/
function getList() { throw new RuntimeException("???"); }
/**
* @template T
* @return Closure(T): T
*/
function id() { throw new RuntimeException("???"); }
/**
* @template A
* @template B
*
* @param list<A> $_items
* @param callable(A): B $_ab
* @return list<B>
*/
function map(array $_items, callable $_ab) { throw new RuntimeException("???"); }
$result = map(getList(), id());
',
'assertions' => [
'$result' => 'list<int>',
],
],
'inferTemplateOfHighOrderFunctionArgByPreviousArgInClassContext' => [
'<?php
/**
* @template A
*/
final class ArrayList
{
/**
* @template B
*
* @param callable(A): B $ab
* @return ArrayList<B>
*/
public function map(callable $ab) { throw new RuntimeException("???"); }
}
/**
* @return ArrayList<int>
*/
function getList() { throw new RuntimeException("???"); }
/**
* @template T
* @return Closure(T): T
*/
function id() { throw new RuntimeException("???"); }
$result = getList()->map(id());
',
'assertions' => [
'$result' => 'ArrayList<int>',
],
],
'inferTemplateOfHighOrderFunctionFromMethodArgByPreviousArg' => [
'<?php
final class Ops
{
/**
* @template T
* @return Closure(list<T>): T
*/
public function flatten() { throw new RuntimeException("???"); }
}
/**
* @return list<list<int>>
*/
function getList() { throw new RuntimeException("???"); }
/**
* @template T
* @return Closure(list<T>): T
*/
function flatten() { throw new RuntimeException("???"); }
/**
* @template A
* @template B
*
* @param list<A> $_a
* @param callable(A): B $_ab
* @return list<B>
*/
function map(array $_a, callable $_ab) { throw new RuntimeException("???"); }
$ops = new Ops;
$result = map(getList(), $ops->flatten());
',
'assertions' => [
'$result' => 'list<int>',
],
],
'inferTemplateOfHighOrderFunctionFromStaticMethodArgByPreviousArg' => [
'<?php
final class StaticOps
{
/**
* @template T
* @return Closure(list<T>): T
*/
public static function flatten() { throw new RuntimeException("???"); }
}
/**
* @return list<list<int>>
*/
function getList() { throw new RuntimeException("???"); }
/**
* @template T
* @return Closure(list<T>): T
*/
function flatten() { throw new RuntimeException("???"); }
/**
* @template A
* @template B
*
* @param list<A> $_a
* @param callable(A): B $_ab
* @return list<B>
*/
function map(array $_a, callable $_ab) { throw new RuntimeException("???"); }
$result = map(getList(), StaticOps::flatten());
',
'assertions' => [
'$result' => 'list<int>',
],
],
'' => [
'<?php
/**
* @template A
* @template B
*/
final class MapOperator
{
/**
* @param Closure(A): B $ab
*/
public function __construct(private Closure $ab) { }
/**
* @param list<A> $a
* @return list<B>
*/
public function __invoke($a): array
{
$b = [];
foreach ($a as $item) {
$b[] = ($this->ab)($item);
}
return $b;
}
}
/**
* @template A
* @template B
*
* @param Closure(A): B $ab
* @return MapOperator<A, B>
*/
function map(Closure $ab): MapOperator
{
return new MapOperator($ab);
}
/**
* @template A
* @template B
*
* @param A $_a
* @param callable(A): B $_ab
* @return B
*/
function pipe(array $_a, callable $_ab): array
{
throw new RuntimeException("???");
}
$result1 = pipe(
["1", "2", "3"],
map(fn ($i) => (int) $i)
);
$result2 = pipe(
["1", "2", "3"],
new MapOperator(fn ($i) => (int) $i)
);
',
'assertions' => [
'$result1' => 'list<int>',
'$result2' => 'list<int>',
],
'error_levels' => [],
'8.0',
],
'inferPipelineWithPartiallyAppliedFunctions' => [
'<?php
/**
* @template T
*
* @param callable(T, int): bool $_predicate
* @return Closure(list<T>): list<T>
*/
function filter(callable $_predicate): Closure { throw new RuntimeException("???"); }
/**
* @template A
* @template B
*
* @param callable(A): B $_ab
* @return Closure(list<A>): list<B>
*/
function map(callable $_ab): Closure { throw new RuntimeException("???"); }
/**
* @template T
* @return (Closure(list<T>): (non-empty-list<T> | null))
*/
function asNonEmptyList(): Closure { throw new RuntimeException("???"); }
/**
* @template T
* @return Closure(T): T
*/
function id(): Closure { throw new RuntimeException("???"); }
/**
* @template A
* @template B
* @template C
* @template D
* @template E
* @template F
*
* @param A $arg
* @param callable(A): B $ab
* @param callable(B): C $bc
* @param callable(C): D $cd
* @param callable(D): E $de
* @param callable(E): F $ef
* @return F
*/
function pipe4(mixed $arg, callable $ab, callable $bc, callable $cd, callable $de, callable $ef): mixed
{
return $ef($de($cd($bc($ab($arg)))));
}
/**
* @template TFoo of string
* @template TBar of bool
*/
final class Item
{
/**
* @param TFoo $foo
* @param TBar $bar
*/
public function __construct(
public string $foo,
public bool $bar,
) { }
}
/**
* @return list<Item>
*/
function getList(): array { return []; }
$result = pipe4(
getList(),
filter(fn($i) => $i->bar),
filter(fn(Item $i) => $i->foo !== "bar"),
map(fn($i) => new Item("test: " . $i->foo, $i->bar)),
asNonEmptyList(),
id(),
);',
'assertions' => [
'$result' => 'non-empty-list<Item<string, bool>>|null',
],
'error_levels' => [],
'8.0',
],
'varReturnType' => [
'code' => '<?php
$add_one = function(int $a) : int {