|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Net.Security;
using System.Runtime;
using System.ServiceModel.Channels;
using System.ServiceModel.Description;
using Microsoft.Xml;
namespace System.ServiceModel.Security
{
public class ChannelProtectionRequirements
{
private ScopedMessagePartSpecification _incomingSignatureParts;
private ScopedMessagePartSpecification _incomingEncryptionParts;
private ScopedMessagePartSpecification _outgoingSignatureParts;
private ScopedMessagePartSpecification _outgoingEncryptionParts;
private bool _isReadOnly;
public ChannelProtectionRequirements()
{
_incomingSignatureParts = new ScopedMessagePartSpecification();
_incomingEncryptionParts = new ScopedMessagePartSpecification();
_outgoingSignatureParts = new ScopedMessagePartSpecification();
_outgoingEncryptionParts = new ScopedMessagePartSpecification();
}
public bool IsReadOnly
{
get
{
return _isReadOnly;
}
}
public ChannelProtectionRequirements(ChannelProtectionRequirements other)
{
if (other == null)
throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new ArgumentNullException("other"));
_incomingSignatureParts = new ScopedMessagePartSpecification(other._incomingSignatureParts);
_incomingEncryptionParts = new ScopedMessagePartSpecification(other._incomingEncryptionParts);
_outgoingSignatureParts = new ScopedMessagePartSpecification(other._outgoingSignatureParts);
_outgoingEncryptionParts = new ScopedMessagePartSpecification(other._outgoingEncryptionParts);
}
internal ChannelProtectionRequirements(ChannelProtectionRequirements other, ProtectionLevel newBodyProtectionLevel)
{
if (other == null)
throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new ArgumentNullException("other"));
_incomingSignatureParts = new ScopedMessagePartSpecification(other._incomingSignatureParts, newBodyProtectionLevel != ProtectionLevel.None);
_incomingEncryptionParts = new ScopedMessagePartSpecification(other._incomingEncryptionParts, newBodyProtectionLevel == ProtectionLevel.EncryptAndSign);
_outgoingSignatureParts = new ScopedMessagePartSpecification(other._outgoingSignatureParts, newBodyProtectionLevel != ProtectionLevel.None);
_outgoingEncryptionParts = new ScopedMessagePartSpecification(other._outgoingEncryptionParts, newBodyProtectionLevel == ProtectionLevel.EncryptAndSign);
}
public ScopedMessagePartSpecification IncomingSignatureParts
{
get
{
return _incomingSignatureParts;
}
}
public ScopedMessagePartSpecification IncomingEncryptionParts
{
get
{
return _incomingEncryptionParts;
}
}
public ScopedMessagePartSpecification OutgoingSignatureParts
{
get
{
return _outgoingSignatureParts;
}
}
public ScopedMessagePartSpecification OutgoingEncryptionParts
{
get
{
return _outgoingEncryptionParts;
}
}
public void Add(ChannelProtectionRequirements protectionRequirements)
{
this.Add(protectionRequirements, false);
}
public void Add(ChannelProtectionRequirements protectionRequirements, bool channelScopeOnly)
{
if (protectionRequirements == null)
throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new ArgumentNullException("protectionRequirements"));
if (protectionRequirements._incomingSignatureParts != null)
_incomingSignatureParts.AddParts(protectionRequirements._incomingSignatureParts.ChannelParts);
if (protectionRequirements._incomingEncryptionParts != null)
_incomingEncryptionParts.AddParts(protectionRequirements._incomingEncryptionParts.ChannelParts);
if (protectionRequirements._outgoingSignatureParts != null)
_outgoingSignatureParts.AddParts(protectionRequirements._outgoingSignatureParts.ChannelParts);
if (protectionRequirements._outgoingEncryptionParts != null)
_outgoingEncryptionParts.AddParts(protectionRequirements._outgoingEncryptionParts.ChannelParts);
if (!channelScopeOnly)
{
AddActionParts(_incomingSignatureParts, protectionRequirements._incomingSignatureParts);
AddActionParts(_incomingEncryptionParts, protectionRequirements._incomingEncryptionParts);
AddActionParts(_outgoingSignatureParts, protectionRequirements._outgoingSignatureParts);
AddActionParts(_outgoingEncryptionParts, protectionRequirements._outgoingEncryptionParts);
}
}
private static void AddActionParts(ScopedMessagePartSpecification to, ScopedMessagePartSpecification from)
{
foreach (string action in from.Actions)
{
MessagePartSpecification p;
if (from.TryGetParts(action, true, out p))
to.AddParts(p, action);
}
}
public void MakeReadOnly()
{
if (!_isReadOnly)
{
_incomingSignatureParts.MakeReadOnly();
_incomingEncryptionParts.MakeReadOnly();
_outgoingSignatureParts.MakeReadOnly();
_outgoingEncryptionParts.MakeReadOnly();
_isReadOnly = true;
}
}
public ChannelProtectionRequirements CreateInverse()
{
ChannelProtectionRequirements result = new ChannelProtectionRequirements();
result.Add(this, true);
result._incomingSignatureParts = new ScopedMessagePartSpecification(this.OutgoingSignatureParts);
result._outgoingSignatureParts = new ScopedMessagePartSpecification(this.IncomingSignatureParts);
result._incomingEncryptionParts = new ScopedMessagePartSpecification(this.OutgoingEncryptionParts);
result._outgoingEncryptionParts = new ScopedMessagePartSpecification(this.IncomingEncryptionParts);
return result;
}
internal static ChannelProtectionRequirements CreateFromContract(ContractDescription contract, ISecurityCapabilities bindingElement, bool isForClient)
{
return CreateFromContract(contract, bindingElement.SupportedRequestProtectionLevel, bindingElement.SupportedResponseProtectionLevel, isForClient);
}
private static MessagePartSpecification UnionMessagePartSpecifications(ScopedMessagePartSpecification actionParts)
{
MessagePartSpecification result = new MessagePartSpecification(false);
foreach (string action in actionParts.Actions)
{
MessagePartSpecification parts;
if (actionParts.TryGetParts(action, out parts))
{
if (parts.IsBodyIncluded)
{
result.IsBodyIncluded = true;
}
foreach (XmlQualifiedName headerType in parts.HeaderTypes)
{
if (!result.IsHeaderIncluded(headerType.Name, headerType.Namespace))
{
result.HeaderTypes.Add(headerType);
}
}
}
}
return result;
}
internal static ChannelProtectionRequirements CreateFromContractAndUnionResponseProtectionRequirements(ContractDescription contract, ISecurityCapabilities bindingElement, bool isForClient)
{
ChannelProtectionRequirements contractRequirements = CreateFromContract(contract, bindingElement.SupportedRequestProtectionLevel, bindingElement.SupportedResponseProtectionLevel, isForClient);
// union all the protection requirements for the response actions
ChannelProtectionRequirements result = new ChannelProtectionRequirements();
if (isForClient)
{
result.IncomingEncryptionParts.AddParts(UnionMessagePartSpecifications(contractRequirements.IncomingEncryptionParts), MessageHeaders.WildcardAction);
result.IncomingSignatureParts.AddParts(UnionMessagePartSpecifications(contractRequirements.IncomingSignatureParts), MessageHeaders.WildcardAction);
contractRequirements.OutgoingEncryptionParts.CopyTo(result.OutgoingEncryptionParts);
contractRequirements.OutgoingSignatureParts.CopyTo(result.OutgoingSignatureParts);
}
else
{
result.OutgoingEncryptionParts.AddParts(UnionMessagePartSpecifications(contractRequirements.OutgoingEncryptionParts), MessageHeaders.WildcardAction);
result.OutgoingSignatureParts.AddParts(UnionMessagePartSpecifications(contractRequirements.OutgoingSignatureParts), MessageHeaders.WildcardAction);
contractRequirements.IncomingEncryptionParts.CopyTo(result.IncomingEncryptionParts);
contractRequirements.IncomingSignatureParts.CopyTo(result.IncomingSignatureParts);
}
return result;
}
internal static ChannelProtectionRequirements CreateFromContract(ContractDescription contract, ProtectionLevel defaultRequestProtectionLevel, ProtectionLevel defaultResponseProtectionLevel, bool isForClient)
{
if (contract == null)
throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new ArgumentNullException("contract"));
ChannelProtectionRequirements requirements = new ChannelProtectionRequirements();
ProtectionLevel contractScopeDefaultRequestProtectionLevel;
ProtectionLevel contractScopeDefaultResponseProtectionLevel;
if (contract.HasProtectionLevel)
{
contractScopeDefaultRequestProtectionLevel = contract.ProtectionLevel;
contractScopeDefaultResponseProtectionLevel = contract.ProtectionLevel;
}
else
{
contractScopeDefaultRequestProtectionLevel = defaultRequestProtectionLevel;
contractScopeDefaultResponseProtectionLevel = defaultResponseProtectionLevel;
}
foreach (OperationDescription operation in contract.Operations)
{
ProtectionLevel operationScopeDefaultRequestProtectionLevel;
ProtectionLevel operationScopeDefaultResponseProtectionLevel;
operationScopeDefaultRequestProtectionLevel = contractScopeDefaultRequestProtectionLevel;
operationScopeDefaultResponseProtectionLevel = contractScopeDefaultResponseProtectionLevel;
foreach (MessageDescription message in operation.Messages)
{
ProtectionLevel messageScopeDefaultProtectionLevel;
if (message.HasProtectionLevel)
{
messageScopeDefaultProtectionLevel = message.ProtectionLevel;
}
else if (message.Direction == MessageDirection.Input)
{
messageScopeDefaultProtectionLevel = operationScopeDefaultRequestProtectionLevel;
}
else
{
messageScopeDefaultProtectionLevel = operationScopeDefaultResponseProtectionLevel;
}
MessagePartSpecification signedParts = new MessagePartSpecification();
MessagePartSpecification encryptedParts = new MessagePartSpecification();
// determine header protection requirements for message
foreach (MessageHeaderDescription header in message.Headers)
{
AddHeaderProtectionRequirements(header, signedParts, encryptedParts, messageScopeDefaultProtectionLevel);
}
// determine body protection requirements for message
ProtectionLevel bodyProtectionLevel;
if (message.Body.Parts.Count > 0)
{
// initialize the body protection level to none. all the body parts will be
// unioned to get the effective body protection level
bodyProtectionLevel = ProtectionLevel.None;
}
else if (message.Body.ReturnValue != null)
{
if (!(message.Body.ReturnValue.GetType().Equals(typeof(MessagePartDescription))))
{
Fx.Assert("Only body return values are supported currently");
throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new InvalidOperationException(SRServiceModel.OnlyBodyReturnValuesSupported));
}
MessagePartDescription desc = message.Body.ReturnValue;
bodyProtectionLevel = desc.HasProtectionLevel ? desc.ProtectionLevel : messageScopeDefaultProtectionLevel;
}
else
{
bodyProtectionLevel = messageScopeDefaultProtectionLevel;
}
// determine body protection requirements for message
if (message.Body.Parts.Count > 0)
{
foreach (MessagePartDescription body in message.Body.Parts)
{
ProtectionLevel partProtectionLevel = body.HasProtectionLevel ? body.ProtectionLevel : messageScopeDefaultProtectionLevel;
bodyProtectionLevel = ProtectionLevelHelper.Max(bodyProtectionLevel, partProtectionLevel);
if (bodyProtectionLevel == ProtectionLevel.EncryptAndSign)
break;
}
}
if (bodyProtectionLevel != ProtectionLevel.None)
{
signedParts.IsBodyIncluded = true;
if (bodyProtectionLevel == ProtectionLevel.EncryptAndSign)
encryptedParts.IsBodyIncluded = true;
}
// add requirements for message
if (message.Direction == MessageDirection.Input)
{
requirements.IncomingSignatureParts.AddParts(signedParts, message.Action);
requirements.IncomingEncryptionParts.AddParts(encryptedParts, message.Action);
}
else
{
requirements.OutgoingSignatureParts.AddParts(signedParts, message.Action);
requirements.OutgoingEncryptionParts.AddParts(encryptedParts, message.Action);
}
}
if (operation.Faults != null)
{
if (operation.IsServerInitiated())
{
AddFaultProtectionRequirements(operation.Faults, requirements, operationScopeDefaultRequestProtectionLevel, true);
}
else
{
AddFaultProtectionRequirements(operation.Faults, requirements, operationScopeDefaultResponseProtectionLevel, false);
}
}
}
return requirements;
}
private static void AddHeaderProtectionRequirements(MessageHeaderDescription header, MessagePartSpecification signedParts,
MessagePartSpecification encryptedParts, ProtectionLevel defaultProtectionLevel)
{
ProtectionLevel p = header.HasProtectionLevel ? header.ProtectionLevel : defaultProtectionLevel;
if (p != ProtectionLevel.None)
{
XmlQualifiedName headerName = new XmlQualifiedName(header.Name, header.Namespace);
signedParts.HeaderTypes.Add(headerName);
if (p == ProtectionLevel.EncryptAndSign)
encryptedParts.HeaderTypes.Add(headerName);
}
}
private static void AddFaultProtectionRequirements(FaultDescriptionCollection faults, ChannelProtectionRequirements requirements, ProtectionLevel defaultProtectionLevel, bool addToIncoming)
{
if (faults == null)
throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new ArgumentNullException("faults"));
if (requirements == null)
throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new ArgumentNullException("requirements"));
foreach (FaultDescription fault in faults)
{
MessagePartSpecification signedParts = new MessagePartSpecification();
MessagePartSpecification encryptedParts = new MessagePartSpecification();
ProtectionLevel p = fault.HasProtectionLevel ? fault.ProtectionLevel : defaultProtectionLevel;
if (p != ProtectionLevel.None)
{
signedParts.IsBodyIncluded = true;
if (p == ProtectionLevel.EncryptAndSign)
{
encryptedParts.IsBodyIncluded = true;
}
}
if (addToIncoming)
{
requirements.IncomingSignatureParts.AddParts(signedParts, fault.Action);
requirements.IncomingEncryptionParts.AddParts(encryptedParts, fault.Action);
}
else
{
requirements.OutgoingSignatureParts.AddParts(signedParts, fault.Action);
requirements.OutgoingEncryptionParts.AddParts(encryptedParts, fault.Action);
}
}
}
}
}
|