diff --git a/src/System.Linq.Dynamic.Core/ExpressionParser.cs b/src/System.Linq.Dynamic.Core/ExpressionParser.cs index 36267d46..c77622eb 100644 --- a/src/System.Linq.Dynamic.Core/ExpressionParser.cs +++ b/src/System.Linq.Dynamic.Core/ExpressionParser.cs @@ -606,11 +606,13 @@ Expression ParseLogicalAndOrOperator() } else { + ConvertNumericTypeToBiggestCommonTypeForBinaryOperator(ref left, ref right); left = Expression.And(left, right); } break; case TokenId.Bar: + ConvertNumericTypeToBiggestCommonTypeForBinaryOperator(ref left, ref right); left = Expression.Or(left, right); break; } @@ -2019,7 +2021,7 @@ Expression PromoteExpression(Expression expr, Type type, bool exact, bool conver if (ce != null) { - if (ce == NullLiteral) + if (ce == NullLiteral || ce.Value == null) { if (!type.GetTypeInfo().IsValueType || IsNullableType(type)) return Expression.Constant(null, type); @@ -2678,5 +2680,47 @@ internal static void ResetDynamicLinqTypes() { _keywords = null; } + + static void ConvertNumericTypeToBiggestCommonTypeForBinaryOperator(ref Expression left, ref Expression right) + { + if (left.Type == right.Type) + return; + + if (left.Type == typeof(UInt64) || right.Type == typeof(UInt64)) + { + right = right.Type != typeof(UInt64) ? Expression.Convert(right, typeof(UInt64)) : right; + left = left.Type != typeof(UInt64) ? Expression.Convert(left, typeof(UInt64)) : left; + } + else if (left.Type == typeof(Int64) || right.Type == typeof(Int64)) + { + right = right.Type != typeof(Int64) ? Expression.Convert(right, typeof(Int64)) : right; + left = left.Type != typeof(Int64) ? Expression.Convert(left, typeof(Int64)) : left; + } + else if (left.Type == typeof(UInt32) || right.Type == typeof(UInt32)) + { + right = right.Type != typeof(UInt32) ? Expression.Convert(right, typeof(UInt32)) : right; + left = left.Type != typeof(UInt32) ? Expression.Convert(left, typeof(UInt32)) : left; + } + else if (left.Type == typeof(Int32) || right.Type == typeof(Int32)) + { + right = right.Type != typeof(Int32) ? Expression.Convert(right, typeof(Int32)) : right; + left = left.Type != typeof(Int32) ? Expression.Convert(left, typeof(Int32)) : left; + } + else if (left.Type == typeof(UInt16) || right.Type == typeof(UInt16)) + { + right = right.Type != typeof(UInt16) ? Expression.Convert(right, typeof(UInt16)) : right; + left = left.Type != typeof(UInt16) ? Expression.Convert(left, typeof(UInt16)) : left; + } + else if (left.Type == typeof(Int16) || right.Type == typeof(Int16)) + { + right = right.Type != typeof(Int16) ? Expression.Convert(right, typeof(Int16)) : right; + left = left.Type != typeof(Int16) ? Expression.Convert(left, typeof(Int16)) : left; + } + else if (left.Type == typeof(Byte) || right.Type == typeof(Byte)) + { + right = right.Type != typeof(Byte) ? Expression.Convert(right, typeof(Byte)) : right; + left = left.Type != typeof(Byte) ? Expression.Convert(left, typeof(Byte)) : left; + } + } } } diff --git a/test/System.Linq.Dynamic.Core.Tests/ExpressionTests.cs b/test/System.Linq.Dynamic.Core.Tests/ExpressionTests.cs index b0dacb84..1d3f5340 100644 --- a/test/System.Linq.Dynamic.Core.Tests/ExpressionTests.cs +++ b/test/System.Linq.Dynamic.Core.Tests/ExpressionTests.cs @@ -34,6 +34,127 @@ public void ExpressionTests_ArrayInitializer() Assert.Throws(() => list.AsQueryable().SelectMany("new] {}")); } + public enum TestEnum2 : sbyte + { + Var1 = 0, + Var2 = 1, + Var3 = 2, + Var4 = 4, + Var5 = 8, + Var6 = 16, + } + + public class TestEnumClass + { + public TestEnum A { get; set; } + + public TestEnum2 B { get; set; } + + public int Id { get; set; } + } + + [Fact] + public void ExpressionTests_BinaryAndNumericConvert() + { + // Arrange + var lst = new List + { + new TestEnumClass {A = TestEnum.Var3, B = TestEnum2.Var3, Id = 1}, + new TestEnumClass {A = TestEnum.Var4, B = TestEnum2.Var4, Id = 2}, + new TestEnumClass {A = TestEnum.Var2, B = TestEnum2.Var2, Id = 3} + }; + var qry = lst.AsQueryable(); + + // Act + var result0 = qry.FirstOrDefault("(it.A & @0) == 1", 1); + var result1 = qry.FirstOrDefault("(it.A & @0) == 1", (uint)1); + var result2 = qry.FirstOrDefault("(it.A & @0) == 1", (long)1); + var result3 = qry.FirstOrDefault("(it.A & @0) == 1", (ulong)1); + var result4 = qry.FirstOrDefault("(it.A & @0) == 1", (byte)1); + var result5 = qry.FirstOrDefault("(it.A & @0) == 1", (sbyte)1); + var result6 = qry.FirstOrDefault("(it.A & @0) == 1", (ushort)1); + var result7 = qry.FirstOrDefault("(it.A & @0) == 1", (short)1); + + var result10 = qry.FirstOrDefault("(it.B & @0) == 1", 1); + var result11 = qry.FirstOrDefault("(it.B & @0) == 1", (uint)1); + var result12 = qry.FirstOrDefault("(it.B & @0) == 1", (long)1); + var result13 = qry.FirstOrDefault("(it.B & @0) == 1", (ulong)1); + var result14 = qry.FirstOrDefault("(it.B & @0) == 1", (byte)1); + var result15 = qry.FirstOrDefault("(it.B & @0) == 1", (sbyte)1); + var result16 = qry.FirstOrDefault("(it.B & @0) == 1", (ushort)1); + var result17 = qry.FirstOrDefault("(it.B & @0) == 1", (short)1); + + //Assert + Assert.Equal(3, result0.Id); + Assert.Equal(3, result1.Id); + Assert.Equal(3, result2.Id); + Assert.Equal(3, result3.Id); + Assert.Equal(3, result4.Id); + Assert.Equal(3, result5.Id); + Assert.Equal(3, result6.Id); + Assert.Equal(3, result7.Id); + + Assert.Equal(3, result10.Id); + Assert.Equal(3, result11.Id); + Assert.Equal(3, result12.Id); + Assert.Equal(3, result13.Id); + Assert.Equal(3, result14.Id); + Assert.Equal(3, result15.Id); + Assert.Equal(3, result16.Id); + Assert.Equal(3, result17.Id); + } + + [Fact] + public void ExpressionTests_BinaryOrNumericConvert() + { + // Arrange + var lst = new List + { + new TestEnumClass {A = TestEnum.Var3, B = TestEnum2.Var3, Id = 1}, + new TestEnumClass {A = TestEnum.Var4, B = TestEnum2.Var4, Id = 2}, + new TestEnumClass {A = TestEnum.Var2, B = TestEnum2.Var2, Id = 3} + }; + var qry = lst.AsQueryable(); + + // Act + var result0 = qry.FirstOrDefault("(it.A | @0) == 1", 1); + var result1 = qry.FirstOrDefault("(it.A | @0) == 1", (uint)1); + var result2 = qry.FirstOrDefault("(it.A | @0) == 1", (long)1); + var result3 = qry.FirstOrDefault("(it.A | @0) == 1", (ulong)1); + var result4 = qry.FirstOrDefault("(it.A | @0) == 1", (byte)1); + var result5 = qry.FirstOrDefault("(it.A | @0) == 1", (sbyte)1); + var result6 = qry.FirstOrDefault("(it.A | @0) == 1", (ushort)1); + var result7 = qry.FirstOrDefault("(it.A | @0) == 1", (short)1); + + var result10 = qry.FirstOrDefault("(it.B | @0) == 1", 1); + var result11 = qry.FirstOrDefault("(it.B | @0) == 1", (uint)1); + var result12 = qry.FirstOrDefault("(it.B | @0) == 1", (long)1); + var result13 = qry.FirstOrDefault("(it.B | @0) == 1", (ulong)1); + var result14 = qry.FirstOrDefault("(it.B | @0) == 1", (byte)1); + var result15 = qry.FirstOrDefault("(it.B | @0) == 1", (sbyte)1); + var result16 = qry.FirstOrDefault("(it.B | @0) == 1", (ushort)1); + var result17 = qry.FirstOrDefault("(it.B | @0) == 1", (short)1); + + //Assert + Assert.Equal(3, result0.Id); + Assert.Equal(3, result1.Id); + Assert.Equal(3, result2.Id); + Assert.Equal(3, result3.Id); + Assert.Equal(3, result4.Id); + Assert.Equal(3, result5.Id); + Assert.Equal(3, result6.Id); + Assert.Equal(3, result7.Id); + + Assert.Equal(3, result10.Id); + Assert.Equal(3, result11.Id); + Assert.Equal(3, result12.Id); + Assert.Equal(3, result13.Id); + Assert.Equal(3, result14.Id); + Assert.Equal(3, result15.Id); + Assert.Equal(3, result16.Id); + Assert.Equal(3, result17.Id); + } + [Fact] public void ExpressionTests_Cast_To_nullableint() { @@ -580,12 +701,12 @@ public void ExpressionTests_Guid_CompareTo_Null() // Act var result2a = qry.FirstOrDefault("it.GuidNull = null"); var result2b = qry.FirstOrDefault("null = it.GuidNull"); - // var result1a = qry.FirstOrDefault("it.GuidNull = @0", null); TODO: fails? - // var result1b = qry.FirstOrDefault("@0 = it.GuidNull", null); TODO: fails? + var result1a = qry.FirstOrDefault("it.GuidNull = @0", new object[] { null }); + var result1b = qry.FirstOrDefault("@0 = it.GuidNull", new object[] { null }); // Assert - // Assert.Equal(1, result1a.Id); - // Assert.Equal(1, result1b.Id); + Assert.Equal(1, result1a.Id); + Assert.Equal(1, result1b.Id); Assert.Equal(1, result2a.Id); Assert.Equal(1, result2b.Id); } @@ -1307,4 +1428,4 @@ public void ExpressionTests_Where_WithCachedLambda() Assert.Equal(res9, list[1]); } } -} \ No newline at end of file +}