File: System\ServiceModel\Channels\SecurityChannelFactory.cs
Web Access
Project: src\src\System.ServiceModel.Primitives\src\System.ServiceModel.Primitives.csproj (System.ServiceModel.Primitives)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Runtime;
using System.ServiceModel.Security;
using System.Threading.Tasks;
 
namespace System.ServiceModel.Channels
{
    internal sealed class SecurityChannelFactory<TChannel> : LayeredChannelFactory<TChannel>
    {
        private SecuritySessionClientSettings<TChannel> _sessionClientSettings;
        private ISecurityCapabilities _securityCapabilities;
 
        public SecurityChannelFactory(ISecurityCapabilities securityCapabilities, BindingContext context,
            SecuritySessionClientSettings<TChannel> sessionClientSettings)
            : this(securityCapabilities, context, sessionClientSettings.ChannelBuilder, sessionClientSettings.CreateInnerChannelFactory())
        {
            SessionMode = true;
            _sessionClientSettings = sessionClientSettings;
        }
 
        public SecurityChannelFactory(ISecurityCapabilities securityCapabilities, BindingContext context, ChannelBuilder channelBuilder, SecurityProtocolFactory protocolFactory)
            : this(securityCapabilities, context, channelBuilder, protocolFactory, channelBuilder.BuildChannelFactory<TChannel>())
        {
        }
 
        public SecurityChannelFactory(ISecurityCapabilities securityCapabilities, BindingContext context, ChannelBuilder channelBuilder, SecurityProtocolFactory protocolFactory, IChannelFactory innerChannelFactory)
            : this(securityCapabilities, context, channelBuilder, innerChannelFactory)
        {
            SecurityProtocolFactory = protocolFactory;
        }
 
        private SecurityChannelFactory(ISecurityCapabilities securityCapabilities, BindingContext context, ChannelBuilder channelBuilder, IChannelFactory innerChannelFactory)
            : base(context.Binding, innerChannelFactory)
        {
            ChannelBuilder = channelBuilder;
            MessageVersion = context.Binding.MessageVersion;
            _securityCapabilities = securityCapabilities;
        }
 
        public ChannelBuilder ChannelBuilder { get; }
 
        public SecurityProtocolFactory SecurityProtocolFactory { get; private set; }
 
        public SecuritySessionClientSettings<TChannel> SessionClientSettings
        {
            get
            {
                Fx.Assert(SessionMode == true, "SessionClientSettings can only be used if SessionMode == true");
                return _sessionClientSettings;
            }
        }
 
        public bool SessionMode { get; }
 
        private bool SupportsDuplex
        {
            get
            {
                ThrowIfProtocolFactoryNotSet();
                return SecurityProtocolFactory.SupportsDuplex;
            }
        }
 
        private bool SupportsRequestReply
        {
            get
            {
                ThrowIfProtocolFactoryNotSet();
                return SecurityProtocolFactory.SupportsRequestReply;
            }
        }
 
        public MessageVersion MessageVersion { get; }
 
        private Task CloseProtocolFactoryAsync(bool aborted, TimeSpan timeout)
        {
            if (SecurityProtocolFactory != null && !SessionMode)
            {
                var factory = SecurityProtocolFactory;
                SecurityProtocolFactory = null;
                return factory.CloseAsync(aborted, timeout);
            }
 
            return Task.CompletedTask;
        }
 
        public override T GetProperty<T>()
        {
            if (SessionMode && (typeof(T) == typeof(IChannelSecureConversationSessionSettings)))
            {
                return (T)(object)SessionClientSettings;
            }
            else if (typeof(T) == typeof(ISecurityCapabilities))
            {
                return (T)(object)_securityCapabilities;
            }
 
            return base.GetProperty<T>();
        }
 
        protected override void OnAbort()
        {
            base.OnAbort();
            CloseProtocolFactoryAsync(true, TimeSpan.Zero);
            if (_sessionClientSettings != null)
            {
                _sessionClientSettings.Abort();
            }
        }
 
        protected override IAsyncResult OnBeginClose(TimeSpan timeout, AsyncCallback callback, object state)
        {
            return OnCloseAsync(timeout).ToApm(callback, state);
        }
 
