File: Utilities\ComStreamWrapper.cs
Web Access
Project: src\src\symreader\src\Microsoft.DiaSymReader\Microsoft.DiaSymReader.csproj (Microsoft.DiaSymReader)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the License.txt file in the project root for more information.

using Microsoft.DiaSymReader.Utilities;
using System;
using System.Diagnostics;
using System.IO;
using System.Runtime.InteropServices;

#if NET9_0_OR_GREATER
using System.Runtime.InteropServices.Marshalling;
#endif

#if NETSTANDARD2_0
using STATSTG = System.Runtime.InteropServices.ComTypes.STATSTG;
#endif

namespace Microsoft.DiaSymReader
{
#if NET9_0_OR_GREATER
    [GeneratedComClass]
#endif
    internal unsafe sealed partial class ComStreamWrapper : IUnsafeComStream, System.Runtime.InteropServices.ComTypes.IStream
    {
        private readonly Stream _stream;

        internal ComStreamWrapper(Stream stream)
        {
            Debug.Assert(stream != null);
            Debug.Assert(stream.CanSeek);

            _stream = stream;
        }

        /// <summary>
        /// Attempts to read all of the requested bytes from the stream into the buffer
        /// </summary>
        /// <returns>
        /// The number of bytes read. Less than <paramref name="count" /> will
        /// only be returned if the end of stream is reached before all bytes can be read.
        /// </returns>
        /// <remarks>
        /// Unlike <see cref="Stream.Read(byte[], int, int)"/> it is not guaranteed that
        /// the stream position or the output buffer will be unchanged if an exception is
        /// returned.
        /// </remarks>
        private static int TryReadAll(Stream stream, byte[] buffer, int offset, int count)
        {
            // The implementations for many streams, e.g. FileStream, allows 0 bytes to be
            // read and returns 0, but the documentation for Stream.Read states that 0 is
            // only returned when the end of the stream has been reached. Rather than deal
            // with this contradiction, let's just never pass a count of 0 bytes
            Debug.Assert(count > 0);

            int totalBytesRead;
            int bytesRead = 0;
            for (totalBytesRead = 0; totalBytesRead < count; totalBytesRead += bytesRead)
            {
                // Note: Don't attempt to save state in-between calls to .Read as it would
                // require a possibly massive intermediate buffer array
                bytesRead = stream.Read(buffer, offset + totalBytesRead, count - totalBytesRead);
                if (bytesRead == 0)
                {
                    break;
                }
            }

            return totalBytesRead;
        }

        /// <summary>
        /// The actual number of bytes read can be fewer than the number of bytes requested
        /// if an error occurs or if the end of the stream is reached during the read operation.
        /// </summary>
        public unsafe void Read(byte* pv, int cb, int* pcbRead)
        {
            var buffer = new byte[cb];
            int bytesRead = TryReadAll(_stream, buffer, 0, cb);

            for (int i = 0; i < bytesRead; ++i)
            {
                pv[i] = buffer[i];
            }

            if (pcbRead != null)
            {
                *pcbRead = bytesRead;
            }
        }

        void System.Runtime.InteropServices.ComTypes.IStream.Read(byte[] pv, int cb, nint pcbRead)
        {
            fixed (byte* p = pv)
            {
                Read(p, cb, (int*)pcbRead);
            }
        }

        public unsafe void Seek(long dlibMove, int origin, long* plibNewPosition)
        {
            long newPosition = _stream.Seek(dlibMove, (SeekOrigin)origin);
            if (plibNewPosition != null)
            {
                *plibNewPosition = newPosition;
            }
        }

        void System.Runtime.InteropServices.ComTypes.IStream.Seek(long dlibMove, int dwOrigin, nint plibNewPosition)
            => Seek(dlibMove, dwOrigin, (long*)plibNewPosition);

        public unsafe void Write(byte* pv, int cb, int* pcbWritten)
        {
            var buffer = new byte[cb];
            for (int i = 0; i < cb; ++i)
            {
                buffer[i] = pv[i];
            }

            _stream.Write(buffer, 0, cb);
            if (pcbWritten != null)
            {
                *pcbWritten = cb;
            }
        }

