File: CpuMathUtils.netstandard.cs
Web Access
Project: src\src\Microsoft.ML.CpuMath\Microsoft.ML.CpuMath.csproj (Microsoft.ML.CpuMath)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Numerics.Tensors;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using Microsoft.ML.Internal.CpuMath.Core;
 
namespace Microsoft.ML.Internal.CpuMath
{
    internal static partial class CpuMathUtils
    {
        // The count of bytes in Vector128<T>, corresponding to _cbAlign in AlignedArray
        private const int Vector128Alignment = 16;
 
        [MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
        public static int GetVectorAlignment()
            => Vector128Alignment;
 
        /// <summary>
        /// 
        /// 
        /// if <paramref name="a"/>'s alignment is suitable to SSE instructions. Returns <see langword="true"/>
        /// if <paramref name="a"/>'s alignment is ok and <see langword="false"/> otherwise.
        /// </summary>
        /// <param name="a">The vector being checked.</param>
        /// <returns>Whether <paramref name="a"/> is aligned well.</returns>
        private static bool Compat(AlignedArray a)
        {
            Contracts.AssertValue(a);
            Contracts.Assert(a.Size > 0);
            return a.CbAlign % Vector128Alignment == 0;
        }
 
        private static unsafe float* Ptr(AlignedArray a, float* p)
        {
            Contracts.AssertValue(a);
            float* q = p + a.GetBase((long)p);
            Contracts.Assert(((long)q & (Vector128Alignment - 1)) == 0);
            return q;
        }
 
        /// <summary>
        /// Compute the product of matrix <paramref name="mat"/> (the matrix is flattened because its type is <see cref="AlignedArray"/> instead of a matrix)
        /// and a vector <paramref name="src"/>.
        /// </summary>
        /// <param name="tran">Whether to transpose <paramref name="mat"/> before doing any computation.</param>
        /// <param name="mat">If <paramref name="tran"/> is <see langword="false"/>, <paramref name="mat"/> is a m-by-n matrix, and the value at the i-th row and the j-th column is indexed by i * n + j in <paramref name="mat"/>.
        /// If <paramref name="tran"/> is <see langword="true"/>, <paramref name="mat"/> would be viewed a n-by-m matrix, and the value at the i-th row and the j-th column in the transposed matrix is indexed by j * m + i in the
        /// original <paramref name="mat"/>.</param>
        /// <param name="src">A n-by-1 matrix, which is also a vector.</param>
        /// <param name="dst">A m-by-1 matrix, which is also a vector.</param>
        /// <param name="crun">The truncation level of <paramref name="dst"/>. For example, if <paramref name="crun"/> is 2, <paramref name="dst"/>
        /// will be considered as a 2-by-1 matrix and therefore elements after its 2nd element will be ignored. If no truncation should happen,
        /// set <paramref name="crun"/> to the length of <paramref name="dst"/>.</param>
        public static void MatrixTimesSource(bool tran, AlignedArray mat, AlignedArray src, AlignedArray dst, int crun)
        {
            Contracts.Assert(Compat(mat));
            Contracts.Assert(Compat(src));
            Contracts.Assert(Compat(dst));
            Contracts.Assert(mat.Size == dst.Size * src.Size);
 
            unsafe
            {
                fixed (float* pmat = &mat.Items[0])
                fixed (float* psrc = &src.Items[0])
                fixed (float* pdst = &dst.Items[0])
                {
                    if (!tran)
                    {
                        Contracts.Assert(0 <= crun && crun <= dst.Size);
                        Thunk.MatMul(Ptr(mat, pmat), Ptr(src, psrc), Ptr(dst, pdst), crun, src.Size);
                    }
                    else
                    {
                        Contracts.Assert(0 <= crun && crun <= src.Size);
                        Thunk.MatMulTran(Ptr(mat, pmat), Ptr(src, psrc), Ptr(dst, pdst), dst.Size, crun);
                    }
                }
            }
        }
 
        public static void MatrixTimesSource(AlignedArray mat, ReadOnlySpan<int> rgposSrc, AlignedArray srcValues,
            int posMin, int iposMin, int iposLim, AlignedArray dst, int crun)
        {
            Contracts.Assert(Compat(mat));
            Contracts.Assert(Compat(srcValues));
            Contracts.Assert(Compat(dst));
            Contracts.Assert(0 <= iposMin && iposMin <= iposLim && iposLim <= rgposSrc.Length);
            Contracts.Assert(mat.Size == dst.Size * srcValues.Size);
 
            if (iposMin >= iposLim)
            {
                dst.ZeroItems();
                return;
            }
            Contracts.AssertNonEmpty(rgposSrc);
            unsafe
            {
                fixed (float* pdst = &dst.Items[0])
                fixed (float* pmat = &mat.Items[0])
                fixed (float* psrc = &srcValues.Items[0])
                fixed (int* ppossrc = &rgposSrc[0])
                {
                    Contracts.Assert(0 <= crun && crun <= dst.Size);
                    Thunk.MatMulP(Ptr(mat, pmat), ppossrc, Ptr(srcValues, psrc), posMin, iposMin, iposLim, Ptr(dst, pdst), crun, srcValues.Size);
                }
            }
        }
 
        // dst[i] = a * (dst[i] + b)
        public static void ScaleAdd(float a, float b, Span<float> dst)
        {
            Contracts.AssertNonEmpty(dst);
 
            unsafe
            {
                fixed (float* pdst = &MemoryMarshal.GetReference(dst))
                    Thunk.ScaleAddU(a, b, pdst, dst.Length);
            }
        }
 
        public static void AddScale(float a, ReadOnlySpan<float> src, ReadOnlySpan<int> indices, Span<float> dst, int count)
        {
            Contracts.AssertNonEmpty(src);
            Contracts.Assert(0 < count && count <= src.Length);
            Contracts.AssertNonEmpty(indices);
            Contracts.Assert(count <= indices.Length);
            Contracts.AssertNonEmpty(dst);
            Contracts.Assert(count < dst.Length);
 
            unsafe
            {
                fixed (float* psrc = &MemoryMarshal.GetReference(src))
                fixed (int* pi = &MemoryMarshal.GetReference(indices))
                fixed (float* pdst = &MemoryMarshal.GetReference(dst))
                    Thunk.AddScaleSU(a, psrc, pi, pdst, count);
            }
        }
 
        public static void Add(ReadOnlySpan<float> src, ReadOnlySpan<int> indices, Span<float> dst, int count)
        {
            Contracts.AssertNonEmpty(src);
            Contracts.Assert(0 < count && count <= src.Length);
            Contracts.AssertNonEmpty(indices);
            Contracts.Assert(count <= indices.Length);
            Contracts.AssertNonEmpty(dst);
            Contracts.Assert(count < dst.Length);
 
            unsafe
            {
                fixed (float* ps = &MemoryMarshal.GetReference(src))
                fixed (int* pi = &MemoryMarshal.GetReference(indices))
                fixed (float* pd = &MemoryMarshal.GetReference(dst))
                    Thunk.AddSU(ps, pi, pd, count);
            }
        }
 
        public static float SumSq(float mean, ReadOnlySpan<float> src)
        {
            Contracts.AssertNonEmpty(src);
 
            unsafe
            {
                fixed (float* psrc = &MemoryMarshal.GetReference(src))
                    return (mean == 0 ? Thunk.SumSqU(psrc, src.Length) : Thunk.SumSqDiffU(mean, psrc, src.Length));
            }
        }
 
        public static float SumAbs(float mean, ReadOnlySpan<float> src)
        {
            Contracts.AssertNonEmpty(src);
 
            unsafe
            {
                fixed (float* psrc = &MemoryMarshal.GetReference(src))
                    return (mean == 0 ? Thunk.SumAbsU(psrc, src.Length) : Thunk.SumAbsDiffU(mean, psrc, src.Length));
            }
        }
 
        public static float MaxAbs(ReadOnlySpan<float> src)
        {
            Contracts.AssertNonEmpty(src);
 
            unsafe
            {
                fixed (float* psrc = &MemoryMarshal.GetReference(src))
                    return Thunk.MaxAbsU(psrc, src.Length);
            }
        }
 
        public static float MaxAbsDiff(float mean, ReadOnlySpan<float> src)
        {
            Contracts.AssertNonEmpty(src);
 
            unsafe
            {
                fixed (float* psrc = &MemoryMarshal.GetReference(src))
                    return Thunk.MaxAbsDiffU(mean, psrc, src.Length);
            }
        }
 
        public static float DotProductSparse(ReadOnlySpan<float> a, ReadOnlySpan<float> b, ReadOnlySpan<int> indices, int count)
        {
            Contracts.AssertNonEmpty(a);
            Contracts.AssertNonEmpty(b);
            Contracts.Assert(0 < count);
            Contracts.Assert(count < a.Length);
            Contracts.Assert(count <= b.Length);
            Contracts.Assert(count <= indices.Length);
 
            unsafe
            {
                fixed (float* pa = &MemoryMarshal.GetReference(a))
                fixed (float* pb = &MemoryMarshal.GetReference(b))
                fixed (int* pi = &MemoryMarshal.GetReference(indices))
                    return Thunk.DotSU(pa, pb, pi, count);
            }
        }
 
        public static float L2DistSquared(ReadOnlySpan<float> a, ReadOnlySpan<float> b, int count)
        {
            Contracts.AssertNonEmpty(a);
            Contracts.AssertNonEmpty(b);
            Contracts.Assert(0 < count && count <= a.Length);
            Contracts.Assert(count <= b.Length);
 
            var value = TensorPrimitives.Distance(a.Slice(0, count), b.Slice(0, count));
            return value * value;
        }
 
        public static void ZeroMatrixItems(AlignedArray dst, int ccol, int cfltRow, int[] indices)
        {
            Contracts.Assert(0 < ccol && ccol <= cfltRow);
 
            unsafe
            {
                fixed (float* pdst = &dst.Items[0])
                fixed (int* pi = &indices[0])
                {
                    if (ccol == cfltRow)
                        Thunk.ZeroItemsU(Ptr(dst, pdst), dst.Size, pi, indices.Length);
                    else
                        Thunk.ZeroMatrixItemsCore(Ptr(dst, pdst), dst.Size, ccol, cfltRow, pi, indices.Length);
                }
            }
        }
 
        public static void SdcaL1UpdateDense(float primalUpdate, int count, ReadOnlySpan<float> src, float threshold, Span<float> v, Span<float> w)
        {
            Contracts.AssertNonEmpty(src);
            Contracts.Assert(count <= src.Length);
            Contracts.AssertNonEmpty(v);
            Contracts.Assert(count <= v.Length);
            Contracts.AssertNonEmpty(w);
            Contracts.Assert(count <= w.Length);
            Contracts.Assert(count > 0);
 
            unsafe
            {
                fixed (float* psrc = &MemoryMarshal.GetReference(src))
                fixed (float* pd1 = &MemoryMarshal.GetReference(v))
                fixed (float* pd2 = &MemoryMarshal.GetReference(w))
                    Thunk.SdcaL1UpdateU(primalUpdate, psrc, threshold, pd1, pd2, count);
            }
        }
 
        public static void SdcaL1UpdateSparse(float primalUpdate, int count, ReadOnlySpan<float> source, ReadOnlySpan<int> indices, float threshold, Span<float> v, Span<float> w)
        {
            Contracts.AssertNonEmpty(source);
            Contracts.Assert(count <= source.Length);
            Contracts.AssertNonEmpty(indices);
            Contracts.Assert(count <= indices.Length);
            Contracts.AssertNonEmpty(v);
            Contracts.Assert(count <= v.Length);
            Contracts.AssertNonEmpty(w);
            Contracts.Assert(count <= w.Length);
            Contracts.Assert(count > 0);
 
            unsafe
            {
                fixed (float* psrc = &MemoryMarshal.GetReference(source))
                fixed (int* pi = &MemoryMarshal.GetReference(indices))
                fixed (float* pd1 = &MemoryMarshal.GetReference(v))
                fixed (float* pd2 = &MemoryMarshal.GetReference(w))
                    Thunk.SdcaL1UpdateSU(primalUpdate, psrc, pi, threshold, pd1, pd2, count);
            }
        }
    }
}