File: BindingSourceGenerator.cs
Web Access
Project: src\src\Controls\src\BindingSourceGen\Controls.BindingSourceGen.csproj (Microsoft.Maui.Controls.BindingSourceGen)
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
 
namespace Microsoft.Maui.Controls.BindingSourceGen;
 
public class TrackingNames
{
	public const string BindingsWithDiagnostics = nameof(BindingsWithDiagnostics);
	public const string Bindings = nameof(Bindings);
}
 
[Generator(LanguageNames.CSharp)]
public class BindingSourceGenerator : IIncrementalGenerator
{
	public void Initialize(IncrementalGeneratorInitializationContext context)
	{
		var bindingsWithDiagnostics = context.SyntaxProvider.CreateSyntaxProvider(
			predicate: static (node, _) => IsSetBindingMethod(node) || IsCreateMethod(node),
			transform: static (ctx, t) => GetBindingForGeneration(ctx, t)
		)
		.WithTrackingName(TrackingNames.BindingsWithDiagnostics);
 
		context.RegisterSourceOutput(bindingsWithDiagnostics, (spc, bindingWithDiagnostic) =>
		{
			foreach (var diagnostic in bindingWithDiagnostic.Diagnostics)
			{
				spc.ReportDiagnostic(Diagnostic.Create(diagnostic.Descriptor, diagnostic.Location?.ToLocation()));
			}
		});
 
		var bindings = bindingsWithDiagnostics
			.Where(static binding => !binding.HasDiagnostics)
			.Select(static (binding, t) => binding.Value)
			.WithTrackingName(TrackingNames.Bindings);
 
		context.RegisterPostInitializationOutput(spc =>
		{
			spc.AddSource("GeneratedBindingInterceptorsCommon.g.cs", BindingCodeWriter.GenerateCommonCode());
		});
 
		context.RegisterImplementationSourceOutput(bindings, (spc, binding) =>
		{
			var fileName = $"{binding.Location.FilePath}-GeneratedBindingInterceptors-{binding.Location.Line}-{binding.Location.Column}.g.cs";
			var sanitizedFileName = fileName.Replace('/', '-').Replace('\\', '-').Replace(':', '-');
			var code = BindingCodeWriter.GenerateBinding(binding, (uint)Math.Abs(binding.Location.GetHashCode()));
			spc.AddSource(sanitizedFileName, code);
		});
	}
 
	private static bool IsSetBindingMethod(SyntaxNode node)
	{
		return node is InvocationExpressionSyntax invocation
			&& invocation.Expression is MemberAccessExpressionSyntax method
			&& method.Name.Identifier.Text == "SetBinding"
			&& invocation.ArgumentList.Arguments.Count >= 2
			&& invocation.ArgumentList.Arguments[1].Expression is not LiteralExpressionSyntax
			&& invocation.ArgumentList.Arguments[1].Expression is not ObjectCreationExpressionSyntax;
	}
 
	private static bool IsCreateMethod(SyntaxNode node)
	{
		return node is InvocationExpressionSyntax invocation
			&& invocation.Expression is MemberAccessExpressionSyntax method
			&& method.Name.Identifier.Text == "Create"
			&& invocation.ArgumentList.Arguments.Count >= 1
			&& invocation.ArgumentList.Arguments[0].Expression is not LiteralExpressionSyntax
			&& invocation.ArgumentList.Arguments[0].Expression is not ObjectCreationExpressionSyntax;
	}
 
