File: System\ServiceModel\Channels\ClientWebSocketTransportDuplexSessionChannel.cs
Web Access
Project: src\src\System.ServiceModel.Http\src\System.ServiceModel.Http.csproj (System.ServiceModel.Http)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
 
using System.Collections.Generic;
using System.Diagnostics.Contracts;
using System.IdentityModel.Selectors;
using System.IdentityModel.Tokens;
using System.Net;
using System.Net.Http;
using System.Net.Security;
using System.Net.WebSockets;
using System.Runtime;
using System.Security.Cryptography.X509Certificates;
using System.Security.Principal;
using System.ServiceModel.Security.Tokens;
using System.Threading;
using System.Threading.Tasks;
 
namespace System.ServiceModel.Channels
{
    internal class ClientWebSocketTransportDuplexSessionChannel : WebSocketTransportDuplexSessionChannel
    {
        private HttpChannelFactory<IDuplexSessionChannel> _channelFactory;
        private SecurityTokenProviderContainer _webRequestTokenProvider;
        private SecurityTokenProviderContainer _webRequestProxyTokenProvider;
        private volatile bool _cleanupStarted;
 
        public ClientWebSocketTransportDuplexSessionChannel(HttpChannelFactory<IDuplexSessionChannel> channelFactory, EndpointAddress remoteAddress, Uri via)
            : base(channelFactory, remoteAddress, via)
        {
            Contract.Assert(channelFactory != null, "connection factory must be set");
            _channelFactory = channelFactory;
        }
 
        protected override bool IsStreamedOutput
        {
            get { return TransferModeHelper.IsRequestStreamed(TransferMode); }
        }
 
        protected override IAsyncResult OnBeginOpen(TimeSpan timeout, AsyncCallback callback, object state)
        {
            return CommunicationObjectInternal.OnBeginOpen(this, timeout, callback, state);
        }
 
        protected override void OnEndOpen(IAsyncResult result)
        {
            CommunicationObjectInternal.OnEnd(result);
        }
 
        protected override void OnOpen(TimeSpan timeout)
        {
            CommunicationObjectInternal.OnOpen(this, timeout);
        }
 
        protected internal override async Task OnOpenAsync(TimeSpan timeout)
        {
            TimeoutHelper helper = new TimeoutHelper(timeout);
 
            bool success = false;
            try
            {
                if (WcfEventSource.Instance.WebSocketConnectionRequestSendStartIsEnabled())
                {
                    WcfEventSource.Instance.WebSocketConnectionRequestSendStart(
                        EventTraceActivity,
                        RemoteAddress != null ? RemoteAddress.ToString() : string.Empty);
                }
 
                try
                {
                    var clientWebSocket = new ClientWebSocket();
                    await ConfigureClientWebSocketAsync(clientWebSocket, helper.RemainingTime());
                    await clientWebSocket.ConnectAsync(Via, await helper.GetCancellationTokenAsync());
                    ValidateWebSocketConnection(clientWebSocket);
                    WebSocket = clientWebSocket;
                }
                finally
                {
                    if (WebSocket != null && _cleanupStarted)
                    {
                        WebSocket.Abort();
                        CommunicationObjectAbortedException communicationObjectAbortedException = new CommunicationObjectAbortedException(
                            new WebSocketException(WebSocketError.ConnectionClosedPrematurely).Message);
                        FxTrace.Exception.AsWarning(communicationObjectAbortedException);
                        throw communicationObjectAbortedException;
                    }
                }
 
                bool inputUseStreaming = TransferModeHelper.IsResponseStreamed(TransferMode);
 
                SetMessageSource(new WebSocketMessageSource(
                    this,
                    WebSocket,
                    inputUseStreaming,
                    this));
 
                success = true;
 
                if (WcfEventSource.Instance.WebSocketConnectionRequestSendStopIsEnabled())
                {
                    WcfEventSource.Instance.WebSocketConnectionRequestSendStop(
                        EventTraceActivity,
                        WebSocket != null ? WebSocket.GetHashCode() : -1);
                }
            }
            catch (WebSocketException ex)
            {
                if (WcfEventSource.Instance.WebSocketConnectionFailedIsEnabled())
                {
                    WcfEventSource.Instance.WebSocketConnectionFailed(EventTraceActivity, ex.Message);
                }
 
                TryConvertAndThrow(ex);
            }
            finally
            {
                CleanupTokenProviders();
                if (!success)
                {
                    CleanupOnError();
                }
            }
        }
 
