.NET — IAsyncEnumerable utility extensions

Collection of utility extensions for async streams

With the introduction of async streams in .NET Core 3, represented by the interface IAsyncEnumerable<T>, and with a direct support in C# 8 to iterate using await foreach or easily implement a new asynchronous stream by defining async IAsyncEnumerable<T> as the method result and using yield return/yield break just like we did for IEnumerable<T>, Microsoft standardized the way .NET developers implement asynchronous streams.

Even if we don’t realize, we probably use async streams on a daily basis, from Entity Framework Core to ASP.NET Core, it has become an important part of .NET that is now widely adopted.
In this article I’m going to show some the most common scenarios I usually face when working directly with IAsyncEnumerable<T> and how I usually solve them.


Utility extensions for async streams

When implementing an extension for an async stream there are a few guidelines to keep in mind.

Firstly, don’t validate function inputs (like null checks) inside the async method, but create a wrapper method instead that validates and then calls an internal implementation. This will keep the stack trace more clean, easy to analyze and lightweight.

Secondly, your code will surely need a CancellationToken but don’t define it in the public method signature even if it has a default value. To make developers life easier, instead of passing a CancellationToken to every method, Microsoft provides the extension method WithCancellation that can be used on any IAsyncEnumerable and if you use the attribute EnumeratorCancellation in your methods signature, the compiler will automatically use the cancellation token passed to the WithCancellation method to all your methods. If you define it as a parameter in your public signature it will cause confusion to the developer using it.

Imagine you were creating a simple Where extension to filter an async stream, to follow the guidelines it would be implemented as follows:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
public static IAsyncEnumerable<T> Where<T>(
this IAsyncEnumerable<T> source,
Func<T, bool> predicate
)
{
ArgumentNullException.ThrowIfNull(source);
ArgumentNullException.ThrowIfNull(predicate);

return Core(source, predicate);

static async IAsyncEnumerable<T> Core(
IAsyncEnumerable<T> source,
Func<T, bool> predicate,
[EnumeratorCancellation] CancellationToken ct = default
)
{
await foreach (var item in source.WithCancellation(ct))
{
if (predicate(item))
yield return item;
}
}
}

As you can see, the validations are done outside the actual implementation just like the CancellationToken is only defined internally.

I recommend giving a good look to the source code of System.Linq.Async, the officially supported LINQ extensions for async streams.

Let’s now drill down to some of my most used utility extensions.

Timeout between fetches

When receiving data from an IAsyncEnumerable we may not know how many items and how long in total it is going to take, but we may want to enforce a maximum timeout between receiving each item to ensure the application doesn’t wait indefinitely.

One very common scenario nowadays is the integration with Large Language Models (LLMs) APIs that use Server-Sent Events (SSE) to stream response tokens to be shown to the user in real time. Even with a loading somewhere, if some token response takes too long the user will probably think the application just stopped working.

The following code will link a CancellationTokenSource to the original CancellationToken and use it’s internal timer to automatically cancel after a given time as passed, throwing a TaskCanceledException when waiting for the next item.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
public static IAsyncEnumerable<T> Timeout<T>(
this IAsyncEnumerable<T> source,
int millisecondsTimeout
)
{
ArgumentNullException.ThrowIfNull(source);
ArgumentOutOfRangeException.ThrowIfLessThan(millisecondsTimeout, -1);

return Core(source, millisecondsTimeout);

static async IAsyncEnumerable<T> Core(
IAsyncEnumerable<T> source,
int millisecondsTimeout,
[EnumeratorCancellation] CancellationToken ct = default
)
{
using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct);
cts.CancelAfter(millisecondsTimeout);

await foreach (var element in source.WithCancellation(ct))
{
// disable the timer while the item is being processed
cts.CancelAfter(-1);

yield return element;

// re-enable the timer before fetching the next item
cts.CancelAfter(millisecondsTimeout);
}
}
}

