|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Linq;
using System.Text;
using Microsoft.ML.Internal.Utilities;
namespace Microsoft.ML.Trainers.FastTree
{
/// <summary>
/// This class contains extension methods that support binary serialization of some base C# types
/// and arrays of these types.
/// SizeInBytes - the number of bytes in the binary representation
/// type.ToByteArray(buffer, ref position) - will write the binary representation of the type to
/// the byte buffer at the given position, and will increment the position to the end of
/// the representation
/// byte[].ToXXX(ref position) - converts the binary representation back into the original type
/// </summary>
internal static class ToByteArrayExtensions
{
// byte
public static int SizeInBytes(this byte a)
{
return sizeof(byte);
}
public static void ToByteArray(this byte a, byte[] buffer, ref int position)
{
buffer[position] = a;
position++;
}
public static byte ToByte(this byte[] buffer, ref int position)
{
byte a = buffer[position];
position++;
return a;
}
// short
public static int SizeInBytes(this short a)
{
return sizeof(short);
}
public static unsafe void ToByteArray(this short a, byte[] buffer, ref int position)
{
fixed (byte* pBuffer = buffer)
{
short* pDest = (short*)(pBuffer + position);
*pDest = a;
}
position += sizeof(short);
}
public static short ToShort(this byte[] buffer, ref int position)
{
short a = BitConverter.ToInt16(buffer, position);
position += sizeof(short);
return a;
}
// ushort
public static int SizeInBytes(this ushort a)
{
return sizeof(ushort);
}
public static unsafe void ToByteArray(this ushort a, byte[] buffer, ref int position)
{
fixed (byte* pBuffer = buffer)
{
ushort* pDest = (ushort*)(pBuffer + position);
*pDest = a;
}
position += sizeof(ushort);
}
public static ushort ToUShort(this byte[] buffer, ref int position)
{
ushort a = BitConverter.ToUInt16(buffer, position);
position += sizeof(ushort);
return a;
}
// int
public static int SizeInBytes(this int a)
{
return sizeof(int);
}
public static unsafe void ToByteArray(this int a, byte[] buffer, ref int position)
{
fixed (byte* pBuffer = buffer)
{
int* pDest = (int*)(pBuffer + position);
*pDest = a;
}
position += sizeof(int);
}
public static unsafe int ToInt(this byte[] buffer, ref int position)
{
int a;
fixed (byte* pBuffer = buffer)
{
int* pIntBuffer = (int*)(pBuffer + position);
a = *pIntBuffer;
}
position += sizeof(int);
return a;
}
// uint
public static int SizeInBytes(this uint a)
{
return sizeof(uint);
}
public static unsafe void ToByteArray(this uint a, byte[] buffer, ref int position)
{
fixed (byte* pBuffer = buffer)
{
uint* pDest = (uint*)(pBuffer + position);
*pDest = a;
}
position += sizeof(uint);
}
public static unsafe uint ToUInt(this byte[] buffer, ref int position)
{
uint a;
fixed (byte* pBuffer = buffer)
{
uint* pIntBuffer = (uint*)(pBuffer + position);
a = *pIntBuffer;
}
position += sizeof(uint);
return a;
}
// long
public static int SizeInBytes(this long a)
{
return sizeof(long);
}
public static unsafe void ToByteArray(this long a, byte[] buffer, ref int position)
{
fixed (byte* pBuffer = buffer)
{
long* pDest = (long*)(pBuffer + position);
*pDest = a;
}
position += sizeof(long);
}
public static long ToLong(this byte[] buffer, ref int position)
{
long a = BitConverter.ToInt64(buffer, position);
position += sizeof(long);
return a;
}
// ulong
public static int SizeInBytes(this ulong a)
{
return sizeof(ulong);
}
public static unsafe void ToByteArray(this ulong a, byte[] buffer, ref int position)
{
fixed (byte* pBuffer = buffer)
{
ulong* pDest = (ulong*)(pBuffer + position);
*pDest = a;
}
position += sizeof(ulong);
}
public static ulong ToULong(this byte[] buffer, ref int position)
{
ulong a = BitConverter.ToUInt64(buffer, position);
position += sizeof(ulong);
return a;
}
// float
public static int SizeInBytes(this float a)
{
return sizeof(float);
}
public static unsafe void ToByteArray(this float a, byte[] buffer, ref int position)
{
fixed (byte* pBuffer = buffer)
{
float* pDest = (float*)(pBuffer + position);
*pDest = a;
}
position += sizeof(float);
}
public static float ToFloat(this byte[] buffer, ref int position)
{
float a = BitConverter.ToSingle(buffer, position);
position += sizeof(float);
return a;
}
// double
public static int SizeInBytes(this double a)
{
return sizeof(double);
}
public static unsafe void ToByteArray(this double a, byte[] buffer, ref int position)
{
fixed (byte* pBuffer = buffer)
{
double* pDest = (double*)(pBuffer + position);
*pDest = a;
}
position += sizeof(double);
}
public static double ToDouble(this byte[] buffer, ref int position)
{
double a = BitConverter.ToDouble(buffer, position);
position += sizeof(double);
return a;
}
// string
public static int SizeInBytes(this string a)
{
return sizeof(int) + Encoding.Unicode.GetByteCount(a);
}
public static void ToByteArray(this string a, byte[] buffer, ref int position)
{
byte[] bytes = Encoding.Unicode.GetBytes(a);
bytes.Length.ToByteArray(buffer, ref position);
Array.Copy(bytes, 0, buffer, position, bytes.Length);
position += bytes.Length;
}
public static byte[] ToByteArray(this string a)
{
byte[] bytes = Encoding.Unicode.GetBytes(a);
byte[] allBytes = new byte[bytes.Length + sizeof(int)];
int position = 0;
bytes.Length.ToByteArray(allBytes, ref position);
Array.Copy(bytes, 0, allBytes, position, bytes.Length);
return allBytes;
}
public static string ToString(this byte[] buffer, ref int position)
{
int length = buffer.ToInt(ref position);
string a = Encoding.Unicode.GetString(buffer, position, length);
position += length;
return a;
}
// byte[]
public static int SizeInBytes(this byte[] a)
{
return sizeof(int) + Utils.Size(a) * sizeof(byte);
}
public static void ToByteArray(this byte[] a, byte[] buffer, ref int position)
{
a.Length.ToByteArray(buffer, ref position);
Array.Copy(a, 0, buffer, position, a.Length);
position += a.Length;
}
public static byte[] ToByteArray(this byte[] buffer, ref int position)
{
int length = buffer.ToInt(ref position);
byte[] a = new byte[length];
Array.Copy(buffer, position, a, 0, length);
position += length;
return a;
}
// short[]
public static int SizeInBytes(this short[] a)
{
return sizeof(int) + Utils.Size(a) * sizeof(short);
}
public static unsafe void ToByteArray(this short[] a, byte[] buffer, ref int position)
{
int length = a.Length;
length.ToByteArray(buffer, ref position);
fixed (byte* tmpBuffer = buffer)
fixed (short* pA = a)
{
short* pBuffer = (short*)(tmpBuffer + position);
for (int i = 0; i < length; ++i)
pBuffer[i] = pA[i];
}
position += length * sizeof(short);
}
public static unsafe short[] ToShortArray(this byte[] buffer, ref int position)
{
int length = buffer.ToInt(ref position);
short[] a = new short[length];
fixed (byte* tmpBuffer = buffer)
fixed (short* pA = a)
{
short* pBuffer = (short*)(tmpBuffer + position);
for (int i = 0; i < length; ++i)
pA[i] = pBuffer[i];
}
position += length * sizeof(short);
return a;
}
// ushort[]
public static int SizeInBytes(this ushort[] a)
{
return sizeof(int) + Utils.Size(a) * sizeof(ushort);
}
public static unsafe void ToByteArray(this ushort[] a, byte[] buffer, ref int position)
{
int length = a.Length;
length.ToByteArray(buffer, ref position);
fixed (byte* tmpBuffer = buffer)
fixed (ushort* pA = a)
{
ushort* pBuffer = (ushort*)(tmpBuffer + position);
for (int i = 0; i < length; ++i)
pBuffer[i] = pA[i];
}
position += length * sizeof(ushort);
}
public static unsafe ushort[] ToUShortArray(this byte[] buffer, ref int position)
{
int length = buffer.ToInt(ref position);
ushort[] a = new ushort[length];
fixed (byte* tmpBuffer = buffer)
fixed (ushort* pA = a)
{
ushort* pBuffer = (ushort*)(tmpBuffer + position);
for (int i = 0; i < length; ++i)
pA[i] = pBuffer[i];
}
position += length * sizeof(ushort);
return a;
}
// int[]
public static int SizeInBytes(this int[] array)
{
return sizeof(int) + Utils.Size(array) * sizeof(int);
}
public static unsafe void ToByteArray(this int[] a, byte[] buffer, ref int position)
{
int length = Utils.Size(a);
length.ToByteArray(buffer, ref position);
fixed (byte* tmpBuffer = buffer)
fixed (int* pA = a)
{
int* pBuffer = (int*)(tmpBuffer + position);
for (int i = 0; i < length; ++i)
pBuffer[i] = pA[i];
}
position += length * sizeof(int);
}
public static unsafe int[] ToIntArray(this byte[] buffer, ref int position)
=> buffer.ToIntArray(ref position, buffer.ToInt(ref position));
public static unsafe int[] ToIntArray(this byte[] buffer, ref int position, int length)
{
if (length == 0)
return null;
int[] a = new int[length];
fixed (byte* tmpBuffer = buffer)
fixed (int* pA = a)
{
int* pBuffer = (int*)(tmpBuffer + position);
for (int i = 0; i < length; ++i)
pA[i] = pBuffer[i];
}
position += length * sizeof(int);
return a;
}
// uint[]
public static int SizeInBytes(this uint[] array)
{
return sizeof(int) + Utils.Size(array) * sizeof(uint);
}
public static unsafe void ToByteArray(this uint[] a, byte[] buffer, ref int position)
{
int length = a.Length;
length.ToByteArray(buffer, ref position);
fixed (byte* tmpBuffer = buffer)
fixed (uint* pA = a)
{
uint* pBuffer = (uint*)(tmpBuffer + position);
for (int i = 0; i < length; ++i)
pBuffer[i] = pA[i];
}
position += length * sizeof(uint);
}
public static unsafe uint[] ToUIntArray(this byte[] buffer, ref int position)
{
int length = buffer.ToInt(ref position);
uint[] a = new uint[length];
fixed (byte* tmpBuffer = buffer)
fixed (uint* pA = a)
{
uint* pBuffer = (uint*)(tmpBuffer + position);
for (int i = 0; i < length; ++i)
pA[i] = pBuffer[i];
}
position += length * sizeof(uint);
return a;
}
// long[]
public static int SizeInBytes(this long[] array)
{
return sizeof(int) + Utils.Size(array) * sizeof(long);
}
public static unsafe void ToByteArray(this long[] a, byte[] buffer, ref int position)
{
int length = a.Length;
length.ToByteArray(buffer, ref position);
fixed (byte* tmpBuffer = buffer)
fixed (long* pA = a)
{
long* pBuffer = (long*)(tmpBuffer + position);
for (int i = 0; i < length; ++i)
pBuffer[i] = pA[i];
}
position += length * sizeof(long);
}
public static unsafe long[] ToLongArray(this byte[] buffer, ref int position)
{
int length = buffer.ToInt(ref position);
long[] a = new long[length];
fixed (byte* tmpBuffer = buffer)
fixed (long* pA = a)
{
long* pBuffer = (long*)(tmpBuffer + position);
for (int i = 0; i < length; ++i)
pA[i] = pBuffer[i];
}
position += length * sizeof(long);
return a;
}
// ulong[]
public static int SizeInBytes(this ulong[] array)
{
return sizeof(int) + Utils.Size(array) * sizeof(ulong);
}
public static unsafe void ToByteArray(this ulong[] a, byte[] buffer, ref int position)
{
int length = a.Length;
length.ToByteArray(buffer, ref position);
fixed (byte* tmpBuffer = buffer)
fixed (ulong* pA = a)
{
ulong* pBuffer = (ulong*)(tmpBuffer + position);
for (int i = 0; i < length; ++i)
pBuffer[i] = pA[i];
}
position += length * sizeof(ulong);
}
public static unsafe ulong[] ToULongArray(this byte[] buffer, ref int position)
{
int length = buffer.ToInt(ref position);
ulong[] a = new ulong[length];
fixed (byte* tmpBuffer = buffer)
fixed (ulong* pA = a)
{
ulong* pBuffer = (ulong*)(tmpBuffer + position);
for (int i = 0; i < length; ++i)
pA[i] = pBuffer[i];
}
position += length * sizeof(ulong);
return a;
}
// float[]
public static int SizeInBytes(this float[] array)
{
return sizeof(int) + Utils.Size(array) * sizeof(float);
}
public static unsafe void ToByteArray(this float[] a, byte[] buffer, ref int position)
{
int length = a.Length;
length.ToByteArray(buffer, ref position);
fixed (byte* tmpBuffer = buffer)
fixed (float* pA = a)
{
float* pBuffer = (float*)(tmpBuffer + position);
for (int i = 0; i < length; ++i)
pBuffer[i] = pA[i];
}
position += length * sizeof(float);
}
public static unsafe float[] ToFloatArray(this byte[] buffer, ref int position)
{
int length = buffer.ToInt(ref position);
float[] a = new float[length];
fixed (byte* tmpBuffer = buffer)
fixed (float* pA = a)
{
float* pBuffer = (float*)(tmpBuffer + position);
for (int i = 0; i < length; ++i)
pA[i] = pBuffer[i];
}
position += length * sizeof(float);
return a;
}
// double[]
public static int SizeInBytes(this double[] array)
{
return sizeof(int) + Utils.Size(array) * sizeof(double);
}
public static unsafe void ToByteArray(this double[] a, byte[] buffer, ref int position)
{
int length = a.Length;
length.ToByteArray(buffer, ref position);
fixed (byte* tmpBuffer = buffer)
fixed (double* pA = a)
{
double* pBuffer = (double*)(tmpBuffer + position);
for (int i = 0; i < length; ++i)
pBuffer[i] = pA[i];
}
position += length * sizeof(double);
}
public static unsafe double[] ToDoubleArray(this byte[] buffer, ref int position)
{
int length = buffer.ToInt(ref position);
double[] a = new double[length];
fixed (byte* tmpBuffer = buffer)
fixed (double* pA = a)
{
double* pBuffer = (double*)(tmpBuffer + position);
for (int i = 0; i < length; ++i)
pA[i] = pBuffer[i];
}
position += length * sizeof(double);
return a;
}
// double[][]
public static int SizeInBytes(this double[][] array)
{
if (Utils.Size(array) == 0)
return sizeof(int);
return sizeof(int) + array.Sum(x => x.SizeInBytes());
}
public static void ToByteArray(this double[][] a, byte[] buffer, ref int position)
{
a.Length.ToByteArray(buffer, ref position);
for (int i = 0; i < a.Length; ++i)
{
a[i].ToByteArray(buffer, ref position);
}
}
public static double[][] ToDoubleJaggedArray(this byte[] buffer, ref int position)
{
int length = buffer.ToInt(ref position);
double[][] a = new double[length][];
for (int i = 0; i < a.Length; ++i)
{
a[i] = buffer.ToDoubleArray(ref position);
}
return a;
}
// string[]
public static long SizeInBytes(this string[] array)
{
long length = sizeof(int);
for (int i = 0; i < Utils.Size(array); ++i)
{
length += array[i].SizeInBytes();
}
return length;
}
public static void ToByteArray(this string[] a, byte[] buffer, ref int position)
{
Utils.Size(a).ToByteArray(buffer, ref position);
for (int i = 0; i < Utils.Size(a); ++i)
{
a[i].ToByteArray(buffer, ref position);
}
}
public static string[] ToStringArray(this byte[] buffer, ref int position)
{
int length = buffer.ToInt(ref position);
string[] a = new string[length];
for (int i = 0; i < a.Length; ++i)
{
a[i] = buffer.ToString(ref position);
}
return a;
}
}
}
|