	private static Result<BindingInvocationDescription> GetBindingForGeneration(GeneratorSyntaxContext context, CancellationToken t)
	{
		var enabledNullable = IsNullableContextEnabled(context);
 
		var invocation = (InvocationExpressionSyntax)context.Node;
		var method = (MemberAccessExpressionSyntax)invocation.Expression;
 
		var invocationParser = new InvocationParser(context);
		var interceptedMethodTypeResult = invocationParser.ParseInvocation(invocation, t);
 
		if (interceptedMethodTypeResult.HasDiagnostics)
		{
			return Result<BindingInvocationDescription>.Failure(interceptedMethodTypeResult.Diagnostics);
		}
 
		var sourceCodeLocation = SourceCodeLocation.CreateFrom(method.Name.GetLocation());
		if (sourceCodeLocation == null)
		{
			return Result<BindingInvocationDescription>.Failure(DiagnosticsFactory.UnableToResolvePath(invocation.GetLocation()));
		}
 
		var lambdaResult = GetLambda(invocation, interceptedMethodTypeResult.Value);
		if (lambdaResult.HasDiagnostics)
		{
			return Result<BindingInvocationDescription>.Failure(lambdaResult.Diagnostics);
		}
 
		var lambdaParamTypeResult = GetLambdaParameterType(lambdaResult.Value, context.SemanticModel, t);
		if (lambdaParamTypeResult.HasDiagnostics)
		{
			return Result<BindingInvocationDescription>.Failure(lambdaParamTypeResult.Diagnostics);
		}
 
		var lambdaReturnTypeResult = GetLambdaReturnType(lambdaResult.Value, context.SemanticModel, t);
		if (lambdaReturnTypeResult.HasDiagnostics)
		{
			return Result<BindingInvocationDescription>.Failure(lambdaReturnTypeResult.Diagnostics);
		}
 
		var pathParser = new PathParser(context, enabledNullable);
		var pathParseResult = pathParser.ParsePath(lambdaResult.Value.ExpressionBody);
		if (pathParseResult.HasDiagnostics)
		{
			return Result<BindingInvocationDescription>.Failure(pathParseResult.Diagnostics);
		}
 
		var binding = new BindingInvocationDescription(
			Location: sourceCodeLocation.ToInterceptorLocation(),
			SourceType: lambdaParamTypeResult.Value.CreateTypeDescription(enabledNullable),
			PropertyType: lambdaReturnTypeResult.Value.CreateTypeDescription(enabledNullable),
			Path: new EquatableArray<IPathPart>([.. pathParseResult.Value]),
			SetterOptions: DeriveSetterOptions(lambdaResult.Value.ExpressionBody, context.SemanticModel, enabledNullable),
			NullableContextEnabled: enabledNullable,
			MethodType: interceptedMethodTypeResult.Value);
		return Result<BindingInvocationDescription>.Success(binding);
	}
 
	private static bool IsNullableContextEnabled(GeneratorSyntaxContext context)
	{
		NullableContext nullableContext = context.SemanticModel.GetNullableContext(context.Node.Span.Start);
		return (nullableContext & NullableContext.Enabled) == NullableContext.Enabled;
	}
 
	private static Result<LambdaExpressionSyntax> GetLambda(InvocationExpressionSyntax invocation, InterceptedMethodType methodType)
	{
		var argumentList = invocation.ArgumentList.Arguments;
		var expression = methodType switch
		{
			InterceptedMethodType.SetBinding => argumentList[1].Expression,
			InterceptedMethodType.Create => argumentList[0].Expression,
			_ => throw new NotSupportedException()
		};
 
		if (expression is not LambdaExpressionSyntax lambda)
		{
			return Result<LambdaExpressionSyntax>.Failure(DiagnosticsFactory.GetterIsNotLambda(expression.GetLocation()));
		}
 
		// We only support static lambdas
		if (!lambda.Modifiers.Any(SyntaxKind.StaticKeyword))
		{
			return Result<LambdaExpressionSyntax>.Failure(DiagnosticsFactory.LambdaIsNotStatic(lambda.GetLocation()));
		}
 
		return Result<LambdaExpressionSyntax>.Success(lambda);
	}
 
	private static Result<ITypeSymbol> GetLambdaReturnType(LambdaExpressionSyntax lambda, SemanticModel semanticModel, CancellationToken t)
	{
		if (lambda.Body is not ExpressionSyntax lambdaBody)
		{
			return Result<ITypeSymbol>.Failure(DiagnosticsFactory.GetterLambdaBodyIsNotExpression(lambda.Body.GetLocation()));
		}
 
		var lambdaResultType = semanticModel.GetTypeInfo(lambdaBody, t).Type;
		if (lambdaResultType == null || lambdaResultType is IErrorTypeSymbol)
		{
			return Result<ITypeSymbol>.Failure(DiagnosticsFactory.LambdaResultCannotBeResolved(lambdaBody.GetLocation()));
		}
 
		return Result<ITypeSymbol>.Success(lambdaResultType);
	}
 