        void System.Runtime.InteropServices.ComTypes.IStream.Write(byte[] pv, int cb, nint pcbWritten)
        {
            fixed (byte* p = pv)
            {
                Write(p, cb, (int*)pcbWritten);
            }
        }

        public void Clone(out IntPtr ppstm)
            => throw new NotSupportedException();

        void System.Runtime.InteropServices.ComTypes.IStream.Clone(out System.Runtime.InteropServices.ComTypes.IStream ppstm)
            => throw new NotSupportedException();

        public void Commit(int grfCommitFlags)
            => _stream.Flush();

        void System.Runtime.InteropServices.ComTypes.IStream.Commit(int grfCommitFlags)
            => Commit(grfCommitFlags);

        public void CopyTo(IntPtr pstm, long cb, int* pcbRead, int* pcbWritten)
            => throw new NotSupportedException();

        void System.Runtime.InteropServices.ComTypes.IStream.CopyTo(System.Runtime.InteropServices.ComTypes.IStream pstm, long cb, nint pcbRead, nint pcbWritten)
            => throw new NotSupportedException();

        public void LockRegion(long libOffset, long cb, int lockType)
            => throw new NotSupportedException();

        void System.Runtime.InteropServices.ComTypes.IStream.LockRegion(long libOffset, long cb, int dwLockType)
            => throw new NotSupportedException();

        public void UnlockRegion(long libOffset, long cb, int lockType)
            => throw new NotSupportedException();

        void System.Runtime.InteropServices.ComTypes.IStream.UnlockRegion(long libOffset, long cb, int dwLockType)
            => throw new NotSupportedException();

        public void Revert()
            => throw new NotSupportedException();

        void System.Runtime.InteropServices.ComTypes.IStream.Revert()
            => throw new NotSupportedException();

        public void SetSize(long libNewSize)
            => _stream.SetLength(libNewSize);

        void System.Runtime.InteropServices.ComTypes.IStream.SetSize(long libNewSize)
            => SetSize(libNewSize);

        public void Stat(out STATSTG pstatstg, int grfStatFlag)
        {
            pstatstg = new STATSTG()
            {
                cbSize = _stream.Length
            };
        }

        void System.Runtime.InteropServices.ComTypes.IStream.Stat(out System.Runtime.InteropServices.ComTypes.STATSTG pstatstg, int grfStatFlag)
        {
            pstatstg = new System.Runtime.InteropServices.ComTypes.STATSTG()
            {
                cbSize = _stream.Length
            };
        }

#if NET9_0_OR_GREATER
        [CustomMarshaller(typeof(System.Runtime.InteropServices.ComTypes.IStream), MarshalMode.ManagedToUnmanagedIn, typeof(Marshaller))]
        [CustomMarshaller(typeof(System.Runtime.InteropServices.ComTypes.IStream), MarshalMode.UnmanagedToManagedIn, typeof(Marshaller))]
        public static class Marshaller
        {
            public static IntPtr ConvertToUnmanaged(System.Runtime.InteropServices.ComTypes.IStream stream)
            {
                if (stream is null)
                {
                    return IntPtr.Zero;
                }
                else if (stream is IUnsafeComStream unsafeComStream)
                {
                    return (IntPtr)ComInterfaceMarshaller<IUnsafeComStream>.ConvertToUnmanaged(unsafeComStream);
                }

                throw new NotSupportedException("IStream implementation cannot be marshalled");
            }

            public static System.Runtime.InteropServices.ComTypes.IStream ConvertToManaged(IntPtr native)
            {
                IUnsafeComStream marshalledStream = ComInterfaceMarshaller<IUnsafeComStream>.ConvertToManaged((void*)native);
                if (marshalledStream is null)
                {
                    throw new NotSupportedException("IStream cannot be marshalled to managed");
                }

                // Put the ComWrappers-compatible IUnsafeComStream into a wrapper that also implements the ComWrappers-incompatible IStream
                return new UnsafeComStreamWrapper(marshalledStream);
            }

            public static void Free(IntPtr unmanaged)
            {
                if (unmanaged != IntPtr.Zero)
                {
                    Marshal.Release(unmanaged);
                }
            }
        }
#endif
    }
}