I like this approach because it’s simple and, except for an extra allocation of a CancellationTokenSource and it’s internal Timer (which will be reused per item), there’s nothing much to be said making it a very efficient implementation. It can be use just like any LINQ method:

1
2
3
4
5
6
await foreach (var i in GetRandomIntegersAsync(20)
.Timeout(500)
.WithCancellation(ct))
{
Console.WriteLine($"{DateTimeOffset.Now:O} -> {i}");
}

If you are like me an prefer to receive a TimeoutException, you can change the implementation by manually iterating over the enumerator instead of using await foreach and if a TaskCanceledException is thrown but the received CancellationToken isn’t cancelled, you can assume it was due to a timeout and throw a TimeoutException instead.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
public static IAsyncEnumerable<T> Timeout<T>(
this IAsyncEnumerable<T> source,
int millisecondsTimeout
)
{
ArgumentNullException.ThrowIfNull(source);
ArgumentOutOfRangeException.ThrowIfLessThan(millisecondsTimeout, -1);

return Core(source, millisecondsTimeout);

static async IAsyncEnumerable<T> Core(
IAsyncEnumerable<T> source,
int millisecondsTimeout,
[EnumeratorCancellation] CancellationToken ct = default
)
{
using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct);
cts.CancelAfter(millisecondsTimeout);

await using var enumerator = source.GetAsyncEnumerator(cts.Token);

while (await MoveNextCheckTimeoutAsync(enumerator, ct))
{
cts.CancelAfter(-1);

yield return enumerator.Current;

cts.CancelAfter(millisecondsTimeout);
}
}

static async ValueTask<bool> MoveNextCheckTimeoutAsync(
IAsyncEnumerator<T> enumerator,
CancellationToken ct
)
{
try
{
return await enumerator.MoveNextAsync();
}
catch (TaskCanceledException e) when (!ct.IsCancellationRequested)
{
throw new TimeoutException("The next item took longer than expected to be received", e);
}
}
}

Batch with maximum waiting time

When receiving data from an IAsyncEnumerable to be persisted is some database, to reduce and optimize the insertion it may be a good idea to batch it instead of store item by item.

Imagine you are storing data into some relational table, if you receive 500 items, that’s 500 database accesses which may affect the performance of your overall application. If instead you create batches of 10 items it will mean 50 database accesses with more time in between, certainly reducing the overall database load.

If we use MoreLINQ Batch method as an example, we can easily convert it to an IAsyncEnumerable implementation.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
public static IAsyncEnumerable<T[]> Batch<T>(
this IAsyncEnumerable<T> source,
int size
)
{
ArgumentNullException.ThrowIfNull(source);
ArgumentOutOfRangeException.ThrowIfNegativeOrZero(size);

return Core(source, size);

static async IAsyncEnumerable<T[]> Core(
IAsyncEnumerable<T> source,
int size,
[EnumeratorCancellation] CancellationToken ct = default
)
{
T[] batch = null;
var count = 0;
await foreach (var item in source.WithCancellation(ct))
{
batch ??= new T[size];
batch[count++] = item;

if (count != size)
continue;

yield return batch;
batch = null;
}

if (count > 0)
{
Array.Resize(ref batch, count);
yield return batch;
}
}
}

Another common scenario when batching items is to support a timeout parameter that will either return a complete batch or, after a given time has passed, it will return whatever items have already been received.

