File: FrameworkFork\System.ServiceModel\System\ServiceModel\Security\ChannelProtectionRequirements.cs
Web Access
Project: src\src\dotnet-svcutil\lib\src\dotnet-svcutil-lib.csproj (dotnet-svcutil-lib)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Net.Security;
using System.Runtime;
using System.ServiceModel.Channels;
using System.ServiceModel.Description;
using Microsoft.Xml;
 
namespace System.ServiceModel.Security
{
    public class ChannelProtectionRequirements
    {
        private ScopedMessagePartSpecification _incomingSignatureParts;
        private ScopedMessagePartSpecification _incomingEncryptionParts;
        private ScopedMessagePartSpecification _outgoingSignatureParts;
        private ScopedMessagePartSpecification _outgoingEncryptionParts;
        private bool _isReadOnly;
 
        public ChannelProtectionRequirements()
        {
            _incomingSignatureParts = new ScopedMessagePartSpecification();
            _incomingEncryptionParts = new ScopedMessagePartSpecification();
            _outgoingSignatureParts = new ScopedMessagePartSpecification();
            _outgoingEncryptionParts = new ScopedMessagePartSpecification();
        }
 
        public bool IsReadOnly
        {
            get
            {
                return _isReadOnly;
            }
        }
 
        public ChannelProtectionRequirements(ChannelProtectionRequirements other)
        {
            if (other == null)
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new ArgumentNullException("other"));
 
            _incomingSignatureParts = new ScopedMessagePartSpecification(other._incomingSignatureParts);
            _incomingEncryptionParts = new ScopedMessagePartSpecification(other._incomingEncryptionParts);
            _outgoingSignatureParts = new ScopedMessagePartSpecification(other._outgoingSignatureParts);
            _outgoingEncryptionParts = new ScopedMessagePartSpecification(other._outgoingEncryptionParts);
        }
 
        internal ChannelProtectionRequirements(ChannelProtectionRequirements other, ProtectionLevel newBodyProtectionLevel)
        {
            if (other == null)
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new ArgumentNullException("other"));
 
