// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. #nullable disable using System.Collections.Generic; using System.Collections.Immutable; using System.Diagnostics; using System.Linq; using System.Threading; using System.Threading.Tasks; using Microsoft.CodeAnalysis.Collections; using Microsoft.CodeAnalysis.CSharp.Extensions; using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.LanguageService; using Microsoft.CodeAnalysis.Operations; using Microsoft.CodeAnalysis.PooledObjects; using Microsoft.CodeAnalysis.Shared.Collections; using Microsoft.CodeAnalysis.Shared.Extensions; using Roslyn.Utilities; namespace Microsoft.CodeAnalysis.CSharp; internal partial class CSharpTypeInferenceService { private sealed class TypeInferrer : AbstractTypeInferrer { internal TypeInferrer( SemanticModel semanticModel, CancellationToken cancellationToken) : base(semanticModel, cancellationToken) { } protected override bool IsUnusableType(ITypeSymbol otherSideType) { return otherSideType.IsErrorType() && (otherSideType.Name == string.Empty || otherSideType.Name == "var"); } protected override IEnumerable<TypeInferenceInfo> GetTypes_DoNotCallDirectly(SyntaxNode node, bool objectAsDefault) { var types = GetTypesSimple(node).Where(IsUsableTypeFunc); if (types.Any()) { return types; } return GetTypesComplex(node).Where(IsUsableTypeFunc); } private static bool DecomposeBinaryOrAssignmentExpression(SyntaxNode node, out SyntaxToken operatorToken, out ExpressionSyntax left, out ExpressionSyntax right) { if (node is BinaryExpressionSyntax binaryExpression) { operatorToken = binaryExpression.OperatorToken; left = binaryExpression.Left; right = binaryExpression.Right; return true; } if (node is AssignmentExpressionSyntax assignmentExpression) { operatorToken = assignmentExpression.OperatorToken; left = assignmentExpression.Left; right = assignmentExpression.Right; return true; } operatorToken = default; left = right = null; return false; } private IEnumerable<TypeInferenceInfo> GetTypesComplex(SyntaxNode node) { if (DecomposeBinaryOrAssignmentExpression(node, out var operatorToken, out var left, out var right)) { var types = InferTypeInBinaryOrAssignmentExpression((ExpressionSyntax)node, operatorToken, left, right, left).Where(IsUsableTypeFunc); if (types.IsEmpty()) { types = InferTypeInBinaryOrAssignmentExpression((ExpressionSyntax)node, operatorToken, left, right, right).Where(IsUsableTypeFunc); } return types; } // TODO(cyrusn): More cases if necessary. return []; } private IEnumerable<TypeInferenceInfo> GetTypesSimple(SyntaxNode node) { if (node is RefTypeSyntax refType) { return GetTypes(refType.Type); } else if (node != null) { var typeInfo = SemanticModel.GetTypeInfo(node, CancellationToken); var symbolInfo = SemanticModel.GetSymbolInfo(node, CancellationToken); if (symbolInfo.CandidateReason != CandidateReason.WrongArity) { var typeInferenceInfo = new TypeInferenceInfo(typeInfo.Type); // If it bound to a method, try to get the Action/Func form of that method. if (typeInferenceInfo.InferredType == null) { var allSymbols = symbolInfo.GetAllSymbols(); if (allSymbols is [IMethodSymbol method]) typeInferenceInfo = new TypeInferenceInfo(method.ConvertToType(this.Compilation)); } if (IsUsableTypeFunc(typeInferenceInfo)) return [typeInferenceInfo]; } } return []; } protected override IEnumerable<TypeInferenceInfo> InferTypesWorker_DoNotCallDirectly( SyntaxNode node) { var expression = node as ExpressionSyntax; if (expression != null) { expression = expression.WalkUpParentheses(); node = expression; } var parent = node.Parent; return parent switch { AnonymousObjectMemberDeclaratorSyntax memberDeclarator => InferTypeInMemberDeclarator(memberDeclarator), ArgumentSyntax argument => InferTypeInArgument(argument), ArrayCreationExpressionSyntax arrayCreationExpression => InferTypeInArrayCreationExpression(arrayCreationExpression), ArrayRankSpecifierSyntax arrayRankSpecifier => InferTypeInArrayRankSpecifier(arrayRankSpecifier), ArrayTypeSyntax arrayType => InferTypeInArrayType(arrayType), ArrowExpressionClauseSyntax arrowClause => InferTypeInArrowExpressionClause(arrowClause), AssignmentExpressionSyntax assignmentExpression => InferTypeInBinaryOrAssignmentExpression(assignmentExpression, assignmentExpression.OperatorToken, assignmentExpression.Left, assignmentExpression.Right, expression), AttributeArgumentSyntax attribute => InferTypeInAttributeArgument(attribute), AttributeSyntax _ => InferTypeInAttribute(), AwaitExpressionSyntax awaitExpression => InferTypeInAwaitExpression(awaitExpression), BinaryExpressionSyntax binaryExpression => InferTypeInBinaryOrAssignmentExpression(binaryExpression, binaryExpression.OperatorToken, binaryExpression.Left, binaryExpression.Right, expression), CastExpressionSyntax castExpression => InferTypeInCastExpression(castExpression, expression), CatchDeclarationSyntax catchDeclaration => InferTypeInCatchDeclaration(catchDeclaration), CatchFilterClauseSyntax catchFilterClause => InferTypeInCatchFilterClause(catchFilterClause), CheckedExpressionSyntax checkedExpression => InferTypes(checkedExpression), ConditionalAccessExpressionSyntax conditionalAccessExpression => InferTypeInConditionalAccessExpression(conditionalAccessExpression), ConditionalExpressionSyntax conditionalExpression => InferTypeInConditionalExpression(conditionalExpression, expression), ConstantPatternSyntax constantPattern => InferTypeInConstantPattern(constantPattern), DoStatementSyntax doStatement => InferTypeInDoStatement(doStatement), EqualsValueClauseSyntax equalsValue => InferTypeInEqualsValueClause(equalsValue), ExpressionColonSyntax expressionColon => InferTypeInExpressionColon(expressionColon), ExpressionStatementSyntax _ => InferTypeInExpressionStatement(), ForEachStatementSyntax forEachStatement => InferTypeInForEachStatement(forEachStatement, expression), ForStatementSyntax forStatement => InferTypeInForStatement(forStatement, expression), IfStatementSyntax ifStatement => InferTypeInIfStatement(ifStatement), InitializerExpressionSyntax initializerExpression => InferTypeInInitializerExpression(initializerExpression, expression), IsPatternExpressionSyntax isPatternExpression => InferTypeInIsPatternExpression(isPatternExpression, node), LockStatementSyntax lockStatement => InferTypeInLockStatement(lockStatement), MemberAccessExpressionSyntax memberAccessExpression => InferTypeInMemberAccessExpression(memberAccessExpression, expression), NameColonSyntax nameColon => InferTypeInNameColon(nameColon), NameEqualsSyntax nameEquals => InferTypeInNameEquals(nameEquals), LambdaExpressionSyntax lambdaExpression => InferTypeInLambdaExpression(lambdaExpression), PostfixUnaryExpressionSyntax postfixUnary => InferTypeInPostfixUnaryExpression(postfixUnary), PrefixUnaryExpressionSyntax prefixUnary => InferTypeInPrefixUnaryExpression(prefixUnary), RecursivePatternSyntax propertyPattern => InferTypeInRecursivePattern(propertyPattern), PropertyPatternClauseSyntax propertySubpattern => InferTypeInPropertyPatternClause(propertySubpattern), RefExpressionSyntax refExpression => InferTypeInRefExpression(refExpression), ReturnStatementSyntax returnStatement => InferTypeForReturnStatement(returnStatement), SubpatternSyntax subpattern => InferTypeInSubpattern(subpattern, node), SwitchExpressionArmSyntax arm => InferTypeInSwitchExpressionArm(arm), SwitchLabelSyntax switchLabel => InferTypeInSwitchLabel(switchLabel), SwitchStatementSyntax switchStatement => InferTypeInSwitchStatement(switchStatement), ThrowExpressionSyntax throwExpression => InferTypeInThrowExpression(throwExpression), ThrowStatementSyntax throwStatement => InferTypeInThrowStatement(throwStatement), UsingStatementSyntax usingStatement => InferTypeInUsingStatement(usingStatement), WhenClauseSyntax whenClause => InferTypeInWhenClause(whenClause), WhileStatementSyntax whileStatement => InferTypeInWhileStatement(whileStatement), YieldStatementSyntax yieldStatement => InferTypeInYieldStatement(yieldStatement), _ => [], }; } protected override IEnumerable<TypeInferenceInfo> InferTypesWorker_DoNotCallDirectly(int position) { var syntaxTree = SemanticModel.SyntaxTree; var token = syntaxTree.FindTokenOnLeftOfPosition(position, CancellationToken); token = token.GetPreviousTokenIfTouchingWord(position); var parent = token.Parent; return parent switch { AnonymousObjectCreationExpressionSyntax anonymousObjectCreation => InferTypeInAnonymousObjectCreation(anonymousObjectCreation, token), AnonymousObjectMemberDeclaratorSyntax memberDeclarator => InferTypeInMemberDeclarator(memberDeclarator, token), ArgumentListSyntax argument => InferTypeInArgumentList(argument, token), ArgumentSyntax argument => InferTypeInArgument(argument, token), ArrayCreationExpressionSyntax arrayCreationExpression => InferTypeInArrayCreationExpression(arrayCreationExpression, token), ArrayRankSpecifierSyntax arrayRankSpecifier => InferTypeInArrayRankSpecifier(arrayRankSpecifier, token), ArrayTypeSyntax arrayType => InferTypeInArrayType(arrayType, token), ArrowExpressionClauseSyntax arrowClause => InferTypeInArrowExpressionClause(arrowClause), AssignmentExpressionSyntax assignmentExpression => InferTypeInBinaryOrAssignmentExpression(assignmentExpression, assignmentExpression.OperatorToken, assignmentExpression.Left, assignmentExpression.Right, previousToken: token), AttributeArgumentListSyntax attributeArgumentList => InferTypeInAttributeArgumentList(attributeArgumentList, token), AttributeArgumentSyntax argument => InferTypeInAttributeArgument(argument, token), AttributeListSyntax attributeDeclaration => InferTypeInAttributeDeclaration(attributeDeclaration, token), AttributeTargetSpecifierSyntax attributeTargetSpecifier => InferTypeInAttributeTargetSpecifier(attributeTargetSpecifier, token), AwaitExpressionSyntax awaitExpression => InferTypeInAwaitExpression(awaitExpression, token), BinaryExpressionSyntax binaryExpression => InferTypeInBinaryOrAssignmentExpression(binaryExpression, binaryExpression.OperatorToken, binaryExpression.Left, binaryExpression.Right, previousToken: token), BinaryPatternSyntax binaryPattern => GetPatternTypes(binaryPattern), BracketedArgumentListSyntax bracketedArgumentList => InferTypeInBracketedArgumentList(bracketedArgumentList, token), CastExpressionSyntax castExpression => InferTypeInCastExpression(castExpression, previousToken: token), CatchDeclarationSyntax catchDeclaration => InferTypeInCatchDeclaration(catchDeclaration, token), CatchFilterClauseSyntax catchFilterClause => InferTypeInCatchFilterClause(catchFilterClause, token), CheckedExpressionSyntax checkedExpression => InferTypes(checkedExpression), ConditionalExpressionSyntax conditionalExpression => InferTypeInConditionalExpression(conditionalExpression, previousToken: token), DefaultExpressionSyntax defaultExpression => InferTypeInDefaultExpression(defaultExpression), DoStatementSyntax doStatement => InferTypeInDoStatement(doStatement, token), EqualsValueClauseSyntax equalsValue => InferTypeInEqualsValueClause(equalsValue, token), ExpressionColonSyntax expressionColon => InferTypeInExpressionColon(expressionColon, token), ExpressionStatementSyntax _ => InferTypeInExpressionStatement(token), ForEachStatementSyntax forEachStatement => InferTypeInForEachStatement(forEachStatement, previousToken: token), ForStatementSyntax forStatement => InferTypeInForStatement(forStatement, previousToken: token), IfStatementSyntax ifStatement => InferTypeInIfStatement(ifStatement, token), ImplicitArrayCreationExpressionSyntax implicitArray => InferTypeInImplicitArrayCreation(implicitArray), InitializerExpressionSyntax initializerExpression => InferTypeInInitializerExpression(initializerExpression, previousToken: token), LockStatementSyntax lockStatement => InferTypeInLockStatement(lockStatement, token), MemberAccessExpressionSyntax memberAccessExpression => InferTypeInMemberAccessExpression(memberAccessExpression, previousToken: token), NameColonSyntax nameColon => InferTypeInNameColon(nameColon, token), NameEqualsSyntax nameEquals => InferTypeInNameEquals(nameEquals, token), BaseObjectCreationExpressionSyntax objectCreation => InferTypeInObjectCreationExpression(objectCreation, token), LambdaExpressionSyntax lambdaExpression => InferTypeInLambdaExpression(lambdaExpression, token), PostfixUnaryExpressionSyntax postfixUnary => InferTypeInPostfixUnaryExpression(postfixUnary, token), PrefixUnaryExpressionSyntax prefixUnary => InferTypeInPrefixUnaryExpression(prefixUnary, token), RelationalPatternSyntax relationalPattern => InferTypeInRelationalPattern(relationalPattern), ReturnStatementSyntax returnStatement => InferTypeForReturnStatement(returnStatement, token), SingleVariableDesignationSyntax singleVariableDesignationSyntax => InferTypeForSingleVariableDesignation(singleVariableDesignationSyntax), SwitchLabelSyntax switchLabel => InferTypeInSwitchLabel(switchLabel, token), SwitchExpressionSyntax switchExpression => InferTypeInSwitchExpression(switchExpression, token), SwitchStatementSyntax switchStatement => InferTypeInSwitchStatement(switchStatement, token), ThrowStatementSyntax throwStatement => InferTypeInThrowStatement(throwStatement, token), TupleExpressionSyntax tupleExpression => InferTypeInTupleExpression(tupleExpression, token), UnaryPatternSyntax unaryPattern => GetPatternTypes(unaryPattern), UsingStatementSyntax usingStatement => InferTypeInUsingStatement(usingStatement, token), WhenClauseSyntax whenClause => InferTypeInWhenClause(whenClause, token), WhileStatementSyntax whileStatement => InferTypeInWhileStatement(whileStatement, token), YieldStatementSyntax yieldStatement => InferTypeInYieldStatement(yieldStatement, token), _ => [], }; } private IEnumerable<TypeInferenceInfo> InferTypeInAnonymousObjectCreation(AnonymousObjectCreationExpressionSyntax expression, SyntaxToken previousToken) { if (previousToken == expression.NewKeyword) { return InferTypes(expression.SpanStart); } return []; } private IEnumerable<TypeInferenceInfo> InferTypeInArgument( ArgumentSyntax argument, SyntaxToken? previousToken = null) { if (previousToken.HasValue) { // If we have a position, then it must be after the colon in a named argument. if (argument.NameColon == null || argument.NameColon.ColonToken != previousToken) return []; } if (argument is { Parent.Parent: ConstructorInitializerSyntax initializer }) { var index = initializer.ArgumentList.Arguments.IndexOf(argument); return InferTypeInConstructorInitializer(initializer, index, argument); } if (argument is { Parent.Parent: InvocationExpressionSyntax invocation }) { var index = invocation.ArgumentList.Arguments.IndexOf(argument); return InferTypeInInvocationExpression(invocation, index, argument); } if (argument is { Parent.Parent: BaseObjectCreationExpressionSyntax creation }) { // new Outer(Goo()); // // new Outer(a: Goo()); // // etc. var index = creation.ArgumentList.Arguments.IndexOf(argument); return InferTypeInObjectCreationExpression(creation, index, argument); } if (argument is { Parent.Parent: PrimaryConstructorBaseTypeSyntax primaryConstructorBaseType }) { // class C() : Base(Goo()); var index = primaryConstructorBaseType.ArgumentList.Arguments.IndexOf(argument); return InferTypeInPrimaryConstructorBaseType(primaryConstructorBaseType, index, argument); } if (argument is { Parent.Parent: ElementAccessExpressionSyntax elementAccess }) { // Outer[Goo()]; // // Outer[a: Goo()]; // // etc. var index = elementAccess.ArgumentList.Arguments.IndexOf(argument); return InferTypeInElementAccessExpression(elementAccess, index, argument); } if (argument is { Parent: TupleExpressionSyntax tupleExpression }) { return InferTypeInTupleExpression(tupleExpression, argument); } if (argument.Parent.IsParentKind(SyntaxKind.ImplicitElementAccess) && argument.Parent.Parent.IsParentKind(SyntaxKind.SimpleAssignmentExpression) && argument.Parent.Parent.Parent.IsParentKind(SyntaxKind.ObjectInitializerExpression) && argument.Parent.Parent.Parent.Parent?.Parent is BaseObjectCreationExpressionSyntax objectCreation) { var types = GetTypes(objectCreation).Select(t => t.InferredType); if (types.Any(t => t is INamedTypeSymbol)) { return types.OfType<INamedTypeSymbol>().SelectMany(t => GetCollectionElementType(t)); } } return []; } private IEnumerable<TypeInferenceInfo> InferTypeInTupleExpression( TupleExpressionSyntax tupleExpression, SyntaxToken previousToken) { if (previousToken == tupleExpression.OpenParenToken) return InferTypeInTupleExpression(tupleExpression, tupleExpression.Arguments[0]); if (previousToken.IsKind(SyntaxKind.CommaToken)) { var argsAndCommas = tupleExpression.Arguments.GetWithSeparators(); var commaIndex = argsAndCommas.IndexOf(previousToken); return InferTypeInTupleExpression(tupleExpression, (ArgumentSyntax)argsAndCommas[commaIndex + 1]); } return []; } private IEnumerable<TypeInferenceInfo> InferTypeInTupleExpression( TupleExpressionSyntax tupleExpression, ArgumentSyntax argument) { var index = tupleExpression.Arguments.IndexOf(argument); var parentTypes = InferTypes(tupleExpression); return parentTypes.Select(typeInfo => typeInfo.InferredType) .OfType<INamedTypeSymbol>() .Where(namedType => namedType.IsTupleType && index < namedType.TupleElements.Length) .Select(tupleType => new TypeInferenceInfo(tupleType.TupleElements[index].Type)); } private IEnumerable<TypeInferenceInfo> InferTypeInAttributeArgument(AttributeArgumentSyntax argument, SyntaxToken? previousToken = null) { if (previousToken.HasValue) { // If we have a position, then it must be after the colon or equals in an argument. if (argument.NameColon == null || argument.NameColon.ColonToken != previousToken || argument.NameEquals.EqualsToken != previousToken) return []; } if (argument.Parent != null) { if (argument.Parent.Parent is AttributeSyntax attribute) { var index = attribute.ArgumentList.Arguments.IndexOf(argument); return InferTypeInAttribute(attribute, index, argument); } } return []; } private IEnumerable<TypeInferenceInfo> InferTypeInConstructorInitializer(ConstructorInitializerSyntax initializer, int index, ArgumentSyntax argument = null) { var info = SemanticModel.GetSymbolInfo(initializer, CancellationToken); var methods = info.GetBestOrAllSymbols().OfType<IMethodSymbol>(); return InferTypeInArgument(index, methods, argument, parentInvocationExpressionToTypeInfer: null); } private IEnumerable<TypeInferenceInfo> InferTypeInObjectCreationExpression(BaseObjectCreationExpressionSyntax expression, SyntaxToken previousToken) { // A couple of broken code scenarios where the new keyword in objectcreationexpression // appears to be a part of a subsequent assignment. For example: // // new Form // { // Location = new $$ // StartPosition = FormStartPosition.CenterParent // }; // The 'new' token is part of an assignment of the assignment to StartPosition, // but the user is really trying to assign to Location. // // Similarly: // bool b; // Task task = new $$ // b = false; // The 'new' token is part of an assignment of the assignment to b, but the user // is really trying to assign to task. // // In both these cases, we simply back up before the 'new' if it follows an equals // and start the inference again. // // Analogously, but in a method call: // Test(new $$ // o = s // or: // Test(1, new $$ // o = s // The new is part of the assignment to o but the user is really trying to // add a parameter to the method call. if (previousToken.Kind() == SyntaxKind.NewKeyword && previousToken.GetPreviousToken().Kind() is SyntaxKind.EqualsToken or SyntaxKind.OpenParenToken or SyntaxKind.CommaToken) { return InferTypes(previousToken.SpanStart); } return InferTypes(expression); } private IEnumerable<TypeInferenceInfo> InferTypeInObjectCreationExpression(BaseObjectCreationExpressionSyntax creation, int index, ArgumentSyntax argumentOpt = null) { var info = SemanticModel.GetTypeInfo(creation, CancellationToken); if (info.Type is not INamedTypeSymbol type) return []; if (type.TypeKind == TypeKind.Delegate) { // new SomeDelegateType( here ); // // They're actually instantiating a delegate, so the delegate type is // that type. return CreateResult(type); } var constructors = type.InstanceConstructors.Where(m => m.Parameters.Length > index); return InferTypeInArgument(index, constructors, argumentOpt, parentInvocationExpressionToTypeInfer: null); } private IEnumerable<TypeInferenceInfo> InferTypeInPrimaryConstructorBaseType( PrimaryConstructorBaseTypeSyntax primaryConstructorBaseType, int index, ArgumentSyntax argumentOpt = null) { var info = SemanticModel.GetTypeInfo(primaryConstructorBaseType.Type, CancellationToken); if (info.Type is not INamedTypeSymbol type) return []; var constructors = type.InstanceConstructors.Where(m => m.Parameters.Length > index); return InferTypeInArgument(index, constructors, argumentOpt, parentInvocationExpressionToTypeInfer: null); } private IEnumerable<TypeInferenceInfo> InferTypeInInvocationExpression( InvocationExpressionSyntax invocation, int index, ArgumentSyntax argumentOpt = null) { // Check all the methods that have at least enough arguments to support // being called with argument at this position. Note: if they're calling an // extension method then it will need one more argument in order for us to // call it. var info = SemanticModel.GetSymbolInfo(invocation, CancellationToken); var methods = info.GetBestOrAllSymbols().OfType<IMethodSymbol>(); // 1. Overload resolution (see DevDiv 611477) in certain extension method cases // can result in GetSymbolInfo returning nothing. // 2. when trying to infer the type of the first argument, it's possible that nothing corresponding to // the argument is typed and there exists an overload takes 0 argument as a viable match. // In one of these cases, get the method group info, which is what signature help already does. if (info.Symbol == null || argumentOpt == null && info.Symbol is IMethodSymbol method && method.Parameters.All(p => p.IsOptional || p.IsParams)) { var memberGroupMethods = SemanticModel.GetMemberGroup(invocation.Expression, CancellationToken) .OfType<IMethodSymbol>(); methods = methods.Concat(memberGroupMethods).Distinct().ToList(); } // Special case: if this is an argument in Enum.HasFlag, infer the Enum type that we're invoking into, // as otherwise we infer "Enum" which isn't useful if (methods.Any(IsEnumHasFlag)) { if (invocation.Expression is MemberAccessExpressionSyntax memberAccess) { var typeInfo = SemanticModel.GetTypeInfo(memberAccess.Expression, CancellationToken); if (typeInfo.Type != null && typeInfo.Type.IsEnumType()) { return CreateResult(typeInfo.Type); } } } return InferTypeInArgument(index, methods, argumentOpt, invocation); } private IEnumerable<TypeInferenceInfo> InferTypeInArgumentList(ArgumentListSyntax argumentList, SyntaxToken previousToken) { // Has to follow the ( or a , if (previousToken != argumentList.OpenParenToken && previousToken.Kind() != SyntaxKind.CommaToken) return []; switch (argumentList.Parent) { case InvocationExpressionSyntax invocation: { var index = GetArgumentListIndex(argumentList, previousToken); return InferTypeInInvocationExpression(invocation, index); } case BaseObjectCreationExpressionSyntax objectCreation: { var index = GetArgumentListIndex(argumentList, previousToken); return InferTypeInObjectCreationExpression(objectCreation, index); } case ConstructorInitializerSyntax constructorInitializer: { var index = GetArgumentListIndex(argumentList, previousToken); return InferTypeInConstructorInitializer(constructorInitializer, index); } } return []; } private IEnumerable<TypeInferenceInfo> InferTypeInAttributeArgumentList(AttributeArgumentListSyntax attributeArgumentList, SyntaxToken previousToken) { // Has to follow the ( or a , if (previousToken != attributeArgumentList.OpenParenToken && previousToken.Kind() != SyntaxKind.CommaToken) return []; if (attributeArgumentList.Parent is AttributeSyntax attribute) { var index = GetArgumentListIndex(attributeArgumentList, previousToken); return InferTypeInAttribute(attribute, index); } return []; } private IEnumerable<TypeInferenceInfo> InferTypeInAttribute(AttributeSyntax attribute, int index, AttributeArgumentSyntax argumentOpt = null) { var info = SemanticModel.GetSymbolInfo(attribute, CancellationToken); var methods = info.GetBestOrAllSymbols().OfType<IMethodSymbol>(); return InferTypeInAttributeArgument(index, methods, argumentOpt); } private IEnumerable<TypeInferenceInfo> InferTypeInElementAccessExpression( ElementAccessExpressionSyntax elementAccess, int index, ArgumentSyntax argumentOpt = null) { var info = SemanticModel.GetTypeInfo(elementAccess.Expression, CancellationToken); if (info.Type is INamedTypeSymbol type) { var indexers = type.GetMembers().OfType<IPropertySymbol>() .Where(p => p.IsIndexer && p.Parameters.Length > index); if (indexers.Any()) { return indexers.SelectMany(i => InferTypeInArgument(index, [i.Parameters], argumentOpt)); } } // For everything else, assume it's an integer. Note: this won't be correct for // type parameters that implement some interface, but that seems like a major // corner case for now. // // This does, however, cover the more common cases of // arrays/pointers/errors/dynamic. return CreateResult(this.Compilation.GetSpecialType(SpecialType.System_Int32)); } private IEnumerable<TypeInferenceInfo> InferTypeInAttributeArgument(int index, IEnumerable<IMethodSymbol> methods, AttributeArgumentSyntax argumentOpt = null) => InferTypeInAttributeArgument(index, methods.SelectAsArray(m => m.Parameters), argumentOpt); private IEnumerable<TypeInferenceInfo> InferTypeInArgument(int index, IEnumerable<IMethodSymbol> methods, ArgumentSyntax argumentOpt, InvocationExpressionSyntax parentInvocationExpressionToTypeInfer) { if (parentInvocationExpressionToTypeInfer != null) { // We're trying to figure out the signature of a method we're an argument to. // That method may be generic, and we might end up using one of its generic // type parameters in the type we infer. First, let's see if we can instantiate // the methods so that the type can be inferred better. var invocationTypes = this.InferTypes(parentInvocationExpressionToTypeInfer).Select(t => t.InferredType).ToList(); var instantiatedMethods = methods.Select(m => Instantiate(m, invocationTypes)).ToList(); // Now that we've instantiated the methods, filter down to the ones that // will actually return a viable type given where this invocation expression // is. var filteredMethods = instantiatedMethods.Where(m => invocationTypes.Any(t => Compilation.ClassifyConversion(m.ReturnType, t).IsImplicit)).ToList(); // If we filtered down to nothing, then just fall back to the instantiated list. // this is a best effort after all. methods = filteredMethods.Any() ? filteredMethods : instantiatedMethods; } return InferTypeInArgument(index, methods.SelectAsArray(m => m.Parameters), argumentOpt); } private static IMethodSymbol Instantiate(IMethodSymbol method, IList<ITypeSymbol> invocationTypes) { // No need to instantiate if this isn't a generic method. if (method.TypeArguments.Length == 0) { return method; } // Can't infer the type parameters if this method doesn't have a return type. // Note: this is because this code path is specifically flowing type information // backward through the return type. Type information is already flowed forward // through arguments by the compiler when we get the initial set of methods. if (method.ReturnsVoid) { return method; } // If the method has already been constructed poorly (i.e. with error types for type // arguments), then unconstruct it. if (method.TypeArguments.Any(static t => t.Kind == SymbolKind.ErrorType)) { method = method.ConstructedFrom; } IDictionary<ITypeParameterSymbol, ITypeSymbol> bestMap = null; foreach (var type in invocationTypes) { // Ok. We inferred a type for this location, and we have the return type of this // method. See if we can then assign any values for type parameters. var map = DetermineTypeParameterMapping(type, method.ReturnType); if (map.Count > 0 && (bestMap == null || map.Count > bestMap.Count)) { bestMap = map; } } if (bestMap == null) { return method; } var typeArguments = method.ConstructedFrom.TypeParameters .Select(tp => bestMap.GetValueOrDefault(tp) ?? tp).ToArray(); return method.ConstructedFrom.Construct(typeArguments); } private static Dictionary<ITypeParameterSymbol, ITypeSymbol> DetermineTypeParameterMapping(ITypeSymbol inferredType, ITypeSymbol returnType) { var result = new Dictionary<ITypeParameterSymbol, ITypeSymbol>(); DetermineTypeParameterMapping(inferredType, returnType, result); return result; } private static void DetermineTypeParameterMapping(ITypeSymbol inferredType, ITypeSymbol returnType, Dictionary<ITypeParameterSymbol, ITypeSymbol> result) { if (inferredType == null || returnType == null) { return; } if (returnType.Kind == SymbolKind.TypeParameter) { if (inferredType.Kind != SymbolKind.TypeParameter) { var returnTypeParameter = (ITypeParameterSymbol)returnType; if (!result.ContainsKey(returnTypeParameter)) { result[returnTypeParameter] = inferredType; } return; } } if (inferredType.Kind != returnType.Kind) { return; } switch (inferredType.Kind) { case SymbolKind.ArrayType: DetermineTypeParameterMapping(((IArrayTypeSymbol)inferredType).ElementType, ((IArrayTypeSymbol)returnType).ElementType, result); return; case SymbolKind.PointerType: DetermineTypeParameterMapping(((IPointerTypeSymbol)inferredType).PointedAtType, ((IPointerTypeSymbol)returnType).PointedAtType, result); return; case SymbolKind.NamedType: var inferredNamedType = (INamedTypeSymbol)inferredType; var returnNamedType = (INamedTypeSymbol)returnType; if (inferredNamedType.TypeArguments.Length == returnNamedType.TypeArguments.Length) { for (int i = 0, n = inferredNamedType.TypeArguments.Length; i < n; i++) { DetermineTypeParameterMapping(inferredNamedType.TypeArguments[i], returnNamedType.TypeArguments[i], result); } } return; } } private IEnumerable<TypeInferenceInfo> InferTypeInAttributeArgument( int index, ImmutableArray<ImmutableArray<IParameterSymbol>> parameterizedSymbols, AttributeArgumentSyntax argumentOpt = null) { if (argumentOpt != null && argumentOpt.NameEquals != null) { // [MyAttribute(Prop = ... return InferTypeInNameEquals(argumentOpt.NameEquals, argumentOpt.NameEquals.EqualsToken); } var name = argumentOpt != null && argumentOpt.NameColon != null ? argumentOpt.NameColon.Name.Identifier.ValueText : null; return InferTypeInArgument(index, parameterizedSymbols, name, RefKind.None); } private static IEnumerable<TypeInferenceInfo> InferTypeInArgument( int index, ImmutableArray<ImmutableArray<IParameterSymbol>> parameterizedSymbols, ArgumentSyntax argumentOpt) { // Prefer parameter lists that match the original number of arguments passed. using var _1 = ArrayBuilder<ImmutableArray<IParameterSymbol>>.GetInstance(out var parameterListsWithMatchingCount); using var _2 = ArrayBuilder<ImmutableArray<IParameterSymbol>>.GetInstance(out var parameterListsWithoutMatchingCount); var argumentCount = argumentOpt?.Parent is BaseArgumentListSyntax baseArgumentList ? baseArgumentList.Arguments.Count : -1; foreach (var parameterList in parameterizedSymbols) { if (argumentCount == -1) { // don't have a known argument count. Just add this all to one of the lists. parameterListsWithMatchingCount.Add(parameterList); } else { var minParameterCount = parameterList.Count(p => !p.IsParams && !p.IsOptional); var maxParameterCount = parameterList.Any(p => p.IsParams) ? int.MaxValue : parameterList.Length; var list = argumentCount >= minParameterCount && argumentCount <= maxParameterCount ? parameterListsWithMatchingCount : parameterListsWithoutMatchingCount; list.Add(parameterList); } } var name = argumentOpt != null && argumentOpt.NameColon != null ? argumentOpt.NameColon.Name.Identifier.ValueText : null; var refKind = argumentOpt.GetRefKind(); return InferTypeInArgument(index, parameterListsWithMatchingCount.ToImmutable(), name, refKind).Concat( InferTypeInArgument(index, parameterListsWithoutMatchingCount.ToImmutable(), name, refKind)); } private static IEnumerable<TypeInferenceInfo> InferTypeInArgument( int index, ImmutableArray<ImmutableArray<IParameterSymbol>> parameterizedSymbols, string name, RefKind refKind) { // If the callsite has a named argument, then try to find a method overload that has a // parameter with that name. If we can find one, then return the type of that one. if (name != null) { var matchingNameParameters = parameterizedSymbols.SelectMany(m => m) .Where(p => p.Name == name) .Select(p => new TypeInferenceInfo(p.Type, p.IsParams)); return matchingNameParameters; } using var _1 = ArrayBuilder<TypeInferenceInfo>.GetInstance(out var allParameters); using var _2 = ArrayBuilder<TypeInferenceInfo>.GetInstance(out var matchingRefParameters); foreach (var parameterSet in parameterizedSymbols) { if (index < parameterSet.Length) { var parameter = parameterSet[index]; var info = new TypeInferenceInfo(parameter.Type, parameter.IsParams); allParameters.Add(info); if (parameter.RefKind == refKind) { matchingRefParameters.Add(info); } } } return matchingRefParameters.Count > 0 ? matchingRefParameters.ToImmutable() : allParameters.ToImmutable(); } private IEnumerable<TypeInferenceInfo> InferTypeInArrayCreationExpression( ArrayCreationExpressionSyntax arrayCreationExpression, SyntaxToken? previousToken = null) { if (previousToken.HasValue && previousToken.Value != arrayCreationExpression.NewKeyword) { // Has to follow the 'new' keyword. return []; } if (previousToken.HasValue && previousToken.Value.GetPreviousToken().Kind() == SyntaxKind.EqualsToken) { // We parsed an array creation but the token before `new` is `=`. // This could be a case like: // // int[] array; // Program p = new | // array[4] = 4; // // This is similar to the cases described in `InferTypeInObjectCreationExpression`. // Again, all we have to do is back up to before `new`. return InferTypes(previousToken.Value.SpanStart); } var outerTypes = InferTypes(arrayCreationExpression); return outerTypes.Where(o => o.InferredType is IArrayTypeSymbol); } private IEnumerable<TypeInferenceInfo> InferTypeInArrayRankSpecifier(ArrayRankSpecifierSyntax arrayRankSpecifier, SyntaxToken? previousToken = null) { // If we have a token, and it's not the open bracket or one of the commas, then no // inference. if (previousToken == arrayRankSpecifier.CloseBracketToken) return []; return CreateResult(this.Compilation.GetSpecialType(SpecialType.System_Int32)); } private IEnumerable<TypeInferenceInfo> InferTypeInArrayType(ArrayTypeSyntax arrayType, SyntaxToken? previousToken = null) { if (previousToken.HasValue) { // TODO(cyrusn): NYI. Handle this appropriately if we need to. return []; } // Bind the array type, then unwrap whatever we get back based on the number of rank // specifiers we see. var currentTypes = InferTypes(arrayType); for (var i = 0; i < arrayType.RankSpecifiers.Count; i++) { currentTypes = currentTypes.Select(t => t.InferredType).OfType<IArrayTypeSymbol>() .SelectAsArray(a => new TypeInferenceInfo(a.ElementType)); } return currentTypes; } private IEnumerable<TypeInferenceInfo> InferTypeInAttribute() => CreateResult(this.Compilation.AttributeType()); private IEnumerable<TypeInferenceInfo> InferTypeInAttributeDeclaration(AttributeListSyntax attributeDeclaration, SyntaxToken? previousToken) { // If we have a position, then it has to be after the open bracket. if (previousToken.HasValue && previousToken.Value != attributeDeclaration.OpenBracketToken) return []; return CreateResult(this.Compilation.AttributeType()); } private IEnumerable<TypeInferenceInfo> InferTypeInAttributeTargetSpecifier( AttributeTargetSpecifierSyntax attributeTargetSpecifier, SyntaxToken? previousToken) { // If we have a position, then it has to be after the colon. if (previousToken.HasValue && previousToken.Value != attributeTargetSpecifier.ColonToken) return []; return CreateResult(this.Compilation.AttributeType()); } private IEnumerable<TypeInferenceInfo> InferTypeInBracketedArgumentList(BracketedArgumentListSyntax bracketedArgumentList, SyntaxToken previousToken) { // Has to follow the [ or a , if (previousToken != bracketedArgumentList.OpenBracketToken && previousToken.Kind() != SyntaxKind.CommaToken) return []; if (bracketedArgumentList.Parent is ElementAccessExpressionSyntax elementAccess) { var index = GetArgumentListIndex(bracketedArgumentList, previousToken); return InferTypeInElementAccessExpression( elementAccess, index); } return []; } private static int GetArgumentListIndex(BaseArgumentListSyntax argumentList, SyntaxToken previousToken) { if (previousToken == argumentList.GetOpenToken()) { return 0; } //// ( node0 , node1 , node2 , node3 , // // Tokidx 0 1 2 3 4 5 6 7 // // index 1 2 3 // // index = (Tokidx + 1) / 2 var tokenIndex = argumentList.Arguments.GetWithSeparators().IndexOf(previousToken); return (tokenIndex + 1) / 2; } private static int GetArgumentListIndex(AttributeArgumentListSyntax attributeArgumentList, SyntaxToken previousToken) { if (previousToken == attributeArgumentList.OpenParenToken) { return 0; } //// ( node0 , node1 , node2 , node3 , // // Tokidx 0 1 2 3 4 5 6 7 // // index 1 2 3 // // index = (Tokidx + 1) / 2 var tokenIndex = attributeArgumentList.Arguments.GetWithSeparators().IndexOf(previousToken); return (tokenIndex + 1) / 2; } private IEnumerable<TypeInferenceInfo> InferTypeInBinaryOrAssignmentExpression(ExpressionSyntax binop, SyntaxToken operatorToken, ExpressionSyntax left, ExpressionSyntax right, ExpressionSyntax expressionOpt = null, SyntaxToken? previousToken = null) { // If we got here through a token, then it must have actually been the binary // operator's token. Contract.ThrowIfTrue(previousToken.HasValue && previousToken.Value != operatorToken); if (binop.Kind() == SyntaxKind.CoalesceExpression) { return InferTypeInCoalesceExpression((BinaryExpressionSyntax)binop, expressionOpt, previousToken); } var onRightOfToken = right == expressionOpt || previousToken.HasValue; switch (operatorToken.Kind()) { case SyntaxKind.LessThanLessThanToken: case SyntaxKind.GreaterThanGreaterThanToken: case SyntaxKind.GreaterThanGreaterThanGreaterThanToken: case SyntaxKind.LessThanLessThanEqualsToken: case SyntaxKind.GreaterThanGreaterThanEqualsToken: case SyntaxKind.GreaterThanGreaterThanGreaterThanEqualsToken: if (onRightOfToken) { // x << Goo(), x >> Goo(), x >>> Goo(), x <<= Goo(), x >>= Goo(), x >>>= Goo() return CreateResult(this.Compilation.GetSpecialType(SpecialType.System_Int32)); } break; } // Infer operands of && and || as bool regardless of the other operand. if (operatorToken.Kind() is SyntaxKind.AmpersandAmpersandToken or SyntaxKind.BarBarToken) { return CreateResult(SpecialType.System_Boolean); } // Infer type for deconstruction if (binop.Kind() == SyntaxKind.SimpleAssignmentExpression && ((AssignmentExpressionSyntax)binop).IsDeconstruction()) { return InferTypeInVariableComponentAssignment(left); } // Try to figure out what's on the other side of the binop. If we can, then just that // type. This is often a reasonable heuristics to use for most operators. NOTE(cyrusn): // we could try to bind the token to see what overloaded operators it corresponds to. // But the gain is pretty marginal IMO. var otherSide = onRightOfToken ? left : right; var otherSideTypes = GetTypes(otherSide); if (otherSideTypes.Any()) { // Don't infer delegate types except in assignments. They're unlikely to be what the // user needs and can cause lambda suggestion mode while // typing type arguments: // https://github.com/dotnet/roslyn/issues/14492 if (binop is not AssignmentExpressionSyntax) { otherSideTypes = otherSideTypes.Where(t => !t.InferredType.IsDelegateType()); } return otherSideTypes; } // For &, &=, |, |=, ^, and ^=, since we couldn't infer the type of either side, // try to infer the type of the entire binary expression. if (operatorToken.Kind() is SyntaxKind.AmpersandToken or SyntaxKind.AmpersandEqualsToken or SyntaxKind.BarToken or SyntaxKind.BarEqualsToken or SyntaxKind.CaretToken or SyntaxKind.CaretEqualsToken) { var parentTypes = InferTypes(binop); if (parentTypes.Any()) { return parentTypes; } } // If it's a plus operator, then do some smarts in case it might be a string or // delegate. if (operatorToken.Kind() == SyntaxKind.PlusToken) { // See Bug 6045. Note: we've already checked the other side of the operator. So this // is the case where the other side was also unknown. So we walk one higher and if // we get a delegate or a string type, then use that type here. var parentTypes = InferTypes(binop); if (parentTypes.Any(static parentType => parentType.InferredType.SpecialType == SpecialType.System_String || parentType.InferredType.TypeKind == TypeKind.Delegate)) { return parentTypes.Where(parentType => parentType.InferredType.SpecialType == SpecialType.System_String || parentType.InferredType.TypeKind == TypeKind.Delegate); } } // Otherwise pick some sane defaults for certain common cases. switch (operatorToken.Kind()) { case SyntaxKind.BarToken: case SyntaxKind.CaretToken: case SyntaxKind.AmpersandToken: case SyntaxKind.LessThanToken: case SyntaxKind.LessThanEqualsToken: case SyntaxKind.GreaterThanToken: case SyntaxKind.GreaterThanEqualsToken: case SyntaxKind.PlusToken: case SyntaxKind.MinusToken: case SyntaxKind.AsteriskToken: case SyntaxKind.SlashToken: case SyntaxKind.PercentToken: case SyntaxKind.CaretEqualsToken: case SyntaxKind.PlusEqualsToken: case SyntaxKind.MinusEqualsToken: case SyntaxKind.AsteriskEqualsToken: case SyntaxKind.SlashEqualsToken: case SyntaxKind.PercentEqualsToken: case SyntaxKind.LessThanLessThanToken: case SyntaxKind.GreaterThanGreaterThanToken: case SyntaxKind.GreaterThanGreaterThanGreaterThanToken: case SyntaxKind.LessThanLessThanEqualsToken: case SyntaxKind.GreaterThanGreaterThanEqualsToken: case SyntaxKind.GreaterThanGreaterThanGreaterThanEqualsToken: return CreateResult(this.Compilation.GetSpecialType(SpecialType.System_Int32)); case SyntaxKind.BarEqualsToken: case SyntaxKind.AmpersandEqualsToken: // NOTE(cyrusn): |= and &= can be used for both ints and bools However, in the // case where there isn't enough information to determine which the user wanted, // I'm just defaulting to bool based on personal preference. return CreateResult(SpecialType.System_Boolean); } return []; } private IEnumerable<TypeInferenceInfo> InferTypeInCastExpression(CastExpressionSyntax castExpression, ExpressionSyntax expressionOpt = null, SyntaxToken? previousToken = null) { if (expressionOpt != null && castExpression.Expression != expressionOpt) return []; // If we have a position, then it has to be after the close paren. if (previousToken.HasValue && previousToken.Value != castExpression.CloseParenToken) return []; return this.GetTypes(castExpression.Type); } private IEnumerable<TypeInferenceInfo> InferTypeInCatchDeclaration(CatchDeclarationSyntax catchDeclaration, SyntaxToken? previousToken = null) { // If we have a position, it has to be after "catch(" if (previousToken.HasValue && previousToken.Value != catchDeclaration.OpenParenToken) return []; return CreateResult(this.Compilation.ExceptionType()); } private IEnumerable<TypeInferenceInfo> InferTypeInCatchFilterClause(CatchFilterClauseSyntax catchFilterClause, SyntaxToken? previousToken = null) { // If we have a position, it has to be after "if (" if (previousToken.HasValue && previousToken.Value != catchFilterClause.OpenParenToken) return []; return CreateResult(SpecialType.System_Boolean); } private IEnumerable<TypeInferenceInfo> InferTypeInCoalesceExpression( BinaryExpressionSyntax coalesceExpression, ExpressionSyntax expressionOpt = null, SyntaxToken? previousToken = null) { // If we got here through a token, then it must have actually been the binary // operator's token. Contract.ThrowIfTrue(previousToken.HasValue && previousToken.Value != coalesceExpression.OperatorToken); var onRightSide = coalesceExpression.Right == expressionOpt || previousToken.HasValue; if (onRightSide) { var leftTypes = GetTypes(coalesceExpression.Left); return leftTypes.Select(x => x.InferredType.IsNullable(out var underlying) ? new TypeInferenceInfo(underlying) // nullableExpr ?? Goo() : x); // normalExpr ?? Goo() } var rightTypes = GetTypes(coalesceExpression.Right); if (!rightTypes.Any()) return CreateResult(SpecialType.System_Object, NullableAnnotation.Annotated); // Goo() ?? "" return rightTypes.Select(x => new TypeInferenceInfo(MakeNullable(x.InferredType, this.Compilation))); static ITypeSymbol MakeNullable(ITypeSymbol symbol, Compilation compilation) { if (symbol.IsErrorType()) { // We could be smart and infer this as an ErrorType?, but in the #nullable disable case we don't know if this is intended to be // a struct (where the question mark is legal) or a class (where it isn't). We'll thus avoid sticking question marks in this case. // https://github.com/dotnet/roslyn/issues/37852 tracks fixing this is a much fancier way. return symbol; } else if (symbol.OriginalDefinition.SpecialType == SpecialType.System_Nullable_T) { // We already have something nullable. Don't wrap in another nullable layer. return symbol; } else if (symbol.IsValueType) { return compilation.GetSpecialType(SpecialType.System_Nullable_T).Construct(symbol); } else if (symbol.IsReferenceType) { return symbol.WithNullableAnnotation(NullableAnnotation.Annotated); } else // it's neither a value nor reference type, so is an unconstrained generic { return symbol; } } } private IEnumerable<TypeInferenceInfo> InferTypeInConditionalAccessExpression(ConditionalAccessExpressionSyntax expression) => InferTypes(expression); private IEnumerable<TypeInferenceInfo> InferTypeInConditionalExpression(ConditionalExpressionSyntax conditional, ExpressionSyntax expressionOpt = null, SyntaxToken? previousToken = null) { if (expressionOpt != null && conditional.Condition == expressionOpt) { // Goo() ? a : b return CreateResult(SpecialType.System_Boolean); } // a ? Goo() : b // // a ? b : Goo() var inTrueClause = (conditional.WhenTrue == expressionOpt) || (previousToken == conditional.QuestionToken); var inFalseClause = (conditional.WhenFalse == expressionOpt) || (previousToken == conditional.ColonToken); var otherTypes = inTrueClause ? GetTypes(conditional.WhenFalse) : inFalseClause ? GetTypes(conditional.WhenTrue) : []; return otherTypes.IsEmpty() ? InferTypes(conditional) : otherTypes; } private IEnumerable<TypeInferenceInfo> InferTypeInDefaultExpression(DefaultExpressionSyntax defaultExpression) => InferTypes(defaultExpression); private IEnumerable<TypeInferenceInfo> InferTypeInDoStatement(DoStatementSyntax doStatement, SyntaxToken? previousToken = null) { // If we have a position, we need to be after "do { } while(" if (previousToken.HasValue && previousToken.Value != doStatement.OpenParenToken) return []; return CreateResult(SpecialType.System_Boolean); } private IEnumerable<TypeInferenceInfo> InferTypeInEqualsValueClause(EqualsValueClauseSyntax equalsValue, SyntaxToken? previousToken = null) { // If we have a position, it has to be after the = if (previousToken.HasValue && previousToken.Value != equalsValue.EqualsToken) return []; if (equalsValue?.Parent is VariableDeclaratorSyntax varDecl) return InferTypeInVariableDeclarator(varDecl); if (equalsValue?.Parent is PropertyDeclarationSyntax propertyDecl) return InferTypeInPropertyDeclaration(propertyDecl); if (equalsValue.IsParentKind(SyntaxKind.Parameter) && SemanticModel.GetDeclaredSymbol(equalsValue.Parent, CancellationToken) is IParameterSymbol parameter) { return CreateResult(parameter.Type); } return []; } private IEnumerable<TypeInferenceInfo> InferTypeInPropertyDeclaration(PropertyDeclarationSyntax propertyDeclaration) { Debug.Assert(propertyDeclaration?.Type != null, "Property type should never be null"); var typeInfo = SemanticModel.GetTypeInfo(propertyDeclaration.Type); return CreateResult(typeInfo.Type); } private IEnumerable<TypeInferenceInfo> InferTypeInExpressionStatement(SyntaxToken? previousToken = null) { // If we're position based, then that means we're after the semicolon. In this case // we don't have any sort of type to infer. if (previousToken.HasValue) return []; return CreateResult(SpecialType.System_Void); } private IEnumerable<TypeInferenceInfo> InferTypeInForEachStatement(ForEachStatementSyntax forEachStatementSyntax, ExpressionSyntax expressionOpt = null, SyntaxToken? previousToken = null) { // If we have a position, then we have to be after "foreach(... in" if (previousToken.HasValue && previousToken.Value != forEachStatementSyntax.InKeyword) return []; if (expressionOpt != null && expressionOpt != forEachStatementSyntax.Expression) return []; var enumerableType = forEachStatementSyntax.AwaitKeyword == default ? this.Compilation.GetSpecialType(SpecialType.System_Collections_Generic_IEnumerable_T) : this.Compilation.GetTypeByMetadataName(typeof(IAsyncEnumerable<>).FullName); enumerableType ??= this.Compilation.GetSpecialType(SpecialType.System_Collections_Generic_IEnumerable_T); // foreach (int v = Goo()) var variableTypes = GetTypes(forEachStatementSyntax.Type); if (!variableTypes.Any()) { return CreateResult( enumerableType .Construct(Compilation.GetSpecialType(SpecialType.System_Object))); } return variableTypes.Select(v => new TypeInferenceInfo(enumerableType.Construct(v.InferredType))); } private IEnumerable<TypeInferenceInfo> InferTypeInForStatement(ForStatementSyntax forStatement, ExpressionSyntax expressionOpt = null, SyntaxToken? previousToken = null) { // If we have a position, it has to be after "for(...;" if (previousToken.HasValue && previousToken.Value != forStatement.FirstSemicolonToken) return []; if (expressionOpt != null && forStatement.Condition != expressionOpt) return []; return CreateResult(SpecialType.System_Boolean); } private IEnumerable<TypeInferenceInfo> InferTypeInIfStatement(IfStatementSyntax ifStatement, SyntaxToken? previousToken = null) { // If we have a position, we have to be after the "if(" if (previousToken.HasValue && previousToken.Value != ifStatement.OpenParenToken) return []; return CreateResult(SpecialType.System_Boolean); } private IEnumerable<TypeInferenceInfo> InferTypeInImplicitArrayCreation(ImplicitArrayCreationExpressionSyntax implicitArray) => InferTypes(implicitArray.SpanStart); private IEnumerable<TypeInferenceInfo> InferTypeInInitializerExpression( InitializerExpressionSyntax initializerExpression, ExpressionSyntax expressionOpt = null, SyntaxToken? previousToken = null) { if (initializerExpression.IsKind(SyntaxKind.ComplexElementInitializerExpression)) { // new Dictionary<K,V> { { x, ... } } // new C { Prop = { { x, ... } } } var parameterIndex = previousToken.HasValue ? initializerExpression.Expressions.GetSeparators().ToList().IndexOf(previousToken.Value) + 1 : initializerExpression.Expressions.IndexOf(expressionOpt); var addMethodSymbols = SemanticModel.GetCollectionInitializerSymbolInfo(initializerExpression).GetAllSymbols(); var addMethodParameterTypes = addMethodSymbols .Cast<IMethodSymbol>() .Where(a => a.Parameters.Length == initializerExpression.Expressions.Count) .Select(a => new TypeInferenceInfo(a.Parameters.ElementAtOrDefault(parameterIndex)?.Type)) .Where(t => t.InferredType != null); if (addMethodParameterTypes.Any()) { return addMethodParameterTypes; } } else if (initializerExpression.IsKind(SyntaxKind.CollectionInitializerExpression)) { if (expressionOpt != null) { // new List<T> { x, ... } // new C { Prop = { x, ... } } var addMethodSymbols = SemanticModel.GetCollectionInitializerSymbolInfo(expressionOpt).GetAllSymbols(); var addMethodParameterTypes = addMethodSymbols .Cast<IMethodSymbol>() .Where(a => a.Parameters.Length == 1) .Select(a => new TypeInferenceInfo(a.Parameters[0].Type)); if (addMethodParameterTypes.Any()) { return addMethodParameterTypes; } } else { // new List<T> { x, // new C { Prop = { x, foreach (var sibling in initializerExpression.Expressions.Where(e => e.Kind() != SyntaxKind.ComplexElementInitializerExpression)) { var types = GetTypes(sibling); if (types.Any()) { return types; } } } } if (initializerExpression?.Parent is ImplicitArrayCreationExpressionSyntax implicitArray) { // new[] { 1, x } // First, try to infer the type that the array should be. If we can infer an // appropriate array type, then use the element type of the array. Otherwise, // look at the siblings of this expression and use their type instead. var arrayTypes = this.InferTypes(implicitArray); var elementTypes = arrayTypes.OfType<IArrayTypeSymbol>().Select(a => new TypeInferenceInfo(a.ElementType)).Where(IsUsableTypeFunc); if (elementTypes.Any()) { return elementTypes; } foreach (var sibling in initializerExpression.Expressions) { if (sibling != expressionOpt) { var types = GetTypes(sibling); if (types.Any()) { return types; } } } } else if (initializerExpression?.Parent is EqualsValueClauseSyntax equalsValueClause) { // = { Goo() } var types = InferTypeInEqualsValueClause(equalsValueClause).Select(t => t.InferredType); if (types.Any(t => t is IArrayTypeSymbol)) { return types.OfType<IArrayTypeSymbol>().Select(t => new TypeInferenceInfo(t.ElementType)); } } else if (initializerExpression?.Parent is ArrayCreationExpressionSyntax arrayCreation) { // new int[] { Goo() } var types = GetTypes(arrayCreation).Select(t => t.InferredType); if (types.Any(t => t is IArrayTypeSymbol)) { return types.OfType<IArrayTypeSymbol>().Select(t => new TypeInferenceInfo(t.ElementType)); } } else if (initializerExpression?.Parent is ObjectCreationExpressionSyntax objectCreation) { // new List<T> { Goo() } var types = GetTypes(objectCreation).Select(t => t.InferredType); if (types.Any(t => t is INamedTypeSymbol)) { return types.OfType<INamedTypeSymbol>().SelectMany(t => GetCollectionElementType(t)); } } else if (initializerExpression.IsParentKind(SyntaxKind.SimpleAssignmentExpression)) { // new Goo { a = { Goo() } } if (expressionOpt != null) { var addMethodSymbols = SemanticModel.GetCollectionInitializerSymbolInfo(expressionOpt).GetAllSymbols(); var addMethodParameterTypes = addMethodSymbols.Select(m => ((IMethodSymbol)m).Parameters[0]).Select(p => new TypeInferenceInfo(p.Type)); if (addMethodParameterTypes.Any()) { return addMethodParameterTypes; } } var assignExpression = (AssignmentExpressionSyntax)initializerExpression.Parent; var types = GetTypes(assignExpression.Left).Select(t => t.InferredType); if (types.Any(t => t is INamedTypeSymbol)) { // new Goo { a = { Goo() } } var parameterIndex = previousToken.HasValue ? initializerExpression.Expressions.GetSeparators().ToList().IndexOf(previousToken.Value) + 1 : initializerExpression.Expressions.IndexOf(expressionOpt); return types.OfType<INamedTypeSymbol>().SelectMany(t => GetCollectionElementType(t)); } } return []; } private IEnumerable<TypeInferenceInfo> InferTypeInRecursivePattern(RecursivePatternSyntax recursivePattern) { var type = this.SemanticModel.GetTypeInfo(recursivePattern).ConvertedType; return CreateResult(type); } private IEnumerable<TypeInferenceInfo> InferTypeInConstantPattern( ConstantPatternSyntax constantPattern) { return InferTypes(constantPattern); } private IEnumerable<TypeInferenceInfo> InferTypeInPropertyPatternClause( PropertyPatternClauseSyntax propertySubpattern) { return InferTypes(propertySubpattern); } private IEnumerable<TypeInferenceInfo> InferTypeInSubpattern( SubpatternSyntax subpattern, SyntaxNode child) { // we have { X: ... }. The type of ... is whatever the type of 'X' is in its // parent type. So look up the parent type first, then find the X member in it // and use that type. if (child == subpattern.Pattern && subpattern.ExpressionColon != null) { using var result = TemporaryArray<TypeInferenceInfo>.Empty; foreach (var symbol in this.SemanticModel.GetSymbolInfo(subpattern.ExpressionColon.Expression).GetAllSymbols()) { switch (symbol) { case IFieldSymbol field: result.Add(new TypeInferenceInfo(field.Type)); break; case IPropertySymbol property: result.Add(new TypeInferenceInfo(property.Type)); break; } } return result.ToImmutableAndClear(); } return []; } private IEnumerable<TypeInferenceInfo> InferTypeForSingleVariableDesignation(SingleVariableDesignationSyntax singleVariableDesignation) { if (singleVariableDesignation.Parent is DeclarationPatternSyntax declarationPattern) { // c is Color.Red or $$ // "or" is not parsed as part of a BinaryPattern until the right hand side // is written. By making sure, the identifier // is "or" or "and", we can assume a BinaryPattern is upcoming. var identifier = singleVariableDesignation.Identifier; if (identifier.HasMatchingText(SyntaxKind.OrKeyword) || identifier.HasMatchingText(SyntaxKind.AndKeyword)) { return GetPatternTypes(declarationPattern); } } return []; } private IEnumerable<TypeInferenceInfo> InferTypeInIsPatternExpression( IsPatternExpressionSyntax isPatternExpression, SyntaxNode child) { if (child == isPatternExpression.Expression) { return GetPatternTypes(isPatternExpression.Pattern); } else if (child == isPatternExpression.Pattern) { return GetTypes(isPatternExpression.Expression); } return []; } private IEnumerable<TypeInferenceInfo> GetPatternTypes(PatternSyntax pattern) { return pattern switch { ConstantPatternSyntax constantPattern => GetTypes(constantPattern.Expression), RecursivePatternSyntax recursivePattern => GetTypesForRecursivePattern(recursivePattern), _ when SemanticModel.GetOperation(pattern, CancellationToken) is IPatternOperation patternOperation => // In cases like this: c is Color.Green or $$ // "pattern" is a DeclarationPatternSyntax and Color.Green is assumed to be the narrowed type. // If the narrowed type can not be resolved, we fall back to the input type of the pattern, which // is a good default for any related case. CreateResult(patternOperation.NarrowedType.IsErrorType() ? patternOperation.InputType : patternOperation.NarrowedType), _ => [], }; } private IEnumerable<TypeInferenceInfo> GetTypesForRecursivePattern(RecursivePatternSyntax recursivePattern) { // if it's of the for "X (...)" then just infer 'X' as the type. if (recursivePattern.Type != null) { var typeInfo = SemanticModel.GetTypeInfo(recursivePattern); return CreateResult(typeInfo.GetConvertedTypeWithAnnotatedNullability()); } // If it's of the form (...) then infer that the type should be a // tuple, whose elements are inferred from the individual patterns // in the deconstruction. var positionalPart = recursivePattern.PositionalPatternClause; if (positionalPart != null) { var subPatternCount = positionalPart.Subpatterns.Count; if (subPatternCount >= 2) { // infer a tuple type for this deconstruction. var elementTypesBuilder = ArrayBuilder<ITypeSymbol>.GetInstance(subPatternCount); var elementNamesBuilder = ArrayBuilder<string>.GetInstance(subPatternCount); foreach (var subPattern in positionalPart.Subpatterns) { elementNamesBuilder.Add(subPattern.NameColon?.Name.Identifier.ValueText); var patternType = GetPatternTypes(subPattern.Pattern).FirstOrDefault(); if (patternType.InferredType == null) return []; elementTypesBuilder.Add(patternType.InferredType); } // Pass the nullable annotations explicitly to work around https://github.com/dotnet/roslyn/issues/40105 var elementTypes = elementTypesBuilder.ToImmutableAndFree(); var type = Compilation.CreateTupleTypeSymbol( elementTypes, elementNamesBuilder.ToImmutableAndFree(), elementNullableAnnotations: GetNullableAnnotations(elementTypes)); return CreateResult(type); } } return []; } private static ImmutableArray<NullableAnnotation> GetNullableAnnotations(ImmutableArray<ITypeSymbol> elementTypes) => elementTypes.SelectAsArray(e => e.NullableAnnotation); private IEnumerable<TypeInferenceInfo> InferTypeInLockStatement(LockStatementSyntax lockStatement, SyntaxToken? previousToken = null) { // If we're position based, then we have to be after the "lock(" if (previousToken.HasValue && previousToken.Value != lockStatement.OpenParenToken) return []; return CreateResult(SpecialType.System_Object); } private IEnumerable<TypeInferenceInfo> InferTypeInLambdaExpression(LambdaExpressionSyntax lambdaExpression, SyntaxToken? previousToken = null) { // If we have a position, it has to be after the lambda arrow. if (previousToken.HasValue && previousToken.Value != lambdaExpression.ArrowToken) return []; return InferTypeInAnonymousFunctionExpression(lambdaExpression); } private IEnumerable<TypeInferenceInfo> InferTypeInAnonymousFunctionExpression(AnonymousFunctionExpressionSyntax anonymousFunction) { // Func<int,string> = i => Goo(); // Func<int,string> = delegate (int i) { return Goo(); }; var types = InferTypes(anonymousFunction); var type = types.FirstOrDefault().InferredType.GetDelegateType(this.Compilation); if (type != null) { var invoke = type.DelegateInvokeMethod; if (invoke != null) { var isAsync = anonymousFunction.AsyncKeyword.Kind() != SyntaxKind.None; return [new TypeInferenceInfo(UnwrapTaskLike(invoke.ReturnType, isAsync))]; } } return []; } private IEnumerable<TypeInferenceInfo> InferTypeInMemberDeclarator(AnonymousObjectMemberDeclaratorSyntax memberDeclarator, SyntaxToken? previousTokenOpt = null) { if (memberDeclarator.NameEquals != null && memberDeclarator.Parent is AnonymousObjectCreationExpressionSyntax) { // If we're position based, then we have to be after the = if (previousTokenOpt.HasValue && previousTokenOpt.Value != memberDeclarator.NameEquals.EqualsToken) return []; var types = InferTypes((AnonymousObjectCreationExpressionSyntax)memberDeclarator.Parent); return types.Where(t => t.InferredType.IsAnonymousType()) .SelectMany(t => t.InferredType.GetValidAnonymousTypeProperties() .Where(p => p.Name == memberDeclarator.NameEquals.Name.Identifier.ValueText) .Select(p => new TypeInferenceInfo(p.Type))); } return []; } private IEnumerable<TypeInferenceInfo> InferTypeInNameColon(NameColonSyntax nameColon, SyntaxToken previousToken) { if (previousToken != nameColon.ColonToken) { // Must follow the colon token. return []; } return nameColon.Parent switch { ArgumentSyntax argumentSyntax => InferTypeInArgument(argumentSyntax), SubpatternSyntax subPattern => InferTypeInSubpattern(subPattern, subPattern.Pattern), _ => [], }; } private IEnumerable<TypeInferenceInfo> InferTypeInExpressionColon(ExpressionColonSyntax expressionColon, SyntaxToken previousToken) { if (previousToken != expressionColon.ColonToken) { // Must follow the colon token. return []; } return expressionColon.Parent switch { SubpatternSyntax subPattern => InferTypeInSubpattern(subPattern, subPattern.Pattern), _ => [], }; } private IEnumerable<TypeInferenceInfo> InferTypeInMemberAccessExpression( MemberAccessExpressionSyntax memberAccessExpression, ExpressionSyntax expressionOpt = null, SyntaxToken? previousToken = null) { // We need to be on the right of the dot to infer an appropriate type for // the member access expression. i.e. if we have "Goo.Bar" then we can // def infer what the type of 'Bar' should be (it's whatever type we infer // for 'Goo.Bar' itself. However, if we're on 'Goo' then we can't figure // out anything about its type. if (previousToken != null) { if (previousToken.Value != memberAccessExpression.OperatorToken) return []; // We're right after the dot in "Goo.Bar". The type for "Bar" should be // whatever type we'd infer for "Goo.Bar" itself. return InferTypes(memberAccessExpression); } else { Debug.Assert(expressionOpt != null); if (expressionOpt == memberAccessExpression.Expression) { return InferTypeForExpressionOfMemberAccessExpression(memberAccessExpression); } // We're right after the dot in "Goo.Bar". The type for "Bar" should be // whatever type we'd infer for "Goo.Bar" itself. return InferTypes(memberAccessExpression); } } private IEnumerable<TypeInferenceInfo> InferTypeForExpressionOfMemberAccessExpression( MemberAccessExpressionSyntax memberAccessExpression) { // If we're on the left side of a dot, it's possible in a few cases // to figure out what type we should be. Specifically, if we have // // await goo.ConfigureAwait() // // then we can figure out what 'goo' should be based on teh await // context. var name = memberAccessExpression.Name.Identifier.Value; if (name.Equals(nameof(Task<>.ConfigureAwait)) && memberAccessExpression?.Parent is InvocationExpressionSyntax invocation && memberAccessExpression.Parent.IsParentKind(SyntaxKind.AwaitExpression)) { return InferTypes(invocation); } else if (name.Equals(nameof(Task<>.ContinueWith))) { // goo.ContinueWith(...) // We want to infer Task<T>. For now, we'll just do Task<object>, // in the future it would be nice to figure out the actual result // type based on the argument to ContinueWith. var taskOfT = this.Compilation.TaskOfTType(); if (taskOfT != null) { return CreateResult(taskOfT.Construct(this.Compilation.ObjectType)); } } else if (name.Equals(nameof(Enumerable.Select)) || name.Equals(nameof(Enumerable.Where))) { var ienumerableType = this.Compilation.IEnumerableOfTType(); // goo.Select // We want to infer IEnumerable<T>. We can try to figure out what // T if we get a delegate as the first argument to Select/Where. if (ienumerableType != null && memberAccessExpression.IsParentKind(SyntaxKind.InvocationExpression, out invocation)) { if (invocation.ArgumentList.Arguments.Count > 0) { var argumentExpression = invocation.ArgumentList.Arguments[0].Expression; if (argumentExpression != null) { var argumentTypes = GetTypes(argumentExpression); var delegateType = argumentTypes.FirstOrDefault().InferredType.GetDelegateType(this.Compilation); var typeArg = delegateType?.TypeArguments.Length > 0 ? delegateType.TypeArguments[0] : this.Compilation.ObjectType; if (IsUnusableType(typeArg) && argumentExpression is LambdaExpressionSyntax lambdaExpression) { typeArg = InferTypeForFirstParameterOfLambda(lambdaExpression) ?? this.Compilation.ObjectType; } return CreateResult(ienumerableType.Construct(typeArg)); } } } } return []; } private ITypeSymbol InferTypeForFirstParameterOfLambda( LambdaExpressionSyntax lambdaExpression) { if (lambdaExpression is ParenthesizedLambdaExpressionSyntax parenLambda) { return InferTypeForFirstParameterOfParenthesizedLambda(parenLambda); } else if (lambdaExpression is SimpleLambdaExpressionSyntax simpleLambda) { return InferTypeForFirstParameterOfSimpleLambda(simpleLambda); } return null; } private ITypeSymbol InferTypeForFirstParameterOfParenthesizedLambda( ParenthesizedLambdaExpressionSyntax lambdaExpression) { return lambdaExpression.ParameterList.Parameters.Count == 0 ? null : InferTypeForFirstParameterOfLambda( lambdaExpression, lambdaExpression.ParameterList.Parameters[0]); } private ITypeSymbol InferTypeForFirstParameterOfSimpleLambda( SimpleLambdaExpressionSyntax lambdaExpression) { return InferTypeForFirstParameterOfLambda( lambdaExpression, lambdaExpression.Parameter); } private ITypeSymbol InferTypeForFirstParameterOfLambda( LambdaExpressionSyntax lambdaExpression, ParameterSyntax parameter) { return InferTypeForFirstParameterOfLambda( parameter.Identifier.ValueText, lambdaExpression.Body); } private ITypeSymbol InferTypeForFirstParameterOfLambda( string parameterName, SyntaxNode node) { if (node is IdentifierNameSyntax identifierName) { if (identifierName.Identifier.ValueText.Equals(parameterName) && SemanticModel.GetSymbolInfo(identifierName.Identifier).Symbol?.Kind == SymbolKind.Parameter) { return InferTypes(identifierName).FirstOrDefault().InferredType; } } else { foreach (var child in node.ChildNodesAndTokens()) { if (child.IsNode) { var type = InferTypeForFirstParameterOfLambda(parameterName, child.AsNode()); if (type != null) { return type; } } } } return null; } private IEnumerable<TypeInferenceInfo> InferTypeInNameColon(NameColonSyntax nameColon) { if (nameColon.Parent is SubpatternSyntax subpattern) { return GetPatternTypes(subpattern.Pattern); } return []; } private IEnumerable<TypeInferenceInfo> InferTypeInExpressionColon(ExpressionColonSyntax expressionColon) { if (expressionColon.Parent is SubpatternSyntax subpattern) { return GetPatternTypes(subpattern.Pattern); } return []; } private IEnumerable<TypeInferenceInfo> InferTypeInNameEquals(NameEqualsSyntax nameEquals, SyntaxToken? previousToken = null) { if (previousToken == nameEquals.EqualsToken) { // we're on the right of the equals. Try to bind the left name to see if it // gives us anything useful. return GetTypes(nameEquals.Name); } if (nameEquals.Parent is AttributeArgumentSyntax attributeArgumentSyntax) { var argumentExpression = attributeArgumentSyntax.Expression; return this.GetTypes(argumentExpression); } return []; } private IEnumerable<TypeInferenceInfo> InferTypeInPostfixUnaryExpression(PostfixUnaryExpressionSyntax postfixUnaryExpressionSyntax, SyntaxToken? previousToken = null) { // If we're after a postfix ++ or -- then we can't infer anything. if (previousToken.HasValue) return []; switch (postfixUnaryExpressionSyntax.Kind()) { case SyntaxKind.PostDecrementExpression: case SyntaxKind.PostIncrementExpression: return CreateResult(this.Compilation.GetSpecialType(SpecialType.System_Int32)); } return []; } private IEnumerable<TypeInferenceInfo> InferTypeInPrefixUnaryExpression(PrefixUnaryExpressionSyntax prefixUnaryExpression, SyntaxToken? previousToken = null) { // If we have a position, then we must be after the prefix token. Contract.ThrowIfTrue(previousToken.HasValue && previousToken.Value != prefixUnaryExpression.OperatorToken); switch (prefixUnaryExpression.Kind()) { case SyntaxKind.PreDecrementExpression: case SyntaxKind.PreIncrementExpression: case SyntaxKind.UnaryPlusExpression: case SyntaxKind.UnaryMinusExpression: // ++, --, +Goo(), -Goo(); return CreateResult(this.Compilation.GetSpecialType(SpecialType.System_Int32)); case SyntaxKind.BitwiseNotExpression: // ~Goo() var types = InferTypes(prefixUnaryExpression); if (!types.Any()) { return CreateResult(this.Compilation.GetSpecialType(SpecialType.System_Int32)); } else { return types; } case SyntaxKind.LogicalNotExpression: // !Goo() return CreateResult(SpecialType.System_Boolean); case SyntaxKind.AddressOfExpression: return InferTypeInAddressOfExpression(prefixUnaryExpression); } return []; } private IEnumerable<TypeInferenceInfo> InferTypeInAddressOfExpression(PrefixUnaryExpressionSyntax prefixUnaryExpression) { foreach (var inferredType in InferTypes(prefixUnaryExpression)) { if (inferredType.InferredType is IPointerTypeSymbol pointerType) { // If the code is `int* x = &...` then we want to infer `int` for `...` yield return new TypeInferenceInfo(pointerType.PointedAtType); } else if (inferredType.InferredType is IFunctionPointerTypeSymbol functionPointerType) { // If the code is `delegate*<int, void> x = &...` then we want to infer a signature of `void // M(int)` here (which we encode as Action/Func as necessary). Higher layers (like // generate-method), then can figure out what to do with that signature. yield return new TypeInferenceInfo(functionPointerType.Signature.ConvertToType(this.Compilation)); } } } private IEnumerable<TypeInferenceInfo> InferTypeInAwaitExpression(AwaitExpressionSyntax awaitExpression, SyntaxToken? previousToken = null) { // If we have a position, then we must be after the prefix token. Contract.ThrowIfTrue(previousToken.HasValue && previousToken.Value != awaitExpression.AwaitKeyword); // await <expression> var types = InferTypes(awaitExpression); var task = this.Compilation.TaskType(); var taskOfT = this.Compilation.TaskOfTType(); if (task == null || taskOfT == null) return []; if (!types.Any()) { return CreateResult(task); } return types.Select(t => t.InferredType.SpecialType == SpecialType.System_Void ? new TypeInferenceInfo(task) : new TypeInferenceInfo(taskOfT.Construct(t.InferredType))); } private IEnumerable<TypeInferenceInfo> InferTypeInYieldStatement(YieldStatementSyntax yieldStatement, SyntaxToken? previousToken = null) { // If we are position based, then we have to be after the return keyword if (previousToken.HasValue && (previousToken.Value != yieldStatement.ReturnOrBreakKeyword || yieldStatement.ReturnOrBreakKeyword.IsKind(SyntaxKind.BreakKeyword))) return []; var declaration = yieldStatement.FirstAncestorOrSelf<SyntaxNode>(n => n.IsReturnableConstruct()); var memberSymbol = GetDeclaredMemberSymbolFromOriginalSemanticModel(declaration); var memberType = memberSymbol.GetMemberType(); // We don't care what the type is, as long as it has 1 type argument. This will work for IEnumerable, IEnumerator, // IAsyncEnumerable, IAsyncEnumerator and it's also good for error recovery in case there is a missing using. return memberType is INamedTypeSymbol namedType && namedType.TypeArguments.Length == 1 ? [new TypeInferenceInfo(namedType.TypeArguments[0])] : []; } private IEnumerable<TypeInferenceInfo> InferTypeInRefExpression(RefExpressionSyntax refExpression) => InferTypes(refExpression); private ITypeSymbol UnwrapTaskLike(ITypeSymbol type, bool isAsync) { if (isAsync) { if (type.OriginalDefinition.Equals(this.Compilation.TaskOfTType()) || type.OriginalDefinition.Equals(this.Compilation.ValueTaskOfTType())) { var namedTypeSymbol = (INamedTypeSymbol)type; return namedTypeSymbol.TypeArguments[0]; } if (type.OriginalDefinition.Equals(this.Compilation.TaskType()) || type.OriginalDefinition.Equals(this.Compilation.ValueTaskType())) { return this.Compilation.GetSpecialType(SpecialType.System_Void); } } return type; } private IEnumerable<TypeInferenceInfo> InferTypeForReturnStatement( ReturnStatementSyntax returnStatement, SyntaxToken? previousToken = null) { // If we are position based, then we have to be after the return statement. if (previousToken.HasValue && previousToken.Value != returnStatement.ReturnKeyword) return []; var ancestor = returnStatement.FirstAncestorOrSelf<SyntaxNode>(n => n.IsReturnableConstruct()); return ancestor is AnonymousFunctionExpressionSyntax anonymousFunction ? InferTypeInAnonymousFunctionExpression(anonymousFunction) : InferTypeInMethodLikeDeclaration(ancestor); } private IEnumerable<TypeInferenceInfo> InferTypeInArrowExpressionClause(ArrowExpressionClauseSyntax arrowClause) => InferTypeInMethodLikeDeclaration(arrowClause.Parent); private IEnumerable<TypeInferenceInfo> InferTypeInMethodLikeDeclaration(SyntaxNode declaration) { // `declaration` can be a base-method member, property, accessor or local function var symbol = GetDeclaredMemberSymbolFromOriginalSemanticModel(declaration); var type = symbol.GetMemberType(); var isAsync = symbol is IMethodSymbol methodSymbol && methodSymbol.IsAsync; return type != null ? [new TypeInferenceInfo(UnwrapTaskLike(type, isAsync))] : []; } private ISymbol GetDeclaredMemberSymbolFromOriginalSemanticModel(SyntaxNode declarationInCurrentTree) { var currentSemanticModel = SemanticModel; var originalSemanticModel = currentSemanticModel.GetOriginalSemanticModel(); if (declarationInCurrentTree is MemberDeclarationSyntax && currentSemanticModel.IsSpeculativeSemanticModel) { var tokenInOriginalTree = originalSemanticModel.SyntaxTree.GetRoot(CancellationToken).FindToken(currentSemanticModel.OriginalPositionForSpeculation); var declaration = tokenInOriginalTree.GetAncestor<MemberDeclarationSyntax>(); return originalSemanticModel.GetDeclaredSymbol(declaration, CancellationToken); } return declarationInCurrentTree != null ? currentSemanticModel.GetDeclaredSymbol(declarationInCurrentTree, CancellationToken) : null; } private IEnumerable<TypeInferenceInfo> InferTypeInSwitchExpressionArm( SwitchExpressionArmSyntax arm) { if (arm.Parent is SwitchExpressionSyntax switchExpression) { // see if we can figure out an appropriate type from a prior/next arm. var armIndex = switchExpression.Arms.IndexOf(arm); if (armIndex > 0) { var previousArm = switchExpression.Arms[armIndex - 1]; var priorArmTypes = GetTypes(previousArm.Expression, objectAsDefault: false); if (priorArmTypes.Any()) return priorArmTypes; } if (armIndex < switchExpression.Arms.Count - 1) { var nextArm = switchExpression.Arms[armIndex + 1]; var priorArmTypes = GetTypes(nextArm.Expression, objectAsDefault: false); if (priorArmTypes.Any()) return priorArmTypes; } // if a prior arm gave us nothing useful, or we're the first arm, then try to infer looking at // what type gets inferred for the switch expression itself. return InferTypes(switchExpression); } return []; } private IEnumerable<TypeInferenceInfo> InferTypeInSwitchExpression(SwitchExpressionSyntax switchExpression, SyntaxToken token) { if (token.Kind() is SyntaxKind.OpenBraceToken or SyntaxKind.CommaToken) return GetTypes(switchExpression.GoverningExpression); return []; } private IEnumerable<TypeInferenceInfo> InferTypeInSwitchLabel( SwitchLabelSyntax switchLabel, SyntaxToken? previousToken = null) { if (previousToken.HasValue) { if (previousToken.Value != switchLabel.Keyword || switchLabel.Kind() != SyntaxKind.CaseSwitchLabel) { return []; } } var switchStatement = (SwitchStatementSyntax)switchLabel.Parent.Parent; return GetTypes(switchStatement.Expression); } private IEnumerable<TypeInferenceInfo> InferTypeInSwitchStatement( SwitchStatementSyntax switchStatement, SyntaxToken? previousToken = null) { // If we have a position, then it has to be after "switch(" if (previousToken.HasValue && previousToken.Value != switchStatement.OpenParenToken) return []; // Use the first case label to determine the return type. if (switchStatement.Sections.SelectMany(ss => ss.Labels) .FirstOrDefault(label => label.Kind() == SyntaxKind.CaseSwitchLabel) is CaseSwitchLabelSyntax firstCase) { var result = GetTypes(firstCase.Value); if (result.Any()) { return result; } } return CreateResult(this.Compilation.GetSpecialType(SpecialType.System_Int32)); } private IEnumerable<TypeInferenceInfo> InferTypeInThrowExpression(ThrowExpressionSyntax throwExpression, SyntaxToken? previousToken = null) { // If we have a position, it has to be after the 'throw' keyword. if (previousToken.HasValue && previousToken.Value != throwExpression.ThrowKeyword) return []; return CreateResult(this.Compilation.ExceptionType()); } private IEnumerable<TypeInferenceInfo> InferTypeInThrowStatement(ThrowStatementSyntax throwStatement, SyntaxToken? previousToken = null) { // If we have a position, it has to be after the 'throw' keyword. if (previousToken.HasValue && previousToken.Value != throwStatement.ThrowKeyword) return []; return CreateResult(this.Compilation.ExceptionType()); } private IEnumerable<TypeInferenceInfo> InferTypeInUsingStatement(UsingStatementSyntax usingStatement, SyntaxToken? previousToken = null) { // If we have a position, it has to be after "using(" if (previousToken.HasValue && previousToken.Value != usingStatement.OpenParenToken) return []; return CreateResult(SpecialType.System_IDisposable); } private IEnumerable<TypeInferenceInfo> InferTypeInVariableDeclarator(VariableDeclaratorSyntax variableDeclarator) { var variableType = variableDeclarator.GetVariableType(); if (variableType == null) return []; var symbol = SemanticModel.GetDeclaredSymbol(variableDeclarator); if (symbol == null) return []; var type = symbol.GetSymbolType(); var types = CreateResult(type).Where(IsUsableTypeFunc); if (!variableType.IsVar || variableDeclarator.Parent is not VariableDeclarationSyntax variableDeclaration) { return types; } // using (var v = Goo()) if (variableDeclaration.IsParentKind(SyntaxKind.UsingStatement)) return CreateResult(SpecialType.System_IDisposable); // for (var v = Goo(); .. if (variableDeclaration.IsParentKind(SyntaxKind.ForStatement)) return CreateResult(this.Compilation.GetSpecialType(SpecialType.System_Int32)); var laterUsageInference = InferTypeBasedOnLaterUsage(symbol, variableDeclaration); if (laterUsageInference is not [] and not [{ InferredType.SpecialType: SpecialType.System_Object }]) return laterUsageInference; // Return the types here if they actually bound to a type called 'var'. return types.Where(t => t.InferredType.Name == "var"); } private ImmutableArray<TypeInferenceInfo> InferTypeBasedOnLaterUsage(ISymbol symbol, SyntaxNode afterNode) { // var v = expr. // Attempt to see how 'v' is used later in the current scope to determine what to do. var container = afterNode.AncestorsAndSelf().FirstOrDefault(a => a is BlockSyntax or SwitchSectionSyntax); if (container != null) { foreach (var descendant in container.DescendantNodesAndSelf().OfType<IdentifierNameSyntax>()) { // only look after the variable we're declaring. if (descendant.SpanStart <= afterNode.Span.End) continue; if (descendant.Identifier.ValueText != symbol.Name) continue; // Make sure it's actually a match for this variable. var descendantSymbol = SemanticModel.GetSymbolInfo(descendant, CancellationToken).GetAnySymbol(); if (symbol.Equals(descendantSymbol)) { // See if we can infer something interesting about this location. var inferredDescendantTypes = InferTypes(descendant, filterUnusable: true); if (inferredDescendantTypes is not [] and not [{ InferredType.SpecialType: SpecialType.System_Object }]) return inferredDescendantTypes; } } } return []; } private IEnumerable<TypeInferenceInfo> InferTypeInVariableComponentAssignment(ExpressionSyntax left) { if (left is DeclarationExpressionSyntax declExpr) { // var (x, y) = Expr(); // Attempt to determine what x and y are based on their future usage. if (declExpr.Type.IsVar && declExpr.Designation is ParenthesizedVariableDesignationSyntax parenthesizedVariableDesignation && parenthesizedVariableDesignation.Variables.All(v => v is SingleVariableDesignationSyntax { Identifier.ValueText: not "" })) { var elementNames = parenthesizedVariableDesignation.Variables.SelectAsArray(v => ((SingleVariableDesignationSyntax)v).Identifier.ValueText); var elementTypes = parenthesizedVariableDesignation.Variables.SelectAsArray(v => { var designation = (SingleVariableDesignationSyntax)v; var symbol = SemanticModel.GetRequiredDeclaredSymbol(designation, CancellationToken); var inferredFutureUsage = InferTypeBasedOnLaterUsage(symbol, afterNode: left.Parent); return inferredFutureUsage.Length > 0 ? inferredFutureUsage[0].InferredType : Compilation.ObjectType; }); return [new TypeInferenceInfo( Compilation.CreateTupleTypeSymbol(elementTypes, elementNames))]; } return GetTypes(declExpr.Type); } else if (left is TupleExpressionSyntax tupleExpression) { // We have something of the form: // (int a, int b) = ... // // This is a deconstruction, and a decent deconstructable type we can infer here is // ValueTuple<int,int>. var tupleType = GetTupleType(tupleExpression); if (tupleType != null) return CreateResult(tupleType); } return []; } private ITypeSymbol GetTupleType( TupleExpressionSyntax tuple) { if (!TryGetTupleTypesAndNames(tuple.Arguments, out var elementTypes, out var elementNames)) { return null; } // Pass the nullable annotations explicitly to work around https://github.com/dotnet/roslyn/issues/40105 return Compilation.CreateTupleTypeSymbol(elementTypes, elementNames, elementNullableAnnotations: GetNullableAnnotations(elementTypes)); } private bool TryGetTupleTypesAndNames( SeparatedSyntaxList<ArgumentSyntax> arguments, out ImmutableArray<ITypeSymbol> elementTypes, out ImmutableArray<string> elementNames) { elementTypes = default; elementNames = default; using var _1 = ArrayBuilder<ITypeSymbol>.GetInstance(out var elementTypesBuilder); using var _2 = ArrayBuilder<string>.GetInstance(out var elementNamesBuilder); foreach (var arg in arguments) { var expr = arg.Expression; if (expr is DeclarationExpressionSyntax declExpr) { AddTypeAndName(declExpr, elementTypesBuilder, elementNamesBuilder); } else if (expr is TupleExpressionSyntax tupleExpr) { AddTypeAndName(tupleExpr, elementTypesBuilder, elementNamesBuilder); } else if (expr is IdentifierNameSyntax name) { elementNamesBuilder.Add(name.Identifier.ValueText == "" ? null : name.Identifier.ValueText); elementTypesBuilder.Add(GetTypes(expr).FirstOrDefault().InferredType ?? this.Compilation.ObjectType); } else { return false; } } if (elementTypesBuilder.Contains(null) || elementTypesBuilder.Count != arguments.Count) { return false; } elementTypes = elementTypesBuilder.ToImmutable(); elementNames = elementNamesBuilder.ToImmutable(); return true; } private void AddTypeAndName( DeclarationExpressionSyntax declaration, ArrayBuilder<ITypeSymbol> elementTypesBuilder, ArrayBuilder<string> elementNamesBuilder) { elementTypesBuilder.Add(GetTypes(declaration.Type).FirstOrDefault().InferredType); var designation = declaration.Designation; if (designation is SingleVariableDesignationSyntax singleVariable) { var name = singleVariable.Identifier.ValueText; if (name != string.Empty) { elementNamesBuilder.Add(name); return; } } elementNamesBuilder.Add(null); } private void AddTypeAndName( TupleExpressionSyntax tuple, ArrayBuilder<ITypeSymbol> elementTypesBuilder, ArrayBuilder<string> elementNamesBuilder) { var tupleType = GetTupleType(tuple); elementTypesBuilder.Add(tupleType); elementNamesBuilder.Add(null); } private IEnumerable<TypeInferenceInfo> InferTypeInWhenClause(WhenClauseSyntax whenClause, SyntaxToken? previousToken = null) { // If we have a position, we have to be after the "when" if (previousToken.HasValue && previousToken.Value != whenClause.WhenKeyword) return []; return [new TypeInferenceInfo(Compilation.GetSpecialType(SpecialType.System_Boolean))]; } private IEnumerable<TypeInferenceInfo> InferTypeInWhileStatement(WhileStatementSyntax whileStatement, SyntaxToken? previousToken = null) { // If we're position based, then we have to be after the "while(" if (previousToken.HasValue && previousToken.Value != whileStatement.OpenParenToken) return []; return CreateResult(SpecialType.System_Boolean); } private IEnumerable<TypeInferenceInfo> InferTypeInRelationalPattern(RelationalPatternSyntax relationalPattern) => InferTypes(relationalPattern); } } |