File: System\Security\Cryptography\X509Certificates\WindowsHelpers.cs
Web Access
Project: src\src\runtime\src\libraries\System.Security.Cryptography\src\System.Security.Cryptography.csproj (System.Security.Cryptography)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Diagnostics;
using System.Runtime.InteropServices;
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;
using System.Text;

namespace Internal.Cryptography
{
    internal static partial class Helpers
    {
        /// <summary>
        /// Convert each Oid's value to an ASCII string, then create an unmanaged array of "numOids" LPSTR pointers, one for each Oid.
        /// "numOids" is the number of LPSTR pointers. This is normally the same as oids.Count, except in the case where a malicious caller
        /// appends to the OidCollection while this method is in progress. In such a case, this method guarantees only that this won't create
        /// an unmanaged buffer overflow condition.
        /// </summary>
        public static SafeHandle ToLpstrArray(this OidCollection? oids, out int numOids)
        {
            if (oids == null || oids.Count == 0)
            {
                numOids = 0;
                return SafeLocalAllocHandle.InvalidHandle;
            }

            // Copy the oid strings to a local array to prevent a security race condition where
            // the OidCollection or individual oids can be modified by another thread and
            // potentially cause a buffer overflow
            var oidStrings = new string[oids.Count];

            for (int i = 0; i < oidStrings.Length; i++)
            {
                oidStrings[i] = oids[i].Value!;
            }

            unsafe
            {
                int allocationSize = checked(oidStrings.Length * sizeof(void*));
                foreach (string oidString in oidStrings)
                {
                    checked
                    {
                        allocationSize += oidString.Length + 1; // Encoding.ASCII doesn't have a fallback, so it's fine to use String.Length
                    }
                }

                SafeLocalAllocHandle safeLocalAllocHandle = SafeLocalAllocHandle.Create(allocationSize);
                byte** pOidPointers = (byte**)(safeLocalAllocHandle.DangerousGetHandle());
                byte* pOidContents = (byte*)(pOidPointers + oidStrings.Length);

                for (int i = 0; i < oidStrings.Length; i++)
                {
                    string oidString = oidStrings[i];

                    pOidPointers[i] = pOidContents;

                    int bytesWritten = Encoding.ASCII.GetBytes(oidString, new Span<byte>(pOidContents, oidString.Length));
                    Debug.Assert(bytesWritten == oidString.Length);

                    pOidContents[oidString.Length] = 0;
                    pOidContents += oidString.Length + 1;
                }

                numOids = oidStrings.Length;
                return safeLocalAllocHandle;
            }
        }

        public unsafe delegate void DecodedObjectReceiver(void* pvDecodedObject, int cbDecodedObject);
        public unsafe delegate TResult DecodedObjectReceiver<TResult>(void* pvDecodedObject, int cbDecodedObject);
        public unsafe delegate TResult DecodedObjectReceiver<TState, TResult>(void* pvDecodedObject, int cbDecodedObject, TState state);

        public static TResult DecodeObject<TResult>(
            this byte[] encoded,
            CryptDecodeObjectStructType lpszStructType,
            DecodedObjectReceiver<TResult> receiver)
        {
            unsafe
            {
                int cb = 0;

                if (!Interop.crypt32.CryptDecodeObjectPointer(
                    Interop.Crypt32.CertEncodingType.All,
                    lpszStructType,
                    encoded,
                    encoded.Length,
                    Interop.Crypt32.CryptDecodeObjectFlags.None,
                    null,
                    ref cb))
                {
                    throw Marshal.GetLastPInvokeError().ToCryptographicException();
                }

                int MaxStackAllocSize = 256;
                Span<byte> decoded = stackalloc byte[MaxStackAllocSize];

                if ((uint)cb > MaxStackAllocSize)
                {
                    decoded = new byte[cb];
                }

                fixed (byte* pDecoded = decoded)
                {
                    if (!Interop.crypt32.CryptDecodeObjectPointer(
                        Interop.Crypt32.CertEncodingType.All,
                        lpszStructType,
                        encoded,
                        encoded.Length,
                        Interop.Crypt32.CryptDecodeObjectFlags.None,
                        pDecoded,
                        ref cb))
                    {
                        throw Marshal.GetLastPInvokeError().ToCryptographicException();
                    }

                    return receiver(pDecoded, cb);
                }
            }
        }

        public static unsafe TResult DecodeObject<TResult>(
            this byte[] encoded,
            string lpszStructType,
            DecodedObjectReceiver<TResult> receiver)
        {
            int cb = 0;

            if (!Interop.Crypt32.CryptDecodeObjectPointer(
                Interop.Crypt32.CertEncodingType.All,
                lpszStructType,
                encoded,
                encoded.Length,
                Interop.Crypt32.CryptDecodeObjectFlags.None,
                null,
                ref cb))
            {
                throw Marshal.GetLastPInvokeError().ToCryptographicException();
            }

            const int MaxStackAllocSize = 256;
            Span<byte> decoded = stackalloc byte[MaxStackAllocSize];

            if ((uint)cb > MaxStackAllocSize)
            {
                decoded = new byte[cb];
            }

            fixed (byte* pDecoded = decoded)
            {
                if (!Interop.Crypt32.CryptDecodeObjectPointer(
                    Interop.Crypt32.CertEncodingType.All,
                    lpszStructType,
                    encoded,
                    encoded.Length,
                    Interop.Crypt32.CryptDecodeObjectFlags.None,
                    pDecoded,
                    ref cb))
                {
                    throw Marshal.GetLastPInvokeError().ToCryptographicException();
                }

                return receiver(pDecoded, cb);
            }
        }

        public static unsafe bool DecodeObjectNoThrow<TState, TResult>(
            this ReadOnlySpan<byte> encoded,
            CryptDecodeObjectStructType lpszStructType,
            TState state,
            DecodedObjectReceiver<TState, TResult> receiver,
            out TResult? result)
        {
            int cb = 0;

            if (!Interop.crypt32.CryptDecodeObjectPointer(
                Interop.Crypt32.CertEncodingType.All,
                lpszStructType,
                encoded,
                Interop.Crypt32.CryptDecodeObjectFlags.None,
                null,
                ref cb))
            {
                result = default;
                return false;
            }

            const int MaxStackAllocSize = 256;
            Span<byte> decoded = stackalloc byte[MaxStackAllocSize];

            if ((uint)cb > MaxStackAllocSize)
            {
                decoded = new byte[cb];
            }

            fixed (byte* pDecoded = decoded)
            {
                if (!Interop.crypt32.CryptDecodeObjectPointer(
                    Interop.Crypt32.CertEncodingType.All,
                    lpszStructType,
                    encoded,
                    Interop.Crypt32.CryptDecodeObjectFlags.None,
                    pDecoded,
                    ref cb))
                {
                    result = default;
                    return false;
                }

                result = receiver(pDecoded, cb, state);
            }

            return true;
        }
    }
}