            _incomingSignatureParts = new ScopedMessagePartSpecification(other._incomingSignatureParts, newBodyProtectionLevel != ProtectionLevel.None);
            _incomingEncryptionParts = new ScopedMessagePartSpecification(other._incomingEncryptionParts, newBodyProtectionLevel == ProtectionLevel.EncryptAndSign);
            _outgoingSignatureParts = new ScopedMessagePartSpecification(other._outgoingSignatureParts, newBodyProtectionLevel != ProtectionLevel.None);
            _outgoingEncryptionParts = new ScopedMessagePartSpecification(other._outgoingEncryptionParts, newBodyProtectionLevel == ProtectionLevel.EncryptAndSign);
        }
 
        public ScopedMessagePartSpecification IncomingSignatureParts
        {
            get
            {
                return _incomingSignatureParts;
            }
        }
 
        public ScopedMessagePartSpecification IncomingEncryptionParts
        {
            get
            {
                return _incomingEncryptionParts;
            }
        }
 
        public ScopedMessagePartSpecification OutgoingSignatureParts
        {
            get
            {
                return _outgoingSignatureParts;
            }
        }
 
        public ScopedMessagePartSpecification OutgoingEncryptionParts
        {
            get
            {
                return _outgoingEncryptionParts;
            }
        }
 
        public void Add(ChannelProtectionRequirements protectionRequirements)
        {
            this.Add(protectionRequirements, false);
        }
 
        public void Add(ChannelProtectionRequirements protectionRequirements, bool channelScopeOnly)
        {
            if (protectionRequirements == null)
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new ArgumentNullException("protectionRequirements"));
 
            if (protectionRequirements._incomingSignatureParts != null)
                _incomingSignatureParts.AddParts(protectionRequirements._incomingSignatureParts.ChannelParts);
            if (protectionRequirements._incomingEncryptionParts != null)
                _incomingEncryptionParts.AddParts(protectionRequirements._incomingEncryptionParts.ChannelParts);
            if (protectionRequirements._outgoingSignatureParts != null)
                _outgoingSignatureParts.AddParts(protectionRequirements._outgoingSignatureParts.ChannelParts);
            if (protectionRequirements._outgoingEncryptionParts != null)
                _outgoingEncryptionParts.AddParts(protectionRequirements._outgoingEncryptionParts.ChannelParts);
 
            if (!channelScopeOnly)
            {
                AddActionParts(_incomingSignatureParts, protectionRequirements._incomingSignatureParts);
                AddActionParts(_incomingEncryptionParts, protectionRequirements._incomingEncryptionParts);
                AddActionParts(_outgoingSignatureParts, protectionRequirements._outgoingSignatureParts);
                AddActionParts(_outgoingEncryptionParts, protectionRequirements._outgoingEncryptionParts);
            }
        }
 
        private static void AddActionParts(ScopedMessagePartSpecification to, ScopedMessagePartSpecification from)
        {
            foreach (string action in from.Actions)
            {
                MessagePartSpecification p;
                if (from.TryGetParts(action, true, out p))
                    to.AddParts(p, action);
            }
        }
 
        public void MakeReadOnly()
        {
            if (!_isReadOnly)
            {
                _incomingSignatureParts.MakeReadOnly();
                _incomingEncryptionParts.MakeReadOnly();
                _outgoingSignatureParts.MakeReadOnly();
                _outgoingEncryptionParts.MakeReadOnly();
                _isReadOnly = true;
            }
        }
 
        public ChannelProtectionRequirements CreateInverse()
        {
            ChannelProtectionRequirements result = new ChannelProtectionRequirements();
 
            result.Add(this, true);
            result._incomingSignatureParts = new ScopedMessagePartSpecification(this.OutgoingSignatureParts);
            result._outgoingSignatureParts = new ScopedMessagePartSpecification(this.IncomingSignatureParts);
            result._incomingEncryptionParts = new ScopedMessagePartSpecification(this.OutgoingEncryptionParts);
            result._outgoingEncryptionParts = new ScopedMessagePartSpecification(this.IncomingEncryptionParts);
 
            return result;
        }
 
        internal static ChannelProtectionRequirements CreateFromContract(ContractDescription contract, ISecurityCapabilities bindingElement, bool isForClient)
        {
            return CreateFromContract(contract, bindingElement.SupportedRequestProtectionLevel, bindingElement.SupportedResponseProtectionLevel, isForClient);
        }
 
        private static MessagePartSpecification UnionMessagePartSpecifications(ScopedMessagePartSpecification actionParts)
        {
            MessagePartSpecification result = new MessagePartSpecification(false);
            foreach (string action in actionParts.Actions)
            {
                MessagePartSpecification parts;
                if (actionParts.TryGetParts(action, out parts))
                {
                    if (parts.IsBodyIncluded)
                    {
                        result.IsBodyIncluded = true;
                    }
                    foreach (XmlQualifiedName headerType in parts.HeaderTypes)
                    {
                        if (!result.IsHeaderIncluded(headerType.Name, headerType.Namespace))
                        {
                            result.HeaderTypes.Add(headerType);
                        }
                    }
                }
            }
            return result;
        }
 
        internal static ChannelProtectionRequirements CreateFromContractAndUnionResponseProtectionRequirements(ContractDescription contract, ISecurityCapabilities bindingElement, bool isForClient)
        {
            ChannelProtectionRequirements contractRequirements = CreateFromContract(contract, bindingElement.SupportedRequestProtectionLevel, bindingElement.SupportedResponseProtectionLevel, isForClient);
            // union all the protection requirements for the response actions
            ChannelProtectionRequirements result = new ChannelProtectionRequirements();
 
            if (isForClient)
            {
                result.IncomingEncryptionParts.AddParts(UnionMessagePartSpecifications(contractRequirements.IncomingEncryptionParts), MessageHeaders.WildcardAction);
                result.IncomingSignatureParts.AddParts(UnionMessagePartSpecifications(contractRequirements.IncomingSignatureParts), MessageHeaders.WildcardAction);
                contractRequirements.OutgoingEncryptionParts.CopyTo(result.OutgoingEncryptionParts);
                contractRequirements.OutgoingSignatureParts.CopyTo(result.OutgoingSignatureParts);
            }
            else
            {
                result.OutgoingEncryptionParts.AddParts(UnionMessagePartSpecifications(contractRequirements.OutgoingEncryptionParts), MessageHeaders.WildcardAction);
                result.OutgoingSignatureParts.AddParts(UnionMessagePartSpecifications(contractRequirements.OutgoingSignatureParts), MessageHeaders.WildcardAction);
                contractRequirements.IncomingEncryptionParts.CopyTo(result.IncomingEncryptionParts);
                contractRequirements.IncomingSignatureParts.CopyTo(result.IncomingSignatureParts);
            }
            return result;
        }
 
        internal static ChannelProtectionRequirements CreateFromContract(ContractDescription contract, ProtectionLevel defaultRequestProtectionLevel, ProtectionLevel defaultResponseProtectionLevel, bool isForClient)
        {
            if (contract == null)
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new ArgumentNullException("contract"));
 
            ChannelProtectionRequirements requirements = new ChannelProtectionRequirements();
 
            ProtectionLevel contractScopeDefaultRequestProtectionLevel;
            ProtectionLevel contractScopeDefaultResponseProtectionLevel;
            if (contract.HasProtectionLevel)
            {
                contractScopeDefaultRequestProtectionLevel = contract.ProtectionLevel;
                contractScopeDefaultResponseProtectionLevel = contract.ProtectionLevel;
            }
            else
            {
                contractScopeDefaultRequestProtectionLevel = defaultRequestProtectionLevel;
                contractScopeDefaultResponseProtectionLevel = defaultResponseProtectionLevel;
            }
 
            foreach (OperationDescription operation in contract.Operations)
            {
                ProtectionLevel operationScopeDefaultRequestProtectionLevel;
                ProtectionLevel operationScopeDefaultResponseProtectionLevel;
 
                operationScopeDefaultRequestProtectionLevel = contractScopeDefaultRequestProtectionLevel;
                operationScopeDefaultResponseProtectionLevel = contractScopeDefaultResponseProtectionLevel;
 
                foreach (MessageDescription message in operation.Messages)
                {
                    ProtectionLevel messageScopeDefaultProtectionLevel;
                    if (message.HasProtectionLevel)
                    {
                        messageScopeDefaultProtectionLevel = message.ProtectionLevel;
                    }
                    else if (message.Direction == MessageDirection.Input)
                    {
                        messageScopeDefaultProtectionLevel = operationScopeDefaultRequestProtectionLevel;
                    }
                    else
                    {
                        messageScopeDefaultProtectionLevel = operationScopeDefaultResponseProtectionLevel;
                    }
 
                    MessagePartSpecification signedParts = new MessagePartSpecification();
                    MessagePartSpecification encryptedParts = new MessagePartSpecification();
 
                    // determine header protection requirements for message
                    foreach (MessageHeaderDescription header in message.Headers)
                    {
                        AddHeaderProtectionRequirements(header, signedParts, encryptedParts, messageScopeDefaultProtectionLevel);
                    }
 
                    // determine body protection requirements for message
                    ProtectionLevel bodyProtectionLevel;
                    if (message.Body.Parts.Count > 0)
                    {
                        // initialize the body protection level to none. all the body parts will be
                        // unioned to get the effective body protection level
                        bodyProtectionLevel = ProtectionLevel.None;
                    }
                    else if (message.Body.ReturnValue != null)
                    {
                        if (!(message.Body.ReturnValue.GetType().Equals(typeof(MessagePartDescription))))
                        {
                            Fx.Assert("Only body return values are supported currently");
                            throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new InvalidOperationException(SRServiceModel.OnlyBodyReturnValuesSupported));
                        }
                        MessagePartDescription desc = message.Body.ReturnValue;
                        bodyProtectionLevel = desc.HasProtectionLevel ? desc.ProtectionLevel : messageScopeDefaultProtectionLevel;
                    }
                    else
                    {
                        bodyProtectionLevel = messageScopeDefaultProtectionLevel;
                    }
 
                    // determine body protection requirements for message
                    if (message.Body.Parts.Count > 0)
                    {
                        foreach (MessagePartDescription body in message.Body.Parts)
                        {
                            ProtectionLevel partProtectionLevel = body.HasProtectionLevel ? body.ProtectionLevel : messageScopeDefaultProtectionLevel;
                            bodyProtectionLevel = ProtectionLevelHelper.Max(bodyProtectionLevel, partProtectionLevel);
                            if (bodyProtectionLevel == ProtectionLevel.EncryptAndSign)
                                break;
                        }
                    }
                    if (bodyProtectionLevel != ProtectionLevel.None)
                    {
                        signedParts.IsBodyIncluded = true;
                        if (bodyProtectionLevel == ProtectionLevel.EncryptAndSign)
                            encryptedParts.IsBodyIncluded = true;
                    }
 
                    // add requirements for message 
                    if (message.Direction == MessageDirection.Input)
                    {
                        requirements.IncomingSignatureParts.AddParts(signedParts, message.Action);
                        requirements.IncomingEncryptionParts.AddParts(encryptedParts, message.Action);
                    }
                    else
                    {
                        requirements.OutgoingSignatureParts.AddParts(signedParts, message.Action);
                        requirements.OutgoingEncryptionParts.AddParts(encryptedParts, message.Action);
                    }
                }
                if (operation.Faults != null)
                {
                    if (operation.IsServerInitiated())
                    {
                        AddFaultProtectionRequirements(operation.Faults, requirements, operationScopeDefaultRequestProtectionLevel, true);
                    }
                    else
                    {
                        AddFaultProtectionRequirements(operation.Faults, requirements, operationScopeDefaultResponseProtectionLevel, false);
                    }
                }
            }
 
            return requirements;
        }
 
        private static void AddHeaderProtectionRequirements(MessageHeaderDescription header, MessagePartSpecification signedParts,
            MessagePartSpecification encryptedParts, ProtectionLevel defaultProtectionLevel)
        {
            ProtectionLevel p = header.HasProtectionLevel ? header.ProtectionLevel : defaultProtectionLevel;
            if (p != ProtectionLevel.None)
            {
                XmlQualifiedName headerName = new XmlQualifiedName(header.Name, header.Namespace);
                signedParts.HeaderTypes.Add(headerName);
                if (p == ProtectionLevel.EncryptAndSign)
                    encryptedParts.HeaderTypes.Add(headerName);
            }
        }
 
        private static void AddFaultProtectionRequirements(FaultDescriptionCollection faults, ChannelProtectionRequirements requirements, ProtectionLevel defaultProtectionLevel, bool addToIncoming)
        {
            if (faults == null)
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new ArgumentNullException("faults"));
            if (requirements == null)
                throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new ArgumentNullException("requirements"));
 
            foreach (FaultDescription fault in faults)
            {
                MessagePartSpecification signedParts = new MessagePartSpecification();
                MessagePartSpecification encryptedParts = new MessagePartSpecification();
                ProtectionLevel p = fault.HasProtectionLevel ? fault.ProtectionLevel : defaultProtectionLevel;
                if (p != ProtectionLevel.None)
                {
                    signedParts.IsBodyIncluded = true;
                    if (p == ProtectionLevel.EncryptAndSign)
                    {
                        encryptedParts.IsBodyIncluded = true;
                    }
                }
                if (addToIncoming)
                {
                    requirements.IncomingSignatureParts.AddParts(signedParts, fault.Action);
                    requirements.IncomingEncryptionParts.AddParts(encryptedParts, fault.Action);
                }
                else
                {
                    requirements.OutgoingSignatureParts.AddParts(signedParts, fault.Action);
                    requirements.OutgoingEncryptionParts.AddParts(encryptedParts, fault.Action);
                }
            }
        }
    }
}