File: System\Net\Http\MultipartContent.cs
Web Access
Project: src\src\libraries\System.Net.Http\src\System.Net.Http.csproj (System.Net.Http)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Net.Http.Headers;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
 
namespace System.Net.Http
{
    public class MultipartContent : HttpContent, IEnumerable<HttpContent>
    {
        #region Fields
 
        private const string CrLf = "\r\n";
 
        private const int CrLfLength = 2;
        private const int DashDashLength = 2;
        private const int ColonSpaceLength = 2;
        private const int CommaSpaceLength = 2;
 
        private static readonly SearchValues<char> s_allowedBoundaryChars =
            SearchValues.Create(" '()+,-./0123456789:=?ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz");
 
        private readonly List<HttpContent> _nestedContent;
        private readonly string _boundary;
 
        #endregion Fields
 
        #region Construction
 
        public MultipartContent()
            : this("mixed", GetDefaultBoundary())
        { }
 
        public MultipartContent(string subtype)
            : this(subtype, GetDefaultBoundary())
        { }
 
        public MultipartContent(string subtype, string boundary)
        {
            ArgumentException.ThrowIfNullOrWhiteSpace(subtype);
            ValidateBoundary(boundary);
 
            _boundary = boundary;
 
            string quotedBoundary = boundary;
            if (!quotedBoundary.StartsWith('\"'))
            {
                quotedBoundary = "\"" + quotedBoundary + "\"";
            }
 
            MediaTypeHeaderValue contentType = new MediaTypeHeaderValue("multipart/" + subtype);
            contentType.Parameters.Add(new NameValueHeaderValue(nameof(boundary), quotedBoundary));
            Headers.ContentType = contentType;
 
            _nestedContent = new List<HttpContent>();
        }
 
        private static void ValidateBoundary(string boundary)
        {
            // NameValueHeaderValue is too restrictive for boundary.
            // Instead validate it ourselves and then quote it.
            ArgumentException.ThrowIfNullOrWhiteSpace(boundary);
 
            // RFC 2046 Section 5.1.1
            // boundary := 0*69<bchars> bcharsnospace
            // bchars := bcharsnospace / " "
            // bcharsnospace := DIGIT / ALPHA / "'" / "(" / ")" / "+" / "_" / "," / "-" / "." / "/" / ":" / "=" / "?"
            if (boundary.Length > 70)
            {
                throw new ArgumentOutOfRangeException(nameof(boundary), boundary,
                    SR.Format(System.Globalization.CultureInfo.InvariantCulture, SR.net_http_content_field_too_long, 70));
            }
            // Cannot end with space.
            if (boundary.EndsWith(' '))
            {
                throw new ArgumentException(SR.Format(System.Globalization.CultureInfo.InvariantCulture, SR.net_http_headers_invalid_value, boundary), nameof(boundary));
            }
 
            if (boundary.AsSpan().ContainsAnyExcept(s_allowedBoundaryChars))
            {
                throw new ArgumentException(SR.Format(System.Globalization.CultureInfo.InvariantCulture, SR.net_http_headers_invalid_value, boundary), nameof(boundary));
            }
        }
 
        private static string GetDefaultBoundary()
        {
            return Guid.NewGuid().ToString();
        }
 
        public virtual void Add(HttpContent content)
        {
            ArgumentNullException.ThrowIfNull(content);
 
            _nestedContent.Add(content);
        }
 
        #endregion Construction
 
        #region Dispose
 
        protected override void Dispose(bool disposing)
        {
            if (disposing)
            {
                foreach (HttpContent content in _nestedContent)
                {
                    content.Dispose();
                }
                _nestedContent.Clear();
            }
            base.Dispose(disposing);
        }
 
        #endregion Dispose
 
        #region IEnumerable<HttpContent> Members
 
        public IEnumerator<HttpContent> GetEnumerator()
        {
            return _nestedContent.GetEnumerator();
        }
 
        #endregion
 
        #region IEnumerable Members
 
        Collections.IEnumerator Collections.IEnumerable.GetEnumerator()
        {
            return _nestedContent.GetEnumerator();
        }
 
        #endregion
 
        #region Serialization
 
