File: src\Shared\ServerInfrastructure\StringUtilities.cs
Web Access
Project: src\src\Servers\Kestrel\Core\src\Microsoft.AspNetCore.Server.Kestrel.Core.csproj (Microsoft.AspNetCore.Server.Kestrel.Core)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Buffers;
using System.Diagnostics;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86;
using System.Text;
 
#nullable enable
 
namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure;
 
internal static class StringUtilities
{
    public static unsafe string GetAsciiOrUTF8StringNonNullCharacters(this ReadOnlySpan<byte> span, Encoding defaultEncoding)
    {
        if (span.IsEmpty)
        {
            return string.Empty;
        }
 
        fixed (byte* source = &MemoryMarshal.GetReference(span))
        {
            var resultString = string.Create(span.Length, (IntPtr)source, s_getAsciiOrUTF8StringNonNullCharacters);
 
            // If resultString is marked, perform UTF-8 encoding
            if (resultString[0] == '\0')
            {
                // null characters are considered invalid
                if (span.IndexOf((byte)0) != -1)
                {
                    throw new InvalidOperationException();
                }
 
                try
                {
                    resultString = defaultEncoding.GetString(span);
                }
                catch (DecoderFallbackException)
                {
                    throw new InvalidOperationException();
                }
            }
 
            return resultString;
        }
    }
 
    private static readonly unsafe SpanAction<char, IntPtr> s_getAsciiOrUTF8StringNonNullCharacters = (Span<char> buffer, IntPtr state) =>
    {
        fixed (char* output = &MemoryMarshal.GetReference(buffer))
        {
            // This version of AsciiUtilities returns false if there are any null ('\0') or non-Ascii
            // character (> 127) in the string.
            if (!TryGetAsciiString((byte*)state.ToPointer(), output, buffer.Length))
            {
                // Mark resultString for UTF-8 encoding
                output[0] = '\0';
            }
        }
    };
 
    public static unsafe string GetAsciiStringNonNullCharacters(this ReadOnlySpan<byte> span)
    {
        if (span.IsEmpty)
        {
            return string.Empty;
        }
 
        fixed (byte* source = &MemoryMarshal.GetReference(span))
        {
            return string.Create(span.Length, (IntPtr)source, s_getAsciiStringNonNullCharacters);
        }
    }
 
    private static readonly unsafe SpanAction<char, IntPtr> s_getAsciiStringNonNullCharacters = (Span<char> buffer, IntPtr state) =>
    {
        fixed (char* output = &MemoryMarshal.GetReference(buffer))
        {
            // This version of AsciiUtilities returns false if there are any null ('\0') or non-Ascii
            // character (> 127) in the string.
            if (!TryGetAsciiString((byte*)state.ToPointer(), output, buffer.Length))
            {
                throw new InvalidOperationException();
            }
        }
    };
 
    public static unsafe string GetLatin1StringNonNullCharacters(this ReadOnlySpan<byte> span)
    {
        if (span.IsEmpty)
        {
            return string.Empty;
        }
 
        fixed (byte* source = &MemoryMarshal.GetReference(span))
        {
            return string.Create(span.Length, (IntPtr)source, s_getLatin1StringNonNullCharacters);
        }
    }
 
    private static readonly unsafe SpanAction<char, IntPtr> s_getLatin1StringNonNullCharacters = (Span<char> buffer, IntPtr state) =>
    {
        fixed (char* output = &MemoryMarshal.GetReference(buffer))
        {
            if (!TryGetLatin1String((byte*)state.ToPointer(), output, buffer.Length))
            {
                // null characters are considered invalid
                throw new InvalidOperationException();
            }
        }
    };
 
