File: PRCreator.cs
Web Access
Project: src\src\Microsoft.DotNet.Baselines.Tasks\Microsoft.DotNet.Baselines.Tasks.csproj (Microsoft.DotNet.Baselines.Tasks)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using Octokit;
using System;
using System.IO;
using System.Linq;
using System.Text;
using System.Net.Http;
using System.Threading.Tasks;
using System.Collections.Concurrent;
using System.Collections.Generic;
using Microsoft.Build.Framework;
using Microsoft.Build.Utilities;
using Task = System.Threading.Tasks.Task;
using FileMode = Octokit.FileMode;
 
namespace Microsoft.DotNet.Baselines.Tasks;
 
internal class PRCreator
{
    private readonly TaskLoggingHelper _logger;
    private readonly string _gitHubOrg;
    private readonly string _gitHubRepoName;
    private readonly GitHubClient _client;
    private const string BuildLink = "https://dev.azure.com/dnceng/internal/_build/results?buildId=";
    private const string TreeMode = "040000";
    private const int MaxRetries = 10;
 
    public PRCreator(TaskLoggingHelper logger, string gitHubOrg, string gitHubRepoName, string gitHubToken)
    {
        _logger = logger;
 
        // Create a new GitHub client
        _client = new GitHubClient(new ProductHeaderValue(System.Reflection.Assembly.GetExecutingAssembly().GetName().Name));
        var authToken = new Credentials(gitHubToken);
        _client.Credentials = authToken;
        _gitHubOrg = gitHubOrg;
        _gitHubRepoName = gitHubRepoName;
    }
 
    public async Task<bool> ExecuteAsync(
        string targetDirectory,
        List<string> updatedFiles,
        int buildId,
        string title,
        string targetBranch,
        string defaultBaselineContent,
        bool unionExclusionsBaselines)
    {
        DateTime startTime = DateTime.Now.ToUniversalTime();
 
        _logger.LogMessage(MessageImportance.High, $"Starting PR creation at {startTime} UTC.");
 
        // Fetch the files within the desired path from the original tree. must be a relative path
        TreeResponse originalTreeResponse = await ApiRequestWithRetries(() => _client.Git.Tree.Get(_gitHubOrg, _gitHubRepoName, targetBranch));
        List<NewTreeItem> originalTreeItems = await FetchOriginalTreeItemsAsync(originalTreeResponse, targetBranch, targetDirectory);
 
        // Update the test results tree
        Dictionary<string, HashSet<string>> parsedUpdatedFiles = ParseAndGroupUpdatedFiles(updatedFiles);
        originalTreeItems = await UpdateAllFilesAsync(parsedUpdatedFiles, originalTreeItems, unionExclusionsBaselines, defaultBaselineContent);
        var testResultsTreeResponse = await CreateTreeFromItemsAsync(originalTreeItems);
        var parentTreeResponse = await CreateParentTreeAsync(testResultsTreeResponse, originalTreeResponse, targetDirectory);
 
        await CreateOrUpdatePullRequestAsync(parentTreeResponse, buildId, title, targetBranch);
 
        return !_logger.HasLoggedErrors;
    }
 
    private async Task<List<NewTreeItem>> FetchOriginalTreeItemsAsync(
        TreeResponse? treeResponse,
        string targetBranch,
        string targetDirectory)
    {
        if (treeResponse == null)
        {
            _logger.LogError($"Failed to fetch the original tree for branch '{targetBranch}' in repository '{_gitHubOrg}/{_gitHubRepoName}'.");
            throw new InvalidOperationException("Original tree response is invalid.");
        }
 
        ConcurrentBag<NewTreeItem> treeItems = new();
        await FetchOriginalTreeItemsAsync(treeResponse, treeItems, targetBranch, targetDirectory);
 
        List<NewTreeItem> items = treeItems.ToList();
        if (!items.Any())
        {
            _logger.LogError($"No files found in the original tree for branch '{targetBranch}' in repository '{_gitHubOrg}/{_gitHubRepoName}'.");
            throw new InvalidOperationException("No files found in the original tree.");
        }
 
        return items;
    }
 
