File: System\Net\WebSockets\WebSocketStream.cs
Web Access
Project: src\src\libraries\System.Net.WebSockets\src\System.Net.WebSockets.csproj (System.Net.WebSockets)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
 
namespace System.Net.WebSockets
{
    /// <summary>Provides a <see cref="Stream"/> that delegates to a wrapped <see cref="WebSocket"/>.</summary>
    public class WebSocketStream : Stream
    {
        /// <summary>The default number of seconds before canceling CloseAsync operation issued during stream disposal.</summary>
        private const int DefaultCloseTimeoutSeconds = 16;
 
        /// <summary>Whether the stream has been disposed.</summary>
        private bool _disposed;
 
        /// <summary>
        /// Initializes a new instance of the <see cref="WebSocketStream"/> class using a specified <see cref="WebSocket"/> instance.
        /// </summary>
        /// <param name="webSocket">The <see cref="WebSocket"/> wrapped by this instance.</param>
        private WebSocketStream(WebSocket webSocket) => WebSocket = webSocket;
 
        /// <summary>Creates a <see cref="WebSocketStream"/> that delegates to a wrapped <see cref="WebSocket"/>.</summary>
        /// <param name="webSocket">The wrapped <see cref="WebSocket"/>.</param>
        /// <param name="writeMessageType">The type of messages that should be written as part of <see cref="M:Stream.WriteAsync"/> calls. Each write produces a message.</param>
        /// <param name="ownsWebSocket">
        /// <see langword="true"/> if disposing the <see cref="Stream"/> should close the underlying <see cref="WebSocket"/>; otherwise, <see langword="false"/>. Defaults to <see langword="false"/>.
        /// </param>
        /// <returns>A new instance of <see cref="WebSocketStream"/> that forwards reads and writes on the <see cref="Stream"/> to the underlying <see cref="WebSocket"/>.</returns>
        public static WebSocketStream Create(WebSocket webSocket, WebSocketMessageType writeMessageType, bool ownsWebSocket = false)
        {
            ArgumentNullException.ThrowIfNull(webSocket);
            ManagedWebSocket.ThrowIfInvalidMessageType(writeMessageType);
 
            return new ReadWriteStream(
                webSocket,
                writeMessageType,
                closeTimeout: ownsWebSocket ? TimeSpan.FromSeconds(DefaultCloseTimeoutSeconds) : null);
        }
 
        /// <summary>Creates a <see cref="WebSocketStream"/> that delegates to a wrapped <see cref="WebSocket"/>.</summary>
        /// <param name="webSocket">The wrapped <see cref="WebSocket"/>.</param>
        /// <param name="writeMessageType">The type of messages that should be written as part of <see cref="M:Stream.WriteAsync"/> calls. Each write produces a message.</param>
        /// <param name="closeTimeout">The amount of time that disposing the <see cref="WebSocketStream"/> will wait for a graceful closing of the <see cref="WebSocket"/>'s output.</param>
        /// <returns>A new instance of <see cref="WebSocketStream"/> that forwards reads and writes on the <see cref="Stream"/> to the underlying <see cref="WebSocket"/>.</returns>
        public static WebSocketStream Create(WebSocket webSocket, WebSocketMessageType writeMessageType, TimeSpan closeTimeout)
        {
            ArgumentNullException.ThrowIfNull(webSocket);
            ManagedWebSocket.ThrowIfInvalidMessageType(writeMessageType);
            if (closeTimeout < TimeSpan.Zero && closeTimeout != Timeout.InfiniteTimeSpan)
            {
                throw new ArgumentOutOfRangeException(nameof(closeTimeout), SR.net_WebSockets_TimeoutOutOfRange);
            }
 
            return new ReadWriteStream(webSocket, writeMessageType, closeTimeout);
        }
 
        /// <summary>Creates a <see cref="WebSocketStream"/> that writes a single message to the underlying <see cref="WebSocket"/>.</summary>
        /// <param name="webSocket">The wrapped <see cref="WebSocket"/>.</param>
        /// <param name="writeMessageType">
        /// The type of messages that should be written as part of <see cref="M:Stream.WriteAsync"/> calls.
        /// Each write on the <see cref="Stream"/> results in writing a partial message to the underlying <see cref="WebSocket"/>.
        /// When the <see cref="Stream"/> is disposed, it will write an empty message to the underlying <see cref="WebSocket"/> to signal the end of the message.
        /// </param>
        /// <returns>A new instance of <see cref="WebSocketStream"/> that forwards writes on the <see cref="Stream"/> to the underlying <see cref="WebSocket"/>.</returns>
        public static WebSocketStream CreateWritableMessageStream(WebSocket webSocket, WebSocketMessageType writeMessageType)
        {
            ArgumentNullException.ThrowIfNull(webSocket);
            ManagedWebSocket.ThrowIfInvalidMessageType(writeMessageType);
 
            return new WriteMessageStream(webSocket, writeMessageType);
        }
 
