File: Util\ProcessorState.cs
Web Access
Project: src\src\sdk\src\TemplateEngine\Microsoft.TemplateEngine.Core\Microsoft.TemplateEngine.Core.csproj (Microsoft.TemplateEngine.Core)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Concurrent;
using System.Text;
using Microsoft.TemplateEngine.Core.Contracts;
using Microsoft.TemplateEngine.Core.Matching;

namespace Microsoft.TemplateEngine.Core.Util
{
    public class ProcessorState : IProcessorState
    {
        private static readonly ConcurrentDictionary<IReadOnlyList<IOperationProvider>, ConcurrentDictionary<Encoding, Trie<OperationTerminal>>> TrieLookup = new();
        private static readonly ConcurrentDictionary<IReadOnlyList<IOperationProvider>, List<string>> OperationsToExplicitlySetOnByDefault = new();
        private readonly StreamProxy _target;
        private readonly TrieEvaluator<OperationTerminal> _trie;
        private readonly int _flushThreshold;
        private readonly int _bomSize;
        private Stream _source;

        public ProcessorState(Stream source, Stream target, int bufferSize, int flushThreshold, IEngineConfig config, IReadOnlyList<IOperationProvider> operationProviders)
        {
            if (source == null)
            {
                throw new ArgumentNullException(nameof(source));
            }

            if (target == null)
            {
                throw new ArgumentNullException(nameof(target));
            }

            if (operationProviders == null)
            {
                throw new ArgumentNullException(nameof(operationProviders));
            }

            if (source.CanSeek)
            {
                try
                {
                    if (source.Length < bufferSize)
                    {
                        bufferSize = (int)source.Length;
                    }
                }
                catch
                {
                    //The stream may not support getting the length property (in NetworkStream for instance, which throw a NotSupportedException), suppress any errors in
                    //  accessing the property and continue with the specified buffer size
                }
            }
            //Buffer has to be at least as large as the largest BOM we could expect
            else if (bufferSize < 4)
            {
                bufferSize = 4;
            }

            _source = source;
            _target = new StreamProxy(target, bufferSize);
            Config = config ?? throw new ArgumentNullException(nameof(config));
            _flushThreshold = flushThreshold;
            CurrentBuffer = new byte[bufferSize];
            CurrentBufferLength = ReadExactBytes(source, CurrentBuffer, 0, CurrentBuffer.Length);

            Encoding encoding = EncodingUtil.Detect(CurrentBuffer, CurrentBufferLength, out byte[] bom);
            EncodingConfig = new EncodingConfig(Config, encoding);
            _bomSize = bom.Length;
            CurrentBufferPosition = _bomSize;
            CurrentSequenceNumber = _bomSize;
            WriteToTarget(bom, 0, _bomSize);

            bool explicitOnConfigurationRequired = false;
            ConcurrentDictionary<Encoding, Trie<OperationTerminal>> byEncoding = TrieLookup.GetOrAdd(operationProviders, x => new());
            List<string> turnOnByDefault = OperationsToExplicitlySetOnByDefault.GetOrAdd(operationProviders, x =>
            {
                explicitOnConfigurationRequired = true;
                return new List<string>();
            });

            if (!byEncoding.TryGetValue(encoding, out Trie<OperationTerminal> trie))
            {
                trie = new Trie<OperationTerminal>();

                for (int i = 0; i < operationProviders.Count; ++i)
                {
                    IOperation op = operationProviders[i].GetOperation(encoding, this);

                    if (op != null)
                    {
                        for (int j = 0; j < op.Tokens.Count; ++j)
                        {
                            if (op.Tokens[j] != null)
                            {
                                trie.AddPath(op.Tokens[j]!.Value, new OperationTerminal(op, j, op.Tokens[j]!.Length, op.Tokens[j]!.Start, op.Tokens[j]!.End));
                            }
                        }

                        if (explicitOnConfigurationRequired && op.IsInitialStateOn && !string.IsNullOrEmpty(op.Id))
                        {
                            turnOnByDefault.Add(op.Id!);
                        }
                    }
                }

                byEncoding.TryAdd(encoding, trie);
            }

            foreach (string state in turnOnByDefault)
            {
                config.Flags[state] = true;
            }

            _trie = new TrieEvaluator<OperationTerminal>(trie);

            if (bufferSize < _trie.MaxLength + 1)
            {
                byte[] tmp = new byte[_trie.MaxLength + 1];
                Buffer.BlockCopy(CurrentBuffer, CurrentBufferPosition, tmp, 0, CurrentBufferLength - CurrentBufferPosition);
                int nRead = ReadExactBytes(_source, tmp, CurrentBufferLength - CurrentBufferPosition, tmp.Length - CurrentBufferLength);
                CurrentBuffer = tmp;
                CurrentBufferLength += nRead - _bomSize;
                CurrentBufferPosition = 0;
                CurrentSequenceNumber = 0;
            }
        }