This can be easily implemented by calculating the a timeout date and check if it has passed every time a item is received. If the given amount of time has passed, you simply resize the collection to the current size, return that batch and calculate again the next timeout date.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
public static IAsyncEnumerable<T[]> Batch<T>(
this IAsyncEnumerable<T> source,
int size,
int millisecondsTimeout
)
{
ArgumentNullException.ThrowIfNull(source);
ArgumentOutOfRangeException.ThrowIfNegativeOrZero(size);
ArgumentOutOfRangeException.ThrowIfNegative(millisecondsTimeout);

return Core(source, size, millisecondsTimeout);

static async IAsyncEnumerable<T[]> Core(
IAsyncEnumerable<T> source,
int size,
int millisecondsTimeout,
[EnumeratorCancellation] CancellationToken ct = default
)
{
T[] batch = null;
var count = 0;
var timeoutOn = DateTime.UtcNow.AddMilliseconds(millisecondsTimeout);
await foreach (var item in source.WithCancellation(ct))
{
batch ??= new T[size];
batch[count++] = item;

if (count != size)
{
if (timeoutOn > DateTime.UtcNow)
continue;

Array.Resize(ref batch, count);
}

yield return batch;

batch = null;
timeoutOn = DateTime.UtcNow.AddMilliseconds(millisecondsTimeout);
count = 0;
}

if (count > 0)
{
Array.Resize(ref batch, count);
yield return batch;
}
}
}

I like this approach because it is very simple to understand and efficient since it only allocates an extra DateTime over the initial batch method. The only downside is that it only timeouts after an item is received not while waiting. To solve that limitation, the implementation would require a timer and constant allocations of TaskCompletitionSource which, at least for my use cases, it isn’t worth the complexity and performance overload.

Again, this method can easily be used just like other LINQ extensions:

1
2
3
4
5
6
7
await foreach (var i in GetRandomIntegersAsync(20)
.Batch(5, 1000)
.Where(e => e.Length > 0)
.WithCancellation(ct))
{
Console.WriteLine($"{DateTimeOffset.Now:O} -> {i.Length}");
}

Throttling

One last scenario I usually face is to apply some throttling to prevent overloading the application in case of a high throughput of items received from IAsyncEnumerable. It usually combines very well with the timeout Batch method to enforce a predictive cadence of data.

The implementation is very simple: you calculate the timeout date, fetch the next item, and if not enough time has passed, you just do a delay for the remaining time and repeat the process per each item.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
public static IAsyncEnumerable<T> Throttling<T>(
this IAsyncEnumerable<T> source,
int millisecondsDelay
)
{
ArgumentNullException.ThrowIfNull(source);
ArgumentOutOfRangeException.ThrowIfNegative(millisecondsDelay);

return Core(source, millisecondsDelay);

static async IAsyncEnumerable<T> Core(
IAsyncEnumerable<T> source,
int millisecondsDelay,
[EnumeratorCancellation] CancellationToken ct = default
)
{
var timeoutOn = DateTime.UtcNow.AddMilliseconds(millisecondsDelay);
await foreach (var item in source.WithCancellation(ct))
{
if (DateTime.UtcNow < timeoutOn)
await Task.Delay(timeoutOn - DateTime.UtcNow, ct);

yield return item;

timeoutOn = DateTime.UtcNow.AddMilliseconds(millisecondsDelay);
}
}
}

Once again, use it just like other LINQ extensions:

1
2
3
4
5
6
await foreach (var i in GetRandomIntegersAsync(20)
.Throttling(5000)
.WithCancellation(ct))
{
Console.WriteLine($"{DateTimeOffset.Now:O} -> {i}");
}

Conclusion

In this article I’ve shown some of the most common utility methods I use when working with IAsyncEnumerable, allowing to either implement timeouts while waiting for the next item, batching with a maximum wait time or enforcing throttling to prevent application overload.

Feel free to change them in ways that make sense to you, like creating overloads that receive a TimeSpan, or use them as an example for creating your own utility methods.

Here’s the full code sample so you can put it in your own projects.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
namespace System.Collections.Generic;