        /// <summary>
        /// Gets or sets a callback that returns the <see cref="Encoding"/> to decode the value for the specified response header name,
        /// or <see langword="null"/> to use the default behavior.
        /// </summary>
        public HeaderEncodingSelector<HttpContent>? HeaderEncodingSelector { get; set; }
 
        // for-each content
        //   write "--" + boundary
        //   for-each content header
        //     write header: header-value
        //   write content.CopyTo[Async]
        // write "--" + boundary + "--"
        // Can't be canceled directly by the user.  If the overall request is canceled
        // then the stream will be closed an exception thrown.
        protected override void SerializeToStream(Stream stream, TransportContext? context, CancellationToken cancellationToken)
        {
            Debug.Assert(stream != null);
            try
            {
                // Write start boundary.
                WriteToStream(stream, "--" + _boundary + CrLf);
 
                // Write each nested content.
                for (int contentIndex = 0; contentIndex < _nestedContent.Count; contentIndex++)
                {
                    // Write divider, headers, and content.
                    HttpContent content = _nestedContent[contentIndex];
                    SerializeHeadersToStream(stream, content, writeDivider: contentIndex != 0);
                    content.CopyTo(stream, context, cancellationToken);
                }
 
                // Write footer boundary.
                WriteToStream(stream, CrLf + "--" + _boundary + "--" + CrLf);
            }
            catch (Exception ex)
            {
                if (NetEventSource.Log.IsEnabled()) NetEventSource.Error(this, ex);
                throw;
            }
        }
 
        // for-each content
        //   write "--" + boundary
        //   for-each content header
        //     write header: header-value
        //   write content.CopyTo[Async]
        // write "--" + boundary + "--"
        // Can't be canceled directly by the user.  If the overall request is canceled
        // then the stream will be closed an exception thrown.
        protected override Task SerializeToStreamAsync(Stream stream, TransportContext? context) =>
            SerializeToStreamAsyncCore(stream, context, default);
 
        protected override Task SerializeToStreamAsync(Stream stream, TransportContext? context, CancellationToken cancellationToken) =>
            // Only skip the original protected virtual SerializeToStreamAsync if this
            // isn't a derived type that may have overridden the behavior.
            GetType() == typeof(MultipartContent) ? SerializeToStreamAsyncCore(stream, context, cancellationToken) :
            base.SerializeToStreamAsync(stream, context, cancellationToken);
 
        private protected async Task SerializeToStreamAsyncCore(Stream stream, TransportContext? context, CancellationToken cancellationToken)
        {
            Debug.Assert(stream != null);
            try
            {
                // Write start boundary.
                await EncodeStringToStreamAsync(stream, "--" + _boundary + CrLf, cancellationToken).ConfigureAwait(false);
 
                // Write each nested content.
                var output = new MemoryStream();
                for (int contentIndex = 0; contentIndex < _nestedContent.Count; contentIndex++)
                {
                    // Write divider, headers, and content.
                    HttpContent content = _nestedContent[contentIndex];
 
                    output.SetLength(0);
                    SerializeHeadersToStream(output, content, writeDivider: contentIndex != 0);
                    output.Position = 0;
                    await output.CopyToAsync(stream, cancellationToken).ConfigureAwait(false);
 
                    await content.CopyToAsync(stream, context, cancellationToken).ConfigureAwait(false);
                }
 
                // Write footer boundary.
                await EncodeStringToStreamAsync(stream, CrLf + "--" + _boundary + "--" + CrLf, cancellationToken).ConfigureAwait(false);
            }
            catch (Exception ex)
            {
                if (NetEventSource.Log.IsEnabled()) NetEventSource.Error(this, ex);
                throw;
            }
        }
 
        protected override Stream CreateContentReadStream(CancellationToken cancellationToken)
        {
            ValueTask<Stream> task = CreateContentReadStreamAsyncCore(async: false, cancellationToken);
            Debug.Assert(task.IsCompleted);
            return task.GetAwaiter().GetResult();
        }
 
        protected override Task<Stream> CreateContentReadStreamAsync() =>
            CreateContentReadStreamAsyncCore(async: true, CancellationToken.None).AsTask();
 
