|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.Diagnostics;
using System.Runtime.Intrinsics;
namespace System.Numerics.Tensors
{
public static partial class TensorPrimitives
{
/// <summary>Computes the element-wise result of raising <c>e</c> to the number powers in the specified tensor.</summary>
/// <param name="x">The tensor, represented as a span.</param>
/// <param name="destination">The destination tensor, represented as a span.</param>
/// <exception cref="ArgumentException">Destination is too short.</exception>
/// <exception cref="ArgumentException"><paramref name="x"/> and <paramref name="destination"/> reference overlapping memory locations and do not begin at the same location.</exception>
/// <remarks>
/// <para>
/// This method effectively computes <c><paramref name="destination" />[i] = <typeparamref name="T"/>.Exp(<paramref name="x" />[i])</c>.
/// </para>
/// <para>
/// If a value equals <see cref="IFloatingPointIeee754{TSelf}.NaN"/> or <see cref="IFloatingPointIeee754{TSelf}.PositiveInfinity"/>, the result stored into the corresponding destination location is set to NaN.
/// If a value equals <see cref="IFloatingPointIeee754{TSelf}.NegativeInfinity"/>, the result stored into the corresponding destination location is set to 0.
/// </para>
/// <para>
/// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different
/// operating systems or architectures.
/// </para>
/// </remarks>
public static void Exp<T>(ReadOnlySpan<T> x, Span<T> destination)
where T : IExponentialFunctions<T> =>
InvokeSpanIntoSpan<T, ExpOperator<T>>(x, destination);
/// <summary>T.Exp(x)</summary>
internal readonly struct ExpOperator<T> : IUnaryOperator<T, T>
where T : IExponentialFunctions<T>
{
public static bool Vectorizable => (typeof(T) == typeof(double))
|| (typeof(T) == typeof(float));
public static T Invoke(T x) => T.Exp(x);
public static Vector128<T> Invoke(Vector128<T> x)
{
#if NET9_0_OR_GREATER
if (typeof(T) == typeof(double))
{
return Vector128.Exp(x.AsDouble()).As<double, T>();
}
else
{
Debug.Assert(typeof(T) == typeof(float));
return Vector128.Exp(x.AsSingle()).As<float, T>();
}
#else
if (typeof(T) == typeof(double))
{
return ExpOperatorDouble.Invoke(x.AsDouble()).As<double, T>();
}
else
{
Debug.Assert(typeof(T) == typeof(float));
return ExpOperatorSingle.Invoke(x.AsSingle()).As<float, T>();
}
#endif
}
public static Vector256<T> Invoke(Vector256<T> x)
{
#if NET9_0_OR_GREATER
if (typeof(T) == typeof(double))
{
return Vector256.Exp(x.AsDouble()).As<double, T>();
}
else
{
Debug.Assert(typeof(T) == typeof(float));
return Vector256.Exp(x.AsSingle()).As<float, T>();
}
#else
if (typeof(T) == typeof(double))
{
return ExpOperatorDouble.Invoke(x.AsDouble()).As<double, T>();
}
else
{
Debug.Assert(typeof(T) == typeof(float));
return ExpOperatorSingle.Invoke(x.AsSingle()).As<float, T>();
}
#endif
}
public static Vector512<T> Invoke(Vector512<T> x)
{
#if NET9_0_OR_GREATER
if (typeof(T) == typeof(double))
{
return Vector512.Exp(x.AsDouble()).As<double, T>();
}
else
{
Debug.Assert(typeof(T) == typeof(float));
return Vector512.Exp(x.AsSingle()).As<float, T>();
}
#else
if (typeof(T) == typeof(double))
{
return ExpOperatorDouble.Invoke(x.AsDouble()).As<double, T>();
}
else
{
Debug.Assert(typeof(T) == typeof(float));
return ExpOperatorSingle.Invoke(x.AsSingle()).As<float, T>();
}
#endif
}
}
#if !NET9_0_OR_GREATER
/// <summary>double.Exp(x)</summary>
private readonly struct ExpOperatorDouble : IUnaryOperator<double, double>
{
// This code is based on `vrd2_exp` from amd/aocl-libm-ose
// Copyright (C) 2019-2020 Advanced Micro Devices, Inc. All rights reserved.
//
// Licensed under the BSD 3-Clause "New" or "Revised" License
// See THIRD-PARTY-NOTICES.TXT for the full license text
// Implementation Notes
// ----------------------
// 1. Argument Reduction:
// e^x = 2^(x/ln2) = 2^(x*(64/ln(2))/64) --- (1)
//
// Choose 'n' and 'f', such that
// x * 64/ln2 = n + f --- (2) | n is integer
// | |f| <= 0.5
// Choose 'm' and 'j' such that,
// n = (64 * m) + j --- (3)
//
// From (1), (2) and (3),
// e^x = 2^((64*m + j + f)/64)
// = (2^m) * (2^(j/64)) * 2^(f/64)
// = (2^m) * (2^(j/64)) * e^(f*(ln(2)/64))
//
// 2. Table Lookup
// Values of (2^(j/64)) are precomputed, j = 0, 1, 2, 3 ... 63
//
// 3. Polynomial Evaluation
// From (2),
// f = x*(64/ln(2)) - n
// Let,
// r = f*(ln(2)/64) = x - n*(ln(2)/64)
//
// 4. Reconstruction
// Thus,
// e^x = (2^m) * (2^(j/64)) * e^r
private const ulong V_ARG_MAX = 0x40862000_00000000;
private const ulong V_DP64_BIAS = 1023;
private const double V_EXPF_HUGE = 6755399441055744;
private const double V_TBL_LN2 = 1.4426950408889634;
private const double V_LN2_HEAD = +0.693359375;
private const double V_LN2_TAIL = -0.00021219444005469057;
private const double C3 = 0.5000000000000018;
private const double C4 = 0.1666666666666617;
private const double C5 = 0.04166666666649277;
private const double C6 = 0.008333333333559272;
private const double C7 = 0.001388888895122404;
private const double C8 = 0.00019841269432677495;
private const double C9 = 2.4801486521374483E-05;
private const double C10 = 2.7557622532543023E-06;
private const double C11 = 2.7632293298250954E-07;
private const double C12 = 2.499430431958571E-08;
public static bool Vectorizable => true;
public static double Invoke(double x) => double.Exp(x);
public static Vector128<double> Invoke(Vector128<double> x)
{
// Check if -709 < vx < 709
if (Vector128.LessThanOrEqualAll(Vector128.Abs(x).AsUInt64(), Vector128.Create(V_ARG_MAX)))
{
// x * (64.0 / ln(2))
Vector128<double> z = x * Vector128.Create(V_TBL_LN2);
Vector128<double> dn = z + Vector128.Create(V_EXPF_HUGE);
// n = (int)z
Vector128<ulong> n = dn.AsUInt64();
// dn = (double)n
dn -= Vector128.Create(V_EXPF_HUGE);
// r = x - (dn * (ln(2) / 64))
// where ln(2) / 64 is split into Head and Tail values
Vector128<double> r = x - (dn * Vector128.Create(V_LN2_HEAD)) - (dn * Vector128.Create(V_LN2_TAIL));
Vector128<double> r2 = r * r;
Vector128<double> r4 = r2 * r2;
Vector128<double> r8 = r4 * r4;
// Compute polynomial
Vector128<double> poly = ((Vector128.Create(C12) * r + Vector128.Create(C11)) * r2 +
Vector128.Create(C10) * r + Vector128.Create(C9)) * r8 +
((Vector128.Create(C8) * r + Vector128.Create(C7)) * r2 +
(Vector128.Create(C6) * r + Vector128.Create(C5))) * r4 +
((Vector128.Create(C4) * r + Vector128.Create(C3)) * r2 + (r + Vector128<double>.One));
// m = (n - j) / 64
// result = polynomial * 2^m
return poly * ((n + Vector128.Create(V_DP64_BIAS)) << 52).AsDouble();
}
else
{
return ScalarFallback(x);
static Vector128<double> ScalarFallback(Vector128<double> x) =>
Vector128.Create(Math.Exp(x.GetElement(0)),
Math.Exp(x.GetElement(1)));
}
}
public static Vector256<double> Invoke(Vector256<double> x)
{
// Check if -709 < vx < 709
if (Vector256.LessThanOrEqualAll(Vector256.Abs(x).AsUInt64(), Vector256.Create(V_ARG_MAX)))
{
// x * (64.0 / ln(2))
Vector256<double> z = x * Vector256.Create(V_TBL_LN2);
Vector256<double> dn = z + Vector256.Create(V_EXPF_HUGE);
// n = (int)z
Vector256<ulong> n = dn.AsUInt64();
// dn = (double)n
dn -= Vector256.Create(V_EXPF_HUGE);
// r = x - (dn * (ln(2) / 64))
// where ln(2) / 64 is split into Head and Tail values
Vector256<double> r = x - (dn * Vector256.Create(V_LN2_HEAD)) - (dn * Vector256.Create(V_LN2_TAIL));
Vector256<double> r2 = r * r;
Vector256<double> r4 = r2 * r2;
Vector256<double> r8 = r4 * r4;
// Compute polynomial
Vector256<double> poly = ((Vector256.Create(C12) * r + Vector256.Create(C11)) * r2 +
Vector256.Create(C10) * r + Vector256.Create(C9)) * r8 +
((Vector256.Create(C8) * r + Vector256.Create(C7)) * r2 +
(Vector256.Create(C6) * r + Vector256.Create(C5))) * r4 +
((Vector256.Create(C4) * r + Vector256.Create(C3)) * r2 + (r + Vector256<double>.One));
// m = (n - j) / 64
// result = polynomial * 2^m
return poly * ((n + Vector256.Create(V_DP64_BIAS)) << 52).AsDouble();
}
else
{
return ScalarFallback(x);
static Vector256<double> ScalarFallback(Vector256<double> x) =>
Vector256.Create(Math.Exp(x.GetElement(0)),
Math.Exp(x.GetElement(1)),
Math.Exp(x.GetElement(2)),
Math.Exp(x.GetElement(3)));
}
}
public static Vector512<double> Invoke(Vector512<double> x)
{
// Check if -709 < vx < 709
if (Vector512.LessThanOrEqualAll(Vector512.Abs(x).AsUInt64(), Vector512.Create(V_ARG_MAX)))
{
// x * (64.0 / ln(2))
Vector512<double> z = x * Vector512.Create(V_TBL_LN2);
Vector512<double> dn = z + Vector512.Create(V_EXPF_HUGE);
// n = (int)z
Vector512<ulong> n = dn.AsUInt64();
// dn = (double)n
dn -= Vector512.Create(V_EXPF_HUGE);
// r = x - (dn * (ln(2) / 64))
// where ln(2) / 64 is split into Head and Tail values
Vector512<double> r = x - (dn * Vector512.Create(V_LN2_HEAD)) - (dn * Vector512.Create(V_LN2_TAIL));
Vector512<double> r2 = r * r;
Vector512<double> r4 = r2 * r2;
Vector512<double> r8 = r4 * r4;
// Compute polynomial
Vector512<double> poly = ((Vector512.Create(C12) * r + Vector512.Create(C11)) * r2 +
Vector512.Create(C10) * r + Vector512.Create(C9)) * r8 +
((Vector512.Create(C8) * r + Vector512.Create(C7)) * r2 +
(Vector512.Create(C6) * r + Vector512.Create(C5))) * r4 +
((Vector512.Create(C4) * r + Vector512.Create(C3)) * r2 + (r + Vector512<double>.One));
// m = (n - j) / 64
// result = polynomial * 2^m
return poly * ((n + Vector512.Create(V_DP64_BIAS)) << 52).AsDouble();
}
else
{
return ScalarFallback(x);
static Vector512<double> ScalarFallback(Vector512<double> x) =>
Vector512.Create(Math.Exp(x.GetElement(0)),
Math.Exp(x.GetElement(1)),
Math.Exp(x.GetElement(2)),
Math.Exp(x.GetElement(3)),
Math.Exp(x.GetElement(4)),
Math.Exp(x.GetElement(5)),
Math.Exp(x.GetElement(6)),
Math.Exp(x.GetElement(7)));
}
}
}
/// <summary>float.Exp(x)</summary>
private readonly struct ExpOperatorSingle : IUnaryOperator<float, float>
{
// This code is based on `vrs4_expf` from amd/aocl-libm-ose
// Copyright (C) 2019-2022 Advanced Micro Devices, Inc. All rights reserved.
//
// Licensed under the BSD 3-Clause "New" or "Revised" License
// See THIRD-PARTY-NOTICES.TXT for the full license text
// Implementation Notes:
// 1. Argument Reduction:
// e^x = 2^(x/ln2) --- (1)
//
// Let x/ln(2) = z --- (2)
//
// Let z = n + r , where n is an integer --- (3)
// |r| <= 1/2
//
// From (1), (2) and (3),
// e^x = 2^z
// = 2^(N+r)
// = (2^N)*(2^r) --- (4)
//
// 2. Polynomial Evaluation
// From (4),
// r = z - N
// 2^r = C1 + C2*r + C3*r^2 + C4*r^3 + C5 *r^4 + C6*r^5
//
// 4. Reconstruction
// Thus,
// e^x = (2^N) * (2^r)
private const uint V_ARG_MAX = 0x42AE0000;
private const float V_EXPF_MIN = -103.97208f;
private const float V_EXPF_MAX = +88.72284f;
private const double V_EXPF_HUGE = 6755399441055744;
private const double V_TBL_LN2 = 1.4426950408889634;
private const double C1 = 1.0000000754895704;
private const double C2 = 0.6931472254087585;
private const double C3 = 0.2402210737432219;
private const double C4 = 0.05550297297702539;
private const double C5 = 0.009676036358193323;
private const double C6 = 0.001341000536524434;
public static bool Vectorizable => true;
public static float Invoke(float x) => float.Exp(x);
public static Vector128<float> Invoke(Vector128<float> x)
{
// Convert x to double precision
(Vector128<double> xl, Vector128<double> xu) = Vector128.Widen(x);
// x * (64.0 / ln(2))
Vector128<double> v_tbl_ln2 = Vector128.Create(V_TBL_LN2);
Vector128<double> zl = xl * v_tbl_ln2;
Vector128<double> zu = xu * v_tbl_ln2;
Vector128<double> v_expf_huge = Vector128.Create(V_EXPF_HUGE);
Vector128<double> dnl = zl + v_expf_huge;
Vector128<double> dnu = zu + v_expf_huge;
// n = (int)z
Vector128<ulong> nl = dnl.AsUInt64();
Vector128<ulong> nu = dnu.AsUInt64();
// dn = (double)n
dnl -= v_expf_huge;
dnu -= v_expf_huge;
// r = z - dn
Vector128<double> c1 = Vector128.Create(C1);
Vector128<double> c2 = Vector128.Create(C2);
Vector128<double> c3 = Vector128.Create(C3);
Vector128<double> c4 = Vector128.Create(C4);
Vector128<double> c5 = Vector128.Create(C5);
Vector128<double> c6 = Vector128.Create(C6);
Vector128<double> rl = zl - dnl;
Vector128<double> rl2 = rl * rl;
Vector128<double> rl4 = rl2 * rl2;
Vector128<double> polyl = (c4 * rl + c3) * rl2
+ ((c6 * rl + c5) * rl4
+ (c2 * rl + c1));
Vector128<double> ru = zu - dnu;
Vector128<double> ru2 = ru * ru;
Vector128<double> ru4 = ru2 * ru2;
Vector128<double> polyu = (c4 * ru + c3) * ru2
+ ((c6 * ru + c5) * ru4
+ (c2 * ru + c1));
// result = (float)(poly + (n << 52))
Vector128<float> ret = Vector128.Narrow(
(polyl.AsUInt64() + (nl << 52)).AsDouble(),
(polyu.AsUInt64() + (nu << 52)).AsDouble()
);
// Check if -103 < |x| < 88
if (Vector128.GreaterThanAny(Vector128.Abs(x).AsUInt32(), Vector128.Create(V_ARG_MAX)))
{
// (x > V_EXPF_MAX) ? float.PositiveInfinity : x
Vector128<float> infinityMask = Vector128.GreaterThan(x, Vector128.Create(V_EXPF_MAX));
ret = Vector128.ConditionalSelect(
infinityMask,
Vector128.Create(float.PositiveInfinity),
ret
);
// (x < V_EXPF_MIN) ? 0 : x
ret = Vector128.AndNot(ret, Vector128.LessThan(x, Vector128.Create(V_EXPF_MIN)));
}
return ret;
}
public static Vector256<float> Invoke(Vector256<float> x)
{
// Convert x to double precision
(Vector256<double> xl, Vector256<double> xu) = Vector256.Widen(x);
// x * (64.0 / ln(2))
Vector256<double> v_tbl_ln2 = Vector256.Create(V_TBL_LN2);
Vector256<double> zl = xl * v_tbl_ln2;
Vector256<double> zu = xu * v_tbl_ln2;
Vector256<double> v_expf_huge = Vector256.Create(V_EXPF_HUGE);
Vector256<double> dnl = zl + v_expf_huge;
Vector256<double> dnu = zu + v_expf_huge;
// n = (int)z
Vector256<ulong> nl = dnl.AsUInt64();
Vector256<ulong> nu = dnu.AsUInt64();
// dn = (double)n
dnl -= v_expf_huge;
dnu -= v_expf_huge;
// r = z - dn
Vector256<double> c1 = Vector256.Create(C1);
Vector256<double> c2 = Vector256.Create(C2);
Vector256<double> c3 = Vector256.Create(C3);
Vector256<double> c4 = Vector256.Create(C4);
Vector256<double> c5 = Vector256.Create(C5);
Vector256<double> c6 = Vector256.Create(C6);
Vector256<double> rl = zl - dnl;
Vector256<double> rl2 = rl * rl;
Vector256<double> rl4 = rl2 * rl2;
Vector256<double> polyl = (c4 * rl + c3) * rl2
+ ((c6 * rl + c5) * rl4
+ (c2 * rl + c1));
Vector256<double> ru = zu - dnu;
Vector256<double> ru2 = ru * ru;
Vector256<double> ru4 = ru2 * ru2;
Vector256<double> polyu = (c4 * ru + c3) * ru2
+ ((c6 * ru + c5) * ru4
+ (c2 * ru + c1));
// result = (float)(poly + (n << 52))
Vector256<float> ret = Vector256.Narrow(
(polyl.AsUInt64() + (nl << 52)).AsDouble(),
(polyu.AsUInt64() + (nu << 52)).AsDouble()
);
// Check if -103 < |x| < 88
if (Vector256.GreaterThanAny(Vector256.Abs(x).AsUInt32(), Vector256.Create(V_ARG_MAX)))
{
// (x > V_EXPF_MAX) ? float.PositiveInfinity : x
Vector256<float> infinityMask = Vector256.GreaterThan(x, Vector256.Create(V_EXPF_MAX));
ret = Vector256.ConditionalSelect(
infinityMask,
Vector256.Create(float.PositiveInfinity),
ret
);
// (x < V_EXPF_MIN) ? 0 : x
ret = Vector256.AndNot(ret, Vector256.LessThan(x, Vector256.Create(V_EXPF_MIN)));
}
return ret;
}
public static Vector512<float> Invoke(Vector512<float> x)
{
// Convert x to double precision
(Vector512<double> xl, Vector512<double> xu) = Vector512.Widen(x);
// x * (64.0 / ln(2))
Vector512<double> v_tbl_ln2 = Vector512.Create(V_TBL_LN2);
Vector512<double> zl = xl * v_tbl_ln2;
Vector512<double> zu = xu * v_tbl_ln2;
Vector512<double> v_expf_huge = Vector512.Create(V_EXPF_HUGE);
Vector512<double> dnl = zl + v_expf_huge;
Vector512<double> dnu = zu + v_expf_huge;
// n = (int)z
Vector512<ulong> nl = dnl.AsUInt64();
Vector512<ulong> nu = dnu.AsUInt64();
// dn = (double)n
dnl -= v_expf_huge;
dnu -= v_expf_huge;
// r = z - dn
Vector512<double> c1 = Vector512.Create(C1);
Vector512<double> c2 = Vector512.Create(C2);
Vector512<double> c3 = Vector512.Create(C3);
Vector512<double> c4 = Vector512.Create(C4);
Vector512<double> c5 = Vector512.Create(C5);
Vector512<double> c6 = Vector512.Create(C6);
Vector512<double> rl = zl - dnl;
Vector512<double> rl2 = rl * rl;
Vector512<double> rl4 = rl2 * rl2;
Vector512<double> polyl = (c4 * rl + c3) * rl2
+ ((c6 * rl + c5) * rl4
+ (c2 * rl + c1));
Vector512<double> ru = zu - dnu;
Vector512<double> ru2 = ru * ru;
Vector512<double> ru4 = ru2 * ru2;
Vector512<double> polyu = (c4 * ru + c3) * ru2
+ ((c6 * ru + c5) * ru4
+ (c2 * ru + c1));
// result = (float)(poly + (n << 52))
Vector512<float> ret = Vector512.Narrow(
(polyl.AsUInt64() + (nl << 52)).AsDouble(),
(polyu.AsUInt64() + (nu << 52)).AsDouble()
);
// Check if -103 < |x| < 88
if (Vector512.GreaterThanAny(Vector512.Abs(x).AsUInt32(), Vector512.Create(V_ARG_MAX)))
{
// (x > V_EXPF_MAX) ? float.PositiveInfinity : x
Vector512<float> infinityMask = Vector512.GreaterThan(x, Vector512.Create(V_EXPF_MAX));
ret = Vector512.ConditionalSelect(
infinityMask,
Vector512.Create(float.PositiveInfinity),
ret
);
// (x < V_EXPF_MIN) ? 0 : x
ret = Vector512.AndNot(ret, Vector512.LessThan(x, Vector512.Create(V_EXPF_MIN)));
}
return ret;
}
}
#endif
}
}
|