public static class AsyncEnumerableExtensions
{
public static IAsyncEnumerable<T> Timeout<T>(
this IAsyncEnumerable<T> source,
int millisecondsTimeout
)
{
ArgumentNullException.ThrowIfNull(source);
ArgumentOutOfRangeException.ThrowIfLessThan(millisecondsTimeout, -1);

return Core(source, millisecondsTimeout);

static async IAsyncEnumerable<T> Core(
IAsyncEnumerable<T> source,
int millisecondsTimeout,
[EnumeratorCancellation] CancellationToken ct = default
)
{
using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct);
cts.CancelAfter(millisecondsTimeout);

await using var enumerator = source.GetAsyncEnumerator(cts.Token);

while (await MoveNextCheckTimeoutAsync(enumerator, ct))
{
cts.CancelAfter(-1);

yield return enumerator.Current;

cts.CancelAfter(millisecondsTimeout);
}
}

static async ValueTask<bool> MoveNextCheckTimeoutAsync(
IAsyncEnumerator<T> enumerator,
CancellationToken ct
)
{
try
{
return await enumerator.MoveNextAsync();
}
catch (TaskCanceledException e) when (!ct.IsCancellationRequested)
{
throw new TimeoutException("The next item took longer than expected to be received", e);
}
}
}

public static IAsyncEnumerable<T[]> Batch<T>(
this IAsyncEnumerable<T> source,
int size
)
{
ArgumentNullException.ThrowIfNull(source);
ArgumentOutOfRangeException.ThrowIfNegativeOrZero(size);

return Core(source, size);

static async IAsyncEnumerable<T[]> Core(
IAsyncEnumerable<T> source,
int size,
[EnumeratorCancellation] CancellationToken ct = default
)
{
T[] batch = null;
var count = 0;
await foreach (var item in source.WithCancellation(ct))
{
batch ??= new T[size];
batch[count++] = item;

if (count != size)
continue;

yield return batch;
batch = null;
}

if (count > 0)
{
Array.Resize(ref batch, count);
yield return batch;
}
}
}

public static IAsyncEnumerable<T[]> Batch<T>(
this IAsyncEnumerable<T> source,
int size,
int millisecondsTimeout
)
{
ArgumentNullException.ThrowIfNull(source);
ArgumentOutOfRangeException.ThrowIfNegativeOrZero(size);
ArgumentOutOfRangeException.ThrowIfNegative(millisecondsTimeout);

return Core(source, size, millisecondsTimeout);

static async IAsyncEnumerable<T[]> Core(
IAsyncEnumerable<T> source,
int size,
int millisecondsTimeout,
[EnumeratorCancellation] CancellationToken ct = default
)
{
T[] batch = null;
var count = 0;
var timeoutOn = DateTime.UtcNow.AddMilliseconds(millisecondsTimeout);
await foreach (var item in source.WithCancellation(ct))
{
batch ??= new T[size];
batch[count++] = item;

if (count != size)
{
if (timeoutOn > DateTime.UtcNow)
continue;

Array.Resize(ref batch, count);
}

yield return batch;

batch = null;
timeoutOn = DateTime.UtcNow.AddMilliseconds(millisecondsTimeout);
count = 0;
}

if (count > 0)
{
Array.Resize(ref batch, count);
yield return batch;
}
}
}

public static IAsyncEnumerable<T> Throttling<T>(
this IAsyncEnumerable<T> source,
int millisecondsDelay
)
{
ArgumentNullException.ThrowIfNull(source);
ArgumentOutOfRangeException.ThrowIfNegative(millisecondsDelay);

return Core(source, millisecondsDelay);

static async IAsyncEnumerable<T> Core(
IAsyncEnumerable<T> source,
int millisecondsDelay,
[EnumeratorCancellation] CancellationToken ct = default
)
{
var timeoutOn = DateTime.UtcNow.AddMilliseconds(millisecondsDelay);
await foreach (var item in source.WithCancellation(ct))
{
if (DateTime.UtcNow < timeoutOn)
await Task.Delay(timeoutOn - DateTime.UtcNow, ct);

yield return item;

timeoutOn = DateTime.UtcNow.AddMilliseconds(millisecondsDelay);
}
}
}
}