File: ContentSafetyEvaluator.cs
Web Access
Project: src\src\Libraries\Microsoft.Extensions.AI.Evaluation.Safety\Microsoft.Extensions.AI.Evaluation.Safety.csproj (Microsoft.Extensions.AI.Evaluation.Safety)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
#pragma warning disable S3604
// S3604: Member initializer values should not be redundant.
// We disable this warning because it is a false positive arising from the analyzer's lack of support for C#'s primary
// constructor syntax.
 
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Globalization;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Shared.Diagnostics;
 
namespace Microsoft.Extensions.AI.Evaluation.Safety;
 
/// <summary>
/// An <see langword="abstract"/> base class that can be used to implement <see cref="IEvaluator"/>s that utilize the
/// Azure AI Content Safety service to evaluate responses produced by an AI model for the presence of a variety of
/// unsafe content such as protected material, vulnerable code, harmful content etc.
/// </summary>
/// <param name="contentSafetyServiceAnnotationTask">
/// The name of the annotation task that should be used when communicating with the Azure AI Content Safety service to
/// perform evaluations.
/// </param>
/// <param name="metricNames">
/// A dictionary containing the mapping from the names of the metrics that are used when communicating with the Azure
/// AI Content Safety to the <see cref="EvaluationMetric.Name"/>s of the <see cref="EvaluationMetric"/>s returned by
/// this <see cref="IEvaluator"/>.
/// </param>
#pragma warning disable S1694 // An abstract class should have both abstract and concrete methods
public abstract class ContentSafetyEvaluator(
    string contentSafetyServiceAnnotationTask,
    IDictionary<string, string> metricNames) : IEvaluator