        protected override Task<Stream> CreateContentReadStreamAsync(CancellationToken cancellationToken) =>
            // Only skip the original protected virtual CreateContentReadStreamAsync if this
            // isn't a derived type that may have overridden the behavior.
            GetType() == typeof(MultipartContent) ? CreateContentReadStreamAsyncCore(async: true, cancellationToken).AsTask() :
            base.CreateContentReadStreamAsync(cancellationToken);
 
        private async ValueTask<Stream> CreateContentReadStreamAsyncCore(bool async, CancellationToken cancellationToken)
        {
            try
            {
                var streams = new Stream[2 + (_nestedContent.Count * 2)];
                int streamIndex = 0;
 
                // Start boundary.
                streams[streamIndex++] = EncodeStringToNewStream("--" + _boundary + CrLf);
 
                // Each nested content.
                for (int contentIndex = 0; contentIndex < _nestedContent.Count; contentIndex++)
                {
                    cancellationToken.ThrowIfCancellationRequested();
 
                    HttpContent nestedContent = _nestedContent[contentIndex];
                    streams[streamIndex++] = EncodeHeadersToNewStream(nestedContent, writeDivider: contentIndex != 0);
 
                    Stream readStream;
                    if (async)
                    {
                        readStream = nestedContent.TryReadAsStream() ?? await nestedContent.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false);
                    }
                    else
                    {
                        readStream = nestedContent.ReadAsStream(cancellationToken);
                    }
                    // Cannot be null, at least an empty stream is necessary.
                    readStream ??= new MemoryStream();
 
                    if (!readStream.CanSeek)
                    {
                        // Seekability impacts whether HttpClientHandlers are able to rewind. To maintain compat
                        // and to allow such use cases when a nested stream isn't seekable (which should be rare),
                        // we fall back to the base behavior. We don't dispose of the streams already obtained
                        // as we don't necessarily own them yet.
 
#pragma warning disable CA2016
                        // Do not pass a cancellationToken to base.CreateContentReadStreamAsync() as it would trigger an infinite loop => StackOverflow
                        return async ? await base.CreateContentReadStreamAsync().ConfigureAwait(false) : base.CreateContentReadStream(cancellationToken);
#pragma warning restore CA2016
                    }
                    streams[streamIndex++] = readStream;
                }
 
                // Footer boundary.
                streams[streamIndex] = EncodeStringToNewStream(CrLf + "--" + _boundary + "--" + CrLf);
 
                return new ContentReadStream(streams);
            }
            catch (Exception ex)
            {
                if (NetEventSource.Log.IsEnabled()) NetEventSource.Error(this, ex);
                throw;
            }
        }
 
        private void SerializeHeadersToStream(Stream stream, HttpContent content, bool writeDivider)
        {
            // Add divider.
            if (writeDivider) // Write divider for all but the first content.
            {
                WriteToStream(stream, CrLf + "--"); // const strings
                WriteToStream(stream, _boundary);
                WriteToStream(stream, CrLf);
            }
 
            // Add headers.
            foreach (KeyValuePair<string, HeaderStringValues> headerPair in content.Headers.NonValidated)
            {
                Encoding headerValueEncoding = HeaderEncodingSelector?.Invoke(headerPair.Key, content) ?? HttpRuleParser.DefaultHttpEncoding;
 
                WriteToStream(stream, headerPair.Key);
                WriteToStream(stream, ": ");
                string delim = string.Empty;
                foreach (string value in headerPair.Value)
                {
                    WriteToStream(stream, delim);
                    WriteToStream(stream, value, headerValueEncoding);
                    delim = ", ";
                }
                WriteToStream(stream, CrLf);
            }
 
            // Extra CRLF to end headers (even if there are no headers).
            WriteToStream(stream, CrLf);
        }
 
        private static ValueTask EncodeStringToStreamAsync(Stream stream, string input, CancellationToken cancellationToken)
        {
            byte[] buffer = HttpRuleParser.DefaultHttpEncoding.GetBytes(input);
            return stream.WriteAsync(new ReadOnlyMemory<byte>(buffer), cancellationToken);
        }
 
