File: RuntimeTypeSystemHelpers\TypeValidation.cs
Web Access
Project: src\src\runtime\src\native\managed\cdac\Microsoft.Diagnostics.DataContractReader.Contracts\Microsoft.Diagnostics.DataContractReader.Contracts.csproj (Microsoft.Diagnostics.DataContractReader.Contracts)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System;
using Microsoft.Diagnostics.DataContractReader.Contracts;
using Microsoft.Diagnostics.DataContractReader.Data;

namespace Microsoft.Diagnostics.DataContractReader.RuntimeTypeSystemHelpers;

internal sealed class TypeValidation
{
    private readonly Target _target;
    private readonly TargetPointer _continuationMethodTablePointer;

    internal TypeValidation(Target target, TargetPointer continuationMethodTablePointer)
    {
        _target = target;
        _continuationMethodTablePointer = continuationMethodTablePointer;
    }

    // This doesn't need as many properties as MethodTable because we don't want to be operating on
    // a NonValidatedMethodTable for too long
    internal struct NonValidatedMethodTable
    {
        private readonly Target _target;
        private readonly Target.TypeInfo _type;
        internal TargetPointer Address { get; init; }

        private MethodTableFlags_1? _methodTableFlags;

        internal NonValidatedMethodTable(Target target, TargetPointer methodTablePointer)
        {
            _target = target;
            _type = target.GetTypeInfo(DataType.MethodTable);
            Address = methodTablePointer;
            _methodTableFlags = null;
        }

        private MethodTableFlags_1 GetOrCreateFlags()
        {
            if (_methodTableFlags == null)
            {
                // note: may throw if the method table Address is corrupted
                MethodTableFlags_1 flags = new MethodTableFlags_1
                {
                    MTFlags = _target.Read<uint>(Address + (ulong)_type.Fields[nameof(MethodTableFlags_1.MTFlags)].Offset),
                    MTFlags2 = _target.Read<uint>(Address + (ulong)_type.Fields[nameof(MethodTableFlags_1.MTFlags2)].Offset),
                    BaseSize = _target.Read<uint>(Address + (ulong)_type.Fields[nameof(MethodTableFlags_1.BaseSize)].Offset),
                };
                _methodTableFlags = flags;
            }
            return _methodTableFlags.Value;
        }

        internal MethodTableFlags_1 Flags => GetOrCreateFlags();

        internal TargetPointer EEClassOrCanonMT => _target.ReadPointer(Address + (ulong)_type.Fields[nameof(EEClassOrCanonMT)].Offset);
        internal TargetPointer EEClass => MethodTableFlags_1.GetEEClassOrCanonMTBits(EEClassOrCanonMT) == MethodTableFlags_1.EEClassOrCanonMTBits.EEClass ? EEClassOrCanonMT : throw new InvalidOperationException("not an EEClass");
        internal TargetPointer CanonMT
        {
            get
            {
                if (MethodTableFlags_1.GetEEClassOrCanonMTBits(EEClassOrCanonMT) == MethodTableFlags_1.EEClassOrCanonMTBits.CanonMT)
                {
                    return MethodTableFlags_1.UntagEEClassOrCanonMT(EEClassOrCanonMT);
                }
                else
                {
                    throw new InvalidOperationException("not a canonical method table");
                }
            }
        }

        internal readonly bool ValidateReadable() => ValidateDataReadable<MethodTable>(_target, Address);
        internal TargetPointer ParentMethodTable => _target.ReadPointer(Address + (ulong)_type.Fields[nameof(ParentMethodTable)].Offset);
    }

    internal struct NonValidatedEEClass
    {
        public readonly Target _target;
        private readonly Target.TypeInfo _type;

        internal TargetPointer Address { get; init; }

        internal NonValidatedEEClass(Target target, TargetPointer eeClassPointer)
        {
            _target = target;
            Address = eeClassPointer;
            _type = target.GetTypeInfo(DataType.EEClass);
        }

        internal TargetPointer MethodTable => _target.ReadPointer(Address + (ulong)_type.Fields[nameof(MethodTable)].Offset);

        internal readonly bool ValidateReadable() => ValidateDataReadable<EEClass>(_target, Address);
    }

    internal static NonValidatedMethodTable GetMethodTableData(Target target, TargetPointer methodTablePointer)
    {
        return new NonValidatedMethodTable(target, methodTablePointer);
    }

    internal static NonValidatedEEClass GetEEClassData(Target target, TargetPointer eeClassPointer)
    {
        return new NonValidatedEEClass(target, eeClassPointer);
    }


