diff --git a/src/Microsoft.EntityFrameworkCore.DynamicLinq/EFDynamicQueryableExtensions.cs b/src/Microsoft.EntityFrameworkCore.DynamicLinq/EFDynamicQueryableExtensions.cs index 38c8f8ac..d11ef3e9 100644 --- a/src/Microsoft.EntityFrameworkCore.DynamicLinq/EFDynamicQueryableExtensions.cs +++ b/src/Microsoft.EntityFrameworkCore.DynamicLinq/EFDynamicQueryableExtensions.cs @@ -661,6 +661,93 @@ public static Task SingleOrDefaultAsync([NotNull] this IQueryable sourc } #endregion SingleOrDefault + #region SumAsync + + /// + /// Asynchronously computes the sum of a sequence of values. + /// + /// + /// Multiple active operations on the same context instance are not supported. Use 'await' to ensure + /// that any asynchronous operations have completed before calling another method on this context. + /// + /// + /// An that contains the elements to be summed. + /// + /// + /// A to observe while waiting for the task to complete. + /// + /// + /// A task that represents the asynchronous operation. + /// The task result contains sum of the values in the sequence. + /// + [PublicAPI] + public static Task SumAsync([NotNull] this IQueryable source, CancellationToken cancellationToken = default(CancellationToken)) + { + Check.NotNull(source, nameof(source)); + Check.NotNull(cancellationToken, nameof(cancellationToken)); + + var sum = GetMethod(nameof(Queryable.Sum), source.ElementType); + + return ExecuteDynamicAsync(sum, source, cancellationToken); + } + + /// + /// Asynchronously computes the sum of a sequence of values. + /// + /// + /// Multiple active operations on the same context instance are not supported. Use 'await' to ensure + /// that any asynchronous operations have completed before calling another method on this context. + /// + /// + /// An that contains the elements to be summed. + /// + /// A projection function to apply to each element. + /// An object array that contains zero or more objects to insert into the predicate as parameters. Similar to the way String.Format formats strings. + /// + /// A task that represents the asynchronous operation. + /// The task result contains the number of elements in the sequence that satisfy the condition in the predicate + /// function. + /// + [PublicAPI] + public static Task SumAsync([NotNull] this IQueryable source, [NotNull] string selector, [CanBeNull] params object[] args) + { + return SumAsync(source, default(CancellationToken), selector, args); + } + + /// + /// Asynchronously computes the sum of a sequence of values. + /// + /// + /// Multiple active operations on the same context instance are not supported. Use 'await' to ensure + /// that any asynchronous operations have completed before calling another method on this context. + /// + /// + /// An that contains the elements to be summed. + /// + /// A projection function to apply to each element. + /// An object array that contains zero or more objects to insert into the predicate as parameters. Similar to the way String.Format formats strings. + /// + /// A to observe while waiting for the task to complete. + /// + /// + /// A task that represents the asynchronous operation. + /// The task result contains the sum of the projected values. + /// + [PublicAPI] + public static Task SumAsync([NotNull] this IQueryable source, CancellationToken cancellationToken, [NotNull] string selector, [CanBeNull] params object[] args) + { + Check.NotNull(source, nameof(source)); + Check.NotNull(selector, nameof(selector)); + Check.NotNull(cancellationToken, nameof(cancellationToken)); + + LambdaExpression lambda = DynamicExpressionParser.ParseLambda(false, source.ElementType, null, selector, args); + + var sumSelector = GetMethod(nameof(Queryable.Sum), lambda.ReturnType, 1); + + return ExecuteDynamicAsync(sumSelector, source, Expression.Quote(lambda), cancellationToken); + } + #endregion SumAsync + #region Private Helpers // Copied from https://github.com/aspnet/EntityFramework/blob/9186d0b78a3176587eeb0f557c331f635760fe92/src/Microsoft.EntityFrameworkCore/EntityFrameworkQueryableExtensions.cs //private static Task ExecuteAsync(MethodInfo operatorMethodInfo, IQueryable source, CancellationToken cancellationToken = default(CancellationToken)) @@ -681,6 +768,24 @@ public static Task SingleOrDefaultAsync([NotNull] this IQueryable sourc // throw new InvalidOperationException(Res.IQueryableProviderNotAsync); //} + private static readonly MethodInfo _executeAsyncMethod = + typeof(EntityFrameworkDynamicQueryableExtensions) +#if NETSTANDARD + .GetMethods(BindingFlags.Static | BindingFlags.NonPublic) + .Single(m => m.Name == nameof(ExecuteAsync) && m.GetParameters().Select(p => p.ParameterType).SequenceEqual(new[] { typeof(MethodInfo), typeof(IQueryable), typeof(CancellationToken) })); +#else + .GetMethod(nameof(ExecuteAsync), BindingFlags.Static | BindingFlags.NonPublic, null, new[] { typeof(MethodInfo), typeof(IQueryable), typeof(CancellationToken) }, null); +#endif + + private static Task ExecuteDynamicAsync(MethodInfo operatorMethodInfo, IQueryable source, CancellationToken cancellationToken = default(CancellationToken)) + { + var executeAsyncMethod = _executeAsyncMethod.MakeGenericMethod(operatorMethodInfo.ReturnType); + + var task = (Task)executeAsyncMethod.Invoke(null, new object[] { operatorMethodInfo, source, cancellationToken }); + var castedTask = task.ContinueWith(t => (dynamic)t.GetType().GetProperty(nameof(Task.Result)).GetValue(t)); + + return castedTask; + } private static Task ExecuteAsync(MethodInfo operatorMethodInfo, IQueryable source, CancellationToken cancellationToken = default(CancellationToken)) { @@ -707,6 +812,25 @@ public static Task SingleOrDefaultAsync([NotNull] this IQueryable sourc private static Task ExecuteAsync(MethodInfo operatorMethodInfo, IQueryable source, LambdaExpression expression, CancellationToken cancellationToken = default(CancellationToken)) => ExecuteAsync(operatorMethodInfo, source, Expression.Quote(expression), cancellationToken); + private static readonly MethodInfo _executeAsyncMethodWithExpression = + typeof(EntityFrameworkDynamicQueryableExtensions) +#if NETSTANDARD + .GetMethods(BindingFlags.Static | BindingFlags.NonPublic) + .Single(m => m.Name == nameof(ExecuteAsync) && m.GetParameters().Select(p => p.ParameterType).SequenceEqual(new[] { typeof(MethodInfo), typeof(IQueryable), typeof(Expression), typeof(CancellationToken) })); +#else + .GetMethod(nameof(ExecuteAsync), BindingFlags.Static | BindingFlags.NonPublic, null, new[] { typeof(MethodInfo), typeof(IQueryable), typeof(Expression), typeof(CancellationToken)}, null); +#endif + + private static Task ExecuteDynamicAsync(MethodInfo operatorMethodInfo, IQueryable source, Expression expression, CancellationToken cancellationToken = default(CancellationToken)) + { + var executeAsyncMethod = _executeAsyncMethodWithExpression.MakeGenericMethod(operatorMethodInfo.ReturnType); + + var task = (Task)executeAsyncMethod.Invoke(null, new object[] { operatorMethodInfo, source, expression, cancellationToken }); + var castedTask = task.ContinueWith(t => (dynamic)t.GetType().GetProperty(nameof(Task.Result)).GetValue(t)); + + return castedTask; + } + private static Task ExecuteAsync(MethodInfo operatorMethodInfo, IQueryable source, Expression expression, CancellationToken cancellationToken = default(CancellationToken)) { #if EFCORE @@ -730,10 +854,13 @@ public static Task SingleOrDefaultAsync([NotNull] this IQueryable sourc } private static MethodInfo GetMethod(string name, int parameterCount = 0, Func predicate = null) => - GetMethod(name, parameterCount, mi => (mi.ReturnType == typeof(TResult)) && ((predicate == null) || predicate(mi))); + GetMethod(name, typeof(TResult), parameterCount, predicate); + + private static MethodInfo GetMethod(string name, Type returnType, int parameterCount = 0, Func predicate = null) => + GetMethod(name, parameterCount, mi => (mi.ReturnType == returnType) && ((predicate == null) || predicate(mi))); private static MethodInfo GetMethod(string name, int parameterCount = 0, Func predicate = null) => - typeof(Queryable).GetTypeInfo().GetDeclaredMethods(name).Single(mi => (mi.GetParameters().Length == parameterCount + 1) && ((predicate == null) || predicate(mi))); - #endregion Private Helpers + typeof(Queryable).GetTypeInfo().GetDeclaredMethods(name).First(mi => (mi.GetParameters().Length == parameterCount + 1) && ((predicate == null) || predicate(mi))); +#endregion Private Helpers } } diff --git a/test/System.Linq.Dynamic.Core.Tests/EntitiesTests.SumAsync.cs b/test/System.Linq.Dynamic.Core.Tests/EntitiesTests.SumAsync.cs new file mode 100644 index 00000000..16013225 --- /dev/null +++ b/test/System.Linq.Dynamic.Core.Tests/EntitiesTests.SumAsync.cs @@ -0,0 +1,45 @@ +#if EFCORE +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.DynamicLinq; +#else +using System.Data.Entity; +using EntityFramework.DynamicLinq; +#endif +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Dynamic.Core.Tests +{ + public partial class EntitiesTests + { + [Fact] + public async Task Entities_SumAsync() + { + //Arrange + PopulateTestData(1, 0); + + var expectedSum = await _context.Blogs.Select(b => b.BlogId).SumAsync(); + + //Act + var actualSum = await _context.Blogs.Select(b => b.BlogId).SumAsync(); + + //Assert + Assert.Equal(expectedSum, actualSum); + } + + [Fact] + public async Task Entities_SumAsync_Selector() + { + //Arrange + PopulateTestData(1, 0); + + var expectedSum = await _context.Blogs.SumAsync(b => b.BlogId); + + //Act + var actualSum = await _context.Blogs.SumAsync("BlogId"); + + //Assert + Assert.Equal(expectedSum, actualSum); + } + } +}