    private async Task FetchOriginalTreeItemsAsync(
        TreeResponse? treeResponse,
        ConcurrentBag<NewTreeItem> treeItems,
        string targetBranch,
        string targetDirectory,
        string relativePath = "")
    {
        if (treeResponse == null)
        {
            return;
        }
 
        await Parallel.ForEachAsync(treeResponse.Tree, async (item, cancellationToken) =>
        {
            string path = Path.Combine(relativePath, item.Path);
            if (!path.StartsWith(targetDirectory) && !targetDirectory.StartsWith(path))
            {
                return;
            }
 
            if (item.Type == TreeType.Tree)
            {
                TreeResponse subTree = await ApiRequestWithRetries(() => _client.Git.Tree.Get(_gitHubOrg, _gitHubRepoName, item.Sha));
                await FetchOriginalTreeItemsAsync(subTree, treeItems, targetBranch, targetDirectory, path);
            }
            else
            {
                var newItem = new NewTreeItem
                {
                    Path = Path.GetRelativePath(targetDirectory, path),
                    Mode = item.Mode,
                    Type = item.Type.Value,
                    Sha = item.Sha
                };
 
                treeItems.Add(newItem);
            }
        });
    }
 
    // Return a dictionary using the filename without the
    // "Updated" prefix (if present) and anything before the first '.' as the key
    private Dictionary<string, HashSet<string>> ParseAndGroupUpdatedFiles(List<string> updatedFiles) =>
        updatedFiles
            .Select(updatedFile => {
                if (!File.Exists(updatedFile))
                {
                    throw new ArgumentException($"Updated file path '{updatedFile}' is not a valid file.");
                }
                return updatedFile;
            })
            .GroupBy(updatedFile => ParseUpdatedFileName(updatedFile).Split('.')[0])
            .ToDictionary(
                group => group.Key,
                group => new HashSet<string>(group)
            );
 
    private async Task<List<NewTreeItem>> UpdateAllFilesAsync(
        Dictionary<string, HashSet<string>> updatedFiles,
        List<NewTreeItem> tree,
        bool unionExclusionsBaselines,
        string defaultBaselineContent)
    {
        foreach (var updatedFile in updatedFiles)
        {
            if (updatedFile.Key.Contains("Exclusions"))
            {
                tree = await UpdateExclusionFileAsync(updatedFile.Key, updatedFile.Value, tree, unionExclusionsBaselines);
            }
            else
            {
                tree = await UpdateRegularFilesAsync(updatedFile.Value, tree, defaultBaselineContent);
            }
        }
        return tree;
    }
 
    private async Task<List<NewTreeItem>> UpdateExclusionFileAsync(
        string fileNameKey,
        HashSet<string> updatedFiles,
        List<NewTreeItem> tree,
        bool union = false)
    {
        string? content = null;
        IEnumerable<string> parsedFile = Enumerable.Empty<string>();
 
        // Combine the lines of the updated files
        foreach (var filePath in updatedFiles)
        {
            var updatedFileLines = File.ReadAllLines(filePath);
            if (!parsedFile.Any())
            {
                parsedFile = updatedFileLines;
            }
            else if (union == true)
            {
                parsedFile = parsedFile.Union(updatedFileLines);
            }
            else
            {
                parsedFile = parsedFile.Where(parsedLine => updatedFileLines.Contains(parsedLine));
            }
        }
 
        if (union == true)
        {
            // Need to compare to the original file and remove any lines that are not in the parsed updated file
 
            // Find the key in the tree, download the blob, and convert it to utf8
            var originalTreeItem = tree
                .Where(item => item.Path.Contains(fileNameKey))
                .FirstOrDefault();
 
            if (originalTreeItem != null)
            {
                var originalBlob = await ApiRequestWithRetries(() => _client.Git.Blob.Get(_gitHubOrg, _gitHubRepoName, originalTreeItem.Sha));
                content = Encoding.UTF8.GetString(Convert.FromBase64String(originalBlob.Content));
                var originalContent = content.Split("\n");
 
                foreach (var line in originalContent)
                {
                    if (!parsedFile.Contains(line))
                    {
                        // If the newline character is not present, the line is at the end of the file
                        content = content.Contains(line + "\n") ? content.Replace(line + "\n", "") : content.Replace(line, "");
                    }
                }
            }
        }
 
        else
        {
            if (parsedFile.Any())
            {
                // No need to compare to the original file, just log the parsed lines
                content = string.Join("\n", parsedFile) + "\n";
            }
        }
 
        string updatedFilePath = fileNameKey + ".txt";
        return await UpdateFileAsync(tree, content, fileNameKey, updatedFilePath);
    }
 
