﻿// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Reflection.Metadata;
using System.Formats.Nrbf.Utils;
using System.Diagnostics;

namespace System.Formats.Nrbf;

// This library tries to minimize the number of concepts the users need to learn to use it.
// Since SZArrays are most common, it provides an SZArrayRecord<T> abstraction.
// Every other array (jagged, multi-dimensional etc) is represented using ArrayRecord.
// The goal of this class is to let the users use SZArrayRecord<SerializationRecord> abstraction.
internal sealed class SZArrayOfRecords : SZArrayRecord<SerializationRecord>
{
    private TypeName? _typeName;

    internal SZArrayOfRecords(ArrayInfo arrayInfo, MemberTypeInfo memberTypeInfo)
        : base(arrayInfo)
    {
        MemberTypeInfo = memberTypeInfo;
        Records = [];
    }

    public override SerializationRecordType RecordType => SerializationRecordType.BinaryArray;

    internal List<SerializationRecord> Records { get; }

    private MemberTypeInfo MemberTypeInfo { get; }

    public override TypeName TypeName
        => _typeName ??= MemberTypeInfo.GetArrayTypeName(ArrayInfo);

    /// <inheritdoc/>
    public override SerializationRecord?[] GetArray(bool allowNulls = true)
        => (SerializationRecord?[])(allowNulls ? _arrayNullsAllowed ??= ToArray(true) : _arrayNullsNotAllowed ??= ToArray(false));

    private SerializationRecord?[] ToArray(bool allowNulls)
    {
        SerializationRecord?[] result = new SerializationRecord?[Length];

        int resultIndex = 0;
        foreach (SerializationRecord record in Records)
        {
            SerializationRecord actual = record is MemberReferenceRecord referenceRecord
                ? referenceRecord.GetReferencedRecord()
                : record;

            if (actual is not NullsRecord nullsRecord)
            {
                result[resultIndex++] = actual;
            }
            else
            {
                if (!allowNulls)
                {
                    ThrowHelper.ThrowArrayContainedNulls();
                }

                int nullCount = nullsRecord.NullCount;
                Debug.Assert(nullCount > 0, "All implementations of NullsRecord are expected to return a positive value for NullCount.");
                do
                {
                    result[resultIndex++] = null;
                    nullCount--;
                }
                while (nullCount > 0);
            }
        }

        Debug.Assert(resultIndex == result.Length, "We should have traversed the entirety of the newly created array.");

        return result;
    }

    private protected override void AddValue(object value) => Records.Add((SerializationRecord)value);

    internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetAllowedRecordType()
    {
        (AllowedRecordTypes allowed, PrimitiveType primitiveType) = MemberTypeInfo.GetNextAllowedRecordType(0);

        if (allowed != AllowedRecordTypes.None)
        {
            // It's an array, it can also contain multiple nulls
            return (allowed | AllowedRecordTypes.Nulls, primitiveType);
        }

        return (allowed, primitiveType);
    }
}