        private static MemoryStream EncodeStringToNewStream(string input)
        {
            return new MemoryStream(HttpRuleParser.DefaultHttpEncoding.GetBytes(input), writable: false);
        }
 
        private MemoryStream EncodeHeadersToNewStream(HttpContent content, bool writeDivider)
        {
            var stream = new MemoryStream();
            SerializeHeadersToStream(stream, content, writeDivider);
            stream.Position = 0;
            return stream;
        }
 
        internal override bool AllowDuplex => false;
 
        protected internal override bool TryComputeLength(out long length)
        {
            // Start Boundary.
            long currentLength = DashDashLength + _boundary.Length + CrLfLength;
 
            if (_nestedContent.Count > 1)
            {
                // Internal boundaries
                currentLength += (_nestedContent.Count - 1) * (CrLfLength + DashDashLength + _boundary.Length + CrLfLength);
            }
 
            foreach (HttpContent content in _nestedContent)
            {
                // Headers.
                foreach (KeyValuePair<string, HeaderStringValues> headerPair in content.Headers.NonValidated)
                {
                    currentLength += headerPair.Key.Length + ColonSpaceLength;
 
                    Encoding headerValueEncoding = HeaderEncodingSelector?.Invoke(headerPair.Key, content) ?? HttpRuleParser.DefaultHttpEncoding;
 
                    int valueCount = 0;
                    foreach (string value in headerPair.Value)
                    {
                        currentLength += headerValueEncoding.GetByteCount(value);
                        valueCount++;
                    }
 
                    if (valueCount > 1)
                    {
                        currentLength += (valueCount - 1) * CommaSpaceLength;
                    }
 
                    currentLength += CrLfLength;
                }
 
                currentLength += CrLfLength;
 
                // Content.
                if (!content.TryComputeLength(out long tempContentLength))
                {
                    length = 0;
                    return false;
                }
                currentLength += tempContentLength;
            }
 
            // Terminating boundary.
            currentLength += CrLfLength + DashDashLength + _boundary.Length + DashDashLength + CrLfLength;
 
            length = currentLength;
            return true;
        }
 
        private sealed class ContentReadStream : Stream
        {
            private readonly Stream[] _streams;
            private readonly long _length;
 
            private int _next;
            private Stream? _current;
            private long _position;
 
            internal ContentReadStream(Stream[] streams)
            {
                Debug.Assert(streams != null);
                _streams = streams;
                foreach (Stream stream in streams)
                {
                    _length += stream.Length;
                }
            }
 
            protected override void Dispose(bool disposing)
            {
                if (disposing)
                {
                    foreach (Stream s in _streams)
                    {
                        s.Dispose();
                    }
                }
            }
 
            public override async ValueTask DisposeAsync()
            {
                foreach (Stream s in _streams)
                {
                    await s.DisposeAsync().ConfigureAwait(false);
                }
            }
 
            public override bool CanRead => true;
            public override bool CanSeek => true;
            public override bool CanWrite => false;
 
            public override int Read(byte[] buffer, int offset, int count)
            {
                ValidateBufferArguments(buffer, offset, count);
                if (count == 0)
                {
                    return 0;
                }
 
                while (true)
                {
                    if (_current != null)
                    {
                        int bytesRead = _current.Read(buffer, offset, count);
                        if (bytesRead != 0)
                        {
                            _position += bytesRead;
                            return bytesRead;
                        }
 
                        _current = null;
                    }
 
                    if (_next >= _streams.Length)
                    {
                        return 0;
                    }
 
                    _current = _streams[_next++];
                }
            }
 
            public override int ReadByte()
            {
                byte b = 0;
                return Read(new Span<byte>(ref b)) == 1 ? b : -1;
            }
 
            public override int Read(Span<byte> buffer)
            {
                if (buffer.Length == 0)
                {
                    return 0;
                }
 
                while (true)
                {
                    if (_current != null)
                    {
                        int bytesRead = _current.Read(buffer);
                        if (bytesRead != 0)
                        {
                            _position += bytesRead;
                            return bytesRead;
                        }
 
                        _current = null;
                    }
 
                    if (_next >= _streams.Length)
                    {
                        return 0;
                    }
 
                    _current = _streams[_next++];
                }
            }
 