        public IEngineConfig Config { get; }

        public byte[] CurrentBuffer { get; }

        public int CurrentBufferLength { get; private set; }

        public int CurrentBufferPosition { get; private set; }

        public int CurrentSequenceNumber { get; private set; }

        public IEncodingConfig EncodingConfig { get; }

        public Encoding Encoding => EncodingConfig.Encoding;

        public bool AdvanceBuffer(int bufferPosition)
        {
            if (CurrentBufferLength == 0 || bufferPosition == 0)
            {
                return false;
            }

            //The number of bytes away from the current buffer position being
            //  retargeted to the buffer head
            int netMove = bufferPosition - CurrentBufferPosition;
            //Since the CurrentSequenceNumber and CurrentBufferPosition are
            //  different mappings over the same value, the same net move
            //  applies to the current sequence number
            CurrentSequenceNumber += netMove;
            //Calculate the number of bytes at the end of the buffer that
            //  should be preserved
            int bytesToPreserveInBuffer = CurrentBufferLength - bufferPosition;

            if (CurrentBufferLength < CurrentBuffer.Length && bytesToPreserveInBuffer == 0)
            {
                CurrentBufferLength = 0;
                CurrentBufferPosition = 0;
                return false;
            }

            //If we actually have to preserve any data, shift it to the start
            if (bytesToPreserveInBuffer > 0)
            {
                //Shift the relevant number of bytes back to the head of the buffer
                Buffer.BlockCopy(CurrentBuffer, bufferPosition, CurrentBuffer, 0, bytesToPreserveInBuffer);
            }

            //Fill the remaining spaces in the buffer with new data, save how
            //  many we've read for recalculating the new effective buffer size
            int nRead = ReadExactBytes(_source, CurrentBuffer, bytesToPreserveInBuffer, CurrentBufferLength - bytesToPreserveInBuffer);
            CurrentBufferLength = bytesToPreserveInBuffer + nRead;

            //The new buffer position is set to point at the byte that buffer
            //  position pointed at (which is now at the head of the buffer)
            CurrentBufferPosition = 0;

            return true;
        }

