|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
#nullable disable
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
namespace BoundTreeGenerator
{
internal enum TargetLanguage
{
VB,
CSharp
}
internal enum NullHandling
{
Allow,
Disallow,
Always,
NotApplicable // for value types
}
internal sealed class BoundNodeClassWriter
{
private readonly TextWriter _writer;
private readonly Tree _tree;
private readonly Dictionary<string, string> _typeMap;
private HashSet<string> _valueTypes;
private readonly TargetLanguage _targetLang;
private BoundNodeClassWriter(TextWriter writer, Tree tree, TargetLanguage targetLang)
{
_writer = writer;
_tree = tree;
_targetLang = targetLang;
_typeMap = tree.Types.Where(t => t is not (EnumType or ValueType)).ToDictionary(n => n.Name, n => n.Base);
_typeMap.Add(tree.Root, null);
InitializeValueTypes();
}
private void InitializeValueTypes()
{
_valueTypes = new HashSet<string>();
foreach (ValueType t in _tree.Types.Where(t => t is ValueType))
_valueTypes.Add(t.Name);
switch (_targetLang)
{
case TargetLanguage.CSharp:
_valueTypes.Add("bool");
_valueTypes.Add("int");
_valueTypes.Add("uint");
_valueTypes.Add("short");
_valueTypes.Add("ushort");
_valueTypes.Add("long");
_valueTypes.Add("ulong");
_valueTypes.Add("byte");
_valueTypes.Add("sbyte");
_valueTypes.Add("char");
_valueTypes.Add("Boolean");
break;
case TargetLanguage.VB:
_valueTypes.Add("Boolean");
_valueTypes.Add("Integer");
_valueTypes.Add("UInteger");
_valueTypes.Add("Short");
_valueTypes.Add("UShort");
_valueTypes.Add("Long");
_valueTypes.Add("ULong");
_valueTypes.Add("Byte");
_valueTypes.Add("SByte");
_valueTypes.Add("Char");
break;
}
_valueTypes.Add("Int8");
_valueTypes.Add("Int16");
_valueTypes.Add("Int32");
_valueTypes.Add("Int64");
_valueTypes.Add("UInt8");
_valueTypes.Add("UInt16");
_valueTypes.Add("UInt32");
_valueTypes.Add("UInt64");
_valueTypes.Add("ImmutableArray");
_valueTypes.Add("PropertyAccessKind");
_valueTypes.Add("TypeWithAnnotations");
_valueTypes.Add("BitVector");
}
public static void Write(TextWriter writer, Tree tree, TargetLanguage targetLang)
{
new BoundNodeClassWriter(writer, tree, targetLang).WriteFile();
}
private int _indent;
private bool _needsIndent = true;
private void Write(string format, params object[] args)
{
if (_needsIndent)
{
_writer.Write(new string(' ', _indent * 4));
_needsIndent = false;
}
_writer.Write(format, args);
}
private void WriteLine(string format, params object[] args)
{
Write(format, args);
_writer.WriteLine();
_needsIndent = true;
}
private void Blank()
{
_writer.WriteLine();
_needsIndent = true;
}
private void Brace()
{
WriteLine("{{");
Indent();
}
private void Unbrace()
{
Outdent();
WriteLine("}}");
}
private void Indent()
{
++_indent;
}
private void Outdent()
{
--_indent;
}
private void WriteFile()
{
switch (_targetLang)
{
case TargetLanguage.CSharp:
WriteLine("// <auto-generated />");
WriteLine("#nullable enable");
break;
case TargetLanguage.VB:
WriteLine("' <auto-generated />"); break;
default:
throw new ArgumentException("Unexpected target language", nameof(_targetLang));
}
Blank();
WriteUsing("System");
WriteUsing("System.Collections");
WriteUsing("System.Collections.Generic");
WriteUsing("System.Collections.Immutable");
WriteUsing("System.Diagnostics");
WriteUsing("System.Linq");
WriteUsing("System.Runtime.CompilerServices");
WriteUsing("System.Text");
WriteUsing("System.Threading");
WriteUsing("Microsoft.CodeAnalysis.Collections");
if (_targetLang == TargetLanguage.CSharp)
{
WriteUsing("Microsoft.CodeAnalysis.CSharp.Symbols");
WriteUsing("Microsoft.CodeAnalysis.CSharp.Syntax");
}
WriteUsing("Microsoft.CodeAnalysis.Text");
if (_targetLang == TargetLanguage.VB)
{
WriteUsing("Microsoft.CodeAnalysis.VisualBasic.Symbols");
WriteUsing("Microsoft.CodeAnalysis.VisualBasic.Syntax");
}
WriteUsing("Roslyn.Utilities");
Blank();
WriteStartNamespace();
WriteKinds();
WriteTypes();
WriteVisitor();
WriteWalker();
WriteRewriter();
WriteNullabilityRewriter();
WriteTreeDumperNodeProducer();
WriteEndNamespace();
}
private void WriteUsing(string nsName)
{
switch (_targetLang)
{
case TargetLanguage.CSharp:
WriteLine("using {0};", nsName); break;
case TargetLanguage.VB:
WriteLine("Imports {0}", nsName); break;
default:
throw new ArgumentException("Unexpected target language", nameof(_targetLang));
}
}
private void WriteStartNamespace()
{
switch (_targetLang)
{
case TargetLanguage.CSharp:
WriteLine("namespace Microsoft.CodeAnalysis.CSharp");
Brace();
break;
case TargetLanguage.VB:
WriteLine("Namespace Microsoft.CodeAnalysis.VisualBasic");
Indent();
break;
default:
throw new ArgumentException("Unexpected target language", nameof(_targetLang));
}
}
private void WriteEndNamespace()
{
switch (_targetLang)
{
case TargetLanguage.CSharp:
Unbrace();
break;
case TargetLanguage.VB:
Outdent();
WriteLine("End Namespace");
break;
default:
throw new ArgumentException("Unexpected target language", nameof(_targetLang));
}
}
private void WriteKinds()
{
switch (_targetLang)
{
case TargetLanguage.CSharp:
WriteLine("internal enum BoundKind : byte");
Brace();
foreach (var node in _tree.Types.OfType<Node>())
WriteLine("{0},", FixKeyword(StripBound(node.Name)));
Unbrace();
break;
case TargetLanguage.VB:
WriteLine("Friend Enum BoundKind as Byte");
Indent();
foreach (var node in _tree.Types.OfType<Node>())
WriteLine("{0}", FixKeyword(StripBound(node.Name)));
Outdent();
WriteLine("End Enum");
break;
default:
throw new ArgumentException("Unexpected target language", nameof(_targetLang));
}
}
private void WriteTypes()
{
foreach (var node in _tree.Types.Where(n => n is not PredefinedNode))
{
Blank();
WriteType(node);
}
}
private bool CanBeSealed(TreeType node)
{
// Is this type the base type of anything?
return !_typeMap.Values.Contains(node.Name);
}
private void WriteClassHeader(TreeType node)
{
switch (_targetLang)
{
case TargetLanguage.CSharp:
{
string abstr = "";
if (node is AbstractNode and not Node)
abstr = "abstract ";
else if (CanBeSealed(node))
abstr = "sealed ";
WriteLine("internal {2}partial class {0} : {1}", node.Name, node.Base, abstr);
Brace();
break;
}
case TargetLanguage.VB:
{
string abstr = "";
if (node is AbstractNode and not Node)
abstr = "MustInherit ";
else if (CanBeSealed(node))
abstr = "NotInheritable ";
WriteLine("Partial Friend {1}Class {0}", node.Name, abstr);
Indent();
WriteLine("Inherits {0}", node.Base);
Blank();
break;
}
default:
throw new ArgumentException("Unexpected target language", nameof(_targetLang));
}
}
private void WriteClassFooter()
{
switch (_targetLang)
{
case TargetLanguage.CSharp:
Unbrace();
break;
case TargetLanguage.VB:
Outdent();
WriteLine("End Class");
break;
default:
throw new ArgumentException("Unexpected target language", nameof(_targetLang));
}
}
private void Paren()
{
Write("(");
}
private void UnParen()
{
Write(")");
}
private void SeparatedList<T>(string separator, IEnumerable<T> items, Func<T, string> func)
{
var first = true;
foreach (T item in items)
{
if (!first)
_writer.Write(separator);
first = false;
_writer.Write(func(item));
}
}
private void Comma<T>(IEnumerable<T> items, Func<T, string> func)
{
SeparatedList(", ", items, func);
}
private void Or<T>(IEnumerable<T> items, Func<T, string> func)
{
switch (_targetLang)
{
case TargetLanguage.CSharp:
SeparatedList(" || ", items, func);
break;
case TargetLanguage.VB:
SeparatedList(" OrElse ", items, func);
break;
default:
throw new ArgumentException("Unexpected target language", nameof(_targetLang));
}
}
private void ParenList<T>(IEnumerable<T> items, Func<T, string> func)
{
Paren();
Comma(items, func);
UnParen();
}
private void ParenList(IEnumerable<string> items)
{
Paren();
Comma(items, x => x);
UnParen();
}
private void WriteConstructor(TreeType node, bool isPublic, bool hasChildNodes)
{
if (hasChildNodes)
{
WriteConstructorWithHasErrors(node, isPublic, hasErrorsIsOptional: true);
}
else
{
WriteConstructorWithHasErrors(node, isPublic, hasErrorsIsOptional: false);
WriteConstructorWithoutHasErrors(node, isPublic);
}
}
private void WriteConstructorWithHasErrors(TreeType node, bool isPublic, bool hasErrorsIsOptional)
{
switch (_targetLang)
{
case TargetLanguage.CSharp:
{
// A public constructor does not have an explicit kind parameter.
Write("{0} {1}", isPublic ? "public" : "protected", node.Name);
IEnumerable<string> fields = isPublic ? ["SyntaxNode syntax"] : ["BoundKind kind", "SyntaxNode syntax"];
fields = fields.Concat(from field in AllSpecifiableFields(node)
let mostSpecific = GetField(node, field.Name)
select mostSpecific.Type + " " + ToCamelCase(field.Name));
if (hasErrorsIsOptional)
fields = fields.Concat(new[] { "bool hasErrors = false" });
else
fields = fields.Concat(new[] { "bool hasErrors" });
ParenList(fields, x => x);
Blank();
Indent();
Write(": base(");
if (isPublic)
{
// Base call has bound kind, syntax, all fields in base type, plus merged HasErrors.
Write(string.Format("BoundKind.{0}", StripBound(node.Name)));
Write(", syntax, ");
foreach (Field baseField in AllSpecifiableFields(BaseType(node)))
Write("{0}, ", FieldNullHandling(node, baseField.Name) == NullHandling.Always ? "null" : ToCamelCase(baseField.Name));
Or((new[] { "hasErrors" })
.Concat(from field in AllNodeOrNodeListFields(node)
select ToCamelCase(field.Name) + ".HasErrors()"), x => x);
}
else
{
// Base call has kind, syntax, and hasErrors. No merging of hasErrors because derived class already did the merge.
Write("kind, syntax, ");
foreach (Field baseField in AllSpecifiableFields(BaseType(node)))
Write("{0}, ", FieldNullHandling(node, baseField.Name) == NullHandling.Always ? "null" : ToCamelCase(baseField.Name));
Write("hasErrors");
}
Write(")");
Blank();
Outdent();
Brace();
WriteNullChecks(node);
foreach (var field in Fields(node))
{
WriteLine("this.{0} = {1};", field.Name, FieldNullHandling(node, field.Name) == NullHandling.Always ? "null" : ToCamelCase(field.Name));
}
bool hasValidate = HasValidate(node);
if (hasValidate)
{
WriteLine("Validate();");
}
Unbrace();
Blank();
if (hasValidate)
{
WriteLine(@"[Conditional(""DEBUG"")]");
WriteLine("private partial void Validate();");
Blank();
}
break;
}
case TargetLanguage.VB:
{
// A public constructor does not have an explicit kind parameter.
Write("{0} {1}", isPublic ? "Public" : "Protected", "Sub New");
IEnumerable<string> fields = isPublic ? ["syntax As SyntaxNode"] : ["kind As BoundKind", "syntax as SyntaxNode"];
fields = fields.Concat(from field in AllSpecifiableFields(node)
select ToCamelCase(field.Name) + " As " + field.Type);
if (hasErrorsIsOptional)
fields = fields.Concat(new[] { "Optional hasErrors As Boolean = False" });
else
fields = fields.Concat(new[] { "hasErrors As Boolean" });
ParenList(fields, x => x);
Blank();
Indent();
Write("MyBase.New(");
if (isPublic)
{
// Base call has bound kind, syntax, all fields in base type, plus merged HasErrors.
Write(string.Format("BoundKind.{0}", StripBound(node.Name)));
Write(", syntax, ");
foreach (Field baseField in AllSpecifiableFields(BaseType(node)))
Write("{0}, ", FieldNullHandling(node, baseField.Name) == NullHandling.Always ? "Nothing" : ToCamelCase(baseField.Name));
Or((new[] { "hasErrors" })
.Concat(from field in AllNodeOrNodeListFields(node)
select ToCamelCase(field.Name) + ".NonNullAndHasErrors()"), x => x);
}
else
{
// Base call has kind, syntax, and hasErrors. No merging of hasErrors because derived class already did the merge.
Write("kind, syntax, ");
foreach (Field baseField in AllSpecifiableFields(BaseType(node)))
Write("{0}, ", FieldNullHandling(node, baseField.Name) == NullHandling.Always ? "Nothing" : ToCamelCase(baseField.Name));
Write("hasErrors");
}
Write(")");
Blank();
WriteNullChecks(node);
foreach (var field in Fields(node))
WriteLine("Me._{0} = {1}", field.Name, FieldNullHandling(node, field.Name) == NullHandling.Always ? "Nothing" : ToCamelCase(field.Name));
bool hasValidate = HasValidate(node);
if (hasValidate)
{
Blank();
WriteLine("Validate()");
}
Outdent();
WriteLine("End Sub");
Blank();
if (hasValidate)
{
WriteLine("Private Partial Sub Validate()");
WriteLine("End Sub");
Blank();
}
break;
}
default:
throw new ArgumentException("Unexpected target language", nameof(_targetLang));
}
}
// This constructor should only be created if no node or list fields, since it just calls base class constructor
// without merging hasErrors.
private void WriteConstructorWithoutHasErrors(TreeType node, bool isPublic)
{
switch (_targetLang)
{
case TargetLanguage.CSharp:
{
// A public constructor does not have an explicit kind parameter.
Write("{0} {1}", isPublic ? "public" : "protected", node.Name);
IEnumerable<string> fields = isPublic ? ["SyntaxNode syntax"] : ["BoundKind kind", "SyntaxNode syntax"];
fields = fields.Concat(from field in AllSpecifiableFields(node)
let mostSpecific = GetField(node, field.Name)
select mostSpecific.Type + " " + ToCamelCase(field.Name));
ParenList(fields, x => x);
Blank();
Indent();
Write(": base(");
if (isPublic)
{
// Base call has bound kind, syntax, fields.
Write(string.Format("BoundKind.{0}", StripBound(node.Name)));
Write(", syntax");
foreach (Field baseField in AllSpecifiableFields(BaseType(node)))
Write(", {0}", FieldNullHandling(node, baseField.Name) == NullHandling.Always ? "null" : ToCamelCase(baseField.Name));
}
else
{
// Base call has kind, syntax, fields
Write("kind, syntax");
foreach (Field baseField in AllSpecifiableFields(BaseType(node)))
Write(", {0}", FieldNullHandling(node, baseField.Name) == NullHandling.Always ? "null" : ToCamelCase(baseField.Name));
}
Write(")");
Blank();
Outdent();
Brace();
WriteNullChecks(node);
foreach (var field in Fields(node))
{
WriteLine("this.{0} = {1};", field.Name, FieldNullHandling(node, field.Name) == NullHandling.Always ? "null" : ToCamelCase(field.Name));
}
Unbrace();
Blank();
break;
}
case TargetLanguage.VB:
{
// A public constructor does not have an explicit kind parameter.
Write("{0} {1}", isPublic ? "Public" : "Protected", "Sub New");
IEnumerable<string> fields = isPublic ? ["syntax As SyntaxNode"] : ["kind As BoundKind", "syntax as SyntaxNode"];
fields = fields.Concat(from field in AllSpecifiableFields(node)
select ToCamelCase(field.Name) + " As " + field.Type);
ParenList(fields, x => x);
Blank();
Indent();
Write("MyBase.New(");
if (isPublic)
{
// Base call has bound kind, syntax, fields.
Write(string.Format("BoundKind.{0}", StripBound(node.Name)));
Write(", syntax");
foreach (Field baseField in AllSpecifiableFields(BaseType(node)))
Write(", {0}", FieldNullHandling(node, baseField.Name) == NullHandling.Always ? "Nothing" : ToCamelCase(baseField.Name));
}
else
{
// Base call has kind, syntax, fields
Write("kind, syntax");
foreach (Field baseField in AllSpecifiableFields(BaseType(node)))
Write(", {0}", FieldNullHandling(node, baseField.Name) == NullHandling.Always ? "Nothing" : ToCamelCase(baseField.Name));
}
Write(")");
Blank();
WriteNullChecks(node);
foreach (var field in Fields(node))
WriteLine("Me._{0} = {1}", field.Name, FieldNullHandling(node, field.Name) == NullHandling.Always ? "Nothing" : ToCamelCase(field.Name));
if (HasValidate(node))
{
Blank();
WriteLine("Validate()");
}
Outdent();
WriteLine("End Sub");
Blank();
break;
}
default:
throw new ArgumentException("Unexpected target language", nameof(_targetLang));
}
}
// Write the null checks for any fields that can't be null.
private void WriteNullChecks(TreeType node)
{
IEnumerable<Field> nullCheckFields = AllFields(node).Where(f => FieldNullHandling(node, f.Name) == NullHandling.Disallow);
if (nullCheckFields.Any())
{
Blank();
foreach (Field field in nullCheckFields)
{
bool useIsDefaultProperty = GetGenericType(field.Type) is "ImmutableArray" or "OneOrMany";
switch (_targetLang)
{
case TargetLanguage.CSharp:
if (useIsDefaultProperty)
WriteLine("RoslynDebug.Assert(!{0}.IsDefault, \"Field '{0}' cannot be null (use Null=\\\"allow\\\" in BoundNodes.xml to remove this check)\");", ToCamelCase(field.Name));
else
WriteLine("RoslynDebug.Assert({0} is object, \"Field '{0}' cannot be null (make the type nullable in BoundNodes.xml to remove this check)\");", ToCamelCase(field.Name));
break;
case TargetLanguage.VB:
if (useIsDefaultProperty)
WriteLine("Debug.Assert(Not ({0}.IsDefault), \"Field '{0}' cannot be null (use Null=\"\"allow\"\" in BoundNodes.xml to remove this check)\")", ToCamelCase(field.Name));
else
WriteLine("Debug.Assert({0} IsNot Nothing, \"Field '{0}' cannot be null (use Null=\"\"allow\"\" in BoundNodes.xml to remove this check)\")", ToCamelCase(field.Name));
break;
}
}
Blank();
}
}
private static IEnumerable<Field> Fields(TreeType node)
{
if (node is AbstractNode aNode)
return from n in aNode.Fields where !n.Override select n;
return Enumerable.Empty<Field>();
}
private static IEnumerable<Field> FieldsIncludingOverrides(TreeType node)
{
if (node is AbstractNode aNode)
return aNode.Fields;
return Enumerable.Empty<Field>();
}
private Field GetMostDerivedField(Node node, string fieldName)
{
foreach (var type in TypeAndBaseTypes(node))
{
if (FieldsIncludingOverrides(type).SingleOrDefault(f => f.Name == fieldName) is { } field)
{
return field;
}
}
return null;
}
private TreeType BaseType(TreeType node)
{
string name = _typeMap[node.Name];
if (name == _tree.Root)
return null;
return _tree.Types.Single(t => t.Name == name);
}
private static bool HasValidate(TreeType node)
{
return node.HasValidate != null && string.Compare(node.HasValidate, "true", true) == 0;
}
private IEnumerable<TreeType> TypeAndBaseTypes(TreeType node)
{
var n = node;
while (n != null)
{
yield return n;
n = BaseType(n);
}
}
private IEnumerable<Field> AllFields(TreeType node)
{
if (node == null)
return Enumerable.Empty<Field>();
return from t in TypeAndBaseTypes(node)
from f in Fields(t)
select f;
}
// Specifiable fields are those that aren't always null.
private IEnumerable<Field> AllSpecifiableFields(TreeType node)
{
return from f in AllFields(node) where FieldNullHandling(node, f.Name) != NullHandling.Always select f;
}
private IEnumerable<Field> AllNodeOrNodeListFields(TreeType node)
{
return AllFields(node).Where(field => IsDerivedOrListOfDerived("BoundNode", field.Type));
}
private IEnumerable<Field> AllTypeFields(TreeType node)
{
return AllFields(node).Where(field => TypeIsTypeSymbol(field));
}
private IEnumerable<Field> AllSymbolOrSymbolListFields(TreeType node)
{
return AllFields(node).Where(field => TypeIsSymbol(field) || (IsImmutableArray(field.Type, out var elementType) && TypeIsSymbol(elementType)));
}
private NullHandling FieldNullHandling(TreeType node, string fieldName)
{
Field f = GetField(node, fieldName);
if (f.Null != null)
{
if (_targetLang == TargetLanguage.CSharp && (f.Null.ToUpperInvariant() is ("ALLOW" or "ALWAYS")) && !f.Type.EndsWith('?') && !IsValueType(f.Type))
{
throw new ArgumentException($"Field '{fieldName}' on node '{node.Name}' should have a nullable type, since it isn't a value type and it is marked null=allow or null=always");
}
switch (f.Null.ToUpperInvariant())
{
case "ALLOW":
return NullHandling.Allow;
case "DISALLOW":
return NullHandling.Disallow;
case "ALWAYS":
return NullHandling.Always;
case "NOTAPPLICABLE":
return NullHandling.NotApplicable;
case "":
break;
default:
throw new ArgumentException("Unexpected value", nameof(f.Null));
}
}
if (f.Type.EndsWith('?'))
{
return NullHandling.Allow;
}
if (f.Override)
return FieldNullHandling(BaseType(node), fieldName);
else if (!IsValueType(f.Type) || GetGenericType(f.Type) == "ImmutableArray")
return NullHandling.Disallow; // default is to disallow nulls.
else
return NullHandling.NotApplicable; // value types can't check nulls.
}
private Field GetField(TreeType node, string fieldName)
{
var fieldsWithName = from f in FieldsIncludingOverrides(node) where f.Name == fieldName select f;
if (fieldsWithName.Any())
return fieldsWithName.Single();
else if (BaseType(node) != null)
return GetField(BaseType(node), fieldName);
else
throw new InvalidOperationException($"Field {fieldName} not found in type {node.Name}");
}
private void WriteField(TreeType node, Field field)
{
switch (_targetLang)
{
case TargetLanguage.CSharp:
if (IsPropertyOverrides(field))
{
WriteLine("public override {0}{1} {2} {{ get; }}", (IsNew(field) ? "new " : ""), field.Type, field.Name);
}
else if (field.Override)
{
// We emit a suppression here because the bound nodes use a pattern which is safe, but can't be tracked.
var suppression = FieldNullHandling(node, field.Name) is (NullHandling.Allow or NullHandling.Always) ? "" : "!";
// The point of overriding a property is to change its nullability, usually
// from nullable to non-nullable. The base is annotated as nullable,
// but since the base property is always set via the base call in the
// constructor, as long as the parameter to the current class's constructor is not
// nullable, the base property is always non-null.
WriteLine($"public new {field.Type} {field.Name} => base.{field.Name}{suppression};");
}
else
{
WriteLine("public {0}{1} {2} {{ get; }}", (IsNew(field) ? "new " : ""), field.Type, field.Name);
}
break;
case TargetLanguage.VB:
Blank();
WriteLine("Private {0}ReadOnly _{2} As {1}", (IsNew(field) ? "Shadows " : ""), field.Type, field.Name);
WriteLine("Public {0}ReadOnly Property {2} As {1}", (IsNew(field) ? "Shadows " : IsPropertyOverrides(field) ? "Overrides " : ""), field.Type, field.Name);
Indent();
WriteLine("Get");
Indent();
WriteLine("Return _{0}", field.Name);
Outdent();
WriteLine("End Get");
Outdent();
WriteLine("End Property");
break;
default:
throw new ArgumentException("Unexpected target language", nameof(_targetLang));
}
}
private void WriteAccept(string name)
{
switch (_targetLang)
{
case TargetLanguage.CSharp:
Blank();
WriteLine("[DebuggerStepThrough]");
WriteLine("public override BoundNode? Accept(BoundTreeVisitor visitor) => visitor.Visit{0}(this);", StripBound(name));
break;
case TargetLanguage.VB:
Blank();
WriteLine("<DebuggerStepThrough>");
WriteLine("Public Overrides Function Accept(visitor as BoundTreeVisitor) As BoundNode");
Indent();
WriteLine("Return visitor.Visit{0}(Me)", StripBound(name));
Outdent();
WriteLine("End Function");
break;
default:
throw new ArgumentException("Unexpected target language", nameof(_targetLang));
}
}
private void WriteType(TreeType node)
{
if (node is not AbstractNode)
return;
WriteClassHeader(node);
bool unsealed = !CanBeSealed(node);
bool concrete = node is Node;
bool hasChildNodes = AllNodeOrNodeListFields(node).Any();
if (unsealed)
{
WriteConstructor(node, isPublic: false, hasChildNodes);
}
if (concrete)
{
WriteConstructor(node, isPublic: true, hasChildNodes);
}
// Only C# can express nullable reference types
foreach (var field in (_targetLang == TargetLanguage.CSharp ? FieldsIncludingOverrides(node) : Fields(node)))
{
WriteField(node, field);
}
if (node is Node)
{
WriteAccept(node.Name);
WriteUpdateMethod(node as Node);
}
WriteClassFooter();
}
private void WriteUpdateMethod(Node node)
{
if (!AllFields(node).Any())
return;
bool emitNew = (!Fields(node).Any()) && BaseType(node) is Node;
switch (_targetLang)
{
case TargetLanguage.CSharp:
WriteUpdatedMethodCSharp(node, emitNew);
break;
case TargetLanguage.VB:
WriteUpdatedMethodVB(node, emitNew);
break;
default:
throw new ArgumentException("Unexpected target language", nameof(_targetLang));
}
}
private void WriteUpdatedMethodCSharp(Node node, bool emitNew)
{
Blank();
Write("public{1} {0} Update", node.Name, emitNew ? " new" : "");
Paren();
Comma(AllSpecifiableFields(node), field => string.Format("{0} {1}", GetField(node, field.Name).Type, ToCamelCase(field.Name)));
UnParen();
Blank();
Brace();
if (AllSpecifiableFields(node).Any())
{
Write("if ");
Paren();
Or(AllSpecifiableFields(node), notEquals);
UnParen();
Blank();
Brace();
Write("var result = new {0}", node.Name);
var fields = new[] { "this.Syntax" }.Concat(AllSpecifiableFields(node).Select(f => ToCamelCase(f.Name))).Concat(new[] { "this.HasErrors" });
ParenList(fields);
WriteLine(";");
WriteLine("result.CopyAttributes(this);");
WriteLine("return result;");
Unbrace();
}
WriteLine("return this;");
Unbrace();
string notEquals(Field field)
{
var parameterName = ToCamelCase(field.Name);
var fieldName = field.Name;
if (TypeIsTypeSymbol(field))
return $"!TypeSymbol.Equals({parameterName}, this.{fieldName}, TypeCompareKind.ConsiderEverything)";
if (TypeIsSymbol(field))
return $"!Symbols.SymbolEqualityComparer.ConsiderEverything.Equals({parameterName}, this.{fieldName})";
if (IsValueType(field.Type) && field.Type[^1] == '?')
return $"!{parameterName}.Equals(this.{fieldName})";
if (GetGenericType(field.Type) == "OneOrMany")
return $"!{parameterName}.SequenceEqual({fieldName})";
return $"{parameterName} != this.{fieldName}";
}
}
private void WriteUpdatedMethodVB(Node node, bool emitNew)
{
Blank();
Write("Public{0} Function Update", emitNew ? " Shadows" : "");
Paren();
Comma(AllSpecifiableFields(node), field => string.Format("{1} As {0}", field.Type, ToCamelCase(field.Name)));
UnParen();
WriteLine(" As {0}", node.Name);
Indent();
if (AllSpecifiableFields(node).Any())
{
Write("If ");
Or(AllSpecifiableFields(node), notEquals);
WriteLine(" Then");
Indent();
Write("Dim result = New {0}", node.Name);
var fields = new[] { "Me.Syntax" }.Concat(AllSpecifiableFields(node).Select(f => ToCamelCase(f.Name))).Concat(new[] { "Me.HasErrors" });
ParenList(fields);
WriteLine("");
WriteLine("result.CopyAttributes(Me)");
WriteLine("Return result");
Outdent();
WriteLine("End If");
}
WriteLine("Return Me");
Outdent();
WriteLine("End Function");
string notEquals(Field field)
{
var parameterName = ToCamelCase(field.Name);
var fieldName = field.Name;
if (!IsValueType(field.Type))
return $"{parameterName} IsNot Me.{fieldName}";
if (GetGenericType(field.Type) == "OneOrMany")
return $"Not {parameterName}.SequenceEqual({fieldName})";
return $"{parameterName} <> Me.{fieldName}";
}
}
private static bool TypeIsTypeSymbol(Field field) => field.Type.TrimEnd('?') == "TypeSymbol";
private static bool TypeIsSymbol(Field field) => TypeIsSymbol(field.Type);
private static bool TypeIsSymbol(string type) => type.TrimEnd('?').EndsWith("Symbol");
private string StripBound(string name)
{
if (name.StartsWith("Bound", StringComparison.Ordinal))
{
name = name.Substring(5);
}
return name;
}
private void WriteVisitor()
{
switch (_targetLang)
{
case TargetLanguage.CSharp:
Blank();
WriteLine("internal abstract partial class BoundTreeVisitor<A, R>");
Brace();
Blank();
WriteLine("[MethodImpl(MethodImplOptions.NoInlining), DebuggerStepThrough]");
WriteLine("internal R VisitInternal(BoundNode node, A arg)");
Brace();
WriteLine("switch (node.Kind)");
Brace();
foreach (var node in _tree.Types.OfType<Node>())
{
WriteLine("case BoundKind.{0}:", FixKeyword(StripBound(node.Name)));
Indent();
WriteLine("return Visit{0}(({1})node, arg);", StripBound(node.Name), node.Name);
Outdent();
}
Unbrace();
Blank(); // end switch
WriteLine("return default(R)!;");
Unbrace(); // end method
Unbrace(); // end class
Blank();
WriteLine("internal abstract partial class BoundTreeVisitor<A, R>");
Brace();
foreach (var node in _tree.Types.OfType<Node>())
{
WriteLine($"public virtual R Visit{StripBound(node.Name)}({node.Name} node, A arg) => this.DefaultVisit(node, arg);");
}
Unbrace();
Blank();
WriteLine("internal abstract partial class BoundTreeVisitor");
Brace();
foreach (var node in _tree.Types.OfType<Node>())
{
WriteLine($"public virtual BoundNode? Visit{StripBound(node.Name)}({node.Name} node) => this.DefaultVisit(node);");
}
Unbrace();
break;
case TargetLanguage.VB:
Blank();
WriteLine("Friend MustInherit Partial Class BoundTreeVisitor(Of A, R)");
Indent();
Blank();
WriteLine("<MethodImpl(MethodImplOptions.NoInlining), DebuggerStepThrough>");
WriteLine("Friend Function VisitInternal(node As BoundNode, arg As A) As R");
Indent();
WriteLine("Select Case node.Kind");
Indent();
foreach (var node in _tree.Types.OfType<Node>())
{
WriteLine("Case BoundKind.{0}", FixKeyword(StripBound(node.Name)));
Indent();
WriteLine("Return Visit{0}(CType(node, {1}), arg)", StripBound(node.Name), node.Name);
Outdent();
}
Outdent();
WriteLine("End Select");
WriteLine("Return DefaultVisit(node, arg)");
Outdent();
WriteLine("End Function");
Blank();
Outdent();
WriteLine("End Class");
Blank();
WriteLine("Friend MustInherit Partial Class BoundTreeVisitor(Of A, R)");
Indent();
foreach (var node in _tree.Types.OfType<Node>())
{
WriteLine("Public Overridable Function Visit{0}(node As {1}, arg As A) As R", StripBound(node.Name), node.Name);
Indent();
WriteLine("Return Me.DefaultVisit(node, arg)");
Outdent();
WriteLine("End Function");
Blank();
}
Outdent();
WriteLine("End Class");
Blank();
WriteLine("Friend MustInherit Partial Class BoundTreeVisitor");
Indent();
foreach (var node in _tree.Types.OfType<Node>())
{
WriteLine(GetVisitFunctionDeclaration(node.Name, isOverride: false));
Indent();
WriteLine("Return Me.DefaultVisit(node)");
Outdent();
WriteLine("End Function");
Blank();
}
Outdent();
WriteLine("End Class");
break;
default:
throw new ArgumentException("Unexpected target language", nameof(_targetLang));
}
}
private void WriteWalker()
{
switch (_targetLang)
{
case TargetLanguage.CSharp:
Blank();
WriteLine("internal abstract partial class BoundTreeWalker : BoundTreeVisitor");
Brace();
foreach (var node in _tree.Types.OfType<Node>())
{
var fields = AllFields(node).Where(f => IsDerivedOrListOfDerived("BoundNode", f.Type) && !SkipInVisitor(f));
if (!fields.Any())
{
WriteLine($"{GetVisitFunctionDeclaration(node.Name, isOverride: true)} => null;");
continue;
}
WriteLine(GetVisitFunctionDeclaration(node.Name, isOverride: true));
Brace();
foreach (Field field in fields)
{
WriteLine("this.Visit{1}(node.{0});", field.Name, IsNodeList(field.Type) ? "List" : "");
}
WriteLine("return null;");
Unbrace();
}
Unbrace();
break;
case TargetLanguage.VB:
Blank();
WriteLine("Friend MustInherit Partial Class BoundTreeWalker");
Indent();
WriteLine("Inherits BoundTreeVisitor");
Blank();
foreach (var node in _tree.Types.OfType<Node>())
{
WriteLine(GetVisitFunctionDeclaration(node.Name, isOverride: true));
Indent();
foreach (Field field in AllFields(node).Where(f => IsDerivedOrListOfDerived("BoundNode", f.Type) && !SkipInVisitor(f)))
WriteLine("Me.Visit{1}(node.{0})", field.Name, IsNodeList(field.Type) ? "List" : "");
WriteLine("Return Nothing");
Outdent();
WriteLine("End Function");
Blank();
}
Outdent();
WriteLine("End Class");
break;
default:
throw new ArgumentException("Unexpected target language", nameof(_targetLang));
}
}
private void WriteTreeDumperNodeProducer()
{
switch (_targetLang)
{
case TargetLanguage.CSharp:
Blank();
WriteLine("internal sealed class BoundTreeDumperNodeProducer : BoundTreeVisitor<object?, TreeDumperNode>");
Brace();
WriteLine("private BoundTreeDumperNodeProducer()");
Brace();
Unbrace();
WriteLine("public static TreeDumperNode MakeTree(BoundNode node) => (new BoundTreeDumperNodeProducer()).Visit(node, null);");
foreach (var node in _tree.Types.OfType<Node>())
{
Write("public override TreeDumperNode Visit{0}({1} node, object? arg) => new TreeDumperNode(\"{2}\", null, ", StripBound(node.Name), node.Name, ToCamelCase(StripBound(node.Name)));
var allFields = AllFields(node).ToArray();
if (allFields.Length > 0)
{
WriteLine("new TreeDumperNode[]");
Brace();
for (int i = 0; i < allFields.Length; ++i)
{
Field field = allFields[i];
if (IsDerivedType("BoundNode", field.Type))
Write("new TreeDumperNode(\"{0}\", null, new TreeDumperNode[] {{ Visit(node.{1}, null) }})", ToCamelCase(field.Name), field.Name);
else if (IsListOfDerived("BoundNode", field.Type))
{
if (IsImmutableArray(field.Type, out _) && FieldNullHandling(node, field.Name) == NullHandling.Disallow)
{
Write("new TreeDumperNode(\"{0}\", null, from x in node.{1} select Visit(x, null))", ToCamelCase(field.Name), field.Name);
}
else
{
Write("new TreeDumperNode(\"{0}\", null, node.{1}.IsDefault ? Array.Empty<TreeDumperNode>() : from x in node.{1} select Visit(x, null))", ToCamelCase(field.Name), field.Name);
}
}
else
Write("new TreeDumperNode(\"{0}\", node.{1}, null)", ToCamelCase(field.Name), field.Name);
if (i != allFields.Length - 1)
WriteLine(",");
}
if (allFields.Length != 0)
{
WriteLine(",");
}
if (IsDerivedType("BoundExpression", node.Name))
{
Write("new TreeDumperNode(\"isSuppressed\", node.IsSuppressed, null)");
WriteLine(",");
}
Write("new TreeDumperNode(\"hasErrors\", node.HasErrors, null)");
WriteLine("");
Unbrace();
}
else
{
WriteLine("Array.Empty<TreeDumperNode>()");
}
WriteLine(");");
}
Unbrace();
break;
case TargetLanguage.VB:
Blank();
WriteLine("Friend NotInheritable Class BoundTreeDumperNodeProducer");
Indent();
WriteLine("Inherits BoundTreeVisitor(Of Object, TreeDumperNode)");
Blank();
WriteLine("Private Sub New()");
WriteLine("End Sub");
Blank();
WriteLine("Public Shared Function MakeTree(node As BoundNode) As TreeDumperNode");
Indent();
WriteLine("Return (New BoundTreeDumperNodeProducer()).Visit(node, Nothing)");
Outdent();
WriteLine("End Function");
Blank();
foreach (var node in _tree.Types.OfType<Node>())
{
WriteLine("Public Overrides Function Visit{0}(node As {1}, arg As Object) As TreeDumperNode", StripBound(node.Name), node.Name);
Indent();
Write("Return New TreeDumperNode(\"{0}\", Nothing, ", ToCamelCase(StripBound(node.Name)));
var allFields = AllFields(node).ToArray();
if (allFields.Length > 0)
{
WriteLine("New TreeDumperNode() {{");
Indent();
for (int i = 0; i < allFields.Length; ++i)
{
Field field = allFields[i];
if (IsDerivedType("BoundNode", field.Type))
Write("New TreeDumperNode(\"{0}\", Nothing, new TreeDumperNode() {{Visit(node.{1}, Nothing)}})", ToCamelCase(field.Name), field.Name);
else if (IsListOfDerived("BoundNode", field.Type))
Write("New TreeDumperNode(\"{0}\", Nothing, From x In node.{1} Select Visit(x, Nothing))", ToCamelCase(field.Name), field.Name);
else
Write("New TreeDumperNode(\"{0}\", node.{1}, Nothing)", ToCamelCase(field.Name), field.Name);
if (i == allFields.Length - 1)
WriteLine("");
else
WriteLine(",");
}
Outdent();
WriteLine("}})");
}
else
{
WriteLine("Array.Empty(Of TreeDumperNode)())");
}
Outdent();
WriteLine("End Function");
Blank();
}
Outdent();
WriteLine("End Class");
break;
default:
throw new ArgumentException("Unexpected target language", nameof(_targetLang));
}
}
private void WriteRewriter()
{
switch (_targetLang)
{
case TargetLanguage.CSharp:
{
Blank();
WriteLine("internal abstract partial class BoundTreeRewriter : BoundTreeVisitor");
Brace();
foreach (var node in _tree.Types.OfType<Node>())
{
if (!AllNodeOrNodeListFields(node).Any() && !AllTypeFields(node).Any())
{
WriteLine($"{GetVisitFunctionDeclaration(node.Name, isOverride: true)} => node;");
continue;
}
WriteLine(GetVisitFunctionDeclaration(node.Name, isOverride: true));
Brace();
bool hadField = false;
foreach (Field field in AllNodeOrNodeListFields(node))
{
hadField = true;
WriteNodeVisitCall(field);
}
foreach (Field field in AllTypeFields(node))
{
hadField = true;
WriteLine("TypeSymbol? {0} = this.VisitType(node.{1});", ToCamelCase(field.Name), field.Name);
}
if (hadField)
{
Write("return node.Update");
ParenList(AllSpecifiableFields(node), field => IsDerivedOrListOfDerived("BoundNode", field.Type) || TypeIsTypeSymbol(field) ? ToCamelCase(field.Name) : string.Format("node.{0}", field.Name));
WriteLine(";");
}
else
{
WriteLine("return node;");
}
Unbrace();
}
Unbrace();
break;
}
case TargetLanguage.VB:
{
Blank();
WriteLine("Friend MustInherit Partial Class BoundTreeRewriter");
Indent();
WriteLine("Inherits BoundTreeVisitor");
Blank();
foreach (var node in _tree.Types.OfType<Node>())
{
WriteLine(GetVisitFunctionDeclaration(node.Name, isOverride: true));
Indent();
bool hadField = false;
foreach (Field field in AllNodeOrNodeListFields(node))
{
hadField = true;
WriteNodeVisitCall(field);
}
foreach (Field field in AllTypeFields(node))
{
hadField = true;
WriteLine("Dim {0} as TypeSymbol = Me.VisitType(node.{1})", ToCamelCase(field.Name), field.Name);
}
if (hadField)
{
Write("Return node.Update");
ParenList(AllSpecifiableFields(node), field => IsDerivedOrListOfDerived("BoundNode", field.Type) || field.Type == "TypeSymbol" ? ToCamelCase(field.Name) : string.Format("node.{0}", field.Name));
WriteLine("");
}
else
{
WriteLine("Return node");
}
Outdent();
WriteLine("End Function");
Blank();
}
Outdent();
WriteLine("End Class");
break;
}
default:
throw new ArgumentException("Unexpected target language", nameof(_targetLang));
}
}
private void WriteNullabilityRewriter()
{
switch (_targetLang)
{
case TargetLanguage.VB:
break;
case TargetLanguage.CSharp:
{
Blank();
WriteLine("internal sealed partial class NullabilityRewriter : BoundTreeRewriter");
Brace();
var updatedNullabilities = "_updatedNullabilities";
var snapshotManager = "_snapshotManager";
var remappedSymbols = "_remappedSymbols";
WriteLine($"private readonly ImmutableDictionary<BoundExpression, (NullabilityInfo Info, TypeSymbol? Type)> {updatedNullabilities};");
WriteLine($"private readonly NullableWalker.SnapshotManager? {snapshotManager};");
WriteLine($"private readonly ImmutableDictionary<Symbol, Symbol>.Builder {remappedSymbols};");
Blank();
WriteLine("public NullabilityRewriter(ImmutableDictionary<BoundExpression, (NullabilityInfo Info, TypeSymbol? Type)> updatedNullabilities, NullableWalker.SnapshotManager? snapshotManager, ImmutableDictionary<Symbol, Symbol>.Builder remappedSymbols)");
Brace();
WriteLine($"{updatedNullabilities} = updatedNullabilities;");
WriteLine($"{snapshotManager} = snapshotManager;");
WriteLine($"{remappedSymbols} = remappedSymbols;");
Unbrace();
foreach (var node in _tree.Types.OfType<Node>())
{
if (SkipInNullabilityRewriter(node))
continue;
var allSpecifiableFields = AllSpecifiableFields(node).ToList();
var isExpression = IsDerivedType("BoundExpression", node.Name);
if (!isExpression && !allSpecifiableFields.Any(f => symbolIsPotentiallyUpdated(f) || immutableArrayIsPotentiallyUpdated(f)))
{
continue;
}
Blank();
WriteLine(GetVisitFunctionDeclaration(node.Name, isOverride: true));
Brace();
bool hadField = false;
foreach (var field in AllSymbolOrSymbolListFields(node))
{
if (symbolIsPotentiallyUpdated(field))
{
WriteLine($"{field.Type} {ToCamelCase(field.Name)} = GetUpdatedSymbol(node, node.{field.Name});");
hadField = true;
}
else if (immutableArrayIsPotentiallyUpdated(field))
{
WriteLine($"{field.Type} {ToCamelCase(field.Name)} = GetUpdatedArray(node, node.{field.Name});");
hadField = true;
}
}
foreach (var field in AllNodeOrNodeListFields(node))
{
hadField = true;
WriteNodeVisitCall(field, forceVisit: VisitFieldOnlyInNullabilityRewriter(field));
}
if (isExpression)
{
if (hadField)
{
WriteLine($"{node.Name} updatedNode;");
Blank();
writeNullabilityCheck(inverted: false);
Brace();
writeUpdateAndDecl(decl: false, updatedType: true);
writeNullabilityUpdate();
Unbrace();
WriteLine("else");
Brace();
writeUpdateAndDecl(decl: false, updatedType: false);
Unbrace();
WriteLine("return updatedNode;");
}
else
{
writeNullabilityCheck(inverted: true);
Brace();
WriteLine("return node;");
Unbrace();
Blank();
writeUpdateAndDecl(decl: true, updatedType: true);
writeNullabilityUpdate();
WriteLine("return updatedNode;");
}
}
else
{
Write("return ");
writeUpdate(updatedType: false);
WriteLine(";");
}
Unbrace();
void writeNullabilityCheck(bool inverted)
=> WriteLine($"if ({(inverted ? "!" : "")}{updatedNullabilities}.TryGetValue(node, out (NullabilityInfo Info, TypeSymbol? Type) infoAndType))");
void writeUpdateAndDecl(bool decl, bool updatedType)
{
Write($"{(decl ? $"{node.Name} " : "")}updatedNode = ");
writeUpdate(updatedType);
WriteLine(";");
}
void writeUpdate(bool updatedType)
{
Write("node.Update");
ParenList(
allSpecifiableFields,
field =>
{
if (SkipInNullabilityRewriter(field))
{
return $"node.{field.Name}";
}
else if (IsDerivedOrListOfDerived("BoundNode", field.Type))
{
return ToCamelCase(field.Name);
}
else if (updatedType && field.Name == "Type")
{
// Use the override for the field if any.
field = GetMostDerivedField(node, field.Name);
return $"infoAndType.Type" + (field.Null == "disallow" ? "!" : "");
}
else if (symbolIsPotentiallyUpdated(field) || immutableArrayIsPotentiallyUpdated(field))
{
return $"{ToCamelCase(field.Name)}";
}
else
{
return $"node.{field.Name}";
}
});
}
void writeNullabilityUpdate()
{
WriteLine($"updatedNode.TopLevelNullability = infoAndType.Info;");
}
static bool symbolIsPotentiallyUpdated(Field f)
{
if (!TypeIsSymbol(f))
return false;
if (f.Name == "Type")
return false;
return typeIsUpdated(f.Type);
}
bool immutableArrayIsPotentiallyUpdated(Field field)
=> IsImmutableArray(field.Type, out var elementType) && TypeIsSymbol(elementType) && typeIsUpdated(elementType);
static bool typeIsUpdated(string type)
{
switch (type.TrimEnd('?'))
{
case "LabelSymbol":
case "GeneratedLabelSymbol":
case "AliasSymbol":
case "NamespaceSymbol":
return false;
default:
return true;
}
}
}
Unbrace();
break;
}
default:
throw new ArgumentException("Unexpected target language", nameof(_targetLang));
}
}
private bool IsDerivedOrListOfDerived(string baseType, string derivedType)
{
return IsDerivedType(baseType, derivedType) || IsListOfDerived(baseType, derivedType);
}
private bool IsListOfDerived(string baseType, string derivedType)
{
return IsNodeList(derivedType) && IsDerivedType(baseType, GetElementType(derivedType));
}
private bool IsImmutableArray(string typeName, out string elementType)
{
string immutableArrayPrefix = _targetLang switch
{
TargetLanguage.CSharp => "ImmutableArray<",
TargetLanguage.VB => "ImmutableArray(Of ",
_ => throw new InvalidOperationException($"Unknown target language {_targetLang}")
};
if (typeName.StartsWith(immutableArrayPrefix, StringComparison.Ordinal))
{
elementType = typeName[immutableArrayPrefix.Length..^1];
return true;
}
elementType = null;
return false;
}
private bool IsNodeList(string typeName)
{
switch (_targetLang)
{
case TargetLanguage.CSharp:
return typeName.StartsWith("IList<", StringComparison.Ordinal) ||
typeName.StartsWith("ImmutableArray<", StringComparison.Ordinal);
case TargetLanguage.VB:
return typeName.StartsWith("IList(Of", StringComparison.OrdinalIgnoreCase) ||
typeName.StartsWith("ImmutableArray(Of", StringComparison.OrdinalIgnoreCase);
default:
throw new ArgumentException("Unexpected target language", nameof(_targetLang));
}
}
public bool IsNodeOrNodeList(string typeName)
{
return IsNode(typeName) || IsNodeList(typeName);
}
private string GetGenericType(string typeName)
{
switch (_targetLang)
{
case TargetLanguage.CSharp:
{
if (!typeName.Contains("<"))
return typeName;
int iStart = typeName.IndexOf('<');
return typeName.Substring(0, iStart);
}
case TargetLanguage.VB:
{
int iStart = typeName.IndexOf("(Of", StringComparison.OrdinalIgnoreCase);
if (iStart == -1)
return typeName;
return typeName.Substring(0, iStart);
}
default:
throw new ArgumentException("Unexpected target language", nameof(_targetLang));
}
}
private string GetElementType(string typeName)
{
switch (_targetLang)
{
case TargetLanguage.CSharp:
{
if (!typeName.Contains("<"))
return string.Empty;
int iStart = typeName.IndexOf('<');
int iEnd = typeName.IndexOf('>', iStart + 1);
if (iEnd < iStart)
return string.Empty;
var sub = typeName.Substring(iStart + 1, iEnd - iStart - 1);
return sub;
}
case TargetLanguage.VB:
{
int iStart = typeName.IndexOf("(Of", StringComparison.OrdinalIgnoreCase);
if (iStart == -1)
return string.Empty;
int iEnd = typeName.IndexOf(')', iStart + 3);
if (iEnd < iStart)
return string.Empty;
var sub = typeName.Substring(iStart + 3, iEnd - iStart - 3).Trim();
return sub;
}
default:
throw new ArgumentException("Unexpected target language", nameof(_targetLang));
}
}
private bool IsValueType(string typeName) => _valueTypes.Contains(GetGenericType(typeName).TrimEnd('?'));
private bool IsDerivedType(string typeName, string derivedTypeName)
{
typeName = typeName.TrimEnd('?');
derivedTypeName = derivedTypeName?.TrimEnd('?');
if (typeName == derivedTypeName)
return true;
if (derivedTypeName != null && _typeMap.TryGetValue(derivedTypeName, out var baseType))
{
return IsDerivedType(typeName, baseType);
}
return false;
}
private bool IsNode(string typeName)
{
return _typeMap.ContainsKey(typeName);
}
private static bool IsNew(Field f)
{
return string.Compare(f.New, "true", true) == 0;
}
private static bool IsPropertyOverrides(Field f)
{
return string.Compare(f.PropertyOverrides, "true", true) == 0;
}
private static bool SkipInVisitor(Field f)
{
return string.Compare(f.SkipInVisitor, "true", true) == 0
|| VisitFieldOnlyInNullabilityRewriter(f);
}
private static bool VisitFieldOnlyInNullabilityRewriter(Field f)
{
return string.Compare(f.SkipInVisitor, "ExceptNullabilityRewriter", true) == 0;
}
private static bool SkipInNullabilityRewriter(Node n)
{
return string.Compare(n.SkipInNullabilityRewriter, "true", true) == 0;
}
private static bool SkipInNullabilityRewriter(Field f)
{
return string.Compare(f.SkipInNullabilityRewriter, "true", ignoreCase: true) == 0;
}
private string ToCamelCase(string name)
{
if (char.IsUpper(name[0]))
{
name = char.ToLowerInvariant(name[0]) + name.Substring(1);
}
return FixKeyword(name);
}
private string FixKeyword(string name)
{
if (IsKeyword(name))
{
switch (_targetLang)
{
case TargetLanguage.CSharp:
return "@" + name;
case TargetLanguage.VB:
return "[" + name + "]";
default:
throw new ArgumentException("Unexpected target language", nameof(_targetLang));
}
}
return name;
}
private bool IsKeyword(string name)
{
switch (_targetLang)
{
case TargetLanguage.CSharp:
return name.IsCSharpKeyword();
case TargetLanguage.VB:
return name.IsVBKeyword();
default:
throw new ArgumentException("Unexpected target language", nameof(_targetLang));
}
}
private string GetVisitFunctionDeclaration(string nodeName, bool isOverride)
{
switch (_targetLang)
{
case TargetLanguage.CSharp:
return $"public {(isOverride ? "override" : "virtual")} BoundNode? Visit{StripBound(nodeName)}({nodeName} node)";
case TargetLanguage.VB:
return $"Public {(isOverride ? "Overrides" : "Overridable")} Function Visit{StripBound(nodeName)}(node As {nodeName}) As BoundNode";
default:
throw new ArgumentException("Unexpected target language", nameof(_targetLang));
}
}
private void WriteNodeVisitCall(Field field, bool forceVisit = false)
{
switch (_targetLang)
{
case TargetLanguage.CSharp:
if (SkipInVisitor(field) && !forceVisit)
{
WriteLine($"{field.Type} {ToCamelCase(field.Name)} = node.{field.Name};");
}
else if (IsNodeList(field.Type))
{
WriteLine($"{field.Type} {ToCamelCase(field.Name)} = this.VisitList(node.{field.Name});");
}
else
{
WriteLine($"{field.Type} {ToCamelCase(field.Name)} = ({field.Type})this.Visit(node.{field.Name});");
}
break;
case TargetLanguage.VB:
if (SkipInVisitor(field))
{
WriteLine("Dim {0} As {2} = node.{1}", ToCamelCase(field.Name), field.Name, field.Type);
}
else if (IsNodeList(field.Type))
{
WriteLine("Dim {0} As {2} = Me.VisitList(node.{1})", ToCamelCase(field.Name), field.Name, field.Type);
}
else
{
WriteLine("Dim {0} As {2} = DirectCast(Me.Visit(node.{1}), {2})", ToCamelCase(field.Name), field.Name, field.Type);
}
break;
}
}
}
internal static class Extensions
{
public static bool IsCSharpKeyword(this string name)
{
switch (name)
{
case "bool":
case "byte":
case "sbyte":
case "short":
case "ushort":
case "int":
case "uint":
case "long":
case "ulong":
case "double":
case "float":
case "decimal":
case "string":
case "char":
case "object":
case "typeof":
case "sizeof":
case "null":
case "true":
case "false":
case "if":
case "else":
case "while":
case "for":
case "foreach":
case "do":
case "switch":
case "case":
case "default":
case "lock":
case "try":
case "throw":
case "catch":
case "finally":
case "goto":
case "break":
case "continue":
case "return":
case "public":
case "private":
case "internal":
case "protected":
case "static":
case "readonly":
case "sealed":
case "const":
case "new":
case "override":
case "abstract":
case "virtual":
case "partial":
case "ref":
case "out":
case "in":
case "where":
case "params":
case "this":
case "base":
case "namespace":
case "using":
case "class":
case "struct":
case "interface":
case "delegate":
case "checked":
case "get":
case "set":
case "add":
case "remove":
case "operator":
case "implicit":
case "explicit":
case "fixed":
case "extern":
case "event":
case "enum":
case "unsafe":
return true;
default:
return false;
}
}
public static bool IsVBKeyword(this string name)
{
switch (name.ToLowerInvariant())
{
case "addhandler":
case "addressof":
case "alias":
case "and":
case "andalso":
case "as":
case "boolean":
case "byref":
case "byte":
case "byval":
case "call":
case "case":
case "catch":
case "cbool":
case "cbyte":
case "cchar":
case "cdate":
case "cdbl":
case "cdec":
case "char":
case "cint":
case "class":
case "clng":
case "cobj":
case "const":
case "continue":
case "csbyte":
case "cshort":
case "csng":
case "cstr":
case "ctype":
case "cuint":
case "culng":
case "cushort":
case "date":
case "decimal":
case "declare":
case "default":
case "delegate":
case "dim":
case "directcast":
case "do":
case "double":
case "each":
case "else":
case "elseif":
case "end":
case "endif":
case "enum":
case "erase":
case "error":
case "event":
case "exit":
case "false":
case "finally":
case "for":
case "friend":
case "function":
case "get":
case "gettype":
case "getxmlnamespace":
case "global":
case "gosub":
case "goto":
case "handles":
case "if":
case "implements":
case "imports":
case "in":
case "inherits":
case "integer":
case "interface":
case "is":
case "isnot":
case "let":
case "lib":
case "like":
case "long":
case "loop":
case "me":
case "mod":
case "module":
case "mustinherit":
case "mustoverride":
case "mybase":
case "myclass":
case "nameof":
case "namespace":
case "narrowing":
case "new":
case "next":
case "not":
case "nothing":
case "notinheritable":
case "notoverridable":
case "object":
case "of":
case "on":
case "operator":
case "option":
case "optional":
case "or":
case "orelse":
case "overloads":
case "overridable":
case "overrides":
case "paramarray":
case "partial":
case "private":
case "property":
case "protected":
case "public":
case "raiseevent":
case "readonly":
case "redim":
case "rem":
case "removehandler":
case "resume":
case "return":
case "sbyte":
case "select":
case "set":
case "shadows":
case "shared":
case "short":
case "single":
case "static":
case "step":
case "stop":
case "string":
case "structure":
case "sub":
case "synclock":
case "then":
case "throw":
case "to":
case "true":
case "try":
case "trycast":
case "typeof":
case "uinteger":
case "ulong":
case "ushort":
case "using":
case "variant":
case "wend":
case "when":
case "while":
case "widening":
case "with":
case "withevents":
case "writeonly":
case "xor":
return true;
default:
return false;
}
}
}
}
|