        private void ValidateWebSocketConnection(ClientWebSocket clientWebSocket)
        {
            string requested = WebSocketSettings.SubProtocol;
            string obtained = clientWebSocket.SubProtocol;
            if (!(requested == null ? string.IsNullOrWhiteSpace(obtained) : requested.Equals(obtained, StringComparison.OrdinalIgnoreCase)))
            {
                clientWebSocket.Dispose();
                throw FxTrace.Exception.AsError(new InvalidOperationException(SR.Format(SR.WebSocketInvalidProtocolNotInClientList, obtained, requested)));
            }
        }
 
        private async Task ConfigureClientWebSocketAsync(ClientWebSocket clientWebSocket, TimeSpan timeout)
        {
            TimeoutHelper helper = new TimeoutHelper(timeout);
            ChannelParameterCollection channelParameterCollection = new ChannelParameterCollection();
            if (HttpChannelFactory<IDuplexSessionChannel>.MapIdentity(RemoteAddress, _channelFactory.AuthenticationScheme))
            {
                clientWebSocket.Options.SetRequestHeader("Host", HttpTransportSecurityHelpers.GetIdentityHostHeader(RemoteAddress));
            }
 
            (_webRequestTokenProvider, _webRequestProxyTokenProvider) =
                await _channelFactory.CreateAndOpenTokenProvidersAsync(
                    RemoteAddress,
                    Via,
                    channelParameterCollection,
                    helper.RemainingTime());
 
            SecurityTokenContainer clientCertificateToken = null;
            if (_channelFactory is HttpsChannelFactory<IDuplexSessionChannel> httpsChannelFactory && httpsChannelFactory.RequireClientCertificate)
            {
                SecurityTokenProvider certificateProvider = await httpsChannelFactory.CreateAndOpenCertificateTokenProviderAsync(RemoteAddress, Via, channelParameterCollection, helper.RemainingTime());
                clientCertificateToken = await httpsChannelFactory.GetCertificateSecurityTokenAsync(certificateProvider, RemoteAddress, Via, channelParameterCollection, helper);
                if (clientCertificateToken != null)
                {
                    X509SecurityToken x509Token = (X509SecurityToken)clientCertificateToken.Token;
                    clientWebSocket.Options.ClientCertificates.Add(x509Token.Certificate);
                }
 
                if (httpsChannelFactory.WebSocketCertificateCallback != null)
                {
                    clientWebSocket.Options.RemoteCertificateValidationCallback = httpsChannelFactory.WebSocketCertificateCallback;
                }
            }
 
            if (WebSocketSettings.SubProtocol != null)
            {
                clientWebSocket.Options.AddSubProtocol(WebSocketSettings.SubProtocol);
            }
 
            // These headers were added for WCF specific handshake to avoid encoder or transfermode mismatch between client and server.
            // For BinaryMessageEncoder, since we are using a sessionful channel for websocket, the encoder is actually different when
            // we are using Buffered or Stramed transfermode. So we need an extra header to identify the transfermode we are using, just
            // to make people a little bit easier to diagnose these mismatch issues.
            if (_channelFactory.MessageVersion != MessageVersion.None)
            {
                clientWebSocket.Options.SetRequestHeader(WebSocketTransportSettings.SoapContentTypeHeader, _channelFactory.WebSocketSoapContentType);
 
                if (_channelFactory.MessageEncoderFactory is BinaryMessageEncoderFactory)
                {
                    clientWebSocket.Options.SetRequestHeader(WebSocketTransportSettings.BinaryEncoderTransferModeHeader, _channelFactory.TransferMode.ToString());
                }
            }
 
            (NetworkCredential credential, TokenImpersonationLevel impersonationLevel, AuthenticationLevel authenticationLevel) =
                await HttpChannelUtilities.GetCredentialAsync(_channelFactory.AuthenticationScheme, _webRequestTokenProvider, timeout);
 
            if (_channelFactory.Proxy != null)
            {
                clientWebSocket.Options.Proxy = _channelFactory.Proxy;
            }
            else if (_channelFactory.ProxyFactory != null)
            {
                clientWebSocket.Options.Proxy = await _channelFactory.ProxyFactory.CreateWebProxyAsync(
                    authenticationLevel,
                    impersonationLevel,
                    _webRequestProxyTokenProvider,
                    helper.RemainingTime());
            }
 
            if (credential == CredentialCache.DefaultCredentials || credential == null)
            {
                if (_channelFactory.AuthenticationScheme != AuthenticationSchemes.Anonymous)
                {
                    clientWebSocket.Options.UseDefaultCredentials = true;
                }
            }
            else
            {
                clientWebSocket.Options.UseDefaultCredentials = false;
                CredentialCache credentials = new CredentialCache();
                Uri credentialCacheUriPrefix = _channelFactory.GetCredentialCacheUriPrefix(Via);
                if (_channelFactory.AuthenticationScheme == AuthenticationSchemes.IntegratedWindowsAuthentication)
                {
                    credentials.Add(credentialCacheUriPrefix, AuthenticationSchemesHelper.ToString(AuthenticationSchemes.Negotiate),
                        credential);
                    credentials.Add(credentialCacheUriPrefix, AuthenticationSchemesHelper.ToString(AuthenticationSchemes.Ntlm),
                        credential);
                }
                else
                {
                    credentials.Add(credentialCacheUriPrefix, AuthenticationSchemesHelper.ToString(_channelFactory.AuthenticationScheme),
                        credential);
                }
 
                clientWebSocket.Options.Credentials = credentials;
            }
 
            if (_channelFactory.AllowCookies)
            {
                var cookieContainerManager = _channelFactory.GetHttpCookieContainerManager();
                clientWebSocket.Options.Cookies = cookieContainerManager.CookieContainer;
            }
 
            clientWebSocket.Options.KeepAliveInterval = _channelFactory.WebSocketSettings.KeepAliveInterval;
        }
 