            public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
            {
                ValidateBufferArguments(buffer, offset, count);
                return ReadAsyncPrivate(new Memory<byte>(buffer, offset, count), cancellationToken).AsTask();
            }
 
            public override ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default) =>
                ReadAsyncPrivate(buffer, cancellationToken);
 
            public override IAsyncResult BeginRead(byte[] array, int offset, int count, AsyncCallback? asyncCallback, object? asyncState) =>
                TaskToAsyncResult.Begin(ReadAsync(array, offset, count, CancellationToken.None), asyncCallback, asyncState);
 
            public override int EndRead(IAsyncResult asyncResult) =>
                TaskToAsyncResult.End<int>(asyncResult);
 
            public async ValueTask<int> ReadAsyncPrivate(Memory<byte> buffer, CancellationToken cancellationToken)
            {
                if (buffer.Length == 0)
                {
                    return 0;
                }
 
                while (true)
                {
                    if (_current != null)
                    {
                        int bytesRead = await _current.ReadAsync(buffer, cancellationToken).ConfigureAwait(false);
                        if (bytesRead != 0)
                        {
                            _position += bytesRead;
                            return bytesRead;
                        }
 
                        _current = null;
                    }
 
                    if (_next >= _streams.Length)
                    {
                        return 0;
                    }
 
                    _current = _streams[_next++];
                }
            }
 
            public override long Position
            {
                get { return _position; }
                set
                {
                    ArgumentOutOfRangeException.ThrowIfNegative(value);
 
                    long previousStreamsLength = 0;
                    for (int i = 0; i < _streams.Length; i++)
                    {
                        Stream curStream = _streams[i];
                        long curLength = curStream.Length;
 
                        if (value < previousStreamsLength + curLength)
                        {
                            _current = curStream;
                            i++;
                            _next = i;
 
                            curStream.Position = value - previousStreamsLength;
                            for (; i < _streams.Length; i++)
                            {
                                _streams[i].Position = 0;
                            }
 
                            _position = value;
                            return;
                        }
 
                        previousStreamsLength += curLength;
                    }
 
                    _current = null;
                    _next = _streams.Length;
                    _position = value;
                }
            }
 
            public override long Seek(long offset, SeekOrigin origin)
            {
                switch (origin)
                {
                    case SeekOrigin.Begin:
                        Position = offset;
                        break;
 
                    case SeekOrigin.Current:
                        Position += offset;
                        break;
 
                    case SeekOrigin.End:
                        Position = _length + offset;
                        break;
 
                    default:
                        throw new ArgumentOutOfRangeException(nameof(origin));
                }
 
                return Position;
            }
 
            public override long Length => _length;
 
            public override void Flush() { }
            public override Task FlushAsync(CancellationToken cancellationToken) => Task.CompletedTask;
 
            public override void SetLength(long value) { throw new NotSupportedException(); }
            public override void Write(byte[] buffer, int offset, int count) { throw new NotSupportedException(); }
            public override void Write(ReadOnlySpan<byte> buffer) { throw new NotSupportedException(); }
            public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { throw new NotSupportedException(); }
            public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default) { throw new NotSupportedException(); }
        }
 
 
        private static void WriteToStream(Stream stream, string content) =>
            WriteToStream(stream, content, HttpRuleParser.DefaultHttpEncoding);
 
        private static void WriteToStream(Stream stream, string content, Encoding encoding)
        {
            const int StackallocThreshold = 1024;
 
            int maxLength = encoding.GetMaxByteCount(content.Length);
 
            byte[]? rentedBuffer = null;
            Span<byte> buffer = maxLength <= StackallocThreshold
                ? stackalloc byte[StackallocThreshold]
                : (rentedBuffer = ArrayPool<byte>.Shared.Rent(maxLength));
 
            try
            {
                int written = encoding.GetBytes(content, buffer);
                stream.Write(buffer.Slice(0, written));
            }
            finally
            {
                if (rentedBuffer != null)
                {
                    ArrayPool<byte>.Shared.Return(rentedBuffer);
                }
            }
        }
 
        #endregion Serialization
    }
}