File: Utilities\Random.cs
Web Access
Project: src\src\Microsoft.ML.Core\Microsoft.ML.Core.csproj (Microsoft.ML.Core)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.IO;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML
{
    [BestFriend]
    internal static class RandomUtils
    {
        public static float NextSingle(this Random random)
        {
            if (random is TauswortheHybrid tauswortheHybrd)
            {
                return tauswortheHybrd.NextSingle();
            }
 
            for (; ; )
            {
                // Since the largest value that NextDouble() can return rounds to 1 when cast to float,
                // we need to protect against returning 1.
                var res = (float)random.NextDouble();
                if (res < 1.0f)
                    return res;
            }
        }
 
        public static int NextSigned(this Random random)
        {
            if (random is TauswortheHybrid tauswortheHybrd)
            {
                return tauswortheHybrd.NextSigned();
            }
 
            // Note that, according to the documentation for System.Random,
            // this won't ever achieve int.MaxValue, but oh well.
            return random.Next(int.MinValue, int.MaxValue);
        }
 
        public static TauswortheHybrid Create()
        {
            // Seed from a system random.
            return new TauswortheHybrid(new Random());
        }
 
        public static TauswortheHybrid Create(int? seed)
        {
            if (seed == null)
                return Create();
            return Create(seed.GetValueOrDefault());
        }
 
        public static TauswortheHybrid Create(int seed)
        {
            var state = new TauswortheHybrid.State((uint)seed);
            return new TauswortheHybrid(state);
        }
 
        public static TauswortheHybrid Create(uint? seed)
        {
            if (seed == null)
                return Create();
            return Create(seed.GetValueOrDefault());
        }
 
        public static TauswortheHybrid Create(uint seed)
        {
            var state = new TauswortheHybrid.State(seed);
            return new TauswortheHybrid(state);
        }
 
        public static TauswortheHybrid Create(Random seed)
        {
            return new TauswortheHybrid(seed);
        }
    }
 
    /// <summary>
    /// Tausworthe hybrid random number generator.
    /// </summary>
    [BestFriend]
    internal sealed class TauswortheHybrid : Random
    {
        public readonly struct State
        {
            public readonly uint U1;
            public readonly uint U2;
            public readonly uint U3;
            public readonly uint U4;
 
            public State(uint seed)
            {
                U1 = seed;
                U2 = Hashing.MurmurRound(U1, U1);
                U3 = Hashing.MurmurRound(U2, U1);
                U4 = Hashing.MurmurRound(U3, U1);
            }
 
            public State(uint u1, uint u2, uint u3, uint u4)
            {
                U1 = u1;
                U2 = u2;
                U3 = u3;
                U4 = u4;
            }
 
            public void Save(BinaryWriter writer)
            {
                writer.Write(U1);
                writer.Write(U2);
                writer.Write(U3);
                writer.Write(U4);
            }
 
            public static State Load(BinaryReader reader)
            {
                var u1 = reader.ReadUInt32();
                var u2 = reader.ReadUInt32();
                var u3 = reader.ReadUInt32();
                var u4 = reader.ReadUInt32();
                return new State(u1, u2, u3, u4);
            }
        }
 
        private uint _z1;
        private uint _z2;
        private uint _z3;
        private uint _z4;
 
        public TauswortheHybrid(State state)
        {
            _z1 = state.U1;
            _z2 = state.U2;
            _z3 = state.U3;
            _z4 = state.U4;
        }
 
        public TauswortheHybrid(Random rng)
        {
            _z1 = GetSeed(rng);
            _z2 = GetSeed(rng);
            _z3 = GetSeed(rng);
            _z4 = GetU(rng);
        }
 
        private static uint GetU(Random rng)
        {
            return ((uint)rng.Next(0x00010000) << 16) | ((uint)rng.Next(0x00010000));
        }
 
        private static uint GetSeed(Random rng)
        {
            for (; ; )
            {
                uint u = GetU(rng);
                if (u >= 128)
                    return u;
            }
        }
 
        public float NextSingle()
        {
            NextState();
            return GetSingle();
        }
 
        public override double NextDouble()
        {
            NextState();
            return GetDouble();
        }
 
        public override int Next()
        {
            NextState();
            uint u = GetUint();
            int n = (int)(u + (u & 0x80000000U));
            Contracts.Assert(n >= 0);
            return n;
        }
 
        public override int Next(int maxValue)
        {
            Contracts.CheckParam(maxValue >= 0, nameof(maxValue), "maxValue must be non-negative");
            NextState();
            uint u = GetUint();
            ulong uu = (ulong)u * (ulong)maxValue;
            int res = (int)(uu >> 32);
            Contracts.Assert(0 <= res && (res < maxValue || res == 0));
            return res;
        }
 
        public override int Next(int minValue, int maxValue)
        {
            Contracts.CheckParam(minValue <= maxValue, nameof(minValue), "minValue must be less than or equal to maxValue.");
 
            long range = (long)maxValue - minValue;
            return (int)((long)(NextDouble() * range) + minValue);
        }
 
        public override void NextBytes(byte[] buffer)
        {
            Contracts.CheckValue(buffer, nameof(buffer));
 
            for (int i = 0; i < buffer.Length; i++)
            {
                buffer[i] = (byte)Next();
            }
        }
 
        public int NextSigned()
        {
            NextState();
            return (int)GetUint();
        }
 
        private uint GetUint()
        {
            return _z1 ^ _z2 ^ _z3 ^ _z4;
        }
 
        private float GetSingle()
        {
            const float scale = (float)1 / (1 << 23);
 
            // Drop the low 9 bits so the conversion to Single is exact. Allowing rounding would cause
            // issues with biasing values and, worse, the possibility of returning exactly 1.
            uint u = GetUint() >> 9;
            Contracts.Assert((uint)(float)u == u);
 
            return (float)u * scale;
        }
 
        private double GetDouble()
        {
            const double scale = (double)1 / (1 << 16) / (1 << 16);
            uint u = GetUint();
            return (double)u * scale;
        }
 
        private void NextState()
        {
            TauswortheStateChange(ref _z1, 13, 19, 12, ~0x1U);
            TauswortheStateChange(ref _z2, 2, 25, 4, ~0x7U);
            TauswortheStateChange(ref _z3, 3, 11, 17, ~0xfU);
            LcgStateChange(ref _z4, 1664525, 1013904223);
        }
 
        private static void TauswortheStateChange(ref uint z, int s1, int s2, int s3, uint m)
        {
            z = ((z & m) << s3) ^ (((z << s1) ^ z) >> s2);
        }
 
        private static void LcgStateChange(ref uint z, uint a, uint c)
        {
            z = a * z + c;
        }
 
        // When creating a new TauswortheHybrid instance by using the constructor taking 4 uints, it is guaranteed that
        // the new instance will produce the same sequence of values (without the prefix that was already generated by this).
        // To get the same sequence from the start, call this method before any calls to NextSingle, NextDouble or Next.
        public State GetState()
        {
            return new State(_z1, _z2, _z3, _z4);
        }
    }
}