        /// <summary>Creates a <see cref="WebSocketStream"/> that reads a single message from the underlying <see cref="WebSocket"/>.</summary>
        /// <param name="webSocket">The wrapped <see cref="WebSocket"/>.</param>
        /// <returns>A new instance of <see cref="WebSocketStream"/> that forwards reads on the <see cref="Stream"/> to the underlying <see cref="WebSocket"/>.</returns>
        /// <remarks>
        /// Reads on the <see cref="Stream"/> will read a single message from the underlying <see cref="WebSocket"/>. This means that reads will start returning
        /// 0 bytes read once all data has been consumed from the next message received in the <see cref="WebSocket"/>.
        /// </remarks>
        public static WebSocketStream CreateReadableMessageStream(WebSocket webSocket)
        {
            ArgumentNullException.ThrowIfNull(webSocket);
 
            return new ReadMessageStream(webSocket);
        }
 
        /// <summary>Gets the underlying <see cref="WebSocket"/> wrapped by this <see cref="WebSocketStream"/>.</summary>
        /// <remarks>The <see cref="WebSocket"/> used to construct this instance.</remarks>
        public WebSocket WebSocket { get; }
 
        /// <inheritdoc />
        public override bool CanRead => !_disposed && WebSocket.State is WebSocketState.Open or WebSocketState.CloseSent;
 
        /// <inheritdoc />
        public override bool CanWrite => !_disposed && WebSocket.State is WebSocketState.Open or WebSocketState.CloseReceived;
 
        /// <inheritdoc />
        public override bool CanSeek => false;
 
        /// <inheritdoc />
        protected override void Dispose(bool disposing)
        {
            if (disposing)
            {
                // There are no synchronous operations on WebSocket, so we're forced to do sync-over-async.
                DisposeAsync().AsTask().GetAwaiter().GetResult();
            }
        }
 
        /// <inheritdoc />
        public override void Flush() { }
 
        /// <inheritdoc />
        public override Task FlushAsync(CancellationToken cancellationToken) =>
            cancellationToken.IsCancellationRequested ? Task.FromCanceled(cancellationToken) :
            Task.CompletedTask;
 
        /// <inheritdoc />
        public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
        {
            ValidateBufferArguments(buffer, offset, count);
 
            return ReadAsync(buffer.AsMemory(offset, count), cancellationToken).AsTask();
        }
 
        /// <inheritdoc />
        public override ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default) =>
            ValueTask.FromException<int>(new NotSupportedException());
 