        public bool Run()
        {
            int nextSequenceNumberThatCouldBeWritten = CurrentSequenceNumber;
            int bytesWrittenSinceLastFlush = 0;
            bool anyOperationsExecuted = false;

            while (true)
            {
                //Loop until we run out of data in the buffer
                while (CurrentBufferPosition < CurrentBufferLength)
                {
                    int posedPosition = CurrentSequenceNumber;
                    bool skipAdvanceBuffer = false;
                    if (_trie.Accept(CurrentBuffer[CurrentBufferPosition], ref posedPosition, out TerminalLocation<OperationTerminal>? terminal))
                    {
                        IOperation operation = terminal!.Terminal.Operation;
                        int matchLength = terminal.Terminal.End - terminal.Terminal.Start + 1;
                        int handoffBufferPosition = CurrentBufferPosition + matchLength - (CurrentSequenceNumber - terminal.Location);

                        if (terminal.Location > nextSequenceNumberThatCouldBeWritten)
                        {
                            int toWrite = terminal.Location - nextSequenceNumberThatCouldBeWritten;
                            //Console.WriteLine("UnmatchedBlock");
                            //string text = System.Text.Encoding.UTF8.GetString(CurrentBuffer, handoffBufferPosition - toWrite - matchLength, toWrite).Replace("\0", "\\0");
                            //Console.WriteLine(text);
                            _target.Write(CurrentBuffer, handoffBufferPosition - toWrite - matchLength, toWrite);
                            bytesWrittenSinceLastFlush += toWrite;
                            nextSequenceNumberThatCouldBeWritten = posedPosition - matchLength + 1;
                        }

                        if (operation.Id == null || (Config.Flags.TryGetValue(operation.Id, out bool opEnabledFlag) && opEnabledFlag))
                        {
                            CurrentSequenceNumber += handoffBufferPosition - CurrentBufferPosition;
                            CurrentBufferPosition = handoffBufferPosition;
                            posedPosition = handoffBufferPosition;
                            int bytesWritten = operation.HandleMatch(this, CurrentBufferLength, ref posedPosition, terminal.Terminal.Token);
                            bytesWrittenSinceLastFlush += bytesWritten;

                            CurrentSequenceNumber += posedPosition - CurrentBufferPosition;
                            CurrentBufferPosition = posedPosition;
                            nextSequenceNumberThatCouldBeWritten = CurrentSequenceNumber;
                            skipAdvanceBuffer = true;
                            anyOperationsExecuted = true;
                        }
                        else
                        {
                            int oldSequenceNumber = CurrentSequenceNumber;
                            CurrentSequenceNumber = terminal.Location + terminal.Terminal.End + 1;
                            CurrentBufferPosition += CurrentSequenceNumber - oldSequenceNumber;
                        }

                        if (bytesWrittenSinceLastFlush >= _flushThreshold)
                        {
                            _target.Flush();
                            bytesWrittenSinceLastFlush = 0;
                        }
                    }

                    if (!skipAdvanceBuffer)
                    {
                        ++CurrentSequenceNumber;
                        ++CurrentBufferPosition;
                    }
                }

                //Calculate the sequence number at the head of the buffer
                int headSequenceNumber = CurrentSequenceNumber - CurrentBufferPosition;

                int bufferPositionToAdvanceTo;
                if (headSequenceNumber > _trie.OldestRequiredSequenceNumber)
                {
                    // if headSequenceNumber is higher than _trie.OldestRequiredSequenceNumber
                    // the window is already missed
                    // we won't be able to continue with current tries anyway
                    // advance to new chunk of the buffer.
                    bufferPositionToAdvanceTo = CurrentBufferLength;
                }
                else
                {
                    bufferPositionToAdvanceTo = _trie.OldestRequiredSequenceNumber - headSequenceNumber;
                }
                int numberOfUncommittedBytesBeforeThePositionToAdvanceTo = _trie.OldestRequiredSequenceNumber - nextSequenceNumberThatCouldBeWritten;

                //If we'd advance data out of the buffer that hasn't been
                //  handled already, write it out
                if (numberOfUncommittedBytesBeforeThePositionToAdvanceTo > 0)
                {
                    int toWrite = numberOfUncommittedBytesBeforeThePositionToAdvanceTo;
                    // Console.WriteLine("AdvancePreserve");
                    // Console.WriteLine($"nextSequenceNumberThatCouldBeWritten {nextSequenceNumberThatCouldBeWritten}");
                    // Console.WriteLine($"headSequenceNumber {headSequenceNumber}");
                    // Console.WriteLine($"bufferPositionToAdvanceTo {bufferPositionToAdvanceTo}");
                    // Console.WriteLine($"numberOfUncommittedBytesBeforeThePositionToAdvanceTo {numberOfUncommittedBytesBeforeThePositionToAdvanceTo}");
                    // Console.WriteLine($"CurrentBufferPosition {CurrentBufferPosition}");
                    // Console.WriteLine($"CurrentBufferLength {CurrentBufferLength}");
                    // Console.WriteLine($"CurrentBuffer.Length {CurrentBuffer.Length}");
                    // string text = System.Text.Encoding.UTF8.GetString(CurrentBuffer, bufferPositionToAdvanceTo - toWrite, toWrite).Replace("\0", "\\0");
                    // Console.WriteLine(text);
                    _target.Write(CurrentBuffer, bufferPositionToAdvanceTo - toWrite, toWrite);
                    bytesWrittenSinceLastFlush += toWrite;
                    nextSequenceNumberThatCouldBeWritten = _trie.OldestRequiredSequenceNumber;
                }

                //We ran out of data in the buffer, so attempt to advance
                //  if we fail,
                if (!AdvanceBuffer(bufferPositionToAdvanceTo))
                {
                    int posedPosition = CurrentSequenceNumber;
                    _trie.FinalizeMatchesInProgress(ref posedPosition, out TerminalLocation<OperationTerminal>? terminal);

                    while (terminal != null)
                    {
                        IOperation operation = terminal.Terminal.Operation;
                        int matchLength = terminal.Terminal.End - terminal.Terminal.Start + 1;
                        int handoffBufferPosition = CurrentBufferPosition + matchLength - (CurrentSequenceNumber - terminal.Location);

                        if (terminal.Location > nextSequenceNumberThatCouldBeWritten)
                        {
                            int toWrite = terminal.Location - nextSequenceNumberThatCouldBeWritten;
                            // Console.WriteLine("TailUnmatchedBlock");
                            // string text = System.Text.Encoding.UTF8.GetString(CurrentBuffer, handoffBufferPosition - toWrite - matchLength, toWrite).Replace("\0", "\\0");
                            // Console.WriteLine(text);
                            _target.Write(CurrentBuffer, handoffBufferPosition - toWrite - matchLength, toWrite);
                            bytesWrittenSinceLastFlush += toWrite;
                            nextSequenceNumberThatCouldBeWritten = terminal.Location;
                        }

                        if (operation.Id == null || (Config.Flags.TryGetValue(operation.Id, out bool opEnabledFlag) && opEnabledFlag))
                        {
                            CurrentSequenceNumber += handoffBufferPosition - CurrentBufferPosition;
                            CurrentBufferPosition = handoffBufferPosition;
                            posedPosition = handoffBufferPosition;
                            int bytesWritten = operation.HandleMatch(this, CurrentBufferLength, ref posedPosition, terminal.Terminal.Token);
                            bytesWrittenSinceLastFlush += bytesWritten;

                            CurrentSequenceNumber += posedPosition - CurrentBufferPosition;
                            CurrentBufferPosition = posedPosition;
                            nextSequenceNumberThatCouldBeWritten = CurrentSequenceNumber;
                            anyOperationsExecuted = true;
                        }
                        else
                        {
                            int oldSequenceNumber = CurrentSequenceNumber;
                            CurrentSequenceNumber = terminal.Location + terminal.Terminal.End + 1;
                            CurrentBufferPosition += CurrentSequenceNumber - oldSequenceNumber;
                        }

                        _trie.FinalizeMatchesInProgress(ref posedPosition, out terminal);
                    }

                    break;
                }
            }

            int endSequenceNumber = CurrentSequenceNumber - CurrentBufferPosition + CurrentBufferLength;
            if (endSequenceNumber > nextSequenceNumberThatCouldBeWritten)
            {
                int toWrite = endSequenceNumber - nextSequenceNumberThatCouldBeWritten;
                // Console.WriteLine("LastBlock");
                // string text = System.Text.Encoding.UTF8.GetString(CurrentBuffer, CurrentBufferLength - toWrite, toWrite).Replace("\0", "\\0");
                // Console.WriteLine(text);
                _target.Write(CurrentBuffer, CurrentBufferLength - toWrite, toWrite);
            }

            _target.FlushToTarget();
            return anyOperationsExecuted;
        }

