mirror of
https://github.com/chylex/.NET-Community-Toolkit.git
synced 2024-11-24 07:42:45 +01:00
388 lines
14 KiB
C#
388 lines
14 KiB
C#
// Licensed to the .NET Foundation under one or more agreements.
|
|
// The .NET Foundation licenses this file to you under the MIT license.
|
|
// See the LICENSE file in the project root for more information.
|
|
|
|
using System;
|
|
using System.Numerics;
|
|
using System.Runtime.CompilerServices;
|
|
|
|
namespace CommunityToolkit.HighPerformance.Helpers.Internals;
|
|
|
|
/// <summary>
|
|
/// Helpers to process sequences of values by reference.
|
|
/// </summary>
|
|
internal static partial class SpanHelper
|
|
{
|
|
/// <summary>
|
|
/// Counts the number of occurrences of a given value into a target search space.
|
|
/// </summary>
|
|
/// <param name="r0">A <typeparamref name="T"/> reference to the start of the search space.</param>
|
|
/// <param name="length">The number of items in the search space.</param>
|
|
/// <param name="value">The <typeparamref name="T"/> value to look for.</param>
|
|
/// <typeparam name="T">The type of value to look for.</typeparam>
|
|
/// <returns>The number of occurrences of <paramref name="value"/> in the search space</returns>
|
|
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
|
public static nint Count<T>(ref T r0, nint length, T value)
|
|
where T : IEquatable<T>
|
|
{
|
|
if (!Vector.IsHardwareAccelerated)
|
|
{
|
|
return CountSequential(ref r0, length, value);
|
|
}
|
|
|
|
// Special vectorized version when using a supported type
|
|
if (typeof(T) == typeof(byte) ||
|
|
typeof(T) == typeof(sbyte) ||
|
|
typeof(T) == typeof(bool))
|
|
{
|
|
ref sbyte r1 = ref Unsafe.As<T, sbyte>(ref r0);
|
|
sbyte target = Unsafe.As<T, sbyte>(ref value);
|
|
|
|
return CountSimd(ref r1, length, target);
|
|
}
|
|
|
|
if (typeof(T) == typeof(char) ||
|
|
typeof(T) == typeof(ushort) ||
|
|
typeof(T) == typeof(short))
|
|
{
|
|
ref short r1 = ref Unsafe.As<T, short>(ref r0);
|
|
short target = Unsafe.As<T, short>(ref value);
|
|
|
|
return CountSimd(ref r1, length, target);
|
|
}
|
|
|
|
if (typeof(T) == typeof(int) ||
|
|
typeof(T) == typeof(uint))
|
|
{
|
|
ref int r1 = ref Unsafe.As<T, int>(ref r0);
|
|
int target = Unsafe.As<T, int>(ref value);
|
|
|
|
return CountSimd(ref r1, length, target);
|
|
}
|
|
|
|
if (typeof(T) == typeof(long) ||
|
|
typeof(T) == typeof(ulong))
|
|
{
|
|
ref long r1 = ref Unsafe.As<T, long>(ref r0);
|
|
long target = Unsafe.As<T, long>(ref value);
|
|
|
|
return CountSimd(ref r1, length, target);
|
|
}
|
|
|
|
#if NET6_0_OR_GREATER
|
|
if (typeof(T) == typeof(nint) ||
|
|
typeof(T) == typeof(nuint))
|
|
{
|
|
ref nint r1 = ref Unsafe.As<T, nint>(ref r0);
|
|
nint target = Unsafe.As<T, nint>(ref value);
|
|
|
|
return CountSimd(ref r1, length, target);
|
|
}
|
|
#endif
|
|
|
|
return CountSequential(ref r0, length, value);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Implements <see cref="Count{T}"/> with a sequential search.
|
|
/// </summary>
|
|
private static nint CountSequential<T>(ref T r0, nint length, T value)
|
|
where T : IEquatable<T>
|
|
{
|
|
nint result = 0;
|
|
nint offset = 0;
|
|
|
|
// Main loop with 8 unrolled iterations
|
|
while (length >= 8)
|
|
{
|
|
result += Unsafe.Add(ref r0, offset + 0).Equals(value).ToByte();
|
|
result += Unsafe.Add(ref r0, offset + 1).Equals(value).ToByte();
|
|
result += Unsafe.Add(ref r0, offset + 2).Equals(value).ToByte();
|
|
result += Unsafe.Add(ref r0, offset + 3).Equals(value).ToByte();
|
|
result += Unsafe.Add(ref r0, offset + 4).Equals(value).ToByte();
|
|
result += Unsafe.Add(ref r0, offset + 5).Equals(value).ToByte();
|
|
result += Unsafe.Add(ref r0, offset + 6).Equals(value).ToByte();
|
|
result += Unsafe.Add(ref r0, offset + 7).Equals(value).ToByte();
|
|
|
|
length -= 8;
|
|
offset += 8;
|
|
}
|
|
|
|
if (length >= 4)
|
|
{
|
|
result += Unsafe.Add(ref r0, offset + 0).Equals(value).ToByte();
|
|
result += Unsafe.Add(ref r0, offset + 1).Equals(value).ToByte();
|
|
result += Unsafe.Add(ref r0, offset + 2).Equals(value).ToByte();
|
|
result += Unsafe.Add(ref r0, offset + 3).Equals(value).ToByte();
|
|
|
|
length -= 4;
|
|
offset += 4;
|
|
}
|
|
|
|
// Iterate over the remaining values and count those that match
|
|
while (length > 0)
|
|
{
|
|
result += Unsafe.Add(ref r0, offset).Equals(value).ToByte();
|
|
|
|
length -= 1;
|
|
offset += 1;
|
|
}
|
|
|
|
return result;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Implements <see cref="Count{T}"/> with a vectorized search.
|
|
/// </summary>
|
|
private static nint CountSimd<T>(ref T r0, nint length, T value)
|
|
where T : unmanaged, IEquatable<T>
|
|
{
|
|
nint result = 0;
|
|
nint offset = 0;
|
|
|
|
// Skip the initialization overhead if there are not enough items
|
|
if (length >= Vector<T>.Count)
|
|
{
|
|
Vector<T> vc = new(value);
|
|
|
|
do
|
|
{
|
|
// Calculate the maximum sequential area that can be processed in
|
|
// one pass without the risk of numeric overflow in the dot product
|
|
// to sum the partial results. We also backup the current offset to
|
|
// be able to track how many items have been processed, which lets
|
|
// us avoid updating a third counter (length) in the loop body.
|
|
nint max = GetUpperBound<T>();
|
|
nint chunkLength = length <= max ? length : max;
|
|
nint initialOffset = offset;
|
|
|
|
Vector<T> partials = Vector<T>.Zero;
|
|
|
|
// Unrolled vectorized loop, with 8 unrolled iterations. We only run this when the
|
|
// current type T is at least 2 bytes in size, otherwise the average chunk length
|
|
// would always be too small to be able to trigger the unrolled loop, and the overall
|
|
// performance would just be slightly worse due to the additional conditional branches.
|
|
if (typeof(T) != typeof(sbyte))
|
|
{
|
|
while (chunkLength >= Vector<T>.Count * 8)
|
|
{
|
|
ref T ri0 = ref Unsafe.Add(ref r0, offset + (Vector<T>.Count * 0));
|
|
Vector<T> vi0 = Unsafe.As<T, Vector<T>>(ref ri0);
|
|
Vector<T> ve0 = Vector.Equals(vi0, vc);
|
|
|
|
partials -= ve0;
|
|
|
|
ref T ri1 = ref Unsafe.Add(ref r0, offset + (Vector<T>.Count * 1));
|
|
Vector<T> vi1 = Unsafe.As<T, Vector<T>>(ref ri1);
|
|
Vector<T> ve1 = Vector.Equals(vi1, vc);
|
|
|
|
partials -= ve1;
|
|
|
|
ref T ri2 = ref Unsafe.Add(ref r0, offset + (Vector<T>.Count * 2));
|
|
Vector<T> vi2 = Unsafe.As<T, Vector<T>>(ref ri2);
|
|
Vector<T> ve2 = Vector.Equals(vi2, vc);
|
|
|
|
partials -= ve2;
|
|
|
|
ref T ri3 = ref Unsafe.Add(ref r0, offset + (Vector<T>.Count * 3));
|
|
Vector<T> vi3 = Unsafe.As<T, Vector<T>>(ref ri3);
|
|
Vector<T> ve3 = Vector.Equals(vi3, vc);
|
|
|
|
partials -= ve3;
|
|
|
|
ref T ri4 = ref Unsafe.Add(ref r0, offset + (Vector<T>.Count * 4));
|
|
Vector<T> vi4 = Unsafe.As<T, Vector<T>>(ref ri4);
|
|
Vector<T> ve4 = Vector.Equals(vi4, vc);
|
|
|
|
partials -= ve4;
|
|
|
|
ref T ri5 = ref Unsafe.Add(ref r0, offset + (Vector<T>.Count * 5));
|
|
Vector<T> vi5 = Unsafe.As<T, Vector<T>>(ref ri5);
|
|
Vector<T> ve5 = Vector.Equals(vi5, vc);
|
|
|
|
partials -= ve5;
|
|
|
|
ref T ri6 = ref Unsafe.Add(ref r0, offset + (Vector<T>.Count * 6));
|
|
Vector<T> vi6 = Unsafe.As<T, Vector<T>>(ref ri6);
|
|
Vector<T> ve6 = Vector.Equals(vi6, vc);
|
|
|
|
partials -= ve6;
|
|
|
|
ref T ri7 = ref Unsafe.Add(ref r0, offset + (Vector<T>.Count * 7));
|
|
Vector<T> vi7 = Unsafe.As<T, Vector<T>>(ref ri7);
|
|
Vector<T> ve7 = Vector.Equals(vi7, vc);
|
|
|
|
partials -= ve7;
|
|
|
|
chunkLength -= Vector<T>.Count * 8;
|
|
offset += Vector<T>.Count * 8;
|
|
}
|
|
}
|
|
|
|
while (chunkLength >= Vector<T>.Count)
|
|
{
|
|
ref T ri = ref Unsafe.Add(ref r0, offset);
|
|
|
|
// Load the current Vector<T> register, and then use
|
|
// Vector.Equals to check for matches. This API sets the
|
|
// values corresponding to matching pairs to all 1s.
|
|
// Since the input type is guaranteed to always be signed,
|
|
// this means that a value with all 1s represents -1, as
|
|
// signed numbers are represented in two's complement.
|
|
// So we can just subtract this intermediate value to the
|
|
// partial results, which effectively sums 1 for each match.
|
|
Vector<T> vi = Unsafe.As<T, Vector<T>>(ref ri);
|
|
Vector<T> ve = Vector.Equals(vi, vc);
|
|
|
|
partials -= ve;
|
|
|
|
chunkLength -= Vector<T>.Count;
|
|
offset += Vector<T>.Count;
|
|
}
|
|
|
|
#if NET6_0_OR_GREATER
|
|
result += CastToNativeInt(Vector.Sum(partials));
|
|
#else
|
|
result += CastToNativeInt(Vector.Dot(partials, Vector<T>.One));
|
|
#endif
|
|
length -= offset - initialOffset;
|
|
}
|
|
while (length >= Vector<T>.Count);
|
|
}
|
|
|
|
// Optional 8 unrolled iterations. This is only done when a single SIMD
|
|
// register can contain over 8 values of the current type, as otherwise
|
|
// there could never be enough items left after the vectorized path
|
|
if (Vector<T>.Count > 8 &&
|
|
length >= 8)
|
|
{
|
|
result += Unsafe.Add(ref r0, offset + 0).Equals(value).ToByte();
|
|
result += Unsafe.Add(ref r0, offset + 1).Equals(value).ToByte();
|
|
result += Unsafe.Add(ref r0, offset + 2).Equals(value).ToByte();
|
|
result += Unsafe.Add(ref r0, offset + 3).Equals(value).ToByte();
|
|
result += Unsafe.Add(ref r0, offset + 4).Equals(value).ToByte();
|
|
result += Unsafe.Add(ref r0, offset + 5).Equals(value).ToByte();
|
|
result += Unsafe.Add(ref r0, offset + 6).Equals(value).ToByte();
|
|
result += Unsafe.Add(ref r0, offset + 7).Equals(value).ToByte();
|
|
|
|
length -= 8;
|
|
offset += 8;
|
|
}
|
|
|
|
// Optional 4 unrolled iterations
|
|
if (Vector<T>.Count > 4 &&
|
|
length >= 4)
|
|
{
|
|
result += Unsafe.Add(ref r0, offset + 0).Equals(value).ToByte();
|
|
result += Unsafe.Add(ref r0, offset + 1).Equals(value).ToByte();
|
|
result += Unsafe.Add(ref r0, offset + 2).Equals(value).ToByte();
|
|
result += Unsafe.Add(ref r0, offset + 3).Equals(value).ToByte();
|
|
|
|
length -= 4;
|
|
offset += 4;
|
|
}
|
|
|
|
// Iterate over the remaining values and count those that match
|
|
while (length > 0)
|
|
{
|
|
result += Unsafe.Add(ref r0, offset).Equals(value).ToByte();
|
|
|
|
length -= 1;
|
|
offset += 1;
|
|
}
|
|
|
|
return result;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Gets the upper bound for partial sums with a given <typeparamref name="T"/> parameter.
|
|
/// </summary>
|
|
/// <typeparam name="T">The type argument currently in use.</typeparam>
|
|
/// <returns>The native <see cref="int"/> value representing the upper bound.</returns>
|
|
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
|
private static unsafe nint GetUpperBound<T>()
|
|
where T : unmanaged
|
|
{
|
|
if (typeof(T) == typeof(sbyte))
|
|
{
|
|
return sbyte.MaxValue;
|
|
}
|
|
|
|
if (typeof(T) == typeof(short))
|
|
{
|
|
return short.MaxValue;
|
|
}
|
|
|
|
if (typeof(T) == typeof(int))
|
|
{
|
|
return int.MaxValue;
|
|
}
|
|
|
|
if (typeof(T) == typeof(long))
|
|
{
|
|
if (sizeof(nint) == sizeof(int))
|
|
{
|
|
return int.MaxValue;
|
|
}
|
|
|
|
// If we are on a 64 bit architecture and we are counting with a SIMD vector of 64
|
|
// bit values, we can use long.MaxValue as the upper bound, as a native integer will
|
|
// be able to contain such a value with no overflows. This will allow the count tight
|
|
// loop to process all the items in the target area in a single pass (except the mod).
|
|
// The (void*) cast is necessary to ensure the right constant is produced on runtimes
|
|
// before .NET 5 that don't natively support C# 9. For instance, removing that (void*)
|
|
// cast results in the value 0xFFFFFFFFFFFFFFFF (-1) instead of 0x7FFFFFFFFFFFFFFFF.
|
|
return (nint)(void*)long.MaxValue;
|
|
}
|
|
|
|
#if NET6_0_OR_GREATER
|
|
if (typeof(T) == typeof(nint))
|
|
{
|
|
return nint.MaxValue;
|
|
}
|
|
#endif
|
|
|
|
throw null!;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Casts a value of a given type to a native <see cref="int"/>.
|
|
/// </summary>
|
|
/// <typeparam name="T">The input type to cast.</typeparam>
|
|
/// <param name="value">The input <typeparamref name="T"/> value to cast to native <see cref="int"/>.</param>
|
|
/// <returns>The native <see cref="int"/> cast of <paramref name="value"/>.</returns>
|
|
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
|
private static nint CastToNativeInt<T>(T value)
|
|
where T : unmanaged
|
|
{
|
|
if (typeof(T) == typeof(sbyte))
|
|
{
|
|
return (byte)(sbyte)(object)value;
|
|
}
|
|
|
|
if (typeof(T) == typeof(short))
|
|
{
|
|
return (ushort)(short)(object)value;
|
|
}
|
|
|
|
if (typeof(T) == typeof(int))
|
|
{
|
|
return (nint)(uint)(int)(object)value;
|
|
}
|
|
|
|
if (typeof(T) == typeof(long))
|
|
{
|
|
return (nint)(ulong)(long)(object)value;
|
|
}
|
|
|
|
#if NET6_0_OR_GREATER
|
|
if (typeof(T) == typeof(nint))
|
|
{
|
|
return (nint)(object)value;
|
|
}
|
|
#endif
|
|
|
|
throw null!;
|
|
}
|
|
}
|