#pragma warning restore S1694
{
    /// <inheritdoc/>
    public IReadOnlyCollection<string> EvaluationMetricNames { get; } = [.. metricNames.Values];
 
    /// <inheritdoc/>
    public virtual ValueTask<EvaluationResult> EvaluateAsync(
        IEnumerable<ChatMessage> messages,
        ChatResponse modelResponse,
        ChatConfiguration? chatConfiguration = null,
        IEnumerable<EvaluationContext>? additionalContext = null,
        CancellationToken cancellationToken = default)
    {
        _ = Throw.IfNull(chatConfiguration);
 
        return EvaluateContentSafetyAsync(
            chatConfiguration.ChatClient,
            messages,
            modelResponse,
            additionalContext,
            cancellationToken: cancellationToken);
    }
 
    /// <summary>
    /// Evaluates the supplied <paramref name="modelResponse"/> using the Azure AI Content Safety Service and returns
    /// an <see cref="EvaluationResult"/> containing one or more <see cref="EvaluationMetric"/>s.
    /// </summary>
    /// <param name="contentSafetyServiceChatClient">
    /// The <see cref="IChatClient"/> that should be used to communicate with the Azure AI Content Safety Service when
    /// performing evaluations.
    /// </param>
    /// <param name="messages">
    /// The conversation history including the request that produced the supplied <paramref name="modelResponse"/>.
    /// </param>
    /// <param name="modelResponse">The response that is to be evaluated.</param>
    /// <param name="additionalContext">
    /// Additional contextual information (beyond that which is available in <paramref name="messages"/>) that the
    /// <see cref="IEvaluator"/> may need to accurately evaluate the supplied <paramref name="modelResponse"/>.
    /// </param>
    /// <param name="contentSafetyServicePayloadFormat">
    /// An identifier that specifies the format of the payload that should be used when communicating with the Azure AI
    /// Content Safety service to perform evaluations.
    /// </param>
    /// <param name="includeMetricNamesInContentSafetyServicePayload">
    /// A <see cref="bool"/> flag that indicates whether the names of the metrics should be included in the payload
    /// that is sent to the Azure AI Content Safety service when performing evaluations.
    /// </param>
    /// <param name="cancellationToken">
    /// A <see cref="CancellationToken"/> that can cancel the evaluation operation.
    /// </param>
    /// <returns>An <see cref="EvaluationResult"/> containing one or more <see cref="EvaluationMetric"/>s.</returns>
    protected async ValueTask<EvaluationResult> EvaluateContentSafetyAsync(
        IChatClient contentSafetyServiceChatClient,
        IEnumerable<ChatMessage> messages,
        ChatResponse modelResponse,
        IEnumerable<EvaluationContext>? additionalContext = null,
        string contentSafetyServicePayloadFormat = "HumanSystem", // ContentSafetyServicePayloadFormat.HumanSystem.ToString()
        bool includeMetricNamesInContentSafetyServicePayload = true,
        CancellationToken cancellationToken = default)
    {
        _ = Throw.IfNull(contentSafetyServiceChatClient);
        _ = Throw.IfNull(modelResponse);
 
        string payload;
        string annotationResult;
        IReadOnlyList<EvaluationDiagnostic>? diagnostics;
        EvaluationResult result;
        Stopwatch stopwatch = Stopwatch.StartNew();
 
        try
        {
            ContentSafetyServicePayloadFormat payloadFormat =
#if NET
                Enum.Parse<ContentSafetyServicePayloadFormat>(contentSafetyServicePayloadFormat);
#else
                (ContentSafetyServicePayloadFormat)Enum.Parse(
                    typeof(ContentSafetyServicePayloadFormat),
                    contentSafetyServicePayloadFormat);
#endif
 
            IEnumerable<ChatMessage> conversation = [.. messages, .. modelResponse.Messages];
 
            string evaluatorName = GetType().Name;
 
            IEnumerable<string>? perTurnContext = null;
            if (additionalContext is not null && additionalContext.Any())
            {
                IReadOnlyList<EvaluationContext>? relevantContext = FilterAdditionalContext(additionalContext);
 
#pragma warning disable S1067 // Expressions should not be too complex
                if (relevantContext is not null && relevantContext.Any() &&
                    relevantContext.SelectMany(c => c.GetContents()) is IEnumerable<AIContent> content && content.Any() &&
                    content.OfType<TextContent>() is IEnumerable<TextContent> textContent && textContent.Any() &&
                    string.Join(Environment.NewLine, textContent.Select(c => c.Text)) is string contextString &&
                    !string.IsNullOrWhiteSpace(contextString))
#pragma warning restore S1067
                {
                    // Currently we only support supplying a context for the last conversation turn (which is the main one
                    // that is being evaluated).
                    perTurnContext = [contextString];
                }
            }
 
            (payload, diagnostics) =
                ContentSafetyServicePayloadUtilities.GetPayload(
                    payloadFormat,
                    conversation,
                    contentSafetyServiceAnnotationTask,
                    evaluatorName,
                    perTurnContext,
                    metricNames: includeMetricNamesInContentSafetyServicePayload ? metricNames.Keys : null,
                    cancellationToken);
 
            var payloadMessage = new ChatMessage(ChatRole.User, payload);
 
            ChatResponse annotationResponse =
                await contentSafetyServiceChatClient.GetResponseAsync(
                    payloadMessage,
                    options: new ContentSafetyChatOptions(contentSafetyServiceAnnotationTask, evaluatorName),
                    cancellationToken: cancellationToken).ConfigureAwait(false);
 
            annotationResult = annotationResponse.Text;
            result = ContentSafetyService.ParseAnnotationResult(annotationResult);
        }
        finally
        {
            stopwatch.Stop();
        }
 
        string duration = $"{stopwatch.Elapsed.TotalSeconds.ToString("F2", CultureInfo.InvariantCulture)} s";
 
        UpdateMetrics();
 
        return result;
 
        void UpdateMetrics()
        {
            foreach (EvaluationMetric metric in result.Metrics.Values)
            {
                string contentSafetyServiceMetricName = metric.Name;
                if (metricNames.TryGetValue(contentSafetyServiceMetricName, out string? metricName))
                {
                    metric.Name = metricName;
                }
 
                metric.AddOrUpdateMetadata(name: "evaluation-duration", value: duration);
 
                metric.Interpretation =
                    metric switch
                    {
                        BooleanMetric booleanMetric => booleanMetric.InterpretContentSafetyScore(),
                        NumericMetric numericMetric => numericMetric.InterpretContentSafetyScore(),
                        _ => metric.Interpretation
                    };
 
                if (diagnostics is not null)
                {
                    metric.AddDiagnostics(diagnostics);
                }
 
#pragma warning disable S125 // Sections of code should not be commented out
                // The following commented code can be useful for debugging purposes.
                // metric.LogJsonData(payload);
                // metric.LogJsonData(annotationResult);
#pragma warning restore S125
            }
        }
    }
 
    /// <summary>
    /// Filters the <see cref="EvaluationContext"/>s supplied by the caller via <paramref name="additionalContext"/>
    /// down to just the <see cref="EvaluationContext"/>s that are relevant to the evaluation being performed by this
    /// <see cref="ContentSafetyEvaluator"/>.
    /// </summary>
    /// <param name="additionalContext">The <see cref="EvaluationContext"/>s supplied by the caller.</param>
    /// <returns>
    /// The <see cref="EvaluationContext"/>s that are relevant to the evaluation being performed by this
    /// <see cref="ContentSafetyEvaluator"/>.
    /// </returns>
    protected virtual IReadOnlyList<EvaluationContext>? FilterAdditionalContext(
        IEnumerable<EvaluationContext>? additionalContext)
            => null;
}