        public void SeekTargetBackUntil(ITokenTrie match, bool consume = false)
        {
            byte[] buffer = new byte[match.MaxLength];
            while (_target.Position > _bomSize)
            {
                if (_target.Position - _bomSize < buffer.Length)
                {
                    _target.Position = _bomSize;
                }
                else
                {
                    _target.Position -= buffer.Length;
                }

                int nRead = ReadExactBytes(_target, buffer, 0, buffer.Length);

                int best = -1;
                int bestPos = -1;
                for (int i = nRead - match.MinLength; i >= 0; --i)
                {
                    int ic = i;
                    if (match.GetOperation(buffer, nRead, ref ic, out int token) && ic >= bestPos)
                    {
                        bestPos = ic;
                        best = token;
                    }
                }

                if (best != -1)
                {
                    _target.Position -= nRead - bestPos + (consume ? match.TokenLength[best] : 0);
                    _target.SetLength(_target.Position);
                    return;
                }

                //Back up the amount we already read to get a new window of data in
                if (_target.Position - _bomSize < buffer.Length)
                {
                    _target.Position = _bomSize;
                }
                else
                {
                    _target.Position -= buffer.Length;
                }
            }

            if (_target.Position == _bomSize)
            {
                _target.SetLength(_bomSize);
            }
        }

