File: Commands\BaseCommand.cs
Web Access
Project: src\src\Tools\Microsoft.dotnet-openapi\src\Microsoft.dotnet-openapi.csproj (dotnet-openapi)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Diagnostics;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Reflection;
using System.Security.Cryptography;
using System.Text.Json;
using Microsoft.Build.Evaluation;
using Microsoft.DotNet.Openapi.Tools;
using Microsoft.DotNet.Openapi.Tools.Internal;
using Microsoft.Extensions.CommandLineUtils;
 
namespace Microsoft.DotNet.OpenApi.Commands;
 
internal abstract class BaseCommand : CommandLineApplication
{
    protected string WorkingDirectory;
 
    protected readonly IHttpClientWrapper _httpClient;
 
    public const string OpenApiReference = "OpenApiReference";
    public const string OpenApiProjectReference = "OpenApiProjectReference";
    protected const string SourceUrlAttrName = "SourceUrl";
 
    public const string ContentDispositionHeaderName = "Content-Disposition";
    private const string CodeGeneratorAttrName = "CodeGenerator";
    private const string DefaultExtension = ".json";
 
    internal const string PackageVersionUrl = "https://go.microsoft.com/fwlink/?linkid=2099561";
 
    public BaseCommand(CommandLineApplication parent, string name, IHttpClientWrapper httpClient)
    {
        Parent = parent;
        Name = name;
        Out = parent.Out ?? Out;
        Error = parent.Error ?? Error;
        _httpClient = httpClient;
 
        ProjectFileOption = Option("-p|--updateProject", "The project file update.", CommandOptionType.SingleValue);
 
        if (Parent is Application application)
        {
            WorkingDirectory = application.WorkingDirectory;
        }
        else
        {
            WorkingDirectory = ((Application)Parent.Parent).WorkingDirectory;
        }
 
        OnExecute(ExecuteAsync);
    }
 
    public CommandOption ProjectFileOption { get; }
 
    public TextWriter Warning
    {
        get { return Out; }
    }
 
    protected abstract Task<int> ExecuteCoreAsync();
 
    protected abstract bool ValidateArguments();
 
    private async Task<int> ExecuteAsync()
    {
        if (GetApplication().Help.HasValue())
        {
            ShowHelp();
            return 0;
        }
 
        if (!ValidateArguments())
        {
            ShowHelp();
            return 1;
        }
 
        return await ExecuteCoreAsync();
    }
 
    private Application GetApplication()
    {
        var parent = Parent;
        while (parent is not Application)
        {
            parent = parent.Parent;
        }
 
        return (Application)parent;
    }
 
    internal FileInfo ResolveProjectFile(CommandOption projectOption)
    {
        string project;
        if (projectOption.HasValue())
        {
            project = projectOption.Value();
            project = GetFullPath(project);
            if (!File.Exists(project))
            {
                throw new ArgumentException($"The project '{project}' does not exist.");
            }
        }
        else
        {
            var projects = Directory.GetFiles(WorkingDirectory, "*.csproj", SearchOption.TopDirectoryOnly);
            if (projects.Length == 0)
            {
                throw new ArgumentException("No project files were found in the current directory. Either move to a new directory or provide the project explicitly");
            }
            if (projects.Length > 1)
            {
                throw new ArgumentException("More than one project was found in this directory, either remove a duplicate or explicitly provide the project.");
            }
 
            project = projects[0];
        }
 
        return new FileInfo(project);
    }
 
    protected static Project LoadProject(FileInfo projectFile)
    {
        var project = ProjectCollection.GlobalProjectCollection.LoadProject(
            projectFile.FullName,
            globalProperties: null,
            toolsVersion: null);
        project.ReevaluateIfNecessary();
        return project;
    }
 
    internal static bool IsProjectFile(string file)
    {
        return File.Exists(Path.GetFullPath(file)) && file.EndsWith(".csproj", StringComparison.Ordinal);
    }
 
    internal static bool IsUrl(string file)
    {
        return Uri.TryCreate(file, UriKind.Absolute, out var _) && file.StartsWith("http", StringComparison.Ordinal);
    }
 