        protected override void OnCleanup()
        {
            _cleanupStarted = true;
            base.OnCleanup();
        }
 
        private static void TryConvertAndThrow(WebSocketException ex)
        {
            switch (ex.WebSocketErrorCode)
            {
                //case WebSocketError.Success:
                //case WebSocketError.InvalidMessageType:
                //case WebSocketError.Faulted:
                //case WebSocketError.NativeError:
                //case WebSocketError.NotAWebSocket:
                case WebSocketError.UnsupportedVersion:
                    throw FxTrace.Exception.AsError(new CommunicationException(SR.Format(SR.WebSocketVersionMismatchFromServer, ""), ex));
                case WebSocketError.UnsupportedProtocol:
                    throw FxTrace.Exception.AsError(new CommunicationException(SR.Format(SR.WebSocketSubProtocolMismatchFromServer, ""), ex));
                //case WebSocketError.HeaderError:
                //case WebSocketError.ConnectionClosedPrematurely:
                //case WebSocketError.InvalidState:
                default:
                    throw FxTrace.Exception.AsError(new CommunicationException(ex.Message, ex));
            }
        }
 
        private void CleanupOnError()
        {
            Cleanup();
        }
 
        private void CleanupTokenProviders()
        {
            if (_webRequestTokenProvider != null)
            {
                _webRequestTokenProvider.Abort();
                _webRequestTokenProvider = null;
            }
 
            if (_webRequestProxyTokenProvider != null)
            {
                _webRequestProxyTokenProvider.Abort();
                _webRequestProxyTokenProvider = null;
            }
        }
    }
}