        public void SeekTargetBackWhile(ITokenTrie match)
        {
            byte[] buffer = new byte[match.MaxLength];
            while (_target.Position > _bomSize)
            {
                if (_target.Position - _bomSize < buffer.Length)
                {
                    _target.Position = _bomSize;
                }
                else
                {
                    _target.Position -= buffer.Length;
                }

                int nRead = ReadExactBytes(_target, buffer, 0, buffer.Length);
                bool anyMatch = false;
                int token = -1;
                int i = nRead - match.MinLength;

                for (; i >= 0; --i)
                {
                    if (match.GetOperation(buffer, nRead, ref i, out token))
                    {
                        i -= match.TokenLength[token];
                        anyMatch = true;
                        break;
                    }
                }

                if (!anyMatch || (token != -1 && i + match.TokenLength[token] != nRead))
                {
                    _target.SetLength(_target.Position);
                    return;
                }

                //Back up the amount we already read to get a new window of data in
                if (_target.Position - _bomSize < buffer.Length)
                {
                    _target.Position = _bomSize;
                }
                else
                {
                    _target.Position -= buffer.Length;
                }
            }

            if (_target.Position == _bomSize)
            {
                _target.SetLength(_bomSize);
            }
        }

        public void WriteToTarget(byte[] buffer, int offset, int count) => _target.Write(buffer, offset, count);

        public void SeekSourceForwardUntil(ITokenTrie match, ref int bufferLength, ref int currentBufferPosition, bool consumeToken = false)
        {
            while (bufferLength >= match.MinLength)
            {
                //Try to get at least the max length of the tree into the buffer
                if (bufferLength - currentBufferPosition < match.MaxLength)
                {
                    AdvanceBuffer(currentBufferPosition);
                    currentBufferPosition = CurrentBufferPosition;
                    bufferLength = CurrentBufferLength;
                }

                int sz = bufferLength == CurrentBuffer.Length ? match.MaxLength : match.MinLength;

                for (; currentBufferPosition < bufferLength - sz + 1; ++currentBufferPosition)
                {
                    if (bufferLength == 0)
                    {
                        currentBufferPosition = 0;
                        return;
                    }

                    if (match.GetOperation(CurrentBuffer, bufferLength, ref currentBufferPosition, false, out int token))
                    {
                        if (!consumeToken)
                        {
                            currentBufferPosition -= match.Tokens[token].Length;
                        }

                        return;
                    }
                }
            }

            //Ran out of places to check and haven't reached the actual match, consume all the way to the end
            currentBufferPosition = bufferLength;
        }

        public void SeekSourceForwardWhile(ITokenTrie trie, ref int bufferLength, ref int currentBufferPosition)
        {
            while (bufferLength > trie.MinLength)
            {
                while (currentBufferPosition < bufferLength - trie.MinLength + 1)
                {
                    if (bufferLength == 0)
                    {
                        currentBufferPosition = 0;
                        return;
                    }

                    if (!trie.GetOperation(CurrentBuffer, bufferLength, ref currentBufferPosition, out _))
                    {
                        return;
                    }
                }

                AdvanceBuffer(currentBufferPosition);
                currentBufferPosition = CurrentBufferPosition;
                bufferLength = CurrentBufferLength;
            }
        }

        public void Inject(Stream staged)
        {
            _source = new CombinedStream(staged, _source, inner => _source = inner);
            CurrentBufferLength = ReadExactBytes(_source, CurrentBuffer, 0, CurrentBufferLength);
            CurrentBufferPosition = 0;
        }

        private int ReadExactBytes(Stream stream, byte[] buffer, int offset, int count)
        {
            if (count + offset > buffer.Length)
            {
                //cannot read more than available buffer length
                count = buffer.Length - offset;
            }
            int totalRead = 0;
            while (totalRead < count)
            {
                int bytesRead = stream.Read(buffer, totalRead + offset, count - totalRead);
                if (bytesRead == 0)
                {
                    return totalRead;
                }
                totalRead += bytesRead;
            }
            return totalRead;
        }
    }
}