    [MethodImpl(MethodImplOptions.AggressiveOptimization)]
    public static unsafe bool TryGetAsciiString(byte* input, char* output, int count)
    {
        Debug.Assert(input != null);
        Debug.Assert(output != null);
 
        var end = input + count;
 
        Debug.Assert((long)end >= Vector256<sbyte>.Count);
 
        // PERF: so the JIT can reuse the zero from a register
        var zero = Vector128<sbyte>.Zero;
 
        if (Sse2.IsSupported)
        {
            if (Avx2.IsSupported && input <= end - Vector256<sbyte>.Count)
            {
                var avxZero = Vector256<sbyte>.Zero;
 
                do
                {
                    var vector = Avx.LoadVector256(input).AsSByte();
                    if (!CheckBytesInAsciiRange(vector, avxZero))
                    {
                        return false;
                    }
 
                    var tmp0 = Avx2.UnpackLow(vector, avxZero);
                    var tmp1 = Avx2.UnpackHigh(vector, avxZero);
 
                    // Bring into the right order
                    var out0 = Avx2.Permute2x128(tmp0, tmp1, 0x20);
                    var out1 = Avx2.Permute2x128(tmp0, tmp1, 0x31);
 
                    Avx.Store((ushort*)output, out0.AsUInt16());
                    Avx.Store((ushort*)output + Vector256<ushort>.Count, out1.AsUInt16());
 
                    input += Vector256<sbyte>.Count;
                    output += Vector256<sbyte>.Count;
                } while (input <= end - Vector256<sbyte>.Count);
 
                if (input == end)
                {
                    return true;
                }
            }
 
            if (input <= end - Vector128<sbyte>.Count)
            {
                do
                {
                    var vector = Sse2.LoadVector128(input).AsSByte();
                    if (!CheckBytesInAsciiRange(vector, zero))
                    {
                        return false;
                    }
 
                    var c0 = Sse2.UnpackLow(vector, zero).AsUInt16();
                    var c1 = Sse2.UnpackHigh(vector, zero).AsUInt16();
 
                    Sse2.Store((ushort*)output, c0);
                    Sse2.Store((ushort*)output + Vector128<ushort>.Count, c1);
 
                    input += Vector128<sbyte>.Count;
                    output += Vector128<sbyte>.Count;
                } while (input <= end - Vector128<sbyte>.Count);
 
                if (input == end)
                {
                    return true;
                }
            }
        }
        else if (Vector.IsHardwareAccelerated)
        {
            while (input <= end - Vector<sbyte>.Count)
            {
                var vector = Unsafe.AsRef<Vector<sbyte>>(input);
                if (!CheckBytesInAsciiRange(vector))
                {
                    return false;
                }
 
                Vector.Widen(
                    vector,
                    out Unsafe.AsRef<Vector<short>>(output),
                    out Unsafe.AsRef<Vector<short>>(output + Vector<short>.Count));
 
                input += Vector<sbyte>.Count;
                output += Vector<sbyte>.Count;
            }
 
            if (input == end)
            {
                return true;
            }
        }
 
        if (Environment.Is64BitProcess) // Use Intrinsic switch for branch elimination
        {
            // 64-bit: Loop longs by default
            while (input <= end - sizeof(long))
            {
                var value = *(long*)input;
                if (!CheckBytesInAsciiRange(value))
                {
                    return false;
                }
 
                // BMI2 could be used, but this variant is faster on both Intel and AMD.
                if (Sse2.X64.IsSupported)
                {
                    var vecNarrow = Sse2.X64.ConvertScalarToVector128Int64(value).AsSByte();
                    var vecWide = Sse2.UnpackLow(vecNarrow, zero).AsUInt64();
                    Sse2.Store((ulong*)output, vecWide);
                }
                else
                {
                    output[0] = (char)input[0];
                    output[1] = (char)input[1];
                    output[2] = (char)input[2];
                    output[3] = (char)input[3];
                    output[4] = (char)input[4];
                    output[5] = (char)input[5];
                    output[6] = (char)input[6];
                    output[7] = (char)input[7];
                }
 
                input += sizeof(long);
                output += sizeof(long);
            }
 
            if (input <= end - sizeof(int))
            {
                var value = *(int*)input;
                if (!CheckBytesInAsciiRange(value))
                {
                    return false;
                }
 
                WidenFourAsciiBytesToUtf16AndWriteToBuffer(output, input, value, zero);
 
                input += sizeof(int);
                output += sizeof(int);
            }
        }
        else
        {
            // 32-bit: Loop ints by default
            while (input <= end - sizeof(int))
            {
                var value = *(int*)input;
                if (!CheckBytesInAsciiRange(value))
                {
                    return false;
                }
 
                WidenFourAsciiBytesToUtf16AndWriteToBuffer(output, input, value, zero);
 
                input += sizeof(int);
                output += sizeof(int);
            }
        }
 
        if (input <= end - sizeof(short))
        {
            if (!CheckBytesInAsciiRange(((short*)input)[0]))
            {
                return false;
            }
 
            output[0] = (char)input[0];
            output[1] = (char)input[1];
 
            input += sizeof(short);
            output += sizeof(short);
        }
 
        if (input < end)
        {
            if (!CheckBytesInAsciiRange(((sbyte*)input)[0]))
            {
                return false;
            }
            output[0] = (char)input[0];
        }
 
        return true;
    }
 
