File: Program.cs
Web Access
Project: src\src\sdk\src\Dotnet.Watch\dotnet-watch\dotnet-watch.csproj (dotnet-watch)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.CommandLine;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.Loader;
using Microsoft.Build.Locator;
using Microsoft.DotNet.Cli.Commands.Run;
using Microsoft.DotNet.ProjectTools;
using Microsoft.Extensions.Logging;

namespace Microsoft.DotNet.Watch;

internal sealed class Program(
    IConsole console,
    ILoggerFactory loggerFactory,
    ILogger logger,
    IProcessOutputReporter processOutputReporter,
    ProjectOptions mainProjectOptions,
    CommandLineOptions options,
    EnvironmentOptions environmentOptions)
{
    public const string LogComponentName = nameof(Program);
    private const string LogMessagePrefix = "dotnet watch";

    public static async Task<int> Main(string[] args)
    {
        try
        {
            var sdkRootDirectory = EnvironmentVariables.SdkRootDirectory;

            // We can register the MSBuild that is bundled with the SDK to perform MSBuild things.
            // In production deployment dotnet-watch is in a nested folder of the SDK's root, we'll back up to it.
            // AppContext.BaseDirectory = $sdkRoot\$sdkVersion\DotnetTools\dotnet-watch\$version\tools\net6.0\any\
            // MSBuild.dll is located at $sdkRoot\$sdkVersion\MSBuild.dll
            if (string.IsNullOrEmpty(sdkRootDirectory))
            {
                sdkRootDirectory = Path.Combine(AppContext.BaseDirectory, "..", "..", "..", "..", "..", "..");
            }

            MSBuildLocator.RegisterMSBuildPath(sdkRootDirectory);

            var environmentOptions = EnvironmentOptions.FromEnvironment(sdkRootDirectory, LogMessagePrefix);

            // Register listeners that load Roslyn-related assemblies from the `Roslyn/bincore` directory.
            RegisterAssemblyResolutionEvents(sdkRootDirectory);
            // msbuild tasks depend on host path variable:
            Environment.SetEnvironmentVariable(EnvironmentVariables.Names.DotnetHostPath, environmentOptions.GetMuxerPath());

            var program = TryCreate(
                args,
                new PhysicalConsole(environmentOptions.TestFlags),
                environmentOptions,
                out var exitCode);

            if (program == null)
            {
                return exitCode;
            }

            return await program.RunAsync();
        }
        catch (Exception ex)
        {
            Console.Error.WriteLine("Unexpected error:");
            Console.Error.WriteLine(ex.ToString());
            return 1;
        }
    }

    private static Program? TryCreate(IReadOnlyList<string> args, IConsole console, EnvironmentOptions environmentOptions, out int errorCode)
    {
        var reporter = new ConsoleReporter(console, environmentOptions.LogMessagePrefix, environmentOptions.SuppressEmojis);
        var parsingLoggerFactory = new LoggerFactory(reporter, environmentOptions.CliLogLevel ?? LogLevel.Information);
        var options = CommandLineOptions.Parse(args, parsingLoggerFactory.CreateLogger(DotNetWatchContext.DefaultLogComponentName), console.Out, out errorCode);
        if (options == null)
        {
            // an error reported or help printed:
            return null;
        }

        var loggerFactory = new LoggerFactory(reporter, environmentOptions.CliLogLevel ?? options.GlobalOptions.LogLevel);
        return TryCreate(options, console, environmentOptions, loggerFactory, reporter, out errorCode);
    }

    // internal for testing
    internal static Program? TryCreate(CommandLineOptions options, IConsole console, EnvironmentOptions environmentOptions, ILoggerFactory loggerFactory, IProcessOutputReporter processOutputReporter, out int errorCode)
    {
        var logger = loggerFactory.CreateLogger(DotNetWatchContext.DefaultLogComponentName);

        var workingDirectory = environmentOptions.WorkingDirectory;
        logger.LogDebug("Working directory: '{Directory}'", workingDirectory);

        if (environmentOptions.TestFlags != TestFlags.None)
        {
            logger.LogDebug("Test flags: {Flags}", environmentOptions.TestFlags);
        }

        var mainProjectOptions = GetMainProjectOptions(options, workingDirectory, logger);
        if (mainProjectOptions == null)
        {
            errorCode = 1;
            return null;
        }

        errorCode = 0;
        return new Program(console, loggerFactory, logger, processOutputReporter, mainProjectOptions, options, environmentOptions);
    }

    // internal for testing
    internal static ProjectOptions? GetMainProjectOptions(CommandLineOptions options, string workingDirectory, ILogger logger)
    {
        ProjectRepresentation project;

        if (options.FilePath != null)
        {
            try
            {
                project = new ProjectRepresentation(projectPath: null, entryPointFilePath: Path.GetFullPath(Path.Combine(workingDirectory, options.FilePath)));
            }
            catch (Exception e)
            {
                logger.LogError(Resources.The_specified_path_0_is_invalid_1, options.FilePath, e.Message);
                return null;
            }
        }
        else if (TryFindProject(workingDirectory, options, logger, out var projectPath) is bool foundProject)
        {
            if (!foundProject)
            {
                // error already reported
                return null;
            }

            project = new ProjectRepresentation(projectPath, entryPointFilePath: null);
        }
        else if (TryFindFileEntryPoint(workingDirectory, options, logger, out var entryPointFilePath))
        {
            project = new ProjectRepresentation(projectPath: null, entryPointFilePath);
        }
        else
        {
            logger.LogError(Resources.Could_not_find_msbuild_project_file_in_0, projectPath);
            return null;
        }

        return options.GetMainProjectOptions(project, workingDirectory);
    }

    private static bool TryFindFileEntryPoint(string workingDirectory, CommandLineOptions options, ILogger logger, [NotNullWhen(true)] out string? entryPointPath)
    {
        if (options.Command is not RunCommandDefinition runCommandDefinition)
        {
            entryPointPath = null;
            return false;
        }

        var runParseResult = runCommandDefinition.Parse(options.CommandArgumentsWithoutBinLog, CommandLineOptions.ParserConfiguration);
        if (runParseResult.GetValue(runCommandDefinition.ApplicationArguments) is not [var firstArg, ..])
        {
            entryPointPath = null;
            return false;
        }

        try
        {
            entryPointPath = Path.GetFullPath(Path.Combine(workingDirectory, firstArg));
        }
        catch
        {
            entryPointPath = null;
            return false;
        }

        return VirtualProjectBuilder.IsValidEntryPointPath(entryPointPath);
    }

    /// <summary>
    /// Finds a compatible MSBuild project.
    /// <param name="workingDirectory">The base directory to search</param>
    /// </summary>
    private static bool? TryFindProject(string workingDirectory, CommandLineOptions options, ILogger logger, out string? projectPath)
    {
        projectPath = options.ProjectPath ?? workingDirectory;

        try
        {
            projectPath = Path.GetFullPath(Path.Combine(workingDirectory, projectPath));
        }
        catch (Exception e)
        {
            logger.LogError(Resources.The_specified_path_0_is_invalid_1, projectPath, e.Message);
            return false;
        }

        if (Directory.Exists(projectPath))
        {
            List<string> projects;
            try
            {
                projects = [.. Directory.GetFiles(projectPath, "*.*proj")
                    .Where(p => !PathUtilities.OSSpecificPathComparer.Equals(Path.GetExtension(p), ".shproj"))];
            }
            catch (Exception e)
            {
                logger.LogError(Resources.The_specified_path_0_is_invalid_1, projectPath, e.Message);
                return false;
            }

            if (projects.Count > 1)
            {
                logger.LogError(Resources.Error_MultipleProjectsFound, projectPath);
                return false;
            }

            if (projects.Count == 0)
            {
                if (options.ProjectPath != null)
                {
                    logger.LogError(Resources.Could_not_find_msbuild_project_file_in_0, projectPath);
                    return false;
                }

                return null;
            }

            projectPath = projects[0];
            return true;
        }

        Debug.Assert(options.ProjectPath != null);

        if (!File.Exists(projectPath))
        {
            logger.LogError(Resources.Error_ProjectPath_NotFound, projectPath);
            return false;
        }

        return true;
    }

    // internal for testing
    internal async Task<int> RunAsync()
    {
        var processRunner = new ProcessRunner(environmentOptions.GetProcessCleanupTimeout());

        using var shutdownHandler = new ShutdownHandler(console, logger);

        try
        {
            if (shutdownHandler.CancellationToken.IsCancellationRequested)
            {
                return 1;
            }

            if (options.List)
            {
                if (mainProjectOptions.Representation.EntryPointFilePath != null)
                {
                    logger.LogError("--list does not support file-based programs");
                    return 1;
                }

                return await ListFilesAsync(processRunner, shutdownHandler.CancellationToken);
            }

            if (environmentOptions.IsPollingEnabled)
            {
                logger.LogInformation("Polling file watcher is enabled");
            }

            using var context = CreateContext(processRunner);

            if (IsHotReloadEnabled())
            {
                using var selectionPrompt = context.Options.NonInteractive ? null : new SpectreBuildParametersSelectionPrompt(console);
                var watcher = new HotReloadDotNetWatcher(context, console, runtimeProcessLauncherFactory: null, selectionPrompt);
                await watcher.WatchAsync(shutdownHandler.CancellationToken);
            }
            else if (mainProjectOptions.Representation.EntryPointFilePath != null)
            {
                logger.LogError("File-based programs are only supported when Hot Reload is enabled");
                return 1;
            }
            else
            {
                await DotNetWatcher.WatchAsync(context, shutdownHandler.CancellationToken);
            }

            return 0;
        }
        catch (OperationCanceledException) when (shutdownHandler.CancellationToken.IsCancellationRequested)
        {
            // Ctrl+C forced an exit
            return 0;
        }
        catch (Exception e)
        {
            logger.LogError("An unexpected error occurred: {Exception}", e.ToString());
            return 1;
        }
    }

    // internal for testing
    internal DotNetWatchContext CreateContext(ProcessRunner processRunner)
    {
        var logger = loggerFactory.CreateLogger(DotNetWatchContext.DefaultLogComponentName);

        return new()
        {
            ProcessOutputReporter = processOutputReporter,
            LoggerFactory = loggerFactory,
            Logger = logger,
            BuildLogger = loggerFactory.CreateLogger(DotNetWatchContext.BuildLogComponentName),
            ProcessRunner = processRunner,
            Options = options.GlobalOptions,
            EnvironmentOptions = environmentOptions,
            MainProjectOptions = mainProjectOptions,
            RootProjects = [mainProjectOptions.Representation],
            BuildArguments = options.BuildArguments,
            BrowserRefreshServerFactory = new BrowserRefreshServerFactory(),
            BrowserLauncher = new BrowserLauncher(logger, processOutputReporter, environmentOptions),
        };
    }

    private bool IsHotReloadEnabled()
    {
        if (mainProjectOptions.Command != "run")
        {
            logger.Log(MessageDescriptor.CommandDoesNotSupportHotReload, mainProjectOptions.Command);
            return false;
        }

        if (options.GlobalOptions.NoHotReload)
        {
            logger.Log(MessageDescriptor.HotReloadDisabledByCommandLineSwitch);
            return false;
        }

        logger.Log(MessageDescriptor.WatchingWithHotReload);
        return true;
    }

    private async Task<int> ListFilesAsync(ProcessRunner processRunner, CancellationToken cancellationToken)
    {
        // file-based programs are not supported with --list
        Debug.Assert(mainProjectOptions.Representation.PhysicalPath != null);

        var buildLogger = loggerFactory.CreateLogger(DotNetWatchContext.BuildLogComponentName);

        var fileSetFactory = new MSBuildFileSetFactory(
            mainProjectOptions.Representation.PhysicalPath,
            options.TargetFramework,
            options.BuildArguments,
            processRunner,
            buildLogger,
            options.GlobalOptions,
            environmentOptions);

        if (await fileSetFactory.TryCreateAsync(requireProjectGraph: null, cancellationToken) is not { } evaluationResult)
        {
            return 1;
        }

        foreach (var (filePath, _) in evaluationResult.Files.OrderBy(e => e.Key))
        {
            console.Out.WriteLine(filePath);
        }

        return 0;
    }

    private static void RegisterAssemblyResolutionEvents(string sdkRootDirectory)
    {
        var roslynPath = Path.Combine(sdkRootDirectory, "Roslyn", "bincore");

        AssemblyLoadContext.Default.Resolving += (context, assembly) =>
        {
            if (assembly.Name is "Microsoft.CodeAnalysis" or "Microsoft.CodeAnalysis.CSharp")
            {
                var loadedAssembly = context.LoadFromAssemblyPath(Path.Combine(roslynPath, assembly.Name + ".dll"));
                // Avoid scenarios where the assembly in rosylnPath is older than what we expect
                if (loadedAssembly.GetName().Version < assembly.Version)
                {
                    throw new Exception($"Found a version of {assembly.Name} that was lower than the target version of {assembly.Version}");
                }
                return loadedAssembly;
            }
            return null;
        };
    }
}