File: DatabaseOperations.cs
Web Access
Project: src\src\Caching\SqlServer\src\Microsoft.Extensions.Caching.SqlServer.csproj (Microsoft.Extensions.Caching.SqlServer)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System;
using System.Buffers;
using System.Data;
using System.Linq;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Data.SqlClient;
using Microsoft.Extensions.Caching.Distributed;
using Microsoft.Extensions.Internal;
 
namespace Microsoft.Extensions.Caching.SqlServer;
 
internal sealed class DatabaseOperations : IDatabaseOperations
{
    /// <summary>
    /// Since there is no specific exception type representing a 'duplicate key' error, we are relying on
    /// the following message number which represents the following text in Microsoft SQL Server database.
    ///     "Violation of %ls constraint '%.*ls'. Cannot insert duplicate key in object '%.*ls'.
    ///     The duplicate key value is %ls."
    /// You can find the list of system messages by executing the following query:
    /// "SELECT * FROM sys.messages WHERE [text] LIKE '%duplicate%'"
    /// </summary>
    private const int DuplicateKeyErrorId = 2627;
 
    private const string UtcNowParameterName = "UtcNow";
 
    public DatabaseOperations(
        string connectionString, string schemaName, string tableName, ISystemClock systemClock)
    {
        ConnectionString = connectionString;
        SchemaName = schemaName;
        TableName = tableName;
        SystemClock = systemClock;
        SqlQueries = new SqlQueries(schemaName, tableName);
    }
 
    internal SqlQueries SqlQueries { get; }
 
    internal string ConnectionString { get; }
 
    internal string SchemaName { get; }
 
    internal string TableName { get; }
 
    private ISystemClock SystemClock { get; }
 
    public void DeleteCacheItem(string key)
    {
        using (var connection = new SqlConnection(ConnectionString))
        using (var command = new SqlCommand(SqlQueries.DeleteCacheItem, connection))
        {
            command.Parameters.AddCacheItemId(key);
 
            connection.Open();
 
            command.ExecuteNonQuery();
        }
    }
 
    public async Task DeleteCacheItemAsync(string key, CancellationToken token = default(CancellationToken))
    {
        token.ThrowIfCancellationRequested();
 
        using (var connection = new SqlConnection(ConnectionString))
        using (var command = new SqlCommand(SqlQueries.DeleteCacheItem, connection))
        {
            command.Parameters.AddCacheItemId(key);
 
            await connection.OpenAsync(token).ConfigureAwait(false);
 
            await command.ExecuteNonQueryAsync(token).ConfigureAwait(false);
        }
    }
 
    public byte[]? GetCacheItem(string key)
    {
        return GetCacheItem(key, includeValue: true);
    }
 
    public bool TryGetCacheItem(string key, IBufferWriter<byte> destination)
    {
        return GetCacheItem(key, includeValue: true, destination: destination) is not null;
    }
 
    public Task<byte[]?> GetCacheItemAsync(string key, CancellationToken token = default(CancellationToken))
    {
        token.ThrowIfCancellationRequested();
 
        return GetCacheItemAsync(key, includeValue: true, token: token);
    }
 
    public async Task<bool> TryGetCacheItemAsync(string key, IBufferWriter<byte> destination, CancellationToken token = default(CancellationToken))
    {
        token.ThrowIfCancellationRequested();
 
        var arr = await GetCacheItemAsync(key, includeValue: true, destination: destination, token: token).ConfigureAwait(false);
        return arr is not null;
    }
 
    public void RefreshCacheItem(string key)
    {
        GetCacheItem(key, includeValue: false);
    }
 
    public Task RefreshCacheItemAsync(string key, CancellationToken token = default(CancellationToken))
    {
        token.ThrowIfCancellationRequested();
 
        return GetCacheItemAsync(key, includeValue: false, token: token);
    }
 
    public void DeleteExpiredCacheItems()
    {
        var utcNow = SystemClock.UtcNow;
 
        using (var connection = new SqlConnection(ConnectionString))
        using (var command = new SqlCommand(SqlQueries.DeleteExpiredCacheItems, connection))
        {
            command.Parameters.AddWithValue(UtcNowParameterName, SqlDbType.DateTimeOffset, utcNow);
 
            connection.Open();
 
            var effectedRowCount = command.ExecuteNonQuery();
        }
    }
 
