File: ContentSafetyServicePayloadUtilities.cs
Web Access
Project: src\src\Libraries\Microsoft.Extensions.AI.Evaluation.Safety\Microsoft.Extensions.AI.Evaluation.Safety.csproj (Microsoft.Extensions.AI.Evaluation.Safety)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json.Nodes;
using System.Threading;
using System.Xml.Linq;
 
namespace Microsoft.Extensions.AI.Evaluation.Safety;
 
internal static class ContentSafetyServicePayloadUtilities
{
    internal static (string payload, IReadOnlyList<EvaluationDiagnostic>? diagnostics) GetPayload(
        ContentSafetyServicePayloadFormat payloadFormat,
        IEnumerable<ChatMessage> conversation,
        string annotationTask,
        string evaluatorName,
        IEnumerable<string?>? perTurnContext = null,
        IEnumerable<string>? metricNames = null,
        CancellationToken cancellationToken = default) =>
            payloadFormat switch
            {
                ContentSafetyServicePayloadFormat.HumanSystem =>
                    GetUserTextListPayloadWithEmbeddedXml(
                        conversation,
                        annotationTask,
                        evaluatorName,
                        perTurnContext,
                        metricNames,
                        cancellationToken: cancellationToken),
 
                ContentSafetyServicePayloadFormat.QuestionAnswer =>
                    GetUserTextListPayloadWithEmbeddedJson(
                        conversation,
                        annotationTask,
                        evaluatorName,
                        perTurnContext,
                        metricNames,
                        cancellationToken: cancellationToken),
 
                ContentSafetyServicePayloadFormat.QueryResponse =>
                    GetUserTextListPayloadWithEmbeddedJson(
                        conversation,
                        annotationTask,
                        evaluatorName,
                        perTurnContext,
                        metricNames,
                        questionPropertyName: "query",
                        answerPropertyName: "response",
                        cancellationToken: cancellationToken),
 
                ContentSafetyServicePayloadFormat.ContextCompletion =>
                    GetUserTextListPayloadWithEmbeddedJson(
                        conversation,
                        annotationTask,
                        evaluatorName,
                        perTurnContext,
                        metricNames,
                        questionPropertyName: "context",
                        answerPropertyName: "completion",
                        cancellationToken: cancellationToken),
 
                ContentSafetyServicePayloadFormat.Conversation =>
                    GetConversationPayload(
                        conversation,
                        annotationTask,
                        evaluatorName,
                        perTurnContext,
                        metricNames,
                        cancellationToken: cancellationToken),
 
                _ => throw new NotSupportedException($"The payload kind '{payloadFormat}' is not supported."),
            };
 
#pragma warning disable S107 // Methods should not have too many parameters
    private static (string payload, IReadOnlyList<EvaluationDiagnostic>? diagnostics)
        GetUserTextListPayloadWithEmbeddedXml(
            IEnumerable<ChatMessage> conversation,
            string annotationTask,
            string evaluatorName,
            IEnumerable<string?>? perTurnContext = null,
            IEnumerable<string>? metricNames = null,
            string questionElementName = "Human",
            string answerElementName = "System",
            string contextElementName = "Context",
            ContentSafetyServicePayloadStrategy strategy = ContentSafetyServicePayloadStrategy.AnnotateConversation,
            CancellationToken cancellationToken = default)
#pragma warning restore S107
    {
        List<Dictionary<string, ChatMessage>> turns;
        List<string?>? normalizedPerTurnContext;
        List<EvaluationDiagnostic>? diagnostics;
 
        (turns, normalizedPerTurnContext, diagnostics, _) =
            PreProcessConversation(
                conversation,
                evaluatorName,
                perTurnContext,
                returnLastTurnOnly: strategy is ContentSafetyServicePayloadStrategy.AnnotateLastTurn,
                cancellationToken: cancellationToken);
 
        IEnumerable<List<XElement>> userTextListItems =
            turns.Select(
                (turn, index) =>
                {
                    cancellationToken.ThrowIfCancellationRequested();
 
                    List<XElement> item = [];
 
                    if (turn.TryGetValue("question", out ChatMessage? question))
                    {
                        item.Add(new XElement(questionElementName, question.Text));
                    }
 
                    if (turn.TryGetValue("answer", out ChatMessage? answer))
                    {
                        item.Add(new XElement(answerElementName, answer.Text));
                    }
 
                    if (normalizedPerTurnContext is not null && normalizedPerTurnContext.Any())
                    {
                        item.Add(new XElement(contextElementName, normalizedPerTurnContext[index]));
                    }
 
                    return item;
                });
 
        IEnumerable<string> userTextListStrings =
            userTextListItems.Select(item => string.Join(string.Empty, item.Select(e => e.ToString())));
 
        if (strategy is ContentSafetyServicePayloadStrategy.AnnotateConversation)
        {
            // Combine all turns into a single string. In this case, the service will produce a single annotation
            // result for the entire conversation.
            userTextListStrings = [string.Join(Environment.NewLine, userTextListStrings)];
        }
        else
        {
            // If ContentSafetyServicePayloadStrategy.AnnotateLastTurn is used, we have already discarded all turns
            // except the last one above. In this case, the service will produce a single annotation result for
            // the last conversation turn only.
            //
            // On the other hand, if ContentSafetyServicePayloadStrategy.AnnotateEachTurn is used, all turns should be
            // retained individually in userTextListStrings above. In this case, the service will produce a separate
            // annotation result for each conversation turn.
        }
 
        var payload =
                new JsonObject
                {
                    ["UserTextList"] = new JsonArray([.. userTextListStrings]),
                    ["AnnotationTask"] = annotationTask,
                };
 
        if (metricNames is not null && metricNames.Any())
        {
            payload["MetricList"] = new JsonArray([.. metricNames]);
        }
 
        return (payload.ToJsonString(), diagnostics);
    }
 
#pragma warning disable S107 // Methods should not have too many parameters
    private static (string payload, IReadOnlyList<EvaluationDiagnostic>? diagnostics)
        GetUserTextListPayloadWithEmbeddedJson(
            IEnumerable<ChatMessage> conversation,
            string annotationTask,
            string evaluatorName,
            IEnumerable<string?>? perTurnContext = null,
            IEnumerable<string>? metricNames = null,
            string questionPropertyName = "question",
            string answerPropertyName = "answer",
            string contextPropertyName = "context",
            ContentSafetyServicePayloadStrategy strategy = ContentSafetyServicePayloadStrategy.AnnotateLastTurn,
            CancellationToken cancellationToken = default)
#pragma warning restore S107
    {
        if (strategy is ContentSafetyServicePayloadStrategy.AnnotateConversation)
        {
            throw new NotSupportedException(
                $"{nameof(GetUserTextListPayloadWithEmbeddedJson)} does not support the {strategy} {nameof(ContentSafetyServicePayloadStrategy)}.");
        }
 
        List<Dictionary<string, ChatMessage>> turns;
        List<string?>? normalizedPerTurnContext;
        List<EvaluationDiagnostic>? diagnostics;
 
        (turns, normalizedPerTurnContext, diagnostics, _) =
            PreProcessConversation(
                conversation,
                evaluatorName,
                perTurnContext,
                returnLastTurnOnly: strategy is ContentSafetyServicePayloadStrategy.AnnotateLastTurn,
                cancellationToken: cancellationToken);
 
        IEnumerable<JsonObject> userTextListItems =
            turns.Select(
                (turn, index) =>
                {
                    cancellationToken.ThrowIfCancellationRequested();
 
                    var item = new JsonObject();
 
                    if (turn.TryGetValue("question", out ChatMessage? question))
                    {
                        item[questionPropertyName] = question.Text;
                    }
 
                    if (turn.TryGetValue("answer", out ChatMessage? answer))
                    {
                        item[answerPropertyName] = answer.Text;
                    }
 
                    if (normalizedPerTurnContext is not null && normalizedPerTurnContext.Any())
                    {
                        item[contextPropertyName] = normalizedPerTurnContext[index];
                    }
 
                    return item;
                });
 
        IEnumerable<string> userTextListStrings = userTextListItems.Select(item => item.ToJsonString());
 
        // If ContentSafetyServicePayloadStrategy.AnnotateLastTurn is used, we have already discarded all turns except
        // the last one above. In this case, the service will produce a single annotation result for the last
        // conversation turn only.
        //
        // On the other hand, if ContentSafetyServicePayloadStrategy.AnnotateEachTurn is used, all turns should be
        // retained individually in userTextListStrings above. In this case, the service will produce a separate
        // annotation result for each conversation turn.
 
        var payload =
            new JsonObject
            {
                ["UserTextList"] = new JsonArray([.. userTextListStrings]),
                ["AnnotationTask"] = annotationTask,
            };
 
        if (metricNames is not null && metricNames.Any())
        {
            payload["MetricList"] = new JsonArray([.. metricNames]);
        }
 
        return (payload.ToJsonString(), diagnostics);
    }
 