	private static Result<ITypeSymbol> GetLambdaParameterType(LambdaExpressionSyntax lambda, SemanticModel semanticModel, CancellationToken t)
	{
		if (semanticModel.GetSymbolInfo(lambda, t).Symbol is not IMethodSymbol lambdaSymbol)
		{
			return Result<ITypeSymbol>.Failure(DiagnosticsFactory.GetterIsNotLambda(lambda.GetLocation()));
		}
 
		var parameters = lambdaSymbol.Parameters;
		if (parameters.Length == 0 || parameters[0].Type is IErrorTypeSymbol)
		{
			return Result<ITypeSymbol>.Failure(DiagnosticsFactory.LambdaParameterCannotBeResolved(lambda.GetLocation()));
		}
 
		var lambdaParamType = parameters[0].Type;
		if (!lambdaParamType.IsAccessible())
		{
			return Result<ITypeSymbol>.Failure(DiagnosticsFactory.UnaccessibleTypeUsedAsLambdaParameter(lambda.GetLocation()));
		}
 
		return Result<ITypeSymbol>.Success(lambdaParamType);
	}
 
	private static SetterOptions DeriveSetterOptions(ExpressionSyntax? lambdaBodyExpression, SemanticModel semanticModel, bool enabledNullable)
	{
		if (lambdaBodyExpression is null)
		{
			return new SetterOptions(IsWritable: false, AcceptsNullValue: false);
		}
		else if (lambdaBodyExpression is IdentifierNameSyntax identifier)
		{
			var symbol = semanticModel.GetSymbolInfo(identifier).Symbol;
			return new SetterOptions(IsWritable(symbol), AcceptsNullValue(symbol, enabledNullable));
		}
		else if (lambdaBodyExpression is ElementAccessExpressionSyntax elementAccess)
		{
			var symbol = semanticModel.GetSymbolInfo(elementAccess).Symbol;
			return new SetterOptions(IsWritable(symbol), AcceptsNullValue(symbol, enabledNullable));
		}
		else if (lambdaBodyExpression is ElementBindingExpressionSyntax elementBinding)
		{
			var symbol = semanticModel.GetSymbolInfo(elementBinding).Symbol;
			return new SetterOptions(IsWritable(symbol), AcceptsNullValue(symbol, enabledNullable));
		}
 
		var nestedExpression = lambdaBodyExpression switch
		{
			MemberAccessExpressionSyntax memberAccess => memberAccess.Name,
			ConditionalAccessExpressionSyntax conditionalAccess => conditionalAccess.WhenNotNull,
			MemberBindingExpressionSyntax memberBinding => memberBinding.Name,
			BinaryExpressionSyntax binary when binary.Kind() == SyntaxKind.AsExpression => binary.Left,
			CastExpressionSyntax cast => cast.Expression,
			ParenthesizedExpressionSyntax parenthesized => parenthesized.Expression,
			_ => null,
		};
 
		return DeriveSetterOptions(nestedExpression, semanticModel, enabledNullable);
 
		static bool IsWritable(ISymbol? symbol)
			=> symbol switch
			{
				IPropertySymbol propertySymbol => propertySymbol.SetMethod != null,
				IFieldSymbol fieldSymbol => !fieldSymbol.IsReadOnly,
				_ => true,
			};
 
		static bool AcceptsNullValue(ISymbol? symbol, bool enabledNullable)
			=> symbol switch
			{
				IPropertySymbol propertySymbol => propertySymbol.Type.IsTypeNullable(enabledNullable),
				IFieldSymbol fieldSymbol => fieldSymbol.Type.IsTypeNullable(enabledNullable),
				_ => false,
			};
	}
}