        /// <inheritdoc />
        public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state) =>
            TaskToAsyncResult.Begin(ReadAsync(buffer, offset, count), callback, state);
 
        /// <inheritdoc />
        public override int EndRead(IAsyncResult asyncResult) =>
            TaskToAsyncResult.End<int>(asyncResult);
 
        /// <inheritdoc />
        public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
        {
            ValidateBufferArguments(buffer, offset, count);
 
            return WriteAsync(buffer.AsMemory(offset, count), cancellationToken).AsTask();
        }
 
        /// <inheritdoc />
        public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default) =>
            ValueTask.FromException(new NotSupportedException());
 
        /// <inheritdoc />
        public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state) =>
            TaskToAsyncResult.Begin(WriteAsync(buffer, offset, count), callback, state);
 
        /// <inheritdoc />
        public override void EndWrite(IAsyncResult asyncResult) =>
            TaskToAsyncResult.End(asyncResult);
 
        /// <inheritdoc />
        public override int Read(byte[] buffer, int offset, int count) =>
            ReadAsync(buffer, offset, count, default).GetAwaiter().GetResult();
 
        /// <inheritdoc />
        public override void Write(byte[] buffer, int offset, int count) =>
            WriteAsync(buffer, offset, count, default).GetAwaiter().GetResult();
 
        /// <inheritdoc />
        public override long Length => throw new NotSupportedException();
 
        /// <inheritdoc />
        public override long Position
        {
            get => throw new NotSupportedException();
            set => throw new NotSupportedException();
        }
 
        /// <inheritdoc />
        public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException();
 
        /// <inheritdoc />
        public override void SetLength(long value) => throw new NotSupportedException();
 
        /// <summary>Provides stream that wraps a <see cref="WebSocket"/> and forwards reads/writes.</summary>
        private sealed class ReadWriteStream(WebSocket webSocket, WebSocketMessageType writeMessageType, TimeSpan? closeTimeout) : WebSocketStream(webSocket)
        {
            private readonly WebSocketMessageType _messageType = writeMessageType;
            private readonly TimeSpan? _closeTimeout = closeTimeout;
 
            /// <inheritdoc />
            public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
            {
                if (_disposed)
                {
                    return ValueTask.FromException(new ObjectDisposedException(GetType().FullName));
                }
 
                if (!CanWrite)
                {
                    return ValueTask.FromException(new NotSupportedException(SR.NotWriteableStream));
                }
 
                if (cancellationToken.IsCancellationRequested)
                {
                    return ValueTask.FromCanceled(cancellationToken);
                }
 
                return WebSocket.SendAsync(buffer, _messageType, endOfMessage: true, cancellationToken);
            }
 
            /// <inheritdoc />
            public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
            {
                ObjectDisposedException.ThrowIf(_disposed, this);
 
                if (!CanRead)
                {
                    throw new NotSupportedException(SR.NotReadableStream);
                }
 
                cancellationToken.ThrowIfCancellationRequested();
 
                while (WebSocket.State < WebSocketState.CloseReceived)
                {
                    ValueWebSocketReceiveResult result = await WebSocket.ReceiveAsync(buffer, cancellationToken).ConfigureAwait(false);
                    if (result.MessageType is WebSocketMessageType.Close)
                    {
                        break;
                    }
 
                    if (result.Count > 0 || buffer.IsEmpty)
                    {
                        return result.Count;
                    }
                }
 
                return 0;
            }
 
            /// <inheritdoc />
            public override async ValueTask DisposeAsync()
            {
                if (!_disposed)
                {
                    _disposed = true;
 
                    if (_closeTimeout is { } timeout)
                    {
                        if (WebSocket.State is < WebSocketState.Closed)
                        {
                            CancellationTokenSource? cts = null;
                            CancellationToken ct;
 
                            if (timeout == default)
                            {
                                ct = new CancellationToken(canceled: true);
                            }
                            else if (timeout == Timeout.InfiniteTimeSpan)
                            {
                                ct = CancellationToken.None;
                            }
                            else
                            {
                                cts = new CancellationTokenSource(timeout);
                                ct = cts.Token;
                            }
 
                            try
                            {
                                await WebSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, null, ct).ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing);
                            }
                            finally
                            {
                                cts?.Dispose();
                            }
                        }
 
                        WebSocket.Dispose();
                    }
                }
            }
        }
 
        /// <summary>Provides a stream that wraps a <see cref="WebSocket"/> and writes a single message.</summary>
        private sealed class WriteMessageStream(WebSocket webSocket, WebSocketMessageType writeMessageType) : WebSocketStream(webSocket)
        {
            private readonly WebSocketMessageType _messageType = writeMessageType;
 
            /// <inheritdoc />
            public override bool CanRead => false;
 
            /// <inheritdoc />
            public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
            {
                if (_disposed)
                {
                    return ValueTask.FromException(new ObjectDisposedException(GetType().FullName));
                }
 
                if (!CanWrite)
                {
                    return ValueTask.FromException(new NotSupportedException(SR.NotWriteableStream));
                }
 
                if (cancellationToken.IsCancellationRequested)
                {
                    return ValueTask.FromCanceled(cancellationToken);
                }
 
                return WebSocket.SendAsync(buffer, _messageType, endOfMessage: false, cancellationToken);
            }
 
            public override ValueTask DisposeAsync()
            {
                if (!_disposed)
                {
                    _disposed = true;
                    return WebSocket.SendAsync(ReadOnlyMemory<byte>.Empty, _messageType, endOfMessage: true, CancellationToken.None);
                }
 
                return default;
            }
        }
 
        /// <summary>Provides a stream that wraps a <see cref="WebSocket"/> and reads a single message.</summary>
        private sealed class ReadMessageStream(WebSocket webSocket) : WebSocketStream(webSocket)
        {
            /// <summary>Whether we've seen and end-of-message marker.</summary>
            private bool _eof;
 
            /// <inheritdoc />
            public override bool CanWrite => false;
 
            /// <inheritdoc />
            public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
            {
                ObjectDisposedException.ThrowIf(_disposed, this);
 
                if (!CanRead)
                {
                    throw new NotSupportedException(SR.NotReadableStream);
                }
 
                cancellationToken.ThrowIfCancellationRequested();
 
                while (!_eof && WebSocket.State < WebSocketState.CloseReceived)
                {
                    ValueWebSocketReceiveResult result = await WebSocket.ReceiveAsync(buffer, cancellationToken).ConfigureAwait(false);
                    if (result.MessageType is WebSocketMessageType.Close)
                    {
                        break;
                    }
 
                    if (result.EndOfMessage)
                    {
                        _eof = true;
                    }
 
                    if (result.Count > 0 || buffer.IsEmpty)
                    {
                        return result.Count;
                    }
                }
 
                return 0;
            }
 
            /// <inheritdoc />
            public override ValueTask DisposeAsync()
            {
                _disposed = true;
                if (!_eof && WebSocket.State < WebSocketState.CloseReceived)
                {
                    WebSocket.Abort();
                }
                return default;
            }
        }
    }
}