    private static (string payload, IReadOnlyList<EvaluationDiagnostic>? diagnostics) GetConversationPayload(
        IEnumerable<ChatMessage> conversation,
        string annotationTask,
        string evaluatorName,
        IEnumerable<string?>? perTurnContext = null,
        IEnumerable<string>? metricNames = null,
        ContentSafetyServicePayloadStrategy strategy = ContentSafetyServicePayloadStrategy.AnnotateConversation,
        CancellationToken cancellationToken = default)
    {
        if (strategy is ContentSafetyServicePayloadStrategy.AnnotateEachTurn)
        {
            throw new NotSupportedException(
                $"{nameof(GetConversationPayload)} does not support the {strategy} {nameof(ContentSafetyServicePayloadStrategy)}.");
        }
 
        List<Dictionary<string, ChatMessage>> turns;
        List<string?>? normalizedPerTurnContext;
        List<EvaluationDiagnostic>? diagnostics;
        string contentType;
 
        (turns, normalizedPerTurnContext, diagnostics, contentType) =
            PreProcessConversation(
                conversation,
                evaluatorName,
                perTurnContext,
                returnLastTurnOnly: strategy is ContentSafetyServicePayloadStrategy.AnnotateLastTurn,
                areImagesSupported: true,
                cancellationToken);
 
        IEnumerable<JsonObject> GetMessages(Dictionary<string, ChatMessage> turn, int turnIndex)
        {
            cancellationToken.ThrowIfCancellationRequested();
 
            if (turn.TryGetValue("question", out ChatMessage? question))
            {
                IEnumerable<JsonObject> contents = GetContents(question);
 
                yield return new JsonObject
                {
                    ["role"] = "user",
                    ["content"] = new JsonArray([.. contents])
                };
            }
 
            if (turn.TryGetValue("answer", out ChatMessage? answer))
            {
                IEnumerable<JsonObject> contents = GetContents(answer);
 
                if (normalizedPerTurnContext is not null &&
                    normalizedPerTurnContext.Any() &&
                    normalizedPerTurnContext[turnIndex] is string context)
                {
                    yield return new JsonObject
                    {
                        ["role"] = "assistant",
                        ["content"] = new JsonArray([.. contents]),
                        ["context"] = context
                    };
                }
                else
                {
                    yield return new JsonObject
                    {
                        ["role"] = "assistant",
                        ["content"] = new JsonArray([.. contents]),
                    };
                }
            }
 
            IEnumerable<JsonObject> GetContents(ChatMessage message)
            {
                foreach (AIContent content in message.Contents)
                {
                    cancellationToken.ThrowIfCancellationRequested();
 
                    if (content is TextContent textContent)
                    {
                        yield return new JsonObject
                        {
                            ["type"] = "text",
                            ["text"] = textContent.Text
                        };
                    }
                    else if (content is UriContent uriContent && uriContent.HasTopLevelMediaType("image"))
                    {
                        yield return new JsonObject
                        {
                            ["type"] = "image_url",
                            ["image_url"] =
                                new JsonObject
                                {
                                    ["url"] = uriContent.Uri.AbsoluteUri
                                }
                        };
                    }
                    else if (content is DataContent dataContent && dataContent.HasTopLevelMediaType("image"))
                    {
                        string url;
                        if (dataContent.IsUriBase64Encoded())
                        {
                            url = dataContent.Uri;
                        }
                        else
                        {
                            BinaryData imageBytes = BinaryData.FromBytes(dataContent.Data);
                            string base64ImageData = Convert.ToBase64String(imageBytes.ToArray());
                            url = $"data:{dataContent.MediaType};base64,{base64ImageData}";
                        }
 
                        yield return new JsonObject
                        {
                            ["type"] = "image_url",
                            ["image_url"] =
                                new JsonObject
                                {
                                    ["url"] = url
                                }
                        };
                    }
                }
            }
        }
 
        var payload =
            new JsonObject
            {
                ["ContentType"] = contentType,
                ["Contents"] =
                    new JsonArray(
                        new JsonObject
                        {
                            ["messages"] = new JsonArray([.. turns.SelectMany(GetMessages)]),
                        }),
                ["AnnotationTask"] = annotationTask,
            };
 
        if (metricNames is not null && metricNames.Any())
        {
            payload["MetricList"] = new JsonArray([.. metricNames]);
        }
 
        // If ContentSafetyServicePayloadStrategy.AnnotateLastTurn is used, we have already discarded all turns except
        // the last one above. In this case, the service will produce a single annotation result for the last
        // conversation turn only.
        //
        // On the other hand, if ContentSafetyServicePayloadStrategy.AnnotateConversation is used, the service will
        // produce a single annotation result for the entire conversation.
        return (payload.ToJsonString(), diagnostics);
    }
 
    private static
        (List<Dictionary<string, ChatMessage>> turns,
        List<string?>? normalizedPerTurnContext,
        List<EvaluationDiagnostic>? diagnostics,
        string contentType) PreProcessConversation(
            IEnumerable<ChatMessage> conversation,
            string evaluatorName,
            IEnumerable<string?>? perTurnContext = null,
            bool returnLastTurnOnly = false,
            bool areImagesSupported = false,
            CancellationToken cancellationToken = default)
    {
        List<Dictionary<string, ChatMessage>> turns = [];
        Dictionary<string, ChatMessage> currentTurn = [];
        List<string?>? normalizedPerTurnContext =
            perTurnContext is null || !perTurnContext.Any() ? null : [.. perTurnContext];
 
        int currentTurnIndex = 0;
        int ignoredMessageCount = 0;
        int incompleteTurnCount = 0;
 
        void StartNewTurn()
        {
            if (!currentTurn.ContainsKey("question") || !currentTurn.ContainsKey("answer"))
            {
                ++incompleteTurnCount;
            }
 
            turns.Add(currentTurn);
            currentTurn = [];
            ++currentTurnIndex;
        }
 
        foreach (ChatMessage message in conversation)
        {
            cancellationToken.ThrowIfCancellationRequested();
 
            if (message.Role == ChatRole.User)
            {
                if (currentTurn.ContainsKey("question"))
                {
                    StartNewTurn();
                }
 
                currentTurn["question"] = message;
            }
            else if (message.Role == ChatRole.Assistant)
            {
                currentTurn["answer"] = message;
 
                StartNewTurn();
            }
            else
            {
                // System prompts are currently not supported.
                ignoredMessageCount++;
            }
        }
 
        if (returnLastTurnOnly)
        {
            turns.RemoveRange(index: 0, count: turns.Count - 1);
        }
 
        int imagesCount = 0;
        int unsupportedContentCount = 0;
 
        void ValidateContents(ChatMessage message)
        {
            foreach (AIContent content in message.Contents)
            {
                cancellationToken.ThrowIfCancellationRequested();
 
                if (areImagesSupported)
                {
                    if (content.IsImageWithSupportedFormat())
                    {
                        ++imagesCount;
                    }
                    else if (!content.IsTextOrUsage())
                    {
                        ++unsupportedContentCount;
                    }
                }
                else if (!content.IsTextOrUsage())
                {
                    ++unsupportedContentCount;
                }
            }
        }
 
        foreach (var turn in turns)
        {
            cancellationToken.ThrowIfCancellationRequested();
 
            foreach (var message in turn.Values)
            {
                cancellationToken.ThrowIfCancellationRequested();
 
                ValidateContents(message);
            }
        }
 
        List<EvaluationDiagnostic>? diagnostics = null;
 
        if (ignoredMessageCount > 0)
        {
            diagnostics = [
                EvaluationDiagnostic.Warning(
                    $"The supplied conversation contained {ignoredMessageCount} messages with unsupported roles. " +
                    $"{evaluatorName} only considers messages with role '{ChatRole.User}' and '{ChatRole.Assistant}'. " +
                    $"The unsupported messages (which may include messages with role '{ChatRole.System}' and '{ChatRole.Tool}') were ignored.")];
        }
 
        if (incompleteTurnCount > 0)
        {
            diagnostics ??= [];
            diagnostics.Add(
                EvaluationDiagnostic.Warning(
                    $"The supplied conversation contained {incompleteTurnCount} incomplete turns. " +
                    $"These turns were either missing a message with role '{ChatRole.User}' or '{ChatRole.Assistant}'. " +
                    $"This may indicate that the supplied conversation was not well-formed and may result in inaccurate evaluation results."));
        }
 
        if (unsupportedContentCount > 0)
        {
            diagnostics ??= [];
            if (areImagesSupported)
            {
                diagnostics.Add(
                    EvaluationDiagnostic.Warning(
                        $"The supplied conversation contained {unsupportedContentCount} instances of unsupported content within messages. " +
                        $"The current evaluation being performed by {evaluatorName} only supports content of type '{nameof(TextContent)}', '{nameof(UriContent)}' and '{nameof(DataContent)}'. " +
                        $"For '{nameof(UriContent)}' and '{nameof(DataContent)}', only content with media type 'image/png', 'image/jpeg' and 'image/gif' are supported. " +
                        $"The unsupported contents were ignored for this evaluation."));
            }
            else
            {
                diagnostics.Add(
                    EvaluationDiagnostic.Warning(
                        $"The supplied conversation contained {unsupportedContentCount} instances of unsupported content within messages. " +
                        $"The current evaluation being performed by {evaluatorName} only supports content of type '{nameof(TextContent)}'. " +
                        $"The unsupported contents were ignored for this evaluation."));
            }
        }
 
        if (normalizedPerTurnContext is not null && normalizedPerTurnContext.Any())
        {
            if (normalizedPerTurnContext.Count > turns.Count)
            {
                var ignoredContextCount = normalizedPerTurnContext.Count - turns.Count;
 
                diagnostics ??= [];
                diagnostics.Add(
                    EvaluationDiagnostic.Warning(
                        $"The supplied conversation contained {turns.Count} turns. " +
                        $"However, context for {normalizedPerTurnContext.Count} turns were supplied as part of the context collection. " +
                        $"The initial {ignoredContextCount} items from the context collection were ignored. " +
                        $"Only the last {turns.Count} items from the context collection were used."));
 
                normalizedPerTurnContext.RemoveRange(0, ignoredContextCount);
            }
            else if (normalizedPerTurnContext.Count < turns.Count)
            {
                int missingContextCount = turns.Count - normalizedPerTurnContext.Count;
 
                diagnostics ??= [];
                diagnostics.Add(
                    EvaluationDiagnostic.Warning(
                        $"The supplied conversation contained {turns.Count} turns. " +
                        $"However, context for only {normalizedPerTurnContext.Count} turns were supplied as part of the context collection. " +
                        $"The initial {missingContextCount} turns in the conversations were evaluated without any context. " +
                        $"The supplied items in the context collection were applied to the last {normalizedPerTurnContext.Count} turns."));
 
                normalizedPerTurnContext.InsertRange(0, Enumerable.Repeat<string?>(null, missingContextCount));
            }
        }
 
        string contentType = areImagesSupported && imagesCount > 0 ? "image" : "text";
 
        return (turns, normalizedPerTurnContext, diagnostics, contentType);
    }
}