    internal async Task AddOpenAPIReference(
        string tagName,
        FileInfo projectFile,
        string sourceFile,
        CodeGenerator? codeGenerator,
        string sourceUrl = null)
    {
        // EnsurePackagesInProjectAsync MUST happen before LoadProject, because otherwise the global state set by ProjectCollection doesn't pick up the nuget edits, and we end up losing them.
        await EnsurePackagesInProjectAsync(projectFile, codeGenerator);
        var project = LoadProject(projectFile);
        var items = project.GetItems(tagName);
        var fileItems = items.Where(i => string.Equals(GetFullPath(i.EvaluatedInclude), GetFullPath(sourceFile), StringComparison.Ordinal));
 
        if (fileItems.Any())
        {
            Warning.Write($"One or more references to {sourceFile} already exist in '{project.FullPath}'. Duplicate references could lead to unexpected behavior.");
            return;
        }
 
        if (sourceUrl != null)
        {
            if (items.Any(
                i => string.Equals(i.GetMetadataValue(SourceUrlAttrName), sourceUrl, StringComparison.Ordinal)))
            {
                Warning.Write($"A reference to '{sourceUrl}' already exists in '{project.FullPath}'.");
                return;
            }
        }
 
        var metadata = new Dictionary<string, string>();
 
        if (!string.IsNullOrEmpty(sourceUrl))
        {
            metadata[SourceUrlAttrName] = sourceUrl;
        }
 
        if (codeGenerator != null)
        {
            metadata[CodeGeneratorAttrName] = codeGenerator.ToString();
        }
 
        project.AddElementWithAttributes(tagName, sourceFile, metadata);
        project.Save();
    }
 
    private async Task EnsurePackagesInProjectAsync(FileInfo projectFile, CodeGenerator? codeGenerator)
    {
        var urlPackages = await LoadPackageVersionsFromURLAsync();
        var attributePackages = GetServicePackages(codeGenerator);
 
        foreach (var kvp in attributePackages)
        {
            var packageId = kvp.Key;
            var version = urlPackages != null && urlPackages.TryGetValue(packageId, out var urlPackageVersion) ? urlPackageVersion : kvp.Value;
 
            await TryAddPackage(packageId, version, projectFile);
        }
    }
 
    private async Task TryAddPackage(string packageId, string packageVersion, FileInfo projectFile)
    {
        var args = new[] {
                "add",
                "package",
                packageId,
                "--version",
                packageVersion,
                "--no-restore"
            };
 
        var muxer = DotNetMuxer.MuxerPathOrDefault();
        if (string.IsNullOrEmpty(muxer))
        {
            throw new ArgumentException("dotnet was not found on the path.");
        }
 
        var startInfo = new ProcessStartInfo
        {
            FileName = muxer,
            Arguments = string.Join(" ", args),
            WorkingDirectory = projectFile.Directory.FullName,
            RedirectStandardError = true,
            RedirectStandardOutput = true,
        };
 
        using var process = Process.Start(startInfo);
 
        var timeout = 20;
        if (!process.WaitForExit(timeout * 1000))
        {
            throw new ArgumentException($"Adding package `{packageId}` to `{projectFile.Directory}` took longer than {timeout} seconds.");
        }
 
        if (process.ExitCode != 0)
        {
            using var csprojStream = projectFile.OpenRead();
            using var csprojReader = new StreamReader(csprojStream);
            var csprojContent = await csprojReader.ReadToEndAsync();
            // We suspect that sometimes dotnet add package is giving a non-zero exit code when it has actually succeeded.
            if (!csprojContent.Contains($"<PackageReference Include=\"{packageId}\" Version=\"{packageVersion}\""))
            {
                var output = await process.StandardOutput.ReadToEndAsync();
                var error = await process.StandardError.ReadToEndAsync();
                await Out.WriteAsync(output);
                await Error.WriteAsync(error);
 
                throw new ArgumentException($"Adding package `{packageId}` to `{projectFile.Directory}` returned ExitCode `{process.ExitCode}` and gave error `{error}` and output `{output}`");
            }
        }
    }
 
    internal async Task DownloadToFileAsync(string url, string destinationPath, bool overwrite)
    {
        using var response = await RetryRequest(() => _httpClient.GetResponseAsync(url));
        await WriteToFileAsync(await response.Stream, destinationPath, overwrite);
    }
 
    internal async Task<string> DownloadGivenOption(string url, CommandOption fileOption)
    {
        using var response = await RetryRequest(() => _httpClient.GetResponseAsync(url));
 
        if (response.IsSuccessCode())
        {
            string destinationPath;
            if (fileOption.HasValue())
            {
                destinationPath = fileOption.Value();
            }
            else
            {
                var fileName = GetFileNameFromResponse(response, url);
                var fullPath = GetFullPath(fileName);
                var directory = Path.GetDirectoryName(fullPath);
                destinationPath = GetUniqueFileName(directory, Path.GetFileNameWithoutExtension(fileName), Path.GetExtension(fileName));
            }
            await WriteToFileAsync(await response.Stream, GetFullPath(destinationPath), overwrite: false);
 
            return destinationPath;
        }
        else
        {
            throw new ArgumentException($"The given url returned '{response.StatusCode}', indicating failure. The url might be wrong, or there might be a networking issue.");
        }
    }
 