        protected override void OnEndClose(IAsyncResult result)
        {
            result.ToApmEnd();
        }
 
        protected internal override async Task OnCloseAsync(TimeSpan timeout)
        {
            TimeoutHelper timeoutHelper = new TimeoutHelper(timeout);
            await base.OnCloseAsync(timeout);
            await CloseProtocolFactoryAsync(false, timeoutHelper.RemainingTime());
            if (_sessionClientSettings != null)
            {
                await _sessionClientSettings.CloseAsync(timeoutHelper.RemainingTime());
            }
        }
 
        protected override void OnClose(TimeSpan timeout)
        {
            OnCloseAsync(timeout).Wait();
        }
 
        protected override TChannel OnCreateChannel(EndpointAddress address, Uri via)
        {
            ThrowIfDisposed();
            if (SessionMode)
            {
                return _sessionClientSettings.OnCreateChannel(address, via);
            }
 
            if (typeof(TChannel) == typeof(IAsyncOutputChannel) || typeof(TChannel) == typeof(IOutputChannel))
            {
                return (TChannel)(object)new SecurityOutputChannel(this, SecurityProtocolFactory, ((IChannelFactory<IOutputChannel>)InnerChannelFactory).CreateChannel(address, via), address, via);
            }
            else if (typeof(TChannel) == typeof(IAsyncOutputSessionChannel) || typeof(TChannel) == typeof(IOutputSessionChannel))
            {
                return (TChannel)(object)new SecurityOutputSessionChannel(this, SecurityProtocolFactory, ((IChannelFactory<IOutputSessionChannel>)InnerChannelFactory).CreateChannel(address, via), address, via);
            }
            else if (typeof(TChannel) == typeof(IAsyncDuplexChannel) || typeof(TChannel) == typeof(IDuplexChannel))
            {
                return (TChannel)(object)new SecurityDuplexChannel(this, SecurityProtocolFactory, ((IChannelFactory<IDuplexChannel>)InnerChannelFactory).CreateChannel(address, via), address, via);
            }
            else if (typeof(TChannel) == typeof(IAsyncDuplexSessionChannel) || typeof(TChannel) == typeof(IDuplexSessionChannel))
            {
                return (TChannel)(object)new SecurityDuplexSessionChannel(this, SecurityProtocolFactory, ((IChannelFactory<IDuplexSessionChannel>)InnerChannelFactory).CreateChannel(address, via), address, via);
            }
            else if (typeof(TChannel) == typeof(IAsyncRequestChannel) || typeof(TChannel) == typeof(IRequestChannel))
            {
                return (TChannel)(object)new SecurityRequestChannel(this, SecurityProtocolFactory, ((IChannelFactory<IRequestChannel>)InnerChannelFactory).CreateChannel(address, via), address, via);
            }
 
            //typeof(TChannel) == typeof(IRequestSessionChannel)
            return (TChannel)(object)new SecurityRequestSessionChannel(this, SecurityProtocolFactory, ((IChannelFactory<IRequestSessionChannel>)InnerChannelFactory).CreateChannel(address, via), address, via);
        }
 
        protected internal override async Task OnOpenAsync(TimeSpan timeout)
        {
            TimeoutHelper timeoutHelper = new TimeoutHelper(timeout);
            await OnOpenCoreAsync(timeoutHelper.RemainingTime());
            await base.OnOpenAsync(timeoutHelper.RemainingTime());
            SetBufferManager();
        }
 
        protected override void OnOpen(TimeSpan timeout)
        {
            OnOpenAsync(timeout).Wait();
        }
 
        private void SetBufferManager()
        {
            ITransportFactorySettings transportSettings = GetProperty<ITransportFactorySettings>();
 
            if (transportSettings == null)
            {
                return;
            }
 
            BufferManager bufferManager = transportSettings.BufferManager;
 
            if (bufferManager == null)
            {
                return;
            }
 
            if (SessionMode && SessionClientSettings != null && SessionClientSettings.SessionProtocolFactory != null)
            {
                SessionClientSettings.SessionProtocolFactory.StreamBufferManager = bufferManager;
            }
            else
            {
                ThrowIfProtocolFactoryNotSet();
                SecurityProtocolFactory.StreamBufferManager = bufferManager;
            }
        }
 
