File: DataLoadSave\Binary\UnsafeTypeOps.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.IO;
using System.Runtime.InteropServices;
using Microsoft.ML.Data;
 
namespace Microsoft.ML.Internal.Internallearn
{
    /// <summary>
    /// Represents some common global operations over a type
    /// including many unsafe operations.
    /// </summary>
    /// <typeparam name="T"></typeparam>
    internal abstract class UnsafeTypeOps<T>
    {
        public abstract int Size { get; }
        public abstract void Apply(ReadOnlySpan<T> array, Action<IntPtr> func);
        public abstract void Write(T a, BinaryWriter writer);
        public abstract T Read(BinaryReader reader);
    }
 
    internal static class UnsafeTypeOpsFactory
    {
        private static readonly Dictionary<Type, object> _type2ops;
 
        static UnsafeTypeOpsFactory()
        {
            _type2ops = new Dictionary<Type, object>();
            _type2ops[typeof(sbyte)] = new SByteUnsafeTypeOps();
            _type2ops[typeof(Byte)] = new ByteUnsafeTypeOps();
            _type2ops[typeof(short)] = new Int16UnsafeTypeOps();
            _type2ops[typeof(UInt16)] = new UInt16UnsafeTypeOps();
            _type2ops[typeof(int)] = new Int32UnsafeTypeOps();
            _type2ops[typeof(UInt32)] = new UInt32UnsafeTypeOps();
            _type2ops[typeof(long)] = new Int64UnsafeTypeOps();
            _type2ops[typeof(UInt64)] = new UInt64UnsafeTypeOps();
            _type2ops[typeof(Single)] = new SingleUnsafeTypeOps();
            _type2ops[typeof(Double)] = new DoubleUnsafeTypeOps();
            _type2ops[typeof(TimeSpan)] = new TimeSpanUnsafeTypeOps();
            _type2ops[typeof(DataViewRowId)] = new UgUnsafeTypeOps();
        }
 
        public static UnsafeTypeOps<T> Get<T>()
        {
            return (UnsafeTypeOps<T>)_type2ops[typeof(T)];
        }
 
        private sealed class SByteUnsafeTypeOps : UnsafeTypeOps<sbyte>
        {
            public override int Size { get { return sizeof(sbyte); } }
            public override unsafe void Apply(ReadOnlySpan<sbyte> array, Action<IntPtr> func)
            {
                fixed (sbyte* pArray = &MemoryMarshal.GetReference(array))
                    func(new IntPtr(pArray));
            }
            public override void Write(sbyte a, BinaryWriter writer) { writer.Write(a); }
            public override sbyte Read(BinaryReader reader) { return reader.ReadSByte(); }
        }
 
        private sealed class ByteUnsafeTypeOps : UnsafeTypeOps<Byte>
        {
            public override int Size { get { return sizeof(Byte); } }
            public override unsafe void Apply(ReadOnlySpan<Byte> array, Action<IntPtr> func)
            {
                fixed (Byte* pArray = &MemoryMarshal.GetReference(array))
                    func(new IntPtr(pArray));
            }
            public override void Write(Byte a, BinaryWriter writer) { writer.Write(a); }
            public override Byte Read(BinaryReader reader) { return reader.ReadByte(); }
        }
 
        private sealed class Int16UnsafeTypeOps : UnsafeTypeOps<short>
        {
            public override int Size { get { return sizeof(short); } }
            public override unsafe void Apply(ReadOnlySpan<short> array, Action<IntPtr> func)
            {
                fixed (short* pArray = &MemoryMarshal.GetReference(array))
                    func(new IntPtr(pArray));
            }
            public override void Write(short a, BinaryWriter writer) { writer.Write(a); }
            public override short Read(BinaryReader reader) { return reader.ReadInt16(); }
        }
 
        private sealed class UInt16UnsafeTypeOps : UnsafeTypeOps<UInt16>
        {
            public override int Size { get { return sizeof(UInt16); } }
            public override unsafe void Apply(ReadOnlySpan<UInt16> array, Action<IntPtr> func)
            {
                fixed (UInt16* pArray = &MemoryMarshal.GetReference(array))
                    func(new IntPtr(pArray));
            }
            public override void Write(UInt16 a, BinaryWriter writer) { writer.Write(a); }
            public override UInt16 Read(BinaryReader reader) { return reader.ReadUInt16(); }
        }
 