    public void SetCacheItem(string key, ArraySegment<byte> value, DistributedCacheEntryOptions options)
    {
        var utcNow = SystemClock.UtcNow;
 
        var absoluteExpiration = DatabaseOperations.GetAbsoluteExpiration(utcNow, options);
        DatabaseOperations.ValidateOptions(options.SlidingExpiration, absoluteExpiration);
 
        using (var connection = new SqlConnection(ConnectionString))
        using (var upsertCommand = new SqlCommand(SqlQueries.SetCacheItem, connection))
        {
            upsertCommand.Parameters
                .AddCacheItemId(key)
                .AddCacheItemValue(value)
                .AddSlidingExpirationInSeconds(options.SlidingExpiration)
                .AddAbsoluteExpiration(absoluteExpiration)
                .AddWithValue(UtcNowParameterName, SqlDbType.DateTimeOffset, utcNow);
 
            connection.Open();
 
            try
            {
                upsertCommand.ExecuteNonQuery();
            }
            catch (SqlException ex)
            {
                if (DatabaseOperations.IsDuplicateKeyException(ex))
                {
                    // There is a possibility that multiple requests can try to add the same item to the cache, in
                    // which case we receive a 'duplicate key' exception on the primary key column.
                }
                else
                {
                    throw;
                }
            }
        }
    }
 
    public async Task SetCacheItemAsync(string key, ArraySegment<byte> value, DistributedCacheEntryOptions options, CancellationToken token = default(CancellationToken))
    {
        token.ThrowIfCancellationRequested();
 
        var utcNow = SystemClock.UtcNow;
 
        var absoluteExpiration = DatabaseOperations.GetAbsoluteExpiration(utcNow, options);
        DatabaseOperations.ValidateOptions(options.SlidingExpiration, absoluteExpiration);
 
        using (var connection = new SqlConnection(ConnectionString))
        using (var upsertCommand = new SqlCommand(SqlQueries.SetCacheItem, connection))
        {
            upsertCommand.Parameters
                .AddCacheItemId(key)
                .AddCacheItemValue(value)
                .AddSlidingExpirationInSeconds(options.SlidingExpiration)
                .AddAbsoluteExpiration(absoluteExpiration)
                .AddWithValue(UtcNowParameterName, SqlDbType.DateTimeOffset, utcNow);
 
            await connection.OpenAsync(token).ConfigureAwait(false);
 
            try
            {
                await upsertCommand.ExecuteNonQueryAsync(token).ConfigureAwait(false);
            }
            catch (SqlException ex)
            {
                if (DatabaseOperations.IsDuplicateKeyException(ex))
                {
                    // There is a possibility that multiple requests can try to add the same item to the cache, in
                    // which case we receive a 'duplicate key' exception on the primary key column.
                }
                else
                {
                    throw;
                }
            }
        }
    }
 
    private byte[]? GetCacheItem(string key, bool includeValue, IBufferWriter<byte>? destination = null)
    {
        var utcNow = SystemClock.UtcNow;
 
        string query;
        if (includeValue)
        {
            query = SqlQueries.GetCacheItem;
        }
        else
        {
            query = SqlQueries.GetCacheItemWithoutValue;
        }
 
        byte[]? value = null;
        using (var connection = new SqlConnection(ConnectionString))
        using (var command = new SqlCommand(query, connection))
        {
            command.Parameters
                .AddCacheItemId(key)
                .AddWithValue(UtcNowParameterName, SqlDbType.DateTimeOffset, utcNow);
 
            connection.Open();
 
            if (includeValue)
            {
                using var reader = command.ExecuteReader(
                    CommandBehavior.SequentialAccess | CommandBehavior.SingleRow | CommandBehavior.SingleResult);
 
                if (reader.Read())
                {
                    if (destination is null)
                    {
                        value = reader.GetFieldValue<byte[]>(Columns.Indexes.CacheItemValueIndex);
                    }
                    else
                    {
                        StreamOut(reader, Columns.Indexes.CacheItemValueIndex, destination);
                        value = []; // use non-null here as a sentinel to say "we got one"
                    }
                }
            }
            else
            {
                command.ExecuteNonQuery();
            }
        }
 
        return value;
    }
 