    [MethodImpl(MethodImplOptions.AggressiveOptimization)]
    public static unsafe bool TryGetLatin1String(byte* input, char* output, int count)
    {
        Debug.Assert(input != null);
        Debug.Assert(output != null);
 
        // Calculate end position
        var end = input + count;
        // Start as valid
        var isValid = true;
 
        do
        {
            // If Vector not-accelerated or remaining less than vector size
            if (!Vector.IsHardwareAccelerated || input > end - Vector<sbyte>.Count)
            {
                if (Environment.Is64BitProcess) // Use Intrinsic switch for branch elimination
                {
                    // 64-bit: Loop longs by default
                    while (input <= end - sizeof(long))
                    {
                        isValid &= CheckBytesNotNull(((long*)input)[0]);
 
                        output[0] = (char)input[0];
                        output[1] = (char)input[1];
                        output[2] = (char)input[2];
                        output[3] = (char)input[3];
                        output[4] = (char)input[4];
                        output[5] = (char)input[5];
                        output[6] = (char)input[6];
                        output[7] = (char)input[7];
 
                        input += sizeof(long);
                        output += sizeof(long);
                    }
                    if (input <= end - sizeof(int))
                    {
                        isValid &= CheckBytesNotNull(((int*)input)[0]);
 
                        output[0] = (char)input[0];
                        output[1] = (char)input[1];
                        output[2] = (char)input[2];
                        output[3] = (char)input[3];
 
                        input += sizeof(int);
                        output += sizeof(int);
                    }
                }
                else
                {
                    // 32-bit: Loop ints by default
                    while (input <= end - sizeof(int))
                    {
                        isValid &= CheckBytesNotNull(((int*)input)[0]);
 
                        output[0] = (char)input[0];
                        output[1] = (char)input[1];
                        output[2] = (char)input[2];
                        output[3] = (char)input[3];
 
                        input += sizeof(int);
                        output += sizeof(int);
                    }
                }
                if (input <= end - sizeof(short))
                {
                    isValid &= CheckBytesNotNull(((short*)input)[0]);
 
                    output[0] = (char)input[0];
                    output[1] = (char)input[1];
 
                    input += sizeof(short);
                    output += sizeof(short);
                }
                if (input < end)
                {
                    isValid &= CheckBytesNotNull(((sbyte*)input)[0]);
                    output[0] = (char)input[0];
                }
 
                return isValid;
            }
 
            // do/while as entry condition already checked
            do
            {
                // Use byte/ushort instead of signed equivalents to ensure it doesn't fill based on the high bit.
                var vector = Unsafe.AsRef<Vector<byte>>(input);
                isValid &= CheckBytesNotNull(vector);
                Vector.Widen(
                    vector,
                    out Unsafe.AsRef<Vector<ushort>>(output),
                    out Unsafe.AsRef<Vector<ushort>>(output + Vector<ushort>.Count));
 
                input += Vector<byte>.Count;
                output += Vector<byte>.Count;
            } while (input <= end - Vector<byte>.Count);
 
            // Vector path done, loop back to do non-Vector
            // If is a exact multiple of vector size, bail now
        } while (input < end);
 
        return isValid;
    }
 
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    public static bool BytesOrdinalEqualsStringAndAscii(string previousValue, ReadOnlySpan<byte> newValue)
    {
        // previousValue is a previously materialized string which *must* have already passed validation.
        Debug.Assert(IsValidHeaderString(previousValue));
 
        return Ascii.Equals(previousValue, newValue);
    }
 
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    private static unsafe void WidenFourAsciiBytesToUtf16AndWriteToBuffer(char* output, byte* input, int value, Vector128<sbyte> zero)
    {
        // BMI2 could be used, but this variant is faster on both Intel and AMD.
        if (Sse2.X64.IsSupported)
        {
            var vecNarrow = Sse2.ConvertScalarToVector128Int32(value).AsSByte();
            var vecWide = Sse2.UnpackLow(vecNarrow, zero).AsUInt64();
            Unsafe.WriteUnaligned(output, Sse2.X64.ConvertToUInt64(vecWide));
        }
        else
        {
            output[0] = (char)input[0];
            output[1] = (char)input[1];
            output[2] = (char)input[2];
            output[3] = (char)input[3];
        }
    }
 
