File: System\ServiceModel\Channels\HttpResponseMessageHelper.cs
Web Access
Project: src\src\System.ServiceModel.Http\src\System.ServiceModel.Http.csproj (System.ServiceModel.Http)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
 
using System.Diagnostics.Contracts;
using System.IO;
using System.Net;
using System.Net.Http;
using System.Runtime;
using System.ServiceModel.Security;
using System.Threading;
using System.Threading.Tasks;
using System.Xml;
 
namespace System.ServiceModel.Channels
{
    internal class HttpResponseMessageHelper
    {
        private readonly HttpChannelFactory<IRequestChannel> _factory;
        private readonly MessageEncoder _encoder;
        private readonly HttpRequestMessage _httpRequestMessage;
        private readonly HttpResponseMessage _httpResponseMessage;
        private string _contentType;
        private long _contentLength;
 
        public HttpResponseMessageHelper(HttpResponseMessage httpResponseMessage, HttpChannelFactory<IRequestChannel> factory)
        {
            Contract.Assert(httpResponseMessage != null);
            Contract.Assert(httpResponseMessage.RequestMessage != null);
            Contract.Assert(factory != null);
            _httpResponseMessage = httpResponseMessage;
            _httpRequestMessage = httpResponseMessage.RequestMessage;
            _factory = factory;
            _encoder = factory.MessageEncoderFactory.Encoder;
        }
 
        internal async Task<Message> ParseIncomingResponse(TimeoutHelper timeoutHelper)
        {
            ValidateAuthentication();
            ValidateResponseStatusCode();
            bool hasContent = await ValidateContentTypeAsync(timeoutHelper);
            Message message = null;
 
            if (!hasContent)
            {
                if (_encoder.MessageVersion == MessageVersion.None)
                {
                    message = new NullMessage();
                }
                else
                {
                    return null;
                }
            }
            else
            {
                message = await ReadStreamAsMessageAsync(timeoutHelper);
            }
 
            var exception = ProcessHttpAddressing(message);
            Contract.Assert(exception == null, "ProcessHttpAddressing should not set an exception after parsing a response message.");
 
            return message;
        }
 
        private Exception ProcessHttpAddressing(Message message)
        {
            Exception result = null;
            AddProperties(message);
 
            // check if user is receiving WS-1 messages
            if (message.Version.Addressing == AddressingVersion.None)
            {
                bool actionAbsent = false;
                try
                {
                    actionAbsent = (message.Headers.Action == null);
                    // message.Headers.Action uses an XmlDictionaryReader. If the xml is malformed,
                    // an XmlException might be thrown when trying to parse the response data.
                    // CommunicationException is the base type for any ServiceModel exceptions. If anything went
                    // wrong in any ServiceModel code, an exception deriving from CommunicationException will be
                    // thrown. 
                    // In these cases, be tolerant of the failure and treat it as though the action is absent.
                }
                catch (XmlException)
                {
                }
                catch (CommunicationException)
                {
                }
                if (!actionAbsent)
                {
                    result = new ProtocolException(SR.Format(SR.HttpAddressingNoneHeaderOnWire,
                        XD.AddressingDictionary.Action.Value));
                }
 
                bool toAbsent = false;
                try
                {
                    toAbsent = (message.Headers.To == null);
                    // message.Headers.To has the same failure modes as for the Action header.
                }
                catch (XmlException)
                {
                }
                catch (CommunicationException)
                {
                }
 
                if (!toAbsent)
                {
                    result = new ProtocolException(SR.Format(SR.HttpAddressingNoneHeaderOnWire,
                        XD.AddressingDictionary.To.Value));
                }
                message.Headers.To = message.Properties.Via;
            }
 
            return result;
        }
 
        private void AddProperties(Message message)
        {
            HttpResponseMessageProperty responseProperty = new HttpResponseMessageProperty(_httpResponseMessage);
            message.Properties.Add(HttpResponseMessageProperty.Name, responseProperty);
            message.Properties.Add("_System.Net.HttpStatusCode", responseProperty.StatusCode);
            message.Properties.Via = message.Version.Addressing.AnonymousUri;
        }
 
        private async Task<bool> ValidateContentTypeAsync(TimeoutHelper timeoutHelper)
        {
            var content = _httpResponseMessage.Content;
            if (content != null)
            {
                var mediaValueContentType = content.Headers.ContentType;
                _contentType = mediaValueContentType == null ? string.Empty : mediaValueContentType.ToString();
                _contentLength = content.Headers.ContentLength.HasValue ? content.Headers.ContentLength.Value : -1;
            }
 
            if (string.IsNullOrEmpty(_contentType))
            {
                Stream contentStream = await GetStreamAsync(timeoutHelper);
                if (contentStream != null)
                {
                    throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new ProtocolException(SR.HttpContentTypeHeaderRequired));
                }
                return false;
            }
            else if (_contentLength != 0)
            {
                if (!_encoder.IsContentTypeSupported(_contentType))
                {
                    int bytesToRead = (int)_contentLength;
                    Stream contentStream = await GetStreamAsync(timeoutHelper);
                    string responseExcerpt = HttpChannelUtilities.GetResponseStreamExcerptString(contentStream, ref bytesToRead);
                    throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(HttpChannelUtilities.TraceResponseException(
                        new ProtocolException(
                            SR.Format(
                                SR.ResponseContentTypeMismatch,
                                _contentType,
                                _encoder.ContentType,
                                bytesToRead,
                                responseExcerpt))));
                }
            }
            return true;
        }
 