        protected override IAsyncResult OnBeginOpen(TimeSpan timeout, AsyncCallback callback, object state)
        {
            return OnOpenAsync(timeout).ToApm(callback, state);
        }
 
        protected override void OnEndOpen(IAsyncResult result)
        {
            result.ToApmEnd();
        }
 
        private Task OnOpenCoreAsync(TimeSpan timeout)
        {
            TimeoutHelper timeoutHelper = new TimeoutHelper(timeout);
            if (SessionMode)
            {
                return SessionClientSettings.OpenAsync(this, InnerChannelFactory, ChannelBuilder, timeoutHelper.RemainingTime());
            }
            else
            {
                ThrowIfProtocolFactoryNotSet();
                return SecurityProtocolFactory.OpenAsync(true, timeoutHelper.RemainingTime());
            }
        }
 
        private void ThrowIfProtocolFactoryNotSet()
        {
            if (SecurityProtocolFactory == null)
            {
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new InvalidOperationException(
                    SRP.Format(SRP.SecurityProtocolFactoryShouldBeSetBeforeThisOperation)));
            }
        }
 
        private abstract class ClientSecurityChannel<UChannel> : SecurityChannel<UChannel>
            where UChannel : class, IChannel
        {
            private ChannelParameterCollection _channelParameters;
 
            protected ClientSecurityChannel(ChannelManagerBase factory, SecurityProtocolFactory securityProtocolFactory,
                UChannel innerChannel, EndpointAddress to, Uri via)
                : base(factory, innerChannel)
            {
                RemoteAddress = to;
                Via = via;
                SecurityProtocolFactory = securityProtocolFactory;
                _channelParameters = new ChannelParameterCollection(this);
            }
 
            protected SecurityProtocolFactory SecurityProtocolFactory { get; }
 
            public EndpointAddress RemoteAddress { get; }
 
            public Uri Via { get; }
 
            protected bool TryGetSecurityFaultException(Message faultMessage, out Exception faultException)
            {
                faultException = null;
                if (!faultMessage.IsFault)
                {
                    return false;
                }
                MessageFault fault = MessageFault.CreateFault(faultMessage, TransportDefaults.MaxSecurityFaultSize);
                faultException = SecurityUtils.CreateSecurityFaultException(fault);
                return true;
            }
 
            protected internal override async Task OnOpenAsync(TimeSpan timeout)
            {
                TimeoutHelper timeoutHelper = new TimeoutHelper(timeout);
                EnableChannelBindingSupport();
 
                SecurityProtocol securityProtocol = SecurityProtocolFactory.CreateSecurityProtocol(
                    RemoteAddress,
                    Via,
                    null,
                    typeof(TChannel) == typeof(IRequestChannel),
                    timeoutHelper.RemainingTime());
                OnProtocolCreationComplete(securityProtocol);
                await SecurityProtocol.OpenAsync(timeoutHelper.RemainingTime());
                await base.OnOpenAsync(timeoutHelper.RemainingTime());
            }
 
            private void EnableChannelBindingSupport()
            {
                if (SecurityProtocolFactory != null && SecurityProtocolFactory.ExtendedProtectionPolicy != null && SecurityProtocolFactory.ExtendedProtectionPolicy.CustomChannelBinding != null)
                {
                    throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new NotSupportedException(SRP.Format(SRP.ExtendedProtectionPolicyCustomChannelBindingNotSupported)));
                }
 
                // Do not enable channel binding if there is no reason as it sets up chunking mode.
                if ((SecurityUtils.IsChannelBindingDisabled) || (!SecurityUtils.IsSecurityBindingSuitableForChannelBinding(SecurityProtocolFactory.SecurityBindingElement as TransportSecurityBindingElement)))
                {
                    return;
                }
 
                if (InnerChannel != null)
                {
                    IChannelBindingProvider cbp = InnerChannel.GetProperty<IChannelBindingProvider>();
                    if (cbp != null)
                    {
                        cbp.EnableChannelBindingSupport();
                    }
                }
            }
 