    /// <summary>
    /// Retries every 1 sec for 60 times by default.
    /// </summary>
    /// <param name="retryBlock"></param>
    /// <param name="logger"></param>
    /// <param name="cancellationToken"></param>
    /// <param name="retryCount"></param>
    private static async Task<IHttpResponseMessageWrapper> RetryRequest(
        Func<Task<IHttpResponseMessageWrapper>> retryBlock,
        int retryCount = 60,
        CancellationToken cancellationToken = default)
    {
        for (var retry = 0; retry < retryCount; retry++)
        {
            if (cancellationToken.IsCancellationRequested)
            {
                throw new OperationCanceledException("Failed to connect, retry canceled.", cancellationToken);
            }
 
            try
            {
                var response = await retryBlock().ConfigureAwait(false);
 
                if (response.StatusCode == HttpStatusCode.ServiceUnavailable)
                {
                    // Automatically retry on 503. May be application is still booting.
                    continue;
                }
 
                return response; // Went through successfully
            }
            catch (Exception exception)
            {
                if (retry == retryCount - 1)
                {
                    throw;
                }
                else
                {
                    if (exception is HttpRequestException || exception is WebException)
                    {
                        await Task.Delay(1 * 1000, cancellationToken); // Wait for a while before retry.
                    }
                }
            }
        }
 
        throw new OperationCanceledException("Failed to connect, retry limit exceeded.");
    }
 
    private static string GetUniqueFileName(string directory, string fileName, string extension)
    {
        var uniqueName = fileName;
 
        var filePath = Path.Combine(directory, fileName + extension);
        var exists = true;
        var count = 0;
 
        do
        {
            if (!File.Exists(filePath))
            {
                exists = false;
            }
            else
            {
                count++;
                uniqueName = fileName + count;
                filePath = Path.Combine(directory, uniqueName + extension);
            }
        }
        while (exists);
 
        return uniqueName + extension;
    }
 
    private static string GetFileNameFromResponse(IHttpResponseMessageWrapper response, string url)
    {
        var contentDisposition = response.ContentDisposition();
        string result;
        if (contentDisposition != null && contentDisposition.FileName != null)
        {
            var fileName = Path.GetFileName(contentDisposition.FileName);
            if (!Path.HasExtension(fileName))
            {
                fileName += DefaultExtension;
            }
 
            result = fileName;
        }
        else
        {
            var uri = new Uri(url);
            if (uri.Segments.Any() && uri.Segments.Last() != "/")
            {
                var lastSegment = uri.Segments.Last();
                if (!Path.HasExtension(lastSegment))
                {
                    lastSegment += DefaultExtension;
                }
 
                result = lastSegment;
            }
            else
            {
                var parts = uri.Host.Split('.');
                var domain = parts.Length switch
                {
                    1 or 2 => parts.First(), // It's localhost or somewhere in an Intranet if 1; no www if 2.
                    3 => parts[1],           // Grab XYZ in www.XYZ.domain.com or similar.
                    _ => throw new NotImplementedException("We don't handle the case that the Host has more than three segments"),
                };
 
                result = domain + DefaultExtension;
            }
        }
 
        return result;
    }
 
    internal static CodeGenerator? GetCodeGenerator(CommandOption codeGeneratorOption)
    {
        CodeGenerator? codeGenerator;
        if (codeGeneratorOption.HasValue())
        {
            codeGenerator = Enum.Parse<CodeGenerator>(codeGeneratorOption.Value());
        }
        else
        {
            codeGenerator = null;
        }
 
        return codeGenerator;
    }
 
    internal static void ValidateCodeGenerator(CommandOption codeGeneratorOption)
    {
        if (codeGeneratorOption.HasValue())
        {
            var value = codeGeneratorOption.Value();
            if (!Enum.TryParse(value, out CodeGenerator _))
            {
                throw new ArgumentException($"Invalid value '{value}' given as code generator.");
            }
        }
    }
 
    internal string GetFullPath(string path)
    {
        return Path.IsPathFullyQualified(path)
            ? path
            : Path.GetFullPath(path, WorkingDirectory);
    }
 