    private async Task<byte[]?> GetCacheItemAsync(string key, bool includeValue, IBufferWriter<byte>? destination = null, CancellationToken token = default(CancellationToken))
    {
        token.ThrowIfCancellationRequested();
 
        var utcNow = SystemClock.UtcNow;
 
        string query;
        if (includeValue)
        {
            query = SqlQueries.GetCacheItem;
        }
        else
        {
            query = SqlQueries.GetCacheItemWithoutValue;
        }
 
        byte[]? value = null;
        using (var connection = new SqlConnection(ConnectionString))
        using (var command = new SqlCommand(query, connection))
        {
            command.Parameters
                .AddCacheItemId(key)
                .AddWithValue(UtcNowParameterName, SqlDbType.DateTimeOffset, utcNow);
 
            await connection.OpenAsync(token).ConfigureAwait(false);
 
            if (includeValue)
            {
                using var reader = await command.ExecuteReaderAsync(
                    CommandBehavior.SequentialAccess | CommandBehavior.SingleRow | CommandBehavior.SingleResult, token).ConfigureAwait(false);
 
                if (await reader.ReadAsync(token).ConfigureAwait(false))
                {
                    if (destination is null)
                    {
                        value = await reader.GetFieldValueAsync<byte[]>(Columns.Indexes.CacheItemValueIndex, token).ConfigureAwait(false);
                    }
                    else
                    {
                        StreamOut(reader, Columns.Indexes.CacheItemValueIndex, destination);
                        value = []; // use non-null here as a sentinel to say "we got one"
                    }
                }
            }
            else
            {
                await command.ExecuteNonQueryAsync(token).ConfigureAwait(false);
            }
        }
 
        return value;
    }
 
    private static long StreamOut(SqlDataReader source, int ordinal, IBufferWriter<byte> destination)
    {
        long dataIndex = 0;
        int read = 0;
        byte[]? lease = null;
        do
        {
            dataIndex += read; // increment offset
 
            const int DefaultPageSize = 8192;
 
            var memory = destination.GetMemory(DefaultPageSize); // start from the page size
            if (MemoryMarshal.TryGetArray<byte>(memory, out var segment))
            {
                // avoid an extra copy by writing directly to the target array when possible
                read = (int)source.GetBytes(ordinal, dataIndex, segment.Array, segment.Offset, segment.Count);
                if (read > 0)
                {
                    destination.Advance(read);
                }
            }
            else
            {
                lease ??= ArrayPool<byte>.Shared.Rent(DefaultPageSize);
                read = (int)source.GetBytes(ordinal, dataIndex, lease, 0, lease.Length);
 
                if (read > 0)
                {
                    if (new ReadOnlySpan<byte>(lease, 0, read).TryCopyTo(memory.Span))
                    {
                        destination.Advance(read);
                    }
                    else
                    {
                        // multi-chunk write (utility method)
                        destination.Write(new(lease, 0, read));
                    }
                }
            }
        }
        while (read > 0);
 
        if (lease is not null)
        {
            ArrayPool<byte>.Shared.Return(lease);
        }
        return dataIndex;
    }
 
    private static bool IsDuplicateKeyException(SqlException ex)
    {
        if (ex.Errors != null)
        {
            return ex.Errors.Cast<SqlError>().Any(error => error.Number == DuplicateKeyErrorId);
        }
        return false;
    }
 
    private static DateTimeOffset? GetAbsoluteExpiration(DateTimeOffset utcNow, DistributedCacheEntryOptions options)
    {
        // calculate absolute expiration
        DateTimeOffset? absoluteExpiration = null;
        if (options.AbsoluteExpirationRelativeToNow.HasValue)
        {
            absoluteExpiration = utcNow.Add(options.AbsoluteExpirationRelativeToNow.Value);
        }
        else if (options.AbsoluteExpiration.HasValue)
        {
            if (options.AbsoluteExpiration.Value <= utcNow)
            {
                throw new InvalidOperationException("The absolute expiration value must be in the future.");
            }
 
            absoluteExpiration = options.AbsoluteExpiration.Value;
        }
        return absoluteExpiration;
    }
 
    private static void ValidateOptions(TimeSpan? slidingExpiration, DateTimeOffset? absoluteExpiration)
    {
        if (!slidingExpiration.HasValue && !absoluteExpiration.HasValue)
        {
            throw new InvalidOperationException("Either absolute or sliding expiration needs " +
                "to be provided.");
        }
    }
}