        private sealed class Int32UnsafeTypeOps : UnsafeTypeOps<int>
        {
            public override int Size { get { return sizeof(int); } }
            public override unsafe void Apply(ReadOnlySpan<int> array, Action<IntPtr> func)
            {
                fixed (int* pArray = &MemoryMarshal.GetReference(array))
                    func(new IntPtr(pArray));
            }
            public override void Write(int a, BinaryWriter writer) { writer.Write(a); }
            public override int Read(BinaryReader reader) { return reader.ReadInt32(); }
        }
 
        private sealed class UInt32UnsafeTypeOps : UnsafeTypeOps<UInt32>
        {
            public override int Size { get { return sizeof(UInt32); } }
            public override unsafe void Apply(ReadOnlySpan<UInt32> array, Action<IntPtr> func)
            {
                fixed (UInt32* pArray = &MemoryMarshal.GetReference(array))
                    func(new IntPtr(pArray));
            }
            public override void Write(UInt32 a, BinaryWriter writer) { writer.Write(a); }
            public override UInt32 Read(BinaryReader reader) { return reader.ReadUInt32(); }
        }
 
        private sealed class Int64UnsafeTypeOps : UnsafeTypeOps<long>
        {
            public override int Size { get { return sizeof(long); } }
            public override unsafe void Apply(ReadOnlySpan<long> array, Action<IntPtr> func)
            {
                fixed (long* pArray = &MemoryMarshal.GetReference(array))
                    func(new IntPtr(pArray));
            }
            public override void Write(long a, BinaryWriter writer) { writer.Write(a); }
            public override long Read(BinaryReader reader) { return reader.ReadInt64(); }
        }
 
        private sealed class UInt64UnsafeTypeOps : UnsafeTypeOps<UInt64>
        {
            public override int Size { get { return sizeof(UInt64); } }
            public override unsafe void Apply(ReadOnlySpan<UInt64> array, Action<IntPtr> func)
            {
                fixed (UInt64* pArray = &MemoryMarshal.GetReference(array))
                    func(new IntPtr(pArray));
            }
            public override void Write(UInt64 a, BinaryWriter writer) { writer.Write(a); }
            public override UInt64 Read(BinaryReader reader) { return reader.ReadUInt64(); }
        }
 
        private sealed class SingleUnsafeTypeOps : UnsafeTypeOps<Single>
        {
            public override int Size { get { return sizeof(Single); } }
            public override unsafe void Apply(ReadOnlySpan<Single> array, Action<IntPtr> func)
            {
                fixed (Single* pArray = &MemoryMarshal.GetReference(array))
                    func(new IntPtr(pArray));
            }
            public override void Write(Single a, BinaryWriter writer) { writer.Write(a); }
            public override Single Read(BinaryReader reader) { return reader.ReadSingle(); }
        }
 
        private sealed class DoubleUnsafeTypeOps : UnsafeTypeOps<Double>
        {
            public override int Size { get { return sizeof(Double); } }
            public override unsafe void Apply(ReadOnlySpan<Double> array, Action<IntPtr> func)
            {
                fixed (Double* pArray = &MemoryMarshal.GetReference(array))
                    func(new IntPtr(pArray));
            }
            public override void Write(Double a, BinaryWriter writer) { writer.Write(a); }
            public override Double Read(BinaryReader reader) { return reader.ReadDouble(); }
        }
 
        private sealed class TimeSpanUnsafeTypeOps : UnsafeTypeOps<TimeSpan>
        {
            public override int Size { get { return sizeof(long); } }
            public override unsafe void Apply(ReadOnlySpan<TimeSpan> array, Action<IntPtr> func)
            {
                fixed (TimeSpan* pArray = &MemoryMarshal.GetReference(array))
                    func(new IntPtr(pArray));
            }
 
            public override void Write(TimeSpan a, BinaryWriter writer) { writer.Write(a.Ticks); }
            public override TimeSpan Read(BinaryReader reader)
            {
                var ticks = reader.ReadInt64();
                return new TimeSpan(ticks == long.MinValue ? default : ticks);
            }
        }
 
        private sealed class UgUnsafeTypeOps : UnsafeTypeOps<DataViewRowId>
        {
            public override int Size { get { return 2 * sizeof(ulong); } }
            public override unsafe void Apply(ReadOnlySpan<DataViewRowId> array, Action<IntPtr> func)
            {
                fixed (DataViewRowId* pArray = &MemoryMarshal.GetReference(array))
                    func(new IntPtr(pArray));
            }
 
            public override void Write(DataViewRowId a, BinaryWriter writer) { writer.Write(a.Low); writer.Write(a.High); }
            public override DataViewRowId Read(BinaryReader reader)
            {
                ulong lo = reader.ReadUInt64();
                ulong hi = reader.ReadUInt64();
                return new DataViewRowId(lo, hi);
            }
        }
    }
}