    private async Task<List<NewTreeItem>> UpdateRegularFilesAsync(
        HashSet<string> updatedFiles,
        List<NewTreeItem> tree,
        string defaultBaselineContent)
    {
        foreach (string filePath in updatedFiles)
        {
            string? content = File.ReadAllText(filePath);
            if (!string.IsNullOrEmpty(defaultBaselineContent) && content == defaultBaselineContent)
            {
                content = null;
            }
            string originalFileName = Path.GetFileName(ParseUpdatedFileName(filePath));
            tree = await UpdateFileAsync(tree, content, originalFileName, originalFileName);
        }
        return tree;
    }
 
    private async Task<List<NewTreeItem>> UpdateFileAsync(
        List<NewTreeItem> tree,
        string? content,
        string searchFileName,
        string updatedPath)
    {
        var originalTreeItem = tree
            .Where(item => item.Path.Contains(searchFileName))
            .FirstOrDefault();
 
        if (content == null)
        {
            // Content is null, delete the file if it exists
            if (originalTreeItem != null)
            {
                tree.Remove(originalTreeItem);
            }
        }
        else if (originalTreeItem == null)
        {
            // Path not in the tree, add a new tree item
            var blob = await CreateBlobAsync(content);
            tree.Add(new NewTreeItem
            {
                Type = TreeType.Blob,
                Mode = FileMode.File,
                Path = updatedPath,
                Sha = blob.Sha
            });
        }
        else
        {
            // Path in the tree, update the sha and the content
            var blob = await CreateBlobAsync(content);
            originalTreeItem.Sha = blob.Sha;
        }
        return tree;
    }
 
    private async Task<BlobReference> CreateBlobAsync(string content)
    {
        var blob = new NewBlob
        {
            Content = content,
            Encoding = EncodingType.Utf8
        };
        return await ApiRequestWithRetries(() => _client.Git.Blob.Create(_gitHubOrg, _gitHubRepoName, blob));
    }
 
    private string ParseUpdatedFileName(string updatedFile)
    {
        try
        {
            return updatedFile.Split("Updated")[1];
        }
        catch
        {
            return updatedFile;
        }
    }
 
    private async Task<TreeResponse> CreateTreeFromItemsAsync(List<NewTreeItem> items, string path = "")
    {
        List<NewTreeItem> newTreeItems = [];
 
        var groups = items.GroupBy(item => Path.GetDirectoryName(item.Path));
        foreach (var group in groups)
        {
            if (string.IsNullOrEmpty(group.Key) || group.Key == path)
            {
                // These items are in the current directory, so add them to the new tree items
                foreach (var item in group)
                {
                    if (item.Type != TreeType.Tree)
                    {
                        newTreeItems.Add(new NewTreeItem
                        {
                            Path = path == string.Empty ? item.Path : Path.GetRelativePath(path, item.Path),
                            Mode = item.Mode,
                            Type = item.Type,
                            Sha = item.Sha
                        });
                    }
                }
            }
            else
            {
                // These items are in a subdirectory, so recursively create a tree for them
                var subtreeResponse = await CreateTreeFromItemsAsync(group.ToList(), group.Key);
                newTreeItems.Add(new NewTreeItem
                {
                    Path = group.Key,
                    Mode = TreeMode,
                    Type = TreeType.Tree,
                    Sha = subtreeResponse.Sha
                });
            }
        }
 
        var newTree = new NewTree();
        foreach (var item in newTreeItems)
        {
            newTree.Tree.Add(item);
        }
        return await ApiRequestWithRetries(() => _client.Git.Tree.Create(_gitHubOrg, _gitHubRepoName, newTree));
    }
 
