|
// Copyright (c) .NET Foundation and contributors. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
using System;
using System.Diagnostics;
using System.Linq;
using Mono.Cecil;
using Mono.Cecil.Cil;
namespace Mono.Linker.Steps
{
public class CodeRewriterStep : BaseStep
{
AssemblyDefinition? assembly;
AssemblyDefinition Assembly
{
get
{
Debug.Assert(assembly != null);
return assembly;
}
}
protected override void ProcessAssembly(AssemblyDefinition assembly)
{
if (Annotations.GetAction(assembly) != AssemblyAction.Link)
return;
this.assembly = assembly;
foreach (var type in assembly.MainModule.Types)
ProcessType(type);
}
void ProcessType(TypeDefinition type)
{
foreach (var method in type.Methods)
{
if (method.HasBody)
ProcessMethod(method);
}
foreach (var nested in type.NestedTypes)
ProcessType(nested);
}
void ProcessMethod(MethodDefinition method)
{
switch (Annotations.GetAction(method))
{
case MethodAction.ConvertToStub:
RewriteBodyToStub(method);
break;
case MethodAction.ConvertToThrow:
RewriteBodyToLinkedAway(method);
break;
}
}
protected virtual void RewriteBodyToLinkedAway(MethodDefinition method)
{
method.ImplAttributes &= ~(MethodImplAttributes.AggressiveInlining | MethodImplAttributes.Synchronized);
method.ImplAttributes |= MethodImplAttributes.NoInlining;
method.Body = CreateThrowLinkedAwayBody(method);
method.ClearDebugInformation();
}
protected virtual void RewriteBodyToStub(MethodDefinition method)
{
if (!method.IsIL)
throw new NotImplementedException();
method.Body = CreateStubBody(method);
method.ClearDebugInformation();
}
MethodBody CreateThrowLinkedAwayBody(MethodDefinition method)
{
var body = new MethodBody(method);
var il = body.GetLinkerILProcessor();
MethodReference? ctor;
// Makes the body verifiable
if (method.IsConstructor && !method.DeclaringType.IsValueType)
{
ctor = Assembly.MainModule.ImportReference(Context.MarkedKnownMembers.ObjectCtor);
il.Emit(OpCodes.Ldarg_0);
il.Emit(OpCodes.Call, ctor);
}
// import the method into the current assembly
ctor = Context.MarkedKnownMembers.NotSupportedExceptionCtorString;
ctor = Assembly.MainModule.ImportReference(ctor);
il.Emit(OpCodes.Ldstr, "Linked away");
il.Emit(OpCodes.Newobj, ctor);
il.Emit(OpCodes.Throw);
return body;
}
MethodBody CreateStubBody(MethodDefinition method)
{
var body = new MethodBody(method);
#pragma warning disable RS0030 // MethodReference.Parameters is banned. This code already works and doesn't need to be changed
if (method.HasParameters && method.Parameters.Any(l => l.IsOut))
throw new NotSupportedException($"Cannot replace body of method '{method.GetDisplayName()}' because it has an out parameter.");
#pragma warning restore RS0030
var il = body.GetLinkerILProcessor();
if (method.IsInstanceConstructor() && !method.DeclaringType.IsValueType)
{
var baseType = Context.Resolve(method.DeclaringType.BaseType);
if (baseType is null)
return body;
MethodReference base_ctor = baseType.GetDefaultInstanceConstructor(Context);
if (base_ctor == null)
throw new NotSupportedException($"Cannot replace constructor for '{method.DeclaringType}' when no base default constructor exists");
base_ctor = Assembly.MainModule.ImportReference(base_ctor);
il.Emit(OpCodes.Ldarg_0);
il.Emit(OpCodes.Call, base_ctor);
}
switch (method.ReturnType.MetadataType)
{
case MetadataType.Void:
break;
default:
var instruction = CreateConstantResultInstruction(Context, method);
if (instruction != null)
{
il.Append(instruction);
}
else
{
StubComplexBody(method, body, il);
}
break;
}
il.Emit(OpCodes.Ret);
return body;
}
static void StubComplexBody(MethodDefinition method, MethodBody body, LinkerILProcessor il)
{
switch (method.ReturnType.MetadataType)
{
case MetadataType.MVar:
case MetadataType.ValueType:
var vd = new VariableDefinition(method.ReturnType);
#pragma warning disable RS0030 // Anything after MarkStep should not use ILProvider since all methods are guaranteed processed
body.Variables.Add(vd);
#pragma warning restore RS0030
body.InitLocals = true;
il.Emit(OpCodes.Ldloca_S, vd);
il.Emit(OpCodes.Initobj, method.ReturnType);
il.Emit(OpCodes.Ldloc_0);
return;
case MetadataType.Pointer:
case MetadataType.IntPtr:
case MetadataType.UIntPtr:
il.Emit(OpCodes.Ldc_I4_0);
il.Emit(OpCodes.Conv_I);
return;
}
throw new NotImplementedException(method.FullName);
}
public static Instruction? CreateConstantResultInstruction(LinkContext context, MethodDefinition method)
{
context.Annotations.TryGetMethodStubValue(method, out object? value);
return CreateConstantResultInstruction(context, method.ReturnType, value);
}
public static Instruction? CreateConstantResultInstruction(LinkContext context, TypeReference inputRtype, object? value = null)
{
TypeReference? rtype = inputRtype;
switch (rtype.MetadataType)
{
case MetadataType.ValueType:
var definition = context.TryResolve(rtype);
if (definition?.IsEnum == true)
{
rtype = definition.GetEnumUnderlyingType();
}
break;
case MetadataType.GenericInstance:
rtype = context.TryResolve(rtype);
break;
}
if (rtype == null)
return null;
switch (rtype.MetadataType)
{
case MetadataType.Boolean:
if (value is int bintValue && bintValue == 1)
return Instruction.Create(OpCodes.Ldc_I4_1);
return Instruction.Create(OpCodes.Ldc_I4_0);
case MetadataType.String:
if (value is string svalue)
return Instruction.Create(OpCodes.Ldstr, svalue);
return Instruction.Create(OpCodes.Ldnull);
case MetadataType.Object:
case MetadataType.Array:
case MetadataType.Class:
Debug.Assert(value == null);
return Instruction.Create(OpCodes.Ldnull);
case MetadataType.Double:
if (value is double dvalue)
return Instruction.Create(OpCodes.Ldc_R8, dvalue);
Debug.Assert(value == null);
return Instruction.Create(OpCodes.Ldc_R8, 0.0);
case MetadataType.Single:
if (value is float fvalue)
return Instruction.Create(OpCodes.Ldc_R4, fvalue);
Debug.Assert(value == null);
return Instruction.Create(OpCodes.Ldc_R4, 0.0f);
case MetadataType.Char:
case MetadataType.Byte:
case MetadataType.SByte:
case MetadataType.Int16:
case MetadataType.UInt16:
case MetadataType.Int32:
case MetadataType.UInt32:
if (value is int intValue)
return Instruction.Create(OpCodes.Ldc_I4, intValue);
Debug.Assert(value == null);
return Instruction.Create(OpCodes.Ldc_I4_0);
case MetadataType.UInt64:
case MetadataType.Int64:
if (value is long longValue)
return Instruction.Create(OpCodes.Ldc_I8, longValue);
Debug.Assert(value == null);
return Instruction.Create(OpCodes.Ldc_I8, 0L);
}
return null;
}
}
}
|