            private void OnProtocolCreationComplete(SecurityProtocol securityProtocol)
            {
                SecurityProtocol = securityProtocol;
                SecurityProtocol.ChannelParameters = _channelParameters;
            }
 
            public override T GetProperty<T>()
            {
                if (typeof(T) == typeof(ChannelParameterCollection))
                {
                    return (T)(object)_channelParameters;
                }
 
                return base.GetProperty<T>();
            }
        }
 
        private class SecurityOutputChannel : ClientSecurityChannel<IOutputChannel>, IOutputChannel, IAsyncOutputChannel
        {
            public SecurityOutputChannel(ChannelManagerBase factory, SecurityProtocolFactory securityProtocolFactory, IOutputChannel innerChannel, EndpointAddress to, Uri via)
                : base(factory, securityProtocolFactory, innerChannel, to, via)
            {
            }
 
            public IAsyncResult BeginSend(Message message, AsyncCallback callback, object state)
            {
                return BeginSend(message, DefaultSendTimeout, callback, state);
            }
 
            public IAsyncResult BeginSend(Message message, TimeSpan timeout, AsyncCallback callback, object state)
            {
                return SendAsync(message, timeout).ToApm(callback, state);
            }
 
            public void EndSend(IAsyncResult result)
            {
                result.ToApmEnd();
            }
 
            public Task SendAsync(Message message)
            {
                return SendAsync(message, DefaultSendTimeout);
            }
 
            public async Task SendAsync(Message message, TimeSpan timeout)
            {
                ThrowIfFaulted();
                ThrowIfDisposedOrNotOpen(message);
                TimeoutHelper timeoutHelper = new TimeoutHelper(timeout);
                message = await SecurityProtocol.SecureOutgoingMessageAsync(message, timeoutHelper.RemainingTime());
                if (InnerChannel is IAsyncOutputChannel asyncOutputChannel)
                {
                    await asyncOutputChannel.SendAsync(message, timeoutHelper.RemainingTime());
                }
                else
                {
                    await Task.Factory.FromAsync(InnerChannel.BeginSend, InnerChannel.EndSend, message, timeoutHelper.RemainingTime(), null);
                }
            }
 
            public void Send(Message message)
            {
                Send(message, DefaultSendTimeout);
            }
 
            public void Send(Message message, TimeSpan timeout)
            {
                SendAsync(message, timeout).GetAwaiter().GetResult();
            }
        }
 
        private sealed class SecurityOutputSessionChannel : SecurityOutputChannel, IOutputSessionChannel
        {
            public SecurityOutputSessionChannel(ChannelManagerBase factory, SecurityProtocolFactory securityProtocolFactory, IOutputSessionChannel innerChannel, EndpointAddress to, Uri via)
                : base(factory, securityProtocolFactory, innerChannel, to, via)
            {
            }
 
            public IOutputSession Session
            {
                get
                {
                    return ((IOutputSessionChannel)InnerChannel).Session;
                }
            }
        }
 
        private class SecurityRequestChannel : ClientSecurityChannel<IRequestChannel>, IAsyncRequestChannel
        {
            public SecurityRequestChannel(ChannelManagerBase factory, SecurityProtocolFactory securityProtocolFactory, IRequestChannel innerChannel, EndpointAddress to, Uri via)
                : base(factory, securityProtocolFactory, innerChannel, to, via)
            {
            }
 
            public IAsyncResult BeginRequest(Message message, AsyncCallback callback, object state)
            {
                return BeginRequest(message, DefaultSendTimeout, callback, state);
            }
 
            public IAsyncResult BeginRequest(Message message, TimeSpan timeout, AsyncCallback callback, object state)
            {
                return RequestAsyncInternal(message, timeout).ToApm(callback, state);
            }
 
            public Message EndRequest(IAsyncResult result)
            {
                return result.ToApmEnd<Message>();
            }
 
            public Message Request(Message message)
            {
                return Request(message, DefaultSendTimeout);
            }
 
            internal Message ProcessReply(Message reply, SecurityProtocolCorrelationState correlationState, TimeSpan timeout)
            {
                if (reply != null)
                {
                    Message unverifiedMessage = reply;
                    Exception faultException = null;
                    try
                    {
                        SecurityProtocol.VerifyIncomingMessage(ref reply, timeout, correlationState);
                    }
                    catch (MessageSecurityException)
                    {
                        TryGetSecurityFaultException(unverifiedMessage, out faultException);
                        if (faultException == null)
                        {
                            throw;
                        }
                    }
                    if (faultException != null)
                    {
                        Fault(faultException);
                        throw DiagnosticUtility.ExceptionUtility.ThrowHelperWarning(faultException);
                    }
                }
                return reply;
            }
 
            public Task<Message> RequestAsync(Message message)
            {
                return RequestAsync(message, DefaultSendTimeout);
            }
 
            public async Task<Message> RequestAsync(Message message, TimeSpan timeout)
            {
                ThrowIfFaulted();
                ThrowIfDisposedOrNotOpen(message);
                TimeoutHelper timeoutHelper = new TimeoutHelper(timeout);
                SecurityProtocolCorrelationState correlationState;
                (correlationState, message) = await SecurityProtocol.SecureOutgoingMessageAsync(message, timeoutHelper.RemainingTime(), null);
                Message reply = await Task.Factory.FromAsync(InnerChannel.BeginRequest, InnerChannel.EndRequest, message, timeoutHelper.RemainingTime(), null);
 
                return ProcessReply(reply, correlationState, timeoutHelper.RemainingTime());
            }
 
            private async Task<Message> RequestAsyncInternal(Message message, TimeSpan timeout)
            {
                await TaskHelpers.EnsureDefaultTaskScheduler();
                return await RequestAsync(message, timeout);
            }
 
            public Message Request(Message message, TimeSpan timeout)
            {
                return RequestAsyncInternal(message, timeout).GetAwaiter().GetResult();
            }
        }
 
        private sealed class SecurityRequestSessionChannel : SecurityRequestChannel, IAsyncRequestSessionChannel
        {
            public SecurityRequestSessionChannel(ChannelManagerBase factory, SecurityProtocolFactory securityProtocolFactory, IRequestSessionChannel innerChannel, EndpointAddress to, Uri via)
                : base(factory, securityProtocolFactory, innerChannel, to, via)
            {
            }
 
            public IOutputSession Session
            {
                get
                {
                    return ((IRequestSessionChannel)InnerChannel).Session;
                }
            }
        }
 
        private class SecurityDuplexChannel : SecurityOutputChannel, IDuplexChannel, IAsyncDuplexChannel
        {
            public SecurityDuplexChannel(ChannelManagerBase factory, SecurityProtocolFactory securityProtocolFactory, IDuplexChannel innerChannel, EndpointAddress to, Uri via)
                : base(factory, securityProtocolFactory, innerChannel, to, via)
            {
            }
 
            internal IDuplexChannel InnerDuplexChannel
            {
                get { return (IDuplexChannel)InnerChannel; }
            }
 
            public EndpointAddress LocalAddress
            {
                get
                {
                    return InnerDuplexChannel.LocalAddress;
                }
            }
 
            internal virtual bool AcceptUnsecuredFaults
            {
                get { return false; }
            }
 
            public Task<Message> ReceiveAsync()
            {
                return ReceiveAsync(DefaultReceiveTimeout);
            }
 
            public Task<Message> ReceiveAsync(TimeSpan timeout)
            {
                return InputChannel.HelpReceiveAsync(this, timeout);
            }
 
            public Message Receive()
            {
                return Receive(DefaultReceiveTimeout);
            }
 
            public Message Receive(TimeSpan timeout)
            {
                return ReceiveAsync(timeout).GetAwaiter().GetResult();
            }
 
            public IAsyncResult BeginReceive(AsyncCallback callback, object state)
            {
                return BeginReceive(DefaultReceiveTimeout, callback, state);
            }
 
            public IAsyncResult BeginReceive(TimeSpan timeout, AsyncCallback callback, object state)
            {
                return ReceiveAsync(timeout).ToApm(callback, state);
            }
 
            public Message EndReceive(IAsyncResult result)
            {
                return result.ToApmEnd<Message>();
            }
 
            public virtual IAsyncResult BeginTryReceive(TimeSpan timeout, AsyncCallback callback, object state)
            {
                return TryReceiveAsync(timeout).ToApm(callback, state);
            }
 
            public virtual bool EndTryReceive(IAsyncResult result, out Message message)
            {
                bool success;
                (success, message) = result.ToApmEnd<(bool, Message)>();
                return success;
            }
 
            internal Message ProcessMessage(Message message, TimeSpan timeout)
            {
                if (message == null)
                {
                    return null;
                }
                Message unverifiedMessage = message;
                Exception faultException = null;
                try
                {
                    SecurityProtocol.VerifyIncomingMessage(ref message, timeout);
                }
                catch (MessageSecurityException)
                {
                    TryGetSecurityFaultException(unverifiedMessage, out faultException);
                    if (faultException == null)
                    {
                        throw;
                    }
                }
                if (faultException != null)
                {
                    if (AcceptUnsecuredFaults)
                    {
                        Fault(faultException);
                    }
                    throw DiagnosticUtility.ExceptionUtility.ThrowHelperWarning(faultException);
                }
                return message;
            }
 
            public async Task<(bool, Message)> TryReceiveAsync(TimeSpan timeout)
            {
                if (DoneReceivingInCurrentState())
                {
                    return (true, null);
                }
 
                TimeoutHelper timeoutHelper = new TimeoutHelper(timeout);
                bool success;
                Message message;
                if (InnerDuplexChannel is IAsyncDuplexChannel asyncDuplexChannel)
                {
                    (success, message) = await asyncDuplexChannel.TryReceiveAsync(timeoutHelper.RemainingTime());
                }
                else
                {
                    (success, message) = await TaskHelpers.FromAsync<TimeSpan, bool, Message>(InnerDuplexChannel.BeginTryReceive, InnerDuplexChannel.EndTryReceive, timeout, null);
                }
                if (success)
                {
                    message = ProcessMessage(message, timeoutHelper.RemainingTime());
                }
 
                return (success, message);
            }
 
 
            public bool TryReceive(TimeSpan timeout, out Message message)
            {
                bool success;
                (success, message) = TryReceiveAsync(timeout).GetAwaiter().GetResult();
                return success;
            }
 
            public Task<bool> WaitForMessageAsync(TimeSpan timeout)
            {
                if (InnerDuplexChannel is IAsyncDuplexChannel asyncDuplexChannel)
                {
                    return asyncDuplexChannel.WaitForMessageAsync(timeout);
                }
                else
                {
                    return Task.Factory.FromAsync(InnerDuplexChannel.BeginWaitForMessage, InnerDuplexChannel.EndWaitForMessage, timeout, null);
                }
            }
 
            public bool WaitForMessage(TimeSpan timeout)
            {
                return InnerDuplexChannel.WaitForMessage(timeout);
            }
 
            public IAsyncResult BeginWaitForMessage(TimeSpan timeout, AsyncCallback callback, object state)
            {
                return InnerDuplexChannel.BeginWaitForMessage(timeout, callback, state);
            }
 
            public bool EndWaitForMessage(IAsyncResult result)
            {
                return InnerDuplexChannel.EndWaitForMessage(result);
            }
        }
 
        private sealed class SecurityDuplexSessionChannel : SecurityDuplexChannel, IDuplexSessionChannel, IAsyncDuplexSessionChannel
        {
            public SecurityDuplexSessionChannel(ChannelManagerBase factory, SecurityProtocolFactory securityProtocolFactory, IDuplexSessionChannel innerChannel, EndpointAddress to, Uri via)
                : base(factory, securityProtocolFactory, innerChannel, to, via)
            {
            }
 
            IDuplexSession ISessionChannel<IDuplexSession>.Session
            {
                get
                {
                    return ((ISessionChannel<IDuplexSession>)InnerChannel).Session;
                }
            }
 
            IAsyncDuplexSession ISessionChannel<IAsyncDuplexSession>.Session
            {
                get
                {
                    return ((ISessionChannel<IAsyncDuplexSession>)InnerChannel).Session;
                }
            }
 
            internal override bool AcceptUnsecuredFaults
            {
                get { return true; }
            }
        }
    }
}