    private async Task<IDictionary<string, string>> LoadPackageVersionsFromURLAsync()
    {
        /* Example Json content
         {
          "Version" : "1.0",
          "Packages"  :  {
            "Microsoft.Azure.SignalR": "1.1.0-preview1-10442",
            "Grpc.AspNetCore.Server": "0.1.22-pre2",
            "Grpc.Net.ClientFactory": "0.1.22-pre2",
            "Google.Protobuf": "3.8.0",
            "Grpc.Tools": "1.22.0",
            "NSwag.ApiDescription.Client": "13.0.3",
            "Microsoft.Extensions.ApiDescription.Client": "0.3.0-preview7.19365.7",
            "Newtonsoft.Json": "12.0.2"
          }
        }*/
        try
        {
            using var packageVersionStream = await (await _httpClient.GetResponseAsync(PackageVersionUrl)).Stream;
            using var packageVersionDocument = await JsonDocument.ParseAsync(packageVersionStream);
            var packageVersionsElement = packageVersionDocument.RootElement.GetProperty("Packages");
            var packageVersionsDictionary = new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase);
 
            foreach (var packageVersion in packageVersionsElement.EnumerateObject())
            {
                packageVersionsDictionary[packageVersion.Name] = packageVersion.Value.GetString();
            }
 
            return packageVersionsDictionary;
        }
        catch
        {
            // TODO (johluo): Consider logging a message indicating what went wrong and actions, if any, to be taken to resolve possible issues.
            // Currently not logging anything since the fwlink is not published yet.
            return null;
        }
    }
 
    private static IDictionary<string, string> GetServicePackages(CodeGenerator? type)
    {
        var generator = type ?? CodeGenerator.NSwagCSharp;
        var name = Enum.GetName(typeof(CodeGenerator), generator);
        var attributes = typeof(Program).Assembly.GetCustomAttributes<OpenApiDependencyAttribute>();
 
        var packages = attributes.Where(a => a.CodeGenerators.Contains(generator));
        var result = new Dictionary<string, string>();
        if (packages != null)
        {
            foreach (var package in packages)
            {
                result[package.Name] = package.Version;
            }
        }
 
        return result;
    }
 
    private static byte[] GetHash(Stream stream)
    {
        using var algorithm = SHA256.Create();
        return algorithm.ComputeHash(stream);
    }
 
    private async Task WriteToFileAsync(Stream content, string destinationPath, bool overwrite)
    {
        if (content.CanSeek)
        {
            content.Seek(0, SeekOrigin.Begin);
        }
 
        destinationPath = GetFullPath(destinationPath);
        var destinationExists = File.Exists(destinationPath);
        if (destinationExists && !overwrite)
        {
            throw new ArgumentException($"File '{destinationPath}' already exists. Aborting to avoid conflicts. Provide the '--output-file' argument with an unused file to resolve.");
        }
 
        await Out.WriteLineAsync($"Downloading to '{destinationPath}'.");
        var reachedCopy = false;
        try
        {
            if (destinationExists)
            {
                // Check hashes before using the downloaded information.
                var downloadHash = GetHash(content);
 
                byte[] destinationHash;
                using (var destinationStream = File.OpenRead(destinationPath))
                {
                    destinationHash = GetHash(destinationStream);
                }
 
                var sameHashes = downloadHash.Length == destinationHash.Length;
                for (var i = 0; sameHashes && i < downloadHash.Length; i++)
                {
                    sameHashes = downloadHash[i] == destinationHash[i];
                }
 
                if (sameHashes)
                {
                    await Out.WriteLineAsync($"Not overwriting existing and matching file '{destinationPath}'.");
                    return;
                }
            }
            else
            {
                // May need to create directory to hold the file.
                var destinationDirectory = Path.GetDirectoryName(destinationPath);
                if (!string.IsNullOrEmpty(destinationDirectory) && !Directory.Exists(destinationDirectory))
                {
                    Directory.CreateDirectory(destinationDirectory);
                }
            }
 
            // Create or overwrite the destination file.
            reachedCopy = true;
            using var fileStream = new FileStream(destinationPath, FileMode.Create, FileAccess.Write);
            fileStream.Seek(0, SeekOrigin.Begin);
            if (content.CanSeek)
            {
                content.Seek(0, SeekOrigin.Begin);
            }
 
            await content.CopyToAsync(fileStream);
        }
        catch (Exception ex)
        {
            await Error.WriteLineAsync("Downloading failed.");
            await Error.WriteLineAsync(ex.ToString());
            if (reachedCopy)
            {
                File.Delete(destinationPath);
            }
        }
    }
}