    private static bool IsValidHeaderString(string value)
    {
        // Method for Debug.Assert to ensure BytesOrdinalEqualsStringAndAscii
        // is not called with an unvalidated string comparitor.
        try
        {
            if (value is null)
            {
                return false;
            }
            new UTF8Encoding(encoderShouldEmitUTF8Identifier: false, throwOnInvalidBytes: true).GetByteCount(value);
            return !value.Contains('\0');
        }
        catch (ArgumentOutOfRangeException)
        {
            return false; // 'value' too large to compute a UTF-8 byte count
        }
        catch (EncoderFallbackException)
        {
            return false; // 'value' cannot be converted losslessly to UTF-8
        }
    }
 
    /// <summary>
    /// A faster version of String.Concat(<paramref name="str"/>, <paramref name="separator"/>, <paramref name="number"/>.ToString("X8"))
    /// </summary>
    /// <param name="str"></param>
    /// <param name="separator"></param>
    /// <param name="number"></param>
    /// <returns></returns>
    public static string ConcatAsHexSuffix(string str, char separator, uint number)
    {
        var length = 1 + 8;
        if (str != null)
        {
            length += str.Length;
        }
 
        return string.Create(length, (str, separator, number), s_populateSpanWithHexSuffix);
    }
 
    private static readonly SpanAction<char, (string? str, char separator, uint number)> s_populateSpanWithHexSuffix = (Span<char> buffer, (string? str, char separator, uint number) tuple) =>
    {
        var (tupleStr, tupleSeparator, tupleNumber) = tuple;
 
        var i = 0;
        if (tupleStr != null)
        {
            tupleStr.AsSpan().CopyTo(buffer);
            i = tupleStr.Length;
        }
 
        buffer[i] = tupleSeparator;
        i++;
 
        if (Ssse3.IsSupported)
        {
            // The constant inline vectors are read from the data section without any additional
            // moves. See https://github.com/dotnet/runtime/issues/44115 Case 1.1 for further details.
 
            var lowNibbles = Ssse3.Shuffle(Vector128.CreateScalarUnsafe(tupleNumber).AsByte(), Vector128.Create(
                0xF, 0xF, 3, 0xF,
                0xF, 0xF, 2, 0xF,
                0xF, 0xF, 1, 0xF,
                0xF, 0xF, 0, 0xF
            ).AsByte());
 
            var highNibbles = Sse2.ShiftRightLogical(Sse2.ShiftRightLogical128BitLane(lowNibbles, 2).AsInt32(), 4).AsByte();
            var indices = Sse2.And(Sse2.Or(lowNibbles, highNibbles), Vector128.Create((byte)0xF));
 
            // Lookup the hex values at the positions of the indices
            var hex = Ssse3.Shuffle(Vector128.Create(
                (byte)'0', (byte)'1', (byte)'2', (byte)'3',
                (byte)'4', (byte)'5', (byte)'6', (byte)'7',
                (byte)'8', (byte)'9', (byte)'A', (byte)'B',
                (byte)'C', (byte)'D', (byte)'E', (byte)'F'
            ), indices);
 
            // The high bytes (0x00) of the chars have also been converted to ascii hex '0', so clear them out.
            hex = Sse2.And(hex, Vector128.Create((ushort)0xFF).AsByte());
 
            // This generates much more efficient asm than fixing the buffer and using
            // Sse2.Store((byte*)(p + i), chars.AsByte());
            Unsafe.WriteUnaligned(
                ref Unsafe.As<char, byte>(
                    ref Unsafe.Add(ref MemoryMarshal.GetReference(buffer), i)),
                hex);
        }
        else
        {
            var number = (int)tupleNumber;
            // Slice the buffer so we can use constant offsets in a backwards order
            // and the highest index [7] will eliminate the bounds checks for all the lower indicies.
            buffer = buffer.Slice(i);
 
            // This must be explicity typed as ReadOnlySpan<byte>
            // This then becomes a non-allocating mapping to the data section of the assembly.
            // If it is a var, Span<byte> or byte[], it allocates the byte array per call.
            ReadOnlySpan<byte> hexEncodeMap = "0123456789ABCDEF"u8;
            // Note: this only works with byte due to endian ambiguity for other types,
            // hence the later (char) casts
 
            buffer[7] = (char)hexEncodeMap[number & 0xF];
            buffer[6] = (char)hexEncodeMap[(number >> 4) & 0xF];
            buffer[5] = (char)hexEncodeMap[(number >> 8) & 0xF];
            buffer[4] = (char)hexEncodeMap[(number >> 12) & 0xF];
            buffer[3] = (char)hexEncodeMap[(number >> 16) & 0xF];
            buffer[2] = (char)hexEncodeMap[(number >> 20) & 0xF];
            buffer[1] = (char)hexEncodeMap[(number >> 24) & 0xF];
            buffer[0] = (char)hexEncodeMap[(number >> 28) & 0xF];
        }
    };
 