    private async Task<TreeResponse> CreateParentTreeAsync(
        TreeResponse testResultsTreeResponse,
        TreeResponse originalTreeResponse,
        string targetDirectory)
    {
        // Create a new tree for the parent directory
        NewTree parentTree = new NewTree { BaseTree = originalTreeResponse.Sha };
 
        //  Connect the updated test results tree
        parentTree.Tree.Add(new NewTreeItem
        {
            Path = targetDirectory,
            Mode = TreeMode,
            Type = TreeType.Tree,
            Sha = testResultsTreeResponse.Sha
        });
 
        return await ApiRequestWithRetries(() => _client.Git.Tree.Create(_gitHubOrg, _gitHubRepoName, parentTree));
    }
 
    private async Task CreateOrUpdatePullRequestAsync(TreeResponse parentTreeResponse, int buildId, string title, string targetBranch)
    {
        var existingPullRequest = await GetExistingPullRequestAsync(title, targetBranch);
 
        // Create the branch name and get the head reference
        string newBranchName = string.Empty;
        string headSha = await GetHeadShaAsync(targetBranch);
        if (existingPullRequest == null)
        {
            string utcTime = DateTime.UtcNow.ToString("yyyyMMddHHmmss");
            newBranchName = $"pr-baseline-{utcTime}";
        }
        else
        {
            newBranchName = existingPullRequest.Head.Ref;
 
            try
            {
                // Merge the target branch into the existing pull request
                var merge = new NewMerge(newBranchName, headSha);
                await ApiRequestWithRetries(() => _client.Repository.Merging.Create(_gitHubOrg, _gitHubRepoName, merge));
            }
            catch (Exception e)
            {
                _logger.LogWarning($"Failed to merge the target branch into the existing pull request: {e.Message}");
                _logger.LogWarning("Continuing with updating the existing pull request. You may need to resolve conflicts manually in the PR.");
            }
 
            headSha = await GetHeadShaAsync(newBranchName);
        }
 
        string commitSha = await CreateCommitAsync(
            parentTreeResponse.Sha,
            headSha,
            $"Update baselines for build {BuildLink}{buildId} (internal Microsoft link)");
 
        if (await ShouldMakeUpdatesAsync(headSha, commitSha))
        {
            string pullRequestBody = $"This PR was created by `Microsoft.DotNet.Baselines.Tasks.CreateUpdatePR` for build {buildId}. \n\n" +
                                 $"The updated test results can be found at {BuildLink}{buildId} (internal Microsoft link)";
            if (existingPullRequest != null)
            {
                await UpdatePullRequestAsync(newBranchName, commitSha, pullRequestBody, existingPullRequest);
            }
            else
            {
                await CreatePullRequestAsync(newBranchName, commitSha, targetBranch, title, pullRequestBody);
            }
        }
    }
 
    private async Task<PullRequest?> GetExistingPullRequestAsync(string title, string targetBranch)
    {
        var request = new PullRequestRequest
        {
            Base = targetBranch
        };
 
        var existingPullRequests = await ApiRequestWithRetries(() =>
            _client.PullRequest.GetAllForRepository(_gitHubOrg, _gitHubRepoName, request));
 
        return existingPullRequests.FirstOrDefault(pr => pr.Title == title);
    }
 
