File: CodeGenerator\CSharp\AzureCodeGenerator\AzureAttachModelCodeGenerator.cs
Web Access
Project: src\src\Microsoft.ML.CodeGenerator\Microsoft.ML.CodeGenerator.csproj (Microsoft.ML.CodeGenerator)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using Microsoft.ML.AutoML;
using Microsoft.ML.CodeGenerator.CodeGenerator.CSharp.Interface;
using Microsoft.ML.CodeGenerator.CSharp;
using Microsoft.ML.CodeGenerator.Templates.Azure.Model;
using Microsoft.ML.CodeGenerator.Templates.Console;
using Microsoft.ML.CodeGenerator.Utilities;
 
namespace Microsoft.ML.CodeGenerator.CodeGenerator.CSharp.AzureCodeGenerator
{
    internal class AzureAttachModelCodeGenerator : ICSharpProjectGenerator
    {
        private readonly Pipeline _pipeline;
        private readonly CodeGeneratorSettings _settings;
        private readonly ColumnInferenceResults _columnInferenceResult;
        private readonly string _nameSpaceValue;
 
        public ICSharpFile ModelInputClass { get; private set; }
        public ICSharpFile AzureImageModelOutputClass { get; private set; }
        public ICSharpFile AzureObjectDetectionModelOutputClass { get; private set; }
        public ICSharpFile ModelProject { get; private set; }
        public ICSharpFile ConsumeModel { get; private set; }
        public ICSharpFile LabelMapping { get; private set; }
        public ICSharpFile ImageLabelMapping { get; private set; }
        public ICSharpFile ObjectDetectionConsumeModel { get; private set; }
        public string Name { get; set; }
 
        public AzureAttachModelCodeGenerator(Pipeline pipeline, ColumnInferenceResults columnInferenceResults, CodeGeneratorSettings options, string namespaceValue)
        {
            _pipeline = pipeline;
            _settings = options;
            _columnInferenceResult = columnInferenceResults;
            _nameSpaceValue = namespaceValue;
            Name = $"{_settings.OutputName}.Model";
 
            ModelInputClass = new CSharpCodeFile()
            {
                File = new ModelInputClass()
                {
                    Namespace = _nameSpaceValue,
                    ClassLabels = Utilities.Utils.GenerateClassLabels(_columnInferenceResult, _settings.OnnxInputMapping),
                    Target = _settings.Target
                }.TransformText(),
                Name = "ModelInput.cs",
            };
 
            var labelType = _columnInferenceResult.TextLoaderOptions.Columns.Where(t => t.Name == _settings.LabelName).First().DataKind;
            Type labelTypeCsharp = Utils.GetCSharpType(labelType);
 
            AzureImageModelOutputClass = new CSharpCodeFile()
            {
                File = new AzureImageModelOutputClass()
                {
                    Namespace = _nameSpaceValue,
                    Target = _settings.Target,
                    Labels = _settings.ClassificationLabel,
                }.TransformText(),
                Name = "ModelOutput.cs",
            };
 
            AzureObjectDetectionModelOutputClass = new CSharpCodeFile()
            {
                File = new AzureObjectDetectionModelOutputClass()
                {
                    Namespace = _nameSpaceValue,
                    Target = _settings.Target,
                    Labels = _settings.ObjectLabel,
                }.TransformText(),
                Name = "ModelOutput.cs",
            };
 
            ModelProject = new CSharpProjectFile()
            {
                File = new ModelProject()
                {
                    IncludeFastTreePackage = false,
                    IncludeImageClassificationPackage = false,
                    IncludeImageTransformerPackage = _settings.IsImage,
                    IncludeLightGBMPackage = false,
                    IncludeMklComponentsPackage = false,
                    IncludeOnnxModel = true,
                    IncludeOnnxRuntime = _settings.IsObjectDetection,
                    IncludeRecommenderPackage = false,
                    StablePackageVersion = _settings.StablePackageVersion,
                    UnstablePackageVersion = _settings.UnstablePackageVersion,
                    OnnxRuntimePackageVersion = _settings.OnnxRuntimePackageVersion,
                    Target = _settings.Target,
                }.TransformText(),
                Name = $"{_settings.OutputName}.Model.csproj",
            };
 
            ConsumeModel = new CSharpCodeFile()
            {
                File = new ConsumeModel()
                {
                    Namespace = _nameSpaceValue,
                    Target = _settings.Target,
                    MLNetModelName = _settings.ModelName,
                    OnnxModelName = _settings.OnnxModelName,
                    IsAzureImage = _settings.IsAzureAttach && _settings.IsImage,
                    IsAzureObjectDetection = _settings.IsObjectDetection && _settings.IsAzureAttach,
                }.TransformText(),
                Name = "ConsumeModel.cs",
            };
        }
 
        public ICSharpProject ToProject()
        {
            CSharpProject project;
            if (_settings.IsImage)
            {
                project = new CSharpProject()
                {
                    ModelInputClass,
                    AzureImageModelOutputClass,
                    ConsumeModel,
                    ModelProject,
                };
            }
            else if (_settings.IsObjectDetection)
            {
                project = new CSharpProject()
                {
                    ModelInputClass,
                    AzureObjectDetectionModelOutputClass,
                    ConsumeModel,
                    ModelProject,
                };
            }
            else
            {
                project = new CSharpProject()
                {
                    ModelInputClass,
                    AzureImageModelOutputClass,
                    ConsumeModel,
                    ModelProject,
                };
            }
            project.Name = Name;
            return project;
        }
    }
}