    [MethodImpl(MethodImplOptions.AggressiveInlining)] // Needs a push
    private static bool CheckBytesInAsciiRange(Vector<sbyte> check)
    {
        // Vectorized byte range check, signed byte > 0 for 1-127
        return Vector.GreaterThanAll(check, Vector<sbyte>.Zero);
    }
 
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    private static bool CheckBytesInAsciiRange(Vector256<sbyte> check, Vector256<sbyte> zero)
    {
        Debug.Assert(Avx2.IsSupported);
 
        var mask = Avx2.CompareGreaterThan(check, zero);
        return (uint)Avx2.MoveMask(mask) == 0xFFFF_FFFF;
    }
 
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    private static bool CheckBytesInAsciiRange(Vector128<sbyte> check, Vector128<sbyte> zero)
    {
        Debug.Assert(Sse2.IsSupported);
 
        var mask = Sse2.CompareGreaterThan(check, zero);
        return Sse2.MoveMask(mask) == 0xFFFF;
    }
 
    // Validate: bytes != 0 && bytes <= 127
    //  Subtract 1 from all bytes to move 0 to high bits
    //  bitwise or with self to catch all > 127 bytes
    //  mask off non high bits and check if 0
 
    [MethodImpl(MethodImplOptions.AggressiveInlining)] // Needs a push
    private static bool CheckBytesInAsciiRange(long check)
    {
        const long HighBits = unchecked((long)0x8080808080808080L);
        return (((check - 0x0101010101010101L) | check) & HighBits) == 0;
    }
 
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    private static bool CheckBytesInAsciiRange(int check)
    {
        const int HighBits = unchecked((int)0x80808080);
        return (((check - 0x01010101) | check) & HighBits) == 0;
    }
 
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    private static bool CheckBytesInAsciiRange(short check)
    {
        const short HighBits = unchecked((short)0x8080);
        return (((short)(check - 0x0101) | check) & HighBits) == 0;
    }
 
    private static bool CheckBytesInAsciiRange(sbyte check)
        => check > 0;
 
    [MethodImpl(MethodImplOptions.AggressiveInlining)] // Needs a push
    private static bool CheckBytesNotNull(Vector<byte> check)
    {
        // Vectorized byte range check, signed byte != null
        return !Vector.EqualsAny(check, Vector<byte>.Zero);
    }
 
    // Validate: bytes != 0
    //  Subtract 1 from all bytes to move 0 to high bits
    //  bitwise and with ~check so high bits are only set for bytes that were originally 0
    //  mask off non high bits and check if 0
 
    [MethodImpl(MethodImplOptions.AggressiveInlining)] // Needs a push
    private static bool CheckBytesNotNull(long check)
    {
        const long HighBits = unchecked((long)0x8080808080808080L);
        return ((check - 0x0101010101010101L) & ~check & HighBits) == 0;
    }
 
    private static bool CheckBytesNotNull(int check)
    {
        const int HighBits = unchecked((int)0x80808080);
        return ((check - 0x01010101) & ~check & HighBits) == 0;
    }
 
    private static bool CheckBytesNotNull(short check)
    {
        const short HighBits = unchecked((short)0x8080);
        return ((check - 0x0101) & ~check & HighBits) == 0;
    }
 
    private static bool CheckBytesNotNull(sbyte check)
        => check != 0;
}