    /// <summary>
    /// Validates that the given address is a valid MethodTable.
    /// </summary>
    ///  <remarks>
    ///  If the target process has memory corruption, we may see pointers that are not valid method tables.
    ///  We validate by looking at the MethodTable -> EEClass -> MethodTable relationship (which may throw if we access invalid memory).
    ///  And then we do some ad-hoc checks on the method table flags.
    private bool ValidateMethodTablePointer(NonValidatedMethodTable umt)
    {
        try
        {
            // Make sure that we can read the method table's data.
            if (!umt.ValidateReadable())
            {
                return false;
            }
            if (!ValidateThrowing(umt))
            {
                return false;
            }
            if (!ValidateMethodTableAdHoc(umt))
            {
                return false;
            }
        }
        catch (System.Exception)
        {
            // TODO(cdac): maybe don't swallow all exceptions? We could consider a richer contract that
            // helps to track down what sort of memory corruption caused the validation to fail.
            // TODO(cdac): we could also consider a more fine-grained exception type so we don't mask
            // programmer mistakes in the cdac.
            return false;
        }
        return true;
    }

    // This portion of validation may throw if we are trying to read an invalid address in the target process
    private bool ValidateThrowing(NonValidatedMethodTable methodTable)
    {
        // For non-generic classes, we can rely on comparing
        //    object->methodtable->class->methodtable
        // to
        //    object->methodtable
        //
        //  However, for generic instantiation this does not work. There we must
        //  compare
        //
        //    object->methodtable->class->methodtable->class
        // to
        //    object->methodtable->class
        TargetPointer eeClassPtr = GetClassThrowing(methodTable);
        if (eeClassPtr != TargetPointer.Null)
        {
            NonValidatedEEClass eeClass = GetEEClassData(_target, eeClassPtr);
            if (!eeClass.ValidateReadable())
            {
                return false;
            }
            TargetPointer methodTablePtrFromClass = eeClass.MethodTable;
            if (methodTable.Address == methodTablePtrFromClass)
            {
                return true;
            }
            if (methodTable.Flags.HasInstantiation || methodTable.Flags.IsArray || IsContinuation(methodTable))
            {
                NonValidatedMethodTable methodTableFromClass = GetMethodTableData(_target, methodTablePtrFromClass);
                if (!methodTableFromClass.ValidateReadable())
                {
                    return false;
                }
                TargetPointer classFromMethodTable = GetClassThrowing(methodTableFromClass);
                return classFromMethodTable == eeClassPtr;
            }
        }
        return false;
    }

    private bool ValidateMethodTableAdHoc(NonValidatedMethodTable methodTable)
    {
        // ad-hoc checks; add more here as needed
        if (!methodTable.Flags.IsInterface && !methodTable.Flags.IsString)
        {
            if (methodTable.Flags.BaseSize == 0 || !_target.IsAlignedToPointerSize(methodTable.Flags.BaseSize))
            {
                return false;
            }
        }
        return true;
    }

    private static bool ValidateDataReadable<T>(Target target, TargetPointer dataAddress) where T : IData<T>
    {
        try
        {
            T dataClass = T.Create(target, dataAddress);
            return true;
        }
        catch (VirtualReadException)
        {
            return false;
        }
    }

    private TargetPointer GetClassThrowing(NonValidatedMethodTable methodTable)
    {
        TargetPointer eeClassOrCanonMT = methodTable.EEClassOrCanonMT;

        if (MethodTableFlags_1.GetEEClassOrCanonMTBits(eeClassOrCanonMT) == MethodTableFlags_1.EEClassOrCanonMTBits.EEClass)
        {
            return methodTable.EEClass;
        }
        else
        {
            TargetPointer canonicalMethodTablePtr = methodTable.CanonMT;
            NonValidatedMethodTable umt = GetMethodTableData(_target, canonicalMethodTablePtr);
            if (!umt.ValidateReadable())
            {
                throw new InvalidOperationException("canon MT is not readable");
            }
            return umt.EEClass;
        }
    }

    // NOTE: The continuation check is duplicated here and in RuntimeTypeSystem_1.IsContinuation.
    // TypeValidation runs before the MethodTable is added to the RuntimeTypeSystem's cache, so we
    // cannot call into RuntimeTypeSystem_1 — the type handle does not exist yet. Instead we
    // duplicate the check using the raw ParentMethodTable read from target memory.
    private bool IsContinuation(NonValidatedMethodTable methodTable)
    {
        return _continuationMethodTablePointer != TargetPointer.Null
            && methodTable.ParentMethodTable == _continuationMethodTablePointer;
    }

    internal bool TryValidateMethodTablePointer(TargetPointer methodTablePointer)
    {
        NonValidatedMethodTable nonvalidatedMethodTable = GetMethodTableData(_target, methodTablePointer);

        return ValidateMethodTablePointer(nonvalidatedMethodTable);
    }
}