线程安全的缓存枚举器 - 带产量的锁

2024-04-30

我有一个自定义的“CachedEnumerable”类(灵感来自缓存 IEnumerable https://stackoverflow.com/q/1537043/5683904)我需要确保我的 ASP.NET Core Web 应用程序的线程安全。

Enumerator 线程的以下实现安全吗? (对 IList _cache 的所有其他读/写都被适当锁定)(可能与C# Yield 是否会释放锁? https://stackoverflow.com/q/4608215/5683904)

更具体地说,如果有 2 个线程访问枚举器,我如何防止一个线程递增“索引”导致第二个枚举线程从 _cache 获取错误的元素(即索引 + 1 处的元素而不是索引处的元素) ?这种竞争条件真的值得关注吗?

public IEnumerator<T> GetEnumerator()
{
    var index = 0;

    while (true)
    {
        T current;
        lock (_enumeratorLock)
        {
            if (index >= _cache.Count && !MoveNext()) break;
            current = _cache[index];
            index++;
        }
        yield return current;
    }
}

我的 CachedEnumerable 版本的完整代码:

 public class CachedEnumerable<T> : IDisposable, IEnumerable<T>
    {
        IEnumerator<T> _enumerator;
        private IList<T> _cache = new List<T>();
        public bool CachingComplete { get; private set; } = false;

        public CachedEnumerable(IEnumerable<T> enumerable)
        {
            switch (enumerable)
            {
                case CachedEnumerable<T> cachedEnumerable: //This case is actually dealt with by the extension method.
                    _cache = cachedEnumerable._cache;
                    CachingComplete = cachedEnumerable.CachingComplete;
                    _enumerator = cachedEnumerable.GetEnumerator();

                    break;
                case IList<T> list:
                    //_cache = list; //without clone...
                    //Clone:
                    _cache = new T[list.Count];
                    list.CopyTo((T[]) _cache, 0);
                    CachingComplete = true;
                    break;
                default:
                    _enumerator = enumerable.GetEnumerator();
                    break;
            }
        }

        public CachedEnumerable(IEnumerator<T> enumerator)
        {
            _enumerator = enumerator;
        }

        private int CurCacheCount
        {
            get
            {
                lock (_enumeratorLock)
                {
                    return _cache.Count;
                }
            }
        }

        public IEnumerator<T> GetEnumerator()
        {
            var index = 0;

            while (true)
            {
                T current;
                lock (_enumeratorLock)
                {
                    if (index >= _cache.Count && !MoveNext()) break;
                    current = _cache[index];
                    index++;
                }
                yield return current;
            }
        }

        //private readonly AsyncLock _enumeratorLock = new AsyncLock();
        private readonly object _enumeratorLock = new object();

        private bool MoveNext()
        {
            if (CachingComplete) return false;

            if (_enumerator != null && _enumerator.MoveNext()) //The null check should have been unnecessary b/c of the lock...
            {
                _cache.Add(_enumerator.Current);
                return true;
            }
            else
            {
                CachingComplete = true;
                DisposeWrappedEnumerator(); //Release the enumerator, as it is no longer needed.
            }

            return false;
        }

        public T ElementAt(int index)
        {
            lock (_enumeratorLock)
            {
                if (index < _cache.Count)
                {
                    return _cache[index];
                }
            }

            EnumerateUntil(index);

            lock (_enumeratorLock)
            {
                if (_cache.Count <= index) throw new ArgumentOutOfRangeException(nameof(index));
                return _cache[index];
            }
        }


        public bool TryGetElementAt(int index, out T value)
        {
            lock (_enumeratorLock)
            {
                value = default;
                if (index < CurCacheCount)
                {
                    value = _cache[index];
                    return true;
                }
            }

            EnumerateUntil(index);

            lock (_enumeratorLock)
            {
                if (_cache.Count <= index) return false;
                value = _cache[index];
            }

            return true;
        }

        private void EnumerateUntil(int index)
        {
            while (true)
            {
                lock (_enumeratorLock)
                {
                    if (_cache.Count > index || !MoveNext()) break;
                }
            }
        }


        public void Dispose()
        {
            DisposeWrappedEnumerator();
        }

        private void DisposeWrappedEnumerator()
        {
            if (_enumerator != null)
            {
                _enumerator.Dispose();
                _enumerator = null;
                if (_cache is List<T> list)
                {
                    list.Trim();
                }
            }
        }

        IEnumerator IEnumerable.GetEnumerator()
        {
            return GetEnumerator();
        }

        public int CachedCount
        {
            get
            {
                lock (_enumeratorLock)
                {
                    return _cache.Count;
                }
            }
        }

        public int Count()
        {
            if (CachingComplete)
            {
                return _cache.Count;
            }

            EnsureCachingComplete();

            return _cache.Count;
        }

        private void EnsureCachingComplete()
        {
            if (CachingComplete)
            {
                return;
            }

            //Enumerate the rest of the collection
            while (!CachingComplete)
            {
                lock (_enumeratorLock)
                {
                    if (!MoveNext()) break;
                }
            }
        }

        public T[] ToArray()
        {
            EnsureCachingComplete();
            //Once Caching is complete, we don't need to lock
            if (!(_cache is T[] array))
            {
                array = _cache.ToArray();
                _cache = array;
            }

            return array;
        }

        public T this[int index] => ElementAt(index);
    }

    public static CachedEnumerable<T> Cached<T>(this IEnumerable<T> source)
    {
        //no gain in caching a cache.
        if (source is CachedEnumerable<T> cached)
        {
            return cached;
        }

        return new CachedEnumerable<T>(source);
    }
}

基本用法:(虽然不是一个有意义的用例)

var cached = expensiveEnumerable.Cached();
foreach (var element in cached) {
   Console.WriteLine(element);
}

Update

我根据 @Theodors 答案测试了当前的实现https://stackoverflow.com/a/58547863/5683904 https://stackoverflow.com/a/58547863/5683904并确认(AFAICT)使用 foreach 枚举时它是线程安全的,而不会创建重复值():

class Program
{
    static async Task Main(string[] args)
    {
        var enumerable = Enumerable.Range(0, 1_000_000);
        var cachedEnumerable = new CachedEnumerable<int>(enumerable);
        var c = new ConcurrentDictionary<int, List<int>>();
        var tasks = Enumerable.Range(1, 100).Select(id => Test(id, cachedEnumerable, c));
        Task.WaitAll(tasks.ToArray());
        foreach (var keyValuePair in c)
        {
            var hasDuplicates = keyValuePair.Value.Distinct().Count() != keyValuePair.Value.Count;
            Console.WriteLine($"Task #{keyValuePair.Key} count: {keyValuePair.Value.Count}. Has duplicates? {hasDuplicates}");
        }
    }

    static async Task Test(int id, IEnumerable<int> cache, ConcurrentDictionary<int, List<int>> c)
    {
        foreach (var i in cache)
        {
            //await Task.Delay(10);
            c.AddOrUpdate(id, v => new List<int>() {i}, (k, v) =>
            {
                v.Add(i);
                return v;
            });
        }
    }
}

您的类不是线程安全的,因为共享状态在类内未受保护的区域中发生了变化。未受保护的区域是:

  1. 构造函数
  2. The Dispose method

共享状态为:

  1. The _enumerator私人领域
  2. The _cache私人领域
  3. The CachingComplete公共财产

关于您的班级的一些其他问题:

  1. 实施IDisposable为调用者创建了处理您的类的责任。没有必要IEnumerable是一次性的。相反IEnumerator是一次性的,但有语言支持它们的自动处置(功能foreach陈述)。
  2. 您的类提供了预期不具备的扩展功能IEnumerable (ElementAt, CountETC)。也许您打算实施一个CachedList反而?如果不实施IList<T>接口,LINQ 方法,例如Count() and ToArray()无法利用您的扩展功能,并将使用慢速路径,就像使用普通香草一样IEnumerables.

Update:我刚刚注意到另一个线程安全问题。这一项与public IEnumerator<T> GetEnumerator()方法。枚举器是编译器生成的,因为该方法是迭代器(利用yield return)。编译器生成的枚举器不是线程安全的。例如考虑以下代码:

var enumerable = Enumerable.Range(0, 1_000_000);
var cachedEnumerable = new CachedEnumerable<int>(enumerable);
var enumerator = cachedEnumerable.GetEnumerator();
var tasks = Enumerable.Range(1, 4).Select(id => Task.Run(() =>
{
    int count = 0;
    while (enumerator.MoveNext())
    {
        count++;
    }
    Console.WriteLine($"Task #{id} count: {count}");
})).ToArray();
Task.WaitAll(tasks);

四个线程同时使用相同的IEnumerator。可枚举有 1,000,000 个项目。您可能期望每个线程会枚举约 250,000 个项目,但事实并非如此。

Output:

任务#1 计数:0
任务#4 计数:0
任务#3 计数:0
任务#2 计数:1000000

The MoveNext在行中while (enumerator.MoveNext())不安全MoveNext。这是编译器生成的不安全MoveNext。虽然不安全,但包括可能用于处理异常的机制 https://stackoverflow.com/questions/58246922/why-the-compiler-generated-state-machine-restores-repeatedly-the-state-to-1,在调用外部提供的代码之前将枚举器暂时标记为已完成。所以当多个线程调用时MoveNext同时,除了第一个之外的所有其他都将获得返回值false,并在完成零循环后立即终止枚举。要解决这个问题,您可能必须自己编写代码IEnumerator class.


Update:实际上我关于线程安全枚举的最后一点有点不公平,因为用IEnumerator接口本质上是一种不安全的操作,如果没有调用代码的配合,这是不可能修复的。这是因为获取下一个元素不是原子操作,因为它涉及两个步骤(调用MoveNext() + read Current)。因此,您的线程安全问题仅限于保护类的内部状态(字段_enumerator, _cache and CachingComplete)。这些仅在构造函数和Dispose方法,但我认为类的正常使用可能不会遵循创建竞争条件的代码路径,从而导致内部状态损坏。

就我个人而言,我也更愿意处理这些代码路径,并且我不会让它随心所欲。


Update:我写了一个缓存IAsyncEnumerables,展示替代技术。源的枚举IAsyncEnumerable不是由调用者驱动,使用锁或信号量来获取独占访问,而是由单独的工作任务驱动。第一个调用者启动工作任务。每个调用者首先生成已缓存的所有项目,然后等待更多项目,或者等待没有更多项目的通知。作为通知机制,我使用了TaskCompletionSource<bool> https://learn.microsoft.com/en-us/dotnet/api/system.threading.tasks.taskcompletionsource-1. A lock仍然用于确保对共享资源的所有访问都是同步的。

public class CachedAsyncEnumerable<T> : IAsyncEnumerable<T>
{
    private readonly object _locker = new object();
    private IAsyncEnumerable<T> _source;
    private Task _sourceEnumerationTask;
    private List<T> _buffer;
    private TaskCompletionSource<bool> _moveNextTCS;
    private Exception _sourceEnumerationException;
    private int _sourceEnumerationVersion; // Incremented on exception

    public CachedAsyncEnumerable(IAsyncEnumerable<T> source)
    {
        _source = source ?? throw new ArgumentNullException(nameof(source));
    }

    public async IAsyncEnumerator<T> GetAsyncEnumerator(
        CancellationToken cancellationToken = default)
    {
        lock (_locker)
        {
            if (_sourceEnumerationTask == null)
            {
                _buffer = new List<T>();
                _moveNextTCS = new TaskCompletionSource<bool>();
                _sourceEnumerationTask = Task.Run(
                    () => EnumerateSourceAsync(cancellationToken));
            }
        }
        int index = 0;
        int localVersion = -1;
        while (true)
        {
            T current = default;
            Task<bool> moveNextTask = null;
            lock (_locker)
            {
                if (localVersion == -1)
                {
                    localVersion = _sourceEnumerationVersion;
                }
                else if (_sourceEnumerationVersion != localVersion)
                {
                    ExceptionDispatchInfo
                        .Capture(_sourceEnumerationException).Throw();
                }
                if (index < _buffer.Count)
                {
                    current = _buffer[index];
                    index++;
                }
                else
                {
                    moveNextTask = _moveNextTCS.Task;
                }
            }
            if (moveNextTask == null)
            {
                yield return current;
                continue;
            }
            var moved = await moveNextTask;
            if (!moved) yield break;
            lock (_locker)
            {
                current = _buffer[index];
                index++;
            }
            yield return current;
        }
    }

    private async Task EnumerateSourceAsync(CancellationToken cancellationToken)
    {
        TaskCompletionSource<bool> localMoveNextTCS;
        try
        {
            await foreach (var item in _source.WithCancellation(cancellationToken))
            {
                lock (_locker)
                {
                    _buffer.Add(item);
                    localMoveNextTCS = _moveNextTCS;
                    _moveNextTCS = new TaskCompletionSource<bool>();
                }
                localMoveNextTCS.SetResult(true);
            }
            lock (_locker)
            {
                localMoveNextTCS = _moveNextTCS;
                _buffer.TrimExcess();
                _source = null;
            }
            localMoveNextTCS.SetResult(false);
        }
        catch (Exception ex)
        {
            lock (_locker)
            {
                localMoveNextTCS = _moveNextTCS;
                _sourceEnumerationException = ex;
                _sourceEnumerationVersion++;
                _sourceEnumerationTask = null;
            }
            localMoveNextTCS.SetException(ex);
        }
    }
}

此实现遵循处理异常的特定策略。如果枚举源时发生异常IAsyncEnumerable,异常将传播到所有当前调用者,当前使用的IAsyncEnumerator会被丢弃,不完整的缓存数据也会被丢弃。当接收到下一个枚举请求时,新的工作任务可以稍后再次启动。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

线程安全的缓存枚举器 - 带产量的锁 的相关文章

随机推荐