        private Task<Message> ReadStreamAsMessageAsync(TimeoutHelper timeoutHelper)
        {
            var content = _httpResponseMessage.Content;
            Task<Stream> contentStreamTask = GetStreamAsync(timeoutHelper);
 
            if (TransferModeHelper.IsResponseStreamed(_factory.TransferMode))
            {
                return ReadStreamedMessageAsync(contentStreamTask);
            }
 
            if (!content.Headers.ContentLength.HasValue)
            {
                return ReadChunkedBufferedMessageAsync(contentStreamTask, timeoutHelper);
            }
 
            return ReadBufferedMessageAsync(contentStreamTask, timeoutHelper);
        }
 
        private async Task<Message> ReadChunkedBufferedMessageAsync(Task<Stream> inputStreamTask, TimeoutHelper timeoutHelper)
        {
            try
            {
                return await _encoder.ReadMessageAsync(await inputStreamTask, _factory.BufferManager, _factory.MaxBufferSize, _contentType, await timeoutHelper.GetCancellationTokenAsync());
            }
            catch (XmlException xmlException)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(
                    new ProtocolException(SR.MessageXmlProtocolError, xmlException));
            }
        }
 
        private async Task<Message> ReadBufferedMessageAsync(Task<Stream> inputStreamTask, TimeoutHelper timeoutHelper)
        {
            var inputStream = await inputStreamTask;
            if (_contentLength > _factory.MaxReceivedMessageSize)
            {
                ThrowMaxReceivedMessageSizeExceeded();
            }
 
            int bufferSize = (int)_contentLength;
 
            var messageBuffer = new ArraySegment<byte>(_factory.BufferManager.TakeBuffer(bufferSize), 0, bufferSize);
 
            byte[] buffer = messageBuffer.Array;
            int offset = 0;
            int count = messageBuffer.Count;
            var ct = await timeoutHelper.GetCancellationTokenAsync();
 
            while (count > 0)
            {
                int bytesRead = await inputStream.ReadAsync(buffer, offset, count, ct);
                if (bytesRead == 0) // EOF 
                {
                    if (_contentLength != -1)
                    {
                        throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(
                            new ProtocolException(SR.HttpContentLengthIncorrect));
                    }
 
                    break;
                }
                count -= bytesRead;
                offset += bytesRead;
            }
 
            return await DecodeBufferedMessageAsync(new ArraySegment<byte>(buffer, 0, offset), inputStream, timeoutHelper);
        }
 
        private async Task<Message> ReadStreamedMessageAsync(Task<Stream> inputStreamTask)
        {
            var inputStream = await inputStreamTask;
            var bufferedInputStream = inputStream as BufferedReadStream;
            MaxMessageSizeStream maxMessageSizeStream = new MaxMessageSizeStream(inputStream, _factory.MaxReceivedMessageSize);
 
            try
            {
                var message = await _encoder.ReadMessageAsync(maxMessageSizeStream, _factory.MaxBufferSize, _contentType);
                if (bufferedInputStream != null)
                {
                    message.Properties[BufferedReadStream.BufferedReadStreamPropertyName] = bufferedInputStream;
                }
 
                return message;
            }
            catch (XmlException xmlException)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(
                    new ProtocolException(SR.MessageXmlProtocolError, xmlException));
            }
        }
 
        private void ThrowMaxReceivedMessageSizeExceeded()
        {
            if (WcfEventSource.Instance.MaxReceivedMessageSizeExceededIsEnabled())
            {
                WcfEventSource.Instance.MaxReceivedMessageSizeExceeded(SR.Format(SR.MaxReceivedMessageSizeExceeded, _factory.MaxReceivedMessageSize));
            }
 
            string message = SR.Format(SR.MaxReceivedMessageSizeExceeded, _factory.MaxReceivedMessageSize);
            Exception inner = new QuotaExceededException(message);
            throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new CommunicationException(message, inner));
        }
 
        private async Task<Message> DecodeBufferedMessageAsync(ArraySegment<byte> buffer, Stream inputStream, TimeoutHelper timeoutHelper)
        {
            try
            {
                var ct = await timeoutHelper.GetCancellationTokenAsync();
                // if we're chunked, make sure we've consumed the whole body
                if (_contentLength == -1 && buffer.Count == _factory.MaxReceivedMessageSize)
                {
                    byte[] extraBuffer = new byte[1];
                    int extraReceived = await inputStream.ReadAsync(extraBuffer, 0, 1, ct);
                    if (extraReceived > 0)
                    {
                        ThrowMaxReceivedMessageSizeExceeded();
                    }
                }
 
                try
                {
                    return _encoder.ReadMessage(buffer, _factory.BufferManager, _contentType);
                }
                catch (XmlException xmlException)
                {
                    throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(
                        new ProtocolException(SR.MessageXmlProtocolError, xmlException));
                }
            }
            finally
            {
                inputStream.Dispose();
            }
        }
 
        private async Task<Stream> GetStreamAsync(TimeoutHelper timeoutHelper)
        {
            var content = _httpResponseMessage.Content;
            Stream contentStream = null;
            _contentLength = -1;
            if (content != null)
            {
                contentStream = await content.ReadAsStreamAsync();
                _contentLength = content.Headers.ContentLength.HasValue ? content.Headers.ContentLength.Value : -1;
                var cancellationToken = await timeoutHelper.GetCancellationTokenAsync();
                if (_contentLength <= 0)
                {
                    var preReadBuffer = new byte[1];
                    if (await contentStream.ReadAsync(preReadBuffer, 0, 1, cancellationToken) == 0)
                    {
                        contentStream.Dispose();
                        contentStream = null;
                    }
                    else
                    {
                        var bufferedStream = new BufferedReadStream(contentStream, _factory.BufferManager);
                        await bufferedStream.PreReadBufferAsync(preReadBuffer[0], cancellationToken);
                        contentStream = bufferedStream;
                    }
                }
                else if (TransferModeHelper.IsResponseStreamed(_factory.TransferMode))
                {
                    // If _contentLength > 0, then the message was sent buffered but we might still
                    // be receiving it streamed. In which case we need a buffered reading stream.
                    var bufferedStream = new BufferedReadStream(contentStream, _factory.BufferManager);
                    await bufferedStream.PreReadBufferAsync(cancellationToken);
                    contentStream = bufferedStream;
                }
            }
 
            return contentStream;
        }
 
        private void ValidateResponseStatusCode()
        {
            if (((int)_httpResponseMessage.StatusCode < 200 || (int)_httpResponseMessage.StatusCode >= 300) && _httpResponseMessage.StatusCode != HttpStatusCode.InternalServerError)
            {
                if (_httpResponseMessage.StatusCode == HttpStatusCode.NotFound)
                {
                    throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new EndpointNotFoundException(SR.Format(SR.EndpointNotFound, _httpRequestMessage.RequestUri.AbsoluteUri)));
                }
 
                if (_httpResponseMessage.StatusCode == HttpStatusCode.ServiceUnavailable)
                {
                    throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new ServerTooBusyException(SR.Format(SR.HttpServerTooBusy, _httpRequestMessage.RequestUri.AbsoluteUri)));
                }
 
                if (_httpResponseMessage.StatusCode == HttpStatusCode.UnsupportedMediaType)
                {
                    string statusDescription = _httpResponseMessage.ReasonPhrase;
                    if (!string.IsNullOrEmpty(statusDescription))
                    {
                        if (string.Compare(statusDescription, HttpChannelUtilities.StatusDescriptionStrings.HttpContentTypeMissing, StringComparison.OrdinalIgnoreCase) == 0)
                        {
                            throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new ProtocolException(SR.Format(SR.MissingContentType, _httpRequestMessage.RequestUri)));
                        }
                    }
                    throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new ProtocolException(SR.Format(SR.FramingContentTypeMismatch, _httpRequestMessage.Content.Headers.ContentType.ToString(), _httpRequestMessage.RequestUri)));
                }
 
                if (_httpResponseMessage.StatusCode == HttpStatusCode.GatewayTimeout)
                {
                    throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new TimeoutException(_httpResponseMessage.StatusCode + " " + _httpResponseMessage.ReasonPhrase));
                }
 
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(HttpChannelUtilities.CreateUnexpectedResponseException(_httpResponseMessage));
            }
        }
 
        private void ValidateAuthentication()
        {
            if (_httpResponseMessage.StatusCode == HttpStatusCode.Unauthorized)
            {
                string message = SR.Format(SR.HttpAuthorizationFailed, _factory.AuthenticationScheme,
                    _httpResponseMessage.Headers.WwwAuthenticate.ToString());
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(
                    HttpChannelUtilities.TraceResponseException(new MessageSecurityException(message)));
            }
 
            if (_httpResponseMessage.StatusCode == HttpStatusCode.Forbidden)
            {
                string message = SR.Format(SR.HttpAuthorizationForbidden, _factory.AuthenticationScheme);
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(
                    HttpChannelUtilities.TraceResponseException(new MessageSecurityException(message)));
            }
        }
    }
}