    private async Task<string> CreateCommitAsync(string newSha, string headSha, string commitMessage)
    {
        var newCommit = new NewCommit(commitMessage, newSha, headSha);
        var commit = await ApiRequestWithRetries(() => _client.Git.Commit.Create(_gitHubOrg, _gitHubRepoName, newCommit));
        return commit.Sha;
    }
 
    private async Task<bool> ShouldMakeUpdatesAsync(string headSha, string commitSha)
    {
        var comparison = await ApiRequestWithRetries(() => _client.Repository.Commit.Compare(_gitHubOrg, _gitHubRepoName, headSha, commitSha));
        if (!comparison.Files.Any())
        {
            _logger.LogMessage(MessageImportance.High, "No changes to commit. Skipping PR creation/updates.");
            return false;
        }
        return true;
    }
 
    private async Task UpdatePullRequestAsync(string branchName, string commitSha, string body, PullRequest pullRequest)
    {
        await UpdateReferenceAsync(branchName, commitSha);
 
        var pullRequestUpdate = new PullRequestUpdate
        {
            Body = body
        };
        await ApiRequestWithRetries(() => _client.PullRequest.Update(_gitHubOrg, _gitHubRepoName, pullRequest.Number, pullRequestUpdate));
 
        _logger.LogMessage(MessageImportance.High, $"Updated existing pull request #{pullRequest.Number}. URL: {pullRequest.HtmlUrl}");
    }
 
    private async Task CreatePullRequestAsync(string newBranchName, string commitSha, string targetBranch, string title, string body)
    {
        await CreateReferenceAsync(newBranchName, commitSha);
 
        var newPullRequest = new NewPullRequest(title, newBranchName, targetBranch)
        {
            Body = body
        };
        var pullRequest = await ApiRequestWithRetries(() => _client.PullRequest.Create(_gitHubOrg, _gitHubRepoName, newPullRequest));
 
        _logger.LogMessage(MessageImportance.High, $"Created pull request #{pullRequest.Number}. URL: {pullRequest.HtmlUrl}");
    }
 
    private async Task<string> GetHeadShaAsync(string branchName)
    {
        var reference = await ApiRequestWithRetries(() => _client.Git.Reference.Get(_gitHubOrg, _gitHubRepoName, $"heads/{branchName}"));
        return reference.Object.Sha;
    }
 
    private async Task UpdateReferenceAsync(string branchName, string commitSha)
    {
        var referenceUpdate = new ReferenceUpdate(commitSha);
        await ApiRequestWithRetries(() => _client.Git.Reference.Update(_gitHubOrg, _gitHubRepoName, $"heads/{branchName}", referenceUpdate));
    }
 
    private async Task CreateReferenceAsync(string branchName, string commitSha)
    {
        var newReference = new NewReference($"refs/heads/{branchName}", commitSha);
        await ApiRequestWithRetries(() => _client.Git.Reference.Create(_gitHubOrg, _gitHubRepoName, newReference));
    }
 
    private async Task<T> ApiRequestWithRetries<T>(Func<Task<T>> action)
    {
        int attempt = 0;
        int delayMilliseconds = 1000;
        while (true)
        {
            try
            {
                return await action();
            }
            catch (RateLimitExceededException ex)
            {
                var resetTime = ex.Reset.UtcDateTime;
                var delay = resetTime - DateTime.UtcNow;
                _logger.LogWarning($"Rate limit exceeded. Retrying after {delay.TotalSeconds} seconds...");
                await Task.Delay(delay);
            }
            catch (Exception ex) when (
                attempt < MaxRetries
                && (ex is ApiException || ex is HttpRequestException)
                && (ex.InnerException is TaskCanceledException))
            {
                attempt++;
                _logger.LogWarning($"Attempt {attempt} failed: {ex.Message}. Retrying in {delayMilliseconds}ms...");
                await Task.Delay(delayMilliseconds * attempt); // Exponential backoff
            }
        }
    }
}