|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.IO;
using System.Text;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
namespace Microsoft.ML.Data.IO
{
internal sealed partial class CodecFactory
{
// REVIEW: In future, this scheme might probably use loadable classes with
// assembly attributes instead of having the mapping from load name to reader hard coded.
// Or maybe not. That may depend on how much flexibility we really need from this.
private readonly Dictionary<string, GetCodecFromStreamDelegate> _loadNameToCodecCreator;
// The non-vector non-generic types can have a very simple codec mapping.
private readonly Dictionary<Type, IValueCodec> _simpleCodecTypeMap;
// A shared object pool of memory buffers. Objects returned to the memory stream pool
// should be cleared and have position set to 0. Use the ReturnMemoryStream helper method.
private readonly MemoryStreamPool _memPool;
// This is the encoding used for strings and textspans.
private readonly Encoding _encoding;
private readonly IHost _host;
private delegate bool GetCodecFromStreamDelegate(Stream definitionStream, out IValueCodec codec);
private delegate bool GetCodecFromTypeDelegate(DataViewType type, out IValueCodec codec);
public CodecFactory(IHostEnvironment env, MemoryStreamPool memPool = null)
{
Contracts.AssertValue(env, "env");
Contracts.AssertValueOrNull(memPool);
_host = env.Register("CodecFactory");
_memPool = memPool ?? new MemoryStreamPool();
_encoding = Encoding.UTF8;
_loadNameToCodecCreator = new Dictionary<string, GetCodecFromStreamDelegate>();
_simpleCodecTypeMap = new Dictionary<Type, IValueCodec>();
// Register the current codecs.
RegisterSimpleCodec(new UnsafeTypeCodec<sbyte>(this));
RegisterSimpleCodec(new UnsafeTypeCodec<byte>(this));
RegisterSimpleCodec(new UnsafeTypeCodec<short>(this));
RegisterSimpleCodec(new UnsafeTypeCodec<ushort>(this));
RegisterSimpleCodec(new UnsafeTypeCodec<int>(this));
RegisterSimpleCodec(new UnsafeTypeCodec<uint>(this));
RegisterSimpleCodec(new UnsafeTypeCodec<long>(this));
RegisterSimpleCodec(new UnsafeTypeCodec<ulong>(this));
RegisterSimpleCodec(new UnsafeTypeCodec<float>(this));
RegisterSimpleCodec(new UnsafeTypeCodec<double>(this));
RegisterSimpleCodec(new UnsafeTypeCodec<TimeSpan>(this));
RegisterSimpleCodec(new TextCodec(this));
RegisterSimpleCodec(new BoolCodec(this));
RegisterSimpleCodec(new DateTimeCodec(this));
RegisterSimpleCodec(new DateTimeOffsetCodec(this));
RegisterSimpleCodec(new UnsafeTypeCodec<DataViewRowId>(this));
// Register the old type system reading codec.
RegisterOtherCodec("DvBool", new OldBoolCodec(this).GetCodec);
RegisterOtherCodec("DvDateTimeZone", new DateTimeOffsetCodec(this).GetCodec);
RegisterOtherCodec("DvDateTime", new DateTimeCodec(this).GetCodec);
RegisterOtherCodec("DvTimeSpan", new UnsafeTypeCodec<TimeSpan>(this).GetCodec);
RegisterOtherCodec("VBuffer", GetVBufferCodec);
RegisterOtherCodec("Key2", GetKeyCodec);
RegisterOtherCodec("Key", GetKeyCodecOld);
}
private BinaryWriter OpenBinaryWriter(Stream stream)
{
return new BinaryWriter(stream, _encoding, leaveOpen: true);
}
private BinaryReader OpenBinaryReader(Stream stream)
{
return new BinaryReader(stream, _encoding, leaveOpen: true);
}
private void RegisterSimpleCodec<T>(SimpleCodec<T> codec)
{
Contracts.Assert(!_loadNameToCodecCreator.ContainsKey(codec.LoadName));
Contracts.Assert(!_simpleCodecTypeMap.ContainsKey(codec.Type.RawType));
_loadNameToCodecCreator.Add(codec.LoadName, codec.GetCodec);
_simpleCodecTypeMap.Add(codec.Type.RawType, codec);
}
private void RegisterOtherCodec(string name, GetCodecFromStreamDelegate fn)
{
Contracts.Assert(!_loadNameToCodecCreator.ContainsKey(name));
_loadNameToCodecCreator.Add(name, fn);
}
public bool TryGetCodec(DataViewType type, out IValueCodec codec)
{
// Handle the primier types specially.
if (type is KeyDataViewType)
return GetKeyCodec(type, out codec);
if (type is VectorDataViewType vectorType)
return GetVBufferCodec(vectorType, out codec);
return _simpleCodecTypeMap.TryGetValue(type.RawType, out codec);
}
/// <summary>
/// Given a codec, write a type description to a stream, from which this codec can be
/// reconstructed later. This returns the number of bytes written, so that, if this
/// were a seekable stream, the positions would differ by this amount before and after
/// a call to this method.
/// </summary>
public int WriteCodec(Stream definitionStream, IValueCodec codec)
{
// *** Codec type description ***
// string: codec loadname
// LEB128 int: Byte size of the parameterization
// byte[]: The indicated parameterization
using (BinaryWriter writer = OpenBinaryWriter(definitionStream))
{
string loadName = codec.LoadName;
writer.Write(loadName);
int bytes = _encoding.GetByteCount(loadName);
bytes = checked(bytes + Utils.Leb128IntLength((uint)bytes));
MemoryStream mem = _memPool.Get();
int output = codec.WriteParameterization(mem);
Contracts.Check(mem.Length == output, "codec description length did not match stream length");
Contracts.Check(mem.Length <= int.MaxValue); // Is this even possible in the current implementation of MemoryStream?
writer.WriteLeb128Int((ulong)mem.Length);
bytes = checked(bytes + Utils.Leb128IntLength((uint)mem.Length) + output);
mem.Position = 0;
mem.CopyTo(definitionStream);
_memPool.Return(ref mem);
return bytes;
}
}
/// <summary>
/// Attempts to define a codec, given a stream positioned at the start of a serialized
/// codec type definition.
/// </summary>
/// <param name="definitionStream">The input stream, which whether this returns true or false
/// will be left at the end of the codec type definition</param>
/// <param name="codec">A codec castable to a generic <c>IValueCodec{T}</c> where
/// <c>typeof(T)==codec.Type.RawType</c></param>
/// <returns>Whether the codec type definition was understood. If true the codec has defined
/// value, and should be usable. If false, the name of the codec was unrecognized. Note that
/// malformed definitions are detected, this will throw instead of returning either true or
/// false.</returns>
public bool TryReadCodec(Stream definitionStream, out IValueCodec codec)
{
Contracts.AssertValue(definitionStream, "definitionStream");
using (IChannel ch = _host.Start("TryGetCodec"))
using (BinaryReader reader = new BinaryReader(definitionStream, Encoding.UTF8, true))
{
string signature = reader.ReadString();
Contracts.CheckDecode(!string.IsNullOrEmpty(signature), "Non-empty signature string expected");
ulong ulen = reader.ReadLeb128Int();
Contracts.CheckDecode(ulen <= long.MaxValue, "Codec type definition read from stream too large");
long len = (long)ulen;
GetCodecFromStreamDelegate del;
if (!_loadNameToCodecCreator.TryGetValue(signature, out del))
{
codec = default(IValueCodec);
if (len == 0)
return false;
// Move the stream past the end of the definition.
if (definitionStream.CanSeek)
{
long remaining = definitionStream.Length - definitionStream.Position;
if (remaining < len)
throw ch.ExceptDecode("Codec type definition supposedly has {0} bytes, but end-of-stream reached after {1} bytes", len, remaining);
definitionStream.Seek(len, SeekOrigin.Current);
}
else
{
for (long i = 0; i < len; ++i)
{
if (definitionStream.ReadByte() == -1)
throw ch.ExceptDecode("Codec type definition supposedly has {0} bytes, but end-of-stream reached after {1} bytes", len, i);
}
}
ch.Warning("Did not recognize value codec signature '{0}'", signature);
return false;
}
// Opportunistically validate in the case of a seekable stream.
long pos = definitionStream.CanSeek ? definitionStream.Position : -1;
bool retval = del(definitionStream, out codec);
if (definitionStream.CanSeek && definitionStream.Position - pos != len)
throw ch.ExceptDecode("Codec type definition supposedly has {0} bytes, but the handler consumed {1}", len, definitionStream.Position - pos);
return retval;
}
}
}
}
|