File: OllamaChatClient.cs
Web Access
Project: src\src\Libraries\Microsoft.Extensions.AI.Ollama\Microsoft.Extensions.AI.Ollama.csproj (Microsoft.Extensions.AI.Ollama)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Collections.Generic;
using System.Globalization;
using System.IO;
using System.Linq;
using System.Net.Http;
using System.Net.Http.Json;
using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Shared.Diagnostics;
 
#pragma warning disable EA0011 // Consider removing unnecessary conditional access operator (?)
#pragma warning disable SA1204 // Static elements should appear before instance elements
 
namespace Microsoft.Extensions.AI;
 
/// <summary>Represents an <see cref="IChatClient"/> for Ollama.</summary>
public sealed class OllamaChatClient : IChatClient
{
    private static readonly JsonElement _defaultParameterSchema = JsonDocument.Parse("{}"{}").RootElement;
    private static readonly JsonElement _schemalessJsonResponseFormatValue = JsonDocument.Parse("\"json\""\"json\"").RootElement;
 
    /// <summary>The api/chat endpoint URI.</summary>
    private readonly Uri _apiChatEndpoint;
 
    /// <summary>The <see cref="HttpClient"/> to use for sending requests.</summary>
    private readonly HttpClient _httpClient;
 
    /// <summary>The <see cref="JsonSerializerOptions"/> use for any serialization activities related to tool call arguments and results.</summary>
    private JsonSerializerOptions _toolCallJsonSerializerOptions = AIJsonUtilities.DefaultOptions;
 
    /// <summary>Initializes a new instance of the <see cref="OllamaChatClient"/> class.</summary>
    /// <param name="endpoint">The endpoint URI where Ollama is hosted.</param>
    /// <param name="modelId">
    /// The ID of the model to use. This ID can also be overridden per request via <see cref="ChatOptions.ModelId"/>.
    /// Either this parameter or <see cref="ChatOptions.ModelId"/> must provide a valid model ID.
    /// </param>
    /// <param name="httpClient">An <see cref="HttpClient"/> instance to use for HTTP operations.</param>
    public OllamaChatClient(string endpoint, string? modelId = null, HttpClient? httpClient = null)
        : this(new Uri(Throw.IfNull(endpoint)), modelId, httpClient)
    {
    }
 
    /// <summary>Initializes a new instance of the <see cref="OllamaChatClient"/> class.</summary>
    /// <param name="endpoint">The endpoint URI where Ollama is hosted.</param>
    /// <param name="modelId">
    /// The ID of the model to use. This ID can also be overridden per request via <see cref="ChatOptions.ModelId"/>.
    /// Either this parameter or <see cref="ChatOptions.ModelId"/> must provide a valid model ID.
    /// </param>
    /// <param name="httpClient">An <see cref="HttpClient"/> instance to use for HTTP operations.</param>
    public OllamaChatClient(Uri endpoint, string? modelId = null, HttpClient? httpClient = null)
    {
        _ = Throw.IfNull(endpoint);
        if (modelId is not null)
        {
            _ = Throw.IfNullOrWhitespace(modelId);
        }
 
        _apiChatEndpoint = new Uri(endpoint, "api/chat");
        _httpClient = httpClient ?? OllamaUtilities.SharedClient;
        Metadata = new("ollama", endpoint, modelId);
    }
 
    /// <inheritdoc />
    public ChatClientMetadata Metadata { get; }
 
    /// <summary>Gets or sets <see cref="JsonSerializerOptions"/> to use for any serialization activities related to tool call arguments and results.</summary>
    public JsonSerializerOptions ToolCallJsonSerializerOptions
    {
        get => _toolCallJsonSerializerOptions;
        set => _toolCallJsonSerializerOptions = Throw.IfNull(value);
    }
 
    /// <inheritdoc />
    public async Task<ChatCompletion> CompleteAsync(IList<ChatMessage> chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default)
    {
        _ = Throw.IfNull(chatMessages);
 
        using var httpResponse = await _httpClient.PostAsJsonAsync(
            _apiChatEndpoint,
            ToOllamaChatRequest(chatMessages, options, stream: false),
            JsonContext.Default.OllamaChatRequest,
            cancellationToken).ConfigureAwait(false);
 
        var response = (await httpResponse.Content.ReadFromJsonAsync(
            JsonContext.Default.OllamaChatResponse,
            cancellationToken).ConfigureAwait(false))!;
 
        if (!string.IsNullOrEmpty(response.Error))
        {
            throw new InvalidOperationException($"Ollama error: {response.Error}");
        }
 
        return new([FromOllamaMessage(response.Message!)])
        {
            CompletionId = response.CreatedAt,
            ModelId = response.Model ?? options?.ModelId ?? Metadata.ModelId,
            CreatedAt = DateTimeOffset.TryParse(response.CreatedAt, CultureInfo.InvariantCulture, DateTimeStyles.None, out DateTimeOffset createdAt) ? createdAt : null,
            FinishReason = ToFinishReason(response),
            Usage = ParseOllamaChatResponseUsage(response),
        };
    }
 
    /// <inheritdoc />
    public async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAsync(
        IList<ChatMessage> chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
    {
        _ = Throw.IfNull(chatMessages);
 
        using HttpRequestMessage request = new(HttpMethod.Post, _apiChatEndpoint)
        {
            Content = JsonContent.Create(ToOllamaChatRequest(chatMessages, options, stream: true), JsonContext.Default.OllamaChatRequest)
        };
        using var httpResponse = await _httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false);
        using var httpResponseStream = await httpResponse.Content
#if NET
            .ReadAsStreamAsync(cancellationToken)
#else
            .ReadAsStreamAsync()
#endif
            .ConfigureAwait(false);
 
        using var streamReader = new StreamReader(httpResponseStream);
#if NET
        while ((await streamReader.ReadLineAsync(cancellationToken).ConfigureAwait(false)) is { } line)
#else
        while ((await streamReader.ReadLineAsync().ConfigureAwait(false)) is { } line)
#endif
        {
            var chunk = JsonSerializer.Deserialize(line, JsonContext.Default.OllamaChatResponse);
            if (chunk is null)
            {
                continue;
            }
 
            string? modelId = chunk.Model ?? Metadata.ModelId;
 
            StreamingChatCompletionUpdate update = new()
            {
                Role = chunk.Message?.Role is not null ? new ChatRole(chunk.Message.Role) : null,
                CreatedAt = DateTimeOffset.TryParse(chunk.CreatedAt, CultureInfo.InvariantCulture, DateTimeStyles.None, out DateTimeOffset createdAt) ? createdAt : null,
                FinishReason = ToFinishReason(chunk),
                ModelId = modelId,
            };
 
            if (chunk.Message is { } message)
            {
                if (message.ToolCalls is { Length: > 0 })
                {
                    foreach (var toolCall in message.ToolCalls)
                    {
                        if (toolCall.Function is { } function)
                        {
                            update.Contents.Add(ToFunctionCallContent(function));
                        }
                    }
                }
 
                // Equivalent rule to the nonstreaming case
                if (message.Content?.Length > 0 || update.Contents.Count == 0)
                {
                    update.Contents.Insert(0, new TextContent(message.Content));
                }
            }
 
            if (ParseOllamaChatResponseUsage(chunk) is { } usage)
            {
                update.Contents.Add(new UsageContent(usage));
            }
 
            yield return update;
        }
    }
 
    /// <inheritdoc />
    public object? GetService(Type serviceType, object? serviceKey = null)
    {
        _ = Throw.IfNull(serviceType);
 
        return
            serviceKey is null && serviceType.IsInstanceOfType(this) ? this :
            null;
    }
 
    /// <inheritdoc />
    public void Dispose()
    {
        if (_httpClient != OllamaUtilities.SharedClient)
        {
            _httpClient.Dispose();
        }
    }
 
    private static UsageDetails? ParseOllamaChatResponseUsage(OllamaChatResponse response)
    {
        AdditionalPropertiesDictionary<long>? additionalCounts = null;
        OllamaUtilities.TransferNanosecondsTime(response, static r => r.LoadDuration, "load_duration", ref additionalCounts);
        OllamaUtilities.TransferNanosecondsTime(response, static r => r.TotalDuration, "total_duration", ref additionalCounts);
        OllamaUtilities.TransferNanosecondsTime(response, static r => r.PromptEvalDuration, "prompt_eval_duration", ref additionalCounts);
        OllamaUtilities.TransferNanosecondsTime(response, static r => r.EvalDuration, "eval_duration", ref additionalCounts);
 
        if (additionalCounts is not null || response.PromptEvalCount is not null || response.EvalCount is not null)
        {
            return new()
            {
                InputTokenCount = response.PromptEvalCount,
                OutputTokenCount = response.EvalCount,
                TotalTokenCount = response.PromptEvalCount.GetValueOrDefault() + response.EvalCount.GetValueOrDefault(),
                AdditionalCounts = additionalCounts,
            };
        }
 
        return null;
    }
 
    private static ChatFinishReason? ToFinishReason(OllamaChatResponse response) =>
        response.DoneReason switch
        {
            null => null,
            "length" => ChatFinishReason.Length,
            "stop" => ChatFinishReason.Stop,
            _ => new ChatFinishReason(response.DoneReason),
        };
 
    private static ChatMessage FromOllamaMessage(OllamaChatResponseMessage message)
    {
        List<AIContent> contents = [];
 
        // Add any tool calls.
        if (message.ToolCalls is { Length: > 0 })
        {
            foreach (var toolCall in message.ToolCalls)
            {
                if (toolCall.Function is { } function)
                {
                    contents.Add(ToFunctionCallContent(function));
                }
            }
        }
 
        // Ollama frequently sends back empty content with tool calls. Rather than always adding an empty
        // content, we only add the content if either it's not empty or there weren't any tool calls.
        if (message.Content?.Length > 0 || contents.Count == 0)
        {
            contents.Insert(0, new TextContent(message.Content));
        }
 
        return new ChatMessage(new(message.Role), contents);
    }
 
    private static FunctionCallContent ToFunctionCallContent(OllamaFunctionToolCall function)
    {
#if NET
        var id = System.Security.Cryptography.RandomNumberGenerator.GetHexString(8);
#else
        var id = Guid.NewGuid().ToString().Substring(0, 8);
#endif
        return new FunctionCallContent(id, function.Name, function.Arguments);
    }
 
    private static JsonElement? ToOllamaChatResponseFormat(ChatResponseFormat? format)
    {
        if (format is ChatResponseFormatJson jsonFormat)
        {
            return jsonFormat.Schema ?? _schemalessJsonResponseFormatValue;
        }
        else
        {
            return null;
        }
    }
 
    private OllamaChatRequest ToOllamaChatRequest(IList<ChatMessage> chatMessages, ChatOptions? options, bool stream)
    {
        OllamaChatRequest request = new()
        {
            Format = ToOllamaChatResponseFormat(options?.ResponseFormat),
            Messages = chatMessages.SelectMany(ToOllamaChatRequestMessages).ToArray(),
            Model = options?.ModelId ?? Metadata.ModelId ?? string.Empty,
            Stream = stream,
            Tools = options?.Tools is { Count: > 0 } tools ? tools.OfType<AIFunction>().Select(ToOllamaTool) : null,
        };
 
        if (options is not null)
        {
            TransferMetadataValue<bool>(nameof(OllamaRequestOptions.embedding_only), (options, value) => options.embedding_only = value);
            TransferMetadataValue<bool>(nameof(OllamaRequestOptions.f16_kv), (options, value) => options.f16_kv = value);
            TransferMetadataValue<bool>(nameof(OllamaRequestOptions.logits_all), (options, value) => options.logits_all = value);
            TransferMetadataValue<bool>(nameof(OllamaRequestOptions.low_vram), (options, value) => options.low_vram = value);
            TransferMetadataValue<int>(nameof(OllamaRequestOptions.main_gpu), (options, value) => options.main_gpu = value);
            TransferMetadataValue<float>(nameof(OllamaRequestOptions.min_p), (options, value) => options.min_p = value);
            TransferMetadataValue<int>(nameof(OllamaRequestOptions.mirostat), (options, value) => options.mirostat = value);
            TransferMetadataValue<float>(nameof(OllamaRequestOptions.mirostat_eta), (options, value) => options.mirostat_eta = value);
            TransferMetadataValue<float>(nameof(OllamaRequestOptions.mirostat_tau), (options, value) => options.mirostat_tau = value);
            TransferMetadataValue<int>(nameof(OllamaRequestOptions.num_batch), (options, value) => options.num_batch = value);
            TransferMetadataValue<int>(nameof(OllamaRequestOptions.num_ctx), (options, value) => options.num_ctx = value);
            TransferMetadataValue<int>(nameof(OllamaRequestOptions.num_gpu), (options, value) => options.num_gpu = value);
            TransferMetadataValue<int>(nameof(OllamaRequestOptions.num_keep), (options, value) => options.num_keep = value);
            TransferMetadataValue<int>(nameof(OllamaRequestOptions.num_thread), (options, value) => options.num_thread = value);
            TransferMetadataValue<bool>(nameof(OllamaRequestOptions.numa), (options, value) => options.numa = value);
            TransferMetadataValue<bool>(nameof(OllamaRequestOptions.penalize_newline), (options, value) => options.penalize_newline = value);
            TransferMetadataValue<int>(nameof(OllamaRequestOptions.repeat_last_n), (options, value) => options.repeat_last_n = value);
            TransferMetadataValue<float>(nameof(OllamaRequestOptions.repeat_penalty), (options, value) => options.repeat_penalty = value);
            TransferMetadataValue<float>(nameof(OllamaRequestOptions.tfs_z), (options, value) => options.tfs_z = value);
            TransferMetadataValue<float>(nameof(OllamaRequestOptions.typical_p), (options, value) => options.typical_p = value);
            TransferMetadataValue<bool>(nameof(OllamaRequestOptions.use_mmap), (options, value) => options.use_mmap = value);
            TransferMetadataValue<bool>(nameof(OllamaRequestOptions.use_mlock), (options, value) => options.use_mlock = value);
            TransferMetadataValue<bool>(nameof(OllamaRequestOptions.vocab_only), (options, value) => options.vocab_only = value);
 
            if (options.FrequencyPenalty is float frequencyPenalty)
            {
                (request.Options ??= new()).frequency_penalty = frequencyPenalty;
            }
 
            if (options.MaxOutputTokens is int maxOutputTokens)
            {
                (request.Options ??= new()).num_predict = maxOutputTokens;
            }
 
            if (options.PresencePenalty is float presencePenalty)
            {
                (request.Options ??= new()).presence_penalty = presencePenalty;
            }
 
            if (options.StopSequences is { Count: > 0 })
            {
                (request.Options ??= new()).stop = [.. options.StopSequences];
            }
 
            if (options.Temperature is float temperature)
            {
                (request.Options ??= new()).temperature = temperature;
            }
 
            if (options.TopP is float topP)
            {
                (request.Options ??= new()).top_p = topP;
            }
 
            if (options.TopK is int topK)
            {
                (request.Options ??= new()).top_k = topK;
            }
 
            if (options.Seed is long seed)
            {
                (request.Options ??= new()).seed = seed;
            }
        }
 
        return request;
 
        void TransferMetadataValue<T>(string propertyName, Action<OllamaRequestOptions, T> setOption)
        {
            if (options.AdditionalProperties?.TryGetValue(propertyName, out T? t) is true)
            {
                request.Options ??= new();
                setOption(request.Options, t);
            }
        }
    }
 
    private IEnumerable<OllamaChatRequestMessage> ToOllamaChatRequestMessages(ChatMessage content)
    {
        // In general, we return a single request message for each understood content item.
        // However, various image models expect both text and images in the same request message.
        // To handle that, attach images to a previous text message if one exists.
 
        OllamaChatRequestMessage? currentTextMessage = null;
        foreach (var item in content.Contents)
        {
            if (currentTextMessage is not null && item is not ImageContent)
            {
                yield return currentTextMessage;
                currentTextMessage = null;
            }
 
            switch (item)
            {
                case TextContent textContent:
                    currentTextMessage = new OllamaChatRequestMessage
                    {
                        Role = content.Role.Value,
                        Content = textContent.Text ?? string.Empty,
                    };
                    break;
 
                case ImageContent imageContent when imageContent.Data is not null:
                    IList<string> images = currentTextMessage?.Images ?? [];
                    images.Add(Convert.ToBase64String(imageContent.Data.Value
#if NET
                        .Span));
#else
                        .ToArray()));
#endif
 
                    if (currentTextMessage is not null)
                    {
                        currentTextMessage.Images = images;
                    }
                    else
                    {
                        yield return new OllamaChatRequestMessage
                        {
                            Role = content.Role.Value,
                            Images = images,
                        };
                    }
 
                    break;
 
                case FunctionCallContent fcc:
                {
                    yield return new OllamaChatRequestMessage
                    {
                        Role = "assistant",
                        Content = JsonSerializer.Serialize(new OllamaFunctionCallContent
                        {
                            CallId = fcc.CallId,
                            Name = fcc.Name,
                            Arguments = JsonSerializer.SerializeToElement(fcc.Arguments, ToolCallJsonSerializerOptions.GetTypeInfo(typeof(IDictionary<string, object?>))),
                        }, JsonContext.Default.OllamaFunctionCallContent)
                    };
                    break;
                }
 
                case FunctionResultContent frc:
                {
                    JsonElement jsonResult = JsonSerializer.SerializeToElement(frc.Result, ToolCallJsonSerializerOptions.GetTypeInfo(typeof(object)));
                    yield return new OllamaChatRequestMessage
                    {
                        Role = "tool",
                        Content = JsonSerializer.Serialize(new OllamaFunctionResultContent
                        {
                            CallId = frc.CallId,
                            Result = jsonResult,
                        }, JsonContext.Default.OllamaFunctionResultContent)
                    };
                    break;
                }
            }
        }
 
        if (currentTextMessage is not null)
        {
            yield return currentTextMessage;
        }
    }
 
    private static OllamaTool ToOllamaTool(AIFunction function) => new()
    {
        Type = "function",
        Function = new OllamaFunctionTool
        {
            Name = function.Metadata.Name,
            Description = function.Metadata.Description,
            Parameters = new OllamaFunctionToolParameters
            {
                Properties = function.Metadata.Parameters.ToDictionary(
                    p => p.Name,
                    p => p.Schema is JsonElement e ? e : _defaultParameterSchema),
                Required = function.Metadata.Parameters.Where(p => p.IsRequired).Select(p => p.Name).ToList(),
            },
        }
    };
}