diff --git a/src/embed_tests/EnumTests.cs b/src/embed_tests/EnumTests.cs new file mode 100644 index 000000000..f8f1789d2 --- /dev/null +++ b/src/embed_tests/EnumTests.cs @@ -0,0 +1,628 @@ +using System; +using System.Collections.Generic; + +using NUnit.Framework; + +using Python.Runtime; + +namespace Python.EmbeddingTest +{ + public class EnumTests + { + private static VerticalDirection[] VerticalDirectionEnumValues = Enum.GetValues(); + private static HorizontalDirection[] HorizontalDirectionEnumValues = Enum.GetValues(); + + [OneTimeSetUp] + public void SetUp() + { + PythonEngine.Initialize(); + } + + [OneTimeTearDown] + public void Dispose() + { + PythonEngine.Shutdown(); + } + + public enum VerticalDirection + { + Down = -2, + Flat = 0, + Up = 2, + } + + public enum HorizontalDirection + { + Left = -2, + Flat = 0, + Right = 2, + } + + [Test] + public void CSharpEnumsBehaveAsEnumsInPython() + { + using var _ = Py.GIL(); + using var module = PyModule.FromString("CSharpEnumsBehaveAsEnumsInPython", $@" +from clr import AddReference +AddReference(""Python.EmbeddingTest"") + +from Python.EmbeddingTest import * + +def enum_is_right_type(enum_value={nameof(EnumTests)}.{nameof(VerticalDirection)}.{nameof(VerticalDirection.Up)}): + return isinstance(enum_value, {nameof(EnumTests)}.{nameof(VerticalDirection)}) +"); + + Assert.IsTrue(module.InvokeMethod("enum_is_right_type").As()); + + // Also test passing the enum value from C# to Python + using var pyEnumValue = VerticalDirection.Up.ToPython(); + Assert.IsTrue(module.InvokeMethod("enum_is_right_type", pyEnumValue).As()); + } + + private PyModule GetTestOperatorsModule(string @operator, VerticalDirection operand1, double operand2) + { + var operand1Str = $"{nameof(EnumTests)}.{nameof(VerticalDirection)}.{operand1}"; + return PyModule.FromString("GetTestOperatorsModule", $@" +from clr import AddReference +AddReference(""Python.EmbeddingTest"") + +from Python.EmbeddingTest import * + +def operation1(): + return {operand1Str} {@operator} {operand2} + +def operation2(): + return {operand2} {@operator} {operand1Str} +"); + } + + [TestCase("*", VerticalDirection.Down, 2, -4, -4)] + [TestCase("/", VerticalDirection.Down, 2, -1, -1)] + [TestCase("+", VerticalDirection.Down, 2, 0, 0)] + [TestCase("-", VerticalDirection.Down, 2, -4, 4)] + [TestCase("*", VerticalDirection.Flat, 2, 0, 0)] + [TestCase("/", VerticalDirection.Flat, 2, 0, 0)] + [TestCase("+", VerticalDirection.Flat, 2, 2, 2)] + [TestCase("-", VerticalDirection.Flat, 2, -2, 2)] + [TestCase("*", VerticalDirection.Up, 2, 4, 4)] + [TestCase("/", VerticalDirection.Up, 2, 1, 1)] + [TestCase("+", VerticalDirection.Up, 2, 4, 4)] + [TestCase("-", VerticalDirection.Up, 2, 0, 0)] + [TestCase("*", VerticalDirection.Down, -2, 4, 4)] + [TestCase("/", VerticalDirection.Down, -2, 1, 1)] + [TestCase("+", VerticalDirection.Down, -2, -4, -4)] + [TestCase("-", VerticalDirection.Down, -2, 0, 0)] + [TestCase("*", VerticalDirection.Flat, -2, 0, 0)] + [TestCase("/", VerticalDirection.Flat, -2, 0, 0)] + [TestCase("+", VerticalDirection.Flat, -2, -2, -2)] + [TestCase("-", VerticalDirection.Flat, -2, 2, -2)] + [TestCase("*", VerticalDirection.Up, -2, -4, -4)] + [TestCase("/", VerticalDirection.Up, -2, -1, -1)] + [TestCase("+", VerticalDirection.Up, -2, 0, 0)] + [TestCase("-", VerticalDirection.Up, -2, 4, -4)] + public void ArithmeticOperatorsWorkWithoutExplicitCast(string @operator, VerticalDirection operand1, double operand2, double expectedResult, double invertedOperationExpectedResult) + { + using var _ = Py.GIL(); + using var module = GetTestOperatorsModule(@operator, operand1, operand2); + + Assert.AreEqual(expectedResult, module.InvokeMethod("operation1").As()); + + if (Convert.ToInt64(operand1) != 0 || @operator != "/") + { + Assert.AreEqual(invertedOperationExpectedResult, module.InvokeMethod("operation2").As()); + } + } + + [TestCase("==", VerticalDirection.Down, -2, true)] + [TestCase("==", VerticalDirection.Down, 0, false)] + [TestCase("==", VerticalDirection.Down, 2, false)] + [TestCase("==", VerticalDirection.Flat, -2, false)] + [TestCase("==", VerticalDirection.Flat, 0, true)] + [TestCase("==", VerticalDirection.Flat, 2, false)] + [TestCase("==", VerticalDirection.Up, -2, false)] + [TestCase("==", VerticalDirection.Up, 0, false)] + [TestCase("==", VerticalDirection.Up, 2, true)] + [TestCase("!=", VerticalDirection.Down, -2, false)] + [TestCase("!=", VerticalDirection.Down, 0, true)] + [TestCase("!=", VerticalDirection.Down, 2, true)] + [TestCase("!=", VerticalDirection.Flat, -2, true)] + [TestCase("!=", VerticalDirection.Flat, 0, false)] + [TestCase("!=", VerticalDirection.Flat, 2, true)] + [TestCase("!=", VerticalDirection.Up, -2, true)] + [TestCase("!=", VerticalDirection.Up, 0, true)] + [TestCase("!=", VerticalDirection.Up, 2, false)] + [TestCase("<", VerticalDirection.Down, -3, false)] + [TestCase("<", VerticalDirection.Down, -2, false)] + [TestCase("<", VerticalDirection.Down, 0, true)] + [TestCase("<", VerticalDirection.Down, 2, true)] + [TestCase("<", VerticalDirection.Flat, -2, false)] + [TestCase("<", VerticalDirection.Flat, 0, false)] + [TestCase("<", VerticalDirection.Flat, 2, true)] + [TestCase("<", VerticalDirection.Up, -2, false)] + [TestCase("<", VerticalDirection.Up, 0, false)] + [TestCase("<", VerticalDirection.Up, 2, false)] + [TestCase("<", VerticalDirection.Up, 3, true)] + [TestCase("<=", VerticalDirection.Down, -3, false)] + [TestCase("<=", VerticalDirection.Down, -2, true)] + [TestCase("<=", VerticalDirection.Down, 0, true)] + [TestCase("<=", VerticalDirection.Down, 2, true)] + [TestCase("<=", VerticalDirection.Flat, -2, false)] + [TestCase("<=", VerticalDirection.Flat, 0, true)] + [TestCase("<=", VerticalDirection.Flat, 2, true)] + [TestCase("<=", VerticalDirection.Up, -2, false)] + [TestCase("<=", VerticalDirection.Up, 0, false)] + [TestCase("<=", VerticalDirection.Up, 2, true)] + [TestCase("<=", VerticalDirection.Up, 3, true)] + [TestCase(">", VerticalDirection.Down, -3, true)] + [TestCase(">", VerticalDirection.Down, -2, false)] + [TestCase(">", VerticalDirection.Down, 0, false)] + [TestCase(">", VerticalDirection.Down, 2, false)] + [TestCase(">", VerticalDirection.Flat, -2, true)] + [TestCase(">", VerticalDirection.Flat, 0, false)] + [TestCase(">", VerticalDirection.Flat, 2, false)] + [TestCase(">", VerticalDirection.Up, -2, true)] + [TestCase(">", VerticalDirection.Up, 0, true)] + [TestCase(">", VerticalDirection.Up, 2, false)] + [TestCase(">", VerticalDirection.Up, 3, false)] + [TestCase(">=", VerticalDirection.Down, -3, true)] + [TestCase(">=", VerticalDirection.Down, -2, true)] + [TestCase(">=", VerticalDirection.Down, 0, false)] + [TestCase(">=", VerticalDirection.Down, 2, false)] + [TestCase(">=", VerticalDirection.Flat, -2, true)] + [TestCase(">=", VerticalDirection.Flat, 0, true)] + [TestCase(">=", VerticalDirection.Flat, 2, false)] + [TestCase(">=", VerticalDirection.Up, -2, true)] + [TestCase(">=", VerticalDirection.Up, 0, true)] + [TestCase(">=", VerticalDirection.Up, 2, true)] + [TestCase(">=", VerticalDirection.Up, 3, false)] + public void IntComparisonOperatorsWorkWithoutExplicitCast(string @operator, VerticalDirection operand1, int operand2, bool expectedResult) + { + using var _ = Py.GIL(); + using var module = GetTestOperatorsModule(@operator, operand1, operand2); + + Assert.AreEqual(expectedResult, module.InvokeMethod("operation1").As()); + + var invertedOperationExpectedResult = (@operator.StartsWith('<') || @operator.StartsWith('>')) && Convert.ToInt64(operand1) != operand2 + ? !expectedResult + : expectedResult; + Assert.AreEqual(invertedOperationExpectedResult, module.InvokeMethod("operation2").As()); + } + + [TestCase("==", VerticalDirection.Down, -2.0, true)] + [TestCase("==", VerticalDirection.Down, -2.00001, false)] + [TestCase("==", VerticalDirection.Down, -1.99999, false)] + [TestCase("==", VerticalDirection.Down, 0.0, false)] + [TestCase("==", VerticalDirection.Down, 2.0, false)] + [TestCase("==", VerticalDirection.Flat, -2.0, false)] + [TestCase("==", VerticalDirection.Flat, 0.0, true)] + [TestCase("==", VerticalDirection.Flat, 0.00001, false)] + [TestCase("==", VerticalDirection.Flat, -0.00001, false)] + [TestCase("==", VerticalDirection.Flat, 2.0, false)] + [TestCase("==", VerticalDirection.Up, -2.0, false)] + [TestCase("==", VerticalDirection.Up, 0.0, false)] + [TestCase("==", VerticalDirection.Up, 2.0, true)] + [TestCase("==", VerticalDirection.Up, 2.00001, false)] + [TestCase("==", VerticalDirection.Up, 1.99999, false)] + [TestCase("!=", VerticalDirection.Down, -2.0, false)] + [TestCase("!=", VerticalDirection.Down, -2.00001, true)] + [TestCase("!=", VerticalDirection.Down, -1.99999, true)] + [TestCase("!=", VerticalDirection.Down, 0.0, true)] + [TestCase("!=", VerticalDirection.Down, 2.0, true)] + [TestCase("!=", VerticalDirection.Flat, -2.0, true)] + [TestCase("!=", VerticalDirection.Flat, 0.0, false)] + [TestCase("!=", VerticalDirection.Flat, 0.00001, true)] + [TestCase("!=", VerticalDirection.Flat, -0.00001, true)] + [TestCase("!=", VerticalDirection.Flat, 2.0, true)] + [TestCase("!=", VerticalDirection.Up, -2.0, true)] + [TestCase("!=", VerticalDirection.Up, 0.0, true)] + [TestCase("!=", VerticalDirection.Up, 2.0, false)] + [TestCase("!=", VerticalDirection.Up, 2.00001, true)] + [TestCase("!=", VerticalDirection.Up, 1.99999, true)] + [TestCase("<", VerticalDirection.Down, -3.0, false)] + [TestCase("<", VerticalDirection.Down, -2.00001, false)] + [TestCase("<", VerticalDirection.Down, -2.0, false)] + [TestCase("<", VerticalDirection.Down, -1.99999, true)] + [TestCase("<", VerticalDirection.Down, 0.0, true)] + [TestCase("<", VerticalDirection.Down, 2.0, true)] + [TestCase("<", VerticalDirection.Flat, -2.0, false)] + [TestCase("<", VerticalDirection.Flat, -0.00001, false)] + [TestCase("<", VerticalDirection.Flat, 0.0, false)] + [TestCase("<", VerticalDirection.Flat, 0.00001, true)] + [TestCase("<", VerticalDirection.Flat, 2.0, true)] + [TestCase("<", VerticalDirection.Up, -2.0, false)] + [TestCase("<", VerticalDirection.Up, 0.0, false)] + [TestCase("<", VerticalDirection.Up, 1.99999, false)] + [TestCase("<", VerticalDirection.Up, 2.0, false)] + [TestCase("<", VerticalDirection.Up, 2.00001, true)] + [TestCase("<", VerticalDirection.Up, 3.0, true)] + [TestCase("<=", VerticalDirection.Down, -3.0, false)] + [TestCase("<=", VerticalDirection.Down, -2.00001, false)] + [TestCase("<=", VerticalDirection.Down, -2.0, true)] + [TestCase("<=", VerticalDirection.Down, -1.99999, true)] + [TestCase("<=", VerticalDirection.Down, 0.0, true)] + [TestCase("<=", VerticalDirection.Down, 2.0, true)] + [TestCase("<=", VerticalDirection.Flat, -2.0, false)] + [TestCase("<=", VerticalDirection.Flat, -0.00001, false)] + [TestCase("<=", VerticalDirection.Flat, 0.0, true)] + [TestCase("<=", VerticalDirection.Flat, 0.00001, true)] + [TestCase("<=", VerticalDirection.Flat, 2.0, true)] + [TestCase("<=", VerticalDirection.Up, -2.0, false)] + [TestCase("<=", VerticalDirection.Up, 0.0, false)] + [TestCase("<=", VerticalDirection.Up, 1.99999, false)] + [TestCase("<=", VerticalDirection.Up, 2.0, true)] + [TestCase("<=", VerticalDirection.Up, 2.00001, true)] + [TestCase("<=", VerticalDirection.Up, 3.0, true)] + [TestCase(">", VerticalDirection.Down, -3.0, true)] + [TestCase(">", VerticalDirection.Down, -2.00001, true)] + [TestCase(">", VerticalDirection.Down, -2.0, false)] + [TestCase(">", VerticalDirection.Down, -1.99999, false)] + [TestCase(">", VerticalDirection.Down, 0.0, false)] + [TestCase(">", VerticalDirection.Down, 2.0, false)] + [TestCase(">", VerticalDirection.Flat, -2.0, true)] + [TestCase(">", VerticalDirection.Flat, -0.00001, true)] + [TestCase(">", VerticalDirection.Flat, 0.0, false)] + [TestCase(">", VerticalDirection.Flat, 0.00001, false)] + [TestCase(">", VerticalDirection.Flat, 2.0, false)] + [TestCase(">", VerticalDirection.Up, -2.0, true)] + [TestCase(">", VerticalDirection.Up, 0.0, true)] + [TestCase(">", VerticalDirection.Up, 1.99999, true)] + [TestCase(">", VerticalDirection.Up, 2.0, false)] + [TestCase(">", VerticalDirection.Up, 2.00001, false)] + [TestCase(">", VerticalDirection.Up, 3.0, false)] + [TestCase(">=", VerticalDirection.Down, -3.0, true)] + [TestCase(">=", VerticalDirection.Down, -2.00001, true)] + [TestCase(">=", VerticalDirection.Down, -2.0, true)] + [TestCase(">=", VerticalDirection.Down, -1.99999, false)] + [TestCase(">=", VerticalDirection.Down, 0.0, false)] + [TestCase(">=", VerticalDirection.Down, 2.0, false)] + [TestCase(">=", VerticalDirection.Flat, -2.0, true)] + [TestCase(">=", VerticalDirection.Flat, -0.00001, true)] + [TestCase(">=", VerticalDirection.Flat, 0.0, true)] + [TestCase(">=", VerticalDirection.Flat, 0.00001, false)] + [TestCase(">=", VerticalDirection.Flat, 2.0, false)] + [TestCase(">=", VerticalDirection.Up, -2.0, true)] + [TestCase(">=", VerticalDirection.Up, 0.0, true)] + [TestCase(">=", VerticalDirection.Up, 1.99999, true)] + [TestCase(">=", VerticalDirection.Up, 2.0, true)] + [TestCase(">=", VerticalDirection.Up, 2.00001, false)] + [TestCase(">=", VerticalDirection.Up, 3.0, false)] + public void FloatComparisonOperatorsWorkWithoutExplicitCast(string @operator, VerticalDirection operand1, double operand2, bool expectedResult) + { + using var _ = Py.GIL(); + using var module = GetTestOperatorsModule(@operator, operand1, operand2); + + Assert.AreEqual(expectedResult, module.InvokeMethod("operation1").As()); + + var invertedOperationExpectedResult = (@operator.StartsWith('<') || @operator.StartsWith('>')) && Convert.ToInt64(operand1) != operand2 + ? !expectedResult + : expectedResult; + Assert.AreEqual(invertedOperationExpectedResult, module.InvokeMethod("operation2").As()); + } + + public static IEnumerable SameEnumTypeComparisonOperatorsTestCases + { + get + { + var operators = new[] { "==", "!=", "<", "<=", ">", ">=" }; + + foreach (var enumValue in VerticalDirectionEnumValues) + { + foreach (var enumValue2 in VerticalDirectionEnumValues) + { + yield return new TestCaseData("==", enumValue, enumValue2, enumValue == enumValue2); + yield return new TestCaseData("!=", enumValue, enumValue2, enumValue != enumValue2); + yield return new TestCaseData("<", enumValue, enumValue2, enumValue < enumValue2); + yield return new TestCaseData("<=", enumValue, enumValue2, enumValue <= enumValue2); + yield return new TestCaseData(">", enumValue, enumValue2, enumValue > enumValue2); + yield return new TestCaseData(">=", enumValue, enumValue2, enumValue >= enumValue2); + } + } + } + } + + [TestCaseSource(nameof(SameEnumTypeComparisonOperatorsTestCases))] + public void SameEnumTypeComparisonOperatorsWorkWithoutExplicitCast(string @operator, VerticalDirection operand1, VerticalDirection operand2, bool expectedResult) + { + using var _ = Py.GIL(); + using var module = PyModule.FromString("SameEnumTypeComparisonOperatorsWorkWithoutExplicitCast", $@" +from clr import AddReference +AddReference(""Python.EmbeddingTest"") + +from Python.EmbeddingTest import * + +def operation(): + return {nameof(EnumTests)}.{nameof(VerticalDirection)}.{operand1} {@operator} {nameof(EnumTests)}.{nameof(VerticalDirection)}.{operand2} +"); + + Assert.AreEqual(expectedResult, module.InvokeMethod("operation").As()); + } + + [TestCase("==", VerticalDirection.Down, "Down", true)] + [TestCase("==", VerticalDirection.Down, "Flat", false)] + [TestCase("==", VerticalDirection.Down, "Up", false)] + [TestCase("==", VerticalDirection.Flat, "Down", false)] + [TestCase("==", VerticalDirection.Flat, "Flat", true)] + [TestCase("==", VerticalDirection.Flat, "Up", false)] + [TestCase("==", VerticalDirection.Up, "Down", false)] + [TestCase("==", VerticalDirection.Up, "Flat", false)] + [TestCase("==", VerticalDirection.Up, "Up", true)] + [TestCase("!=", VerticalDirection.Down, "Down", false)] + [TestCase("!=", VerticalDirection.Down, "Flat", true)] + [TestCase("!=", VerticalDirection.Down, "Up", true)] + [TestCase("!=", VerticalDirection.Flat, "Down", true)] + [TestCase("!=", VerticalDirection.Flat, "Flat", false)] + [TestCase("!=", VerticalDirection.Flat, "Up", true)] + [TestCase("!=", VerticalDirection.Up, "Down", true)] + [TestCase("!=", VerticalDirection.Up, "Flat", true)] + [TestCase("!=", VerticalDirection.Up, "Up", false)] + public void EnumComparisonOperatorsWorkWithString(string @operator, VerticalDirection operand1, string operand2, bool expectedResult) + { + using var _ = Py.GIL(); + using var module = PyModule.FromString("EnumComparisonOperatorsWorkWithString", $@" +from clr import AddReference +AddReference(""Python.EmbeddingTest"") + +from Python.EmbeddingTest import * + +def operation1(): + return {nameof(EnumTests)}.{nameof(VerticalDirection)}.{operand1} {@operator} ""{operand2}"" + +def operation2(): + return ""{operand2}"" {@operator} {nameof(EnumTests)}.{nameof(VerticalDirection)}.{operand1} +"); + + Assert.AreEqual(expectedResult, module.InvokeMethod("operation1").As()); + Assert.AreEqual(expectedResult, module.InvokeMethod("operation2").As()); + } + + public static IEnumerable OtherEnumsComparisonOperatorsTestCases + { + get + { + var operators = new[] { "==", "!=", "<", "<=", ">", ">=" }; + + foreach (var enumValue in VerticalDirectionEnumValues) + { + foreach (var enum2Value in HorizontalDirectionEnumValues) + { + var intEnumValue = Convert.ToInt64(enumValue); + var intEnum2Value = Convert.ToInt64(enum2Value); + + yield return new TestCaseData("==", enumValue, enum2Value, intEnumValue == intEnum2Value, intEnum2Value == intEnumValue); + yield return new TestCaseData("!=", enumValue, enum2Value, intEnumValue != intEnum2Value, intEnum2Value != intEnumValue); + yield return new TestCaseData("<", enumValue, enum2Value, intEnumValue < intEnum2Value, intEnum2Value < intEnumValue); + yield return new TestCaseData("<=", enumValue, enum2Value, intEnumValue <= intEnum2Value, intEnum2Value <= intEnumValue); + yield return new TestCaseData(">", enumValue, enum2Value, intEnumValue > intEnum2Value, intEnum2Value > intEnumValue); + yield return new TestCaseData(">=", enumValue, enum2Value, intEnumValue >= intEnum2Value, intEnum2Value >= intEnumValue); + } + } + } + } + + [TestCaseSource(nameof(OtherEnumsComparisonOperatorsTestCases))] + public void OtherEnumsComparisonOperatorsWorkWithoutExplicitCast(string @operator, VerticalDirection operand1, HorizontalDirection operand2, bool expectedResult, bool invertedOperationExpectedResult) + { + using var _ = Py.GIL(); + using var module = PyModule.FromString("OtherEnumsComparisonOperatorsWorkWithoutExplicitCast", $@" +from clr import AddReference +AddReference(""Python.EmbeddingTest"") + +from Python.EmbeddingTest import * + +def operation1(): + return {nameof(EnumTests)}.{nameof(VerticalDirection)}.{operand1} {@operator} {nameof(EnumTests)}.{nameof(HorizontalDirection)}.{operand2} + +def operation2(): + return {nameof(EnumTests)}.{nameof(HorizontalDirection)}.{operand2} {@operator} {nameof(EnumTests)}.{nameof(VerticalDirection)}.{operand1} +"); + + Assert.AreEqual(expectedResult, module.InvokeMethod("operation1").As()); + Assert.AreEqual(invertedOperationExpectedResult, module.InvokeMethod("operation2").As()); + } + + private static IEnumerable IdentityComparisonTestCases + { + get + { + foreach (var enumValue1 in VerticalDirectionEnumValues) + { + foreach (var enumValue2 in VerticalDirectionEnumValues) + { + if (enumValue2 != enumValue1) + { + yield return new TestCaseData(enumValue1, enumValue2); + } + } + } + } + } + + [TestCaseSource(nameof(IdentityComparisonTestCases))] + public void CSharpEnumsAreSingletonsInPthonAndIdentityComparisonWorks(VerticalDirection enumValue1, VerticalDirection enumValue2) + { + var enumValue1Str = $"{nameof(EnumTests)}.{nameof(VerticalDirection)}.{enumValue1}"; + var enumValue2Str = $"{nameof(EnumTests)}.{nameof(VerticalDirection)}.{enumValue2}"; + + using var _ = Py.GIL(); + using var module = PyModule.FromString("CSharpEnumsAreSingletonsInPthonAndIdentityComparisonWorks", $@" +from clr import AddReference +AddReference(""Python.EmbeddingTest"") + +from Python.EmbeddingTest import * + +def are_same1(): + return {enumValue1Str} is {enumValue1Str} + +def are_same2(): + enum_value = {enumValue1Str} + return enum_value is {enumValue1Str} + +def are_same3(): + enum_value = {enumValue1Str} + return {enumValue1Str} is enum_value + +def are_same4(): + enum_value1 = {enumValue1Str} + enum_value2 = {enumValue1Str} + return enum_value1 is enum_value2 + +def are_not_same1(): + return {enumValue1Str} is not {enumValue2Str} + +def are_not_same2(): + enum_value = {enumValue1Str} + return enum_value is not {enumValue2Str} + +def are_not_same3(): + enum_value = {enumValue2Str} + return {enumValue1Str} is not enum_value + +def are_not_same4(): + enum_value1 = {enumValue1Str} + enum_value2 = {enumValue2Str} + return enum_value1 is not enum_value2 + + +"); + + Assert.IsTrue(module.InvokeMethod("are_same1").As()); + Assert.IsTrue(module.InvokeMethod("are_same2").As()); + Assert.IsTrue(module.InvokeMethod("are_same3").As()); + Assert.IsTrue(module.InvokeMethod("are_same4").As()); + + Assert.IsTrue(module.InvokeMethod("are_not_same1").As()); + Assert.IsTrue(module.InvokeMethod("are_not_same2").As()); + Assert.IsTrue(module.InvokeMethod("are_not_same3").As()); + Assert.IsTrue(module.InvokeMethod("are_not_same4").As()); + } + + [Test] + public void IdentityComparisonBetweenDifferentEnumTypesIsNeverTrue( + [ValueSource(nameof(VerticalDirectionEnumValues))] VerticalDirection enumValue1, + [ValueSource(nameof(HorizontalDirectionEnumValues))] HorizontalDirection enumValue2) + { + var enumValue1Str = $"{nameof(EnumTests)}.{nameof(VerticalDirection)}.{enumValue1}"; + var enumValue2Str = $"{nameof(EnumTests)}.{nameof(HorizontalDirection)}.{enumValue2}"; + + using var _ = Py.GIL(); + using var module = PyModule.FromString("IdentityComparisonBetweenDifferentEnumTypesIsNeverTrue", $@" +from clr import AddReference +AddReference(""Python.EmbeddingTest"") + +from Python.EmbeddingTest import * + +enum_value1 = {enumValue1Str} +enum_value2 = {enumValue2Str} + +def are_same1(): + return {enumValue1Str} is {enumValue2Str} + +def are_same2(): + return enum_value1 is {enumValue2Str} + +def are_same3(): + return {enumValue2Str} is enum_value1 + +def are_same4(): + return enum_value2 is {enumValue1Str} + +def are_same5(): + return {enumValue1Str} is enum_value2 + +def are_same6(): + return enum_value1 is enum_value2 + +def are_same7(): + return enum_value2 is enum_value1 +"); + + Assert.IsFalse(module.InvokeMethod("are_same1").As()); + Assert.IsFalse(module.InvokeMethod("are_same2").As()); + Assert.IsFalse(module.InvokeMethod("are_same3").As()); + Assert.IsFalse(module.InvokeMethod("are_same4").As()); + Assert.IsFalse(module.InvokeMethod("are_same5").As()); + Assert.IsFalse(module.InvokeMethod("are_same6").As()); + Assert.IsFalse(module.InvokeMethod("are_same7").As()); + } + + private PyModule GetCSharpObjectsComparisonTestModule(string @operator) + { + return PyModule.FromString("GetCSharpObjectsComparisonTestModule", $@" +from clr import AddReference +AddReference(""Python.EmbeddingTest"") + +from Python.EmbeddingTest import * + +enum_value = {nameof(EnumTests)}.{nameof(VerticalDirection)}.{VerticalDirection.Up} + +def compare_with_none1(): + return enum_value {@operator} None + +def compare_with_none2(): + return None {@operator} enum_value + +def compare_with_csharp_object1(csharp_object): + return enum_value {@operator} csharp_object + +def compare_with_csharp_object2(csharp_object): + return csharp_object {@operator} enum_value +"); + } + + [TestCase("==", false)] + [TestCase("!=", true)] + public void EqualityComparisonWithNull(string @operator, bool expectedResult) + { + using var _ = Py.GIL(); + using var module = GetCSharpObjectsComparisonTestModule(@operator); + + Assert.AreEqual(expectedResult, module.InvokeMethod("compare_with_none1").As()); + Assert.AreEqual(expectedResult, module.InvokeMethod("compare_with_none2").As()); + + using var pyNull = ((TestClass)null).ToPython(); + Assert.AreEqual(expectedResult, module.InvokeMethod("compare_with_csharp_object1", pyNull).As()); + Assert.AreEqual(expectedResult, module.InvokeMethod("compare_with_csharp_object2", pyNull).As()); + } + + [TestCase("==", false)] + [TestCase("!=", true)] + public void EqualityOperatorsWithNonEnumObjects(string @operator, bool expectedResult) + { + using var _ = Py.GIL(); + using var module = GetCSharpObjectsComparisonTestModule(@operator); + + using var pyCSharpObject = new TestClass().ToPython(); + Assert.AreEqual(expectedResult, module.InvokeMethod("compare_with_csharp_object1", pyCSharpObject).As()); + Assert.AreEqual(expectedResult, module.InvokeMethod("compare_with_csharp_object2", pyCSharpObject).As()); + } + + [Test] + public void ThrowsOnObjectComparisonOperators([Values("<", "<=", ">", ">=")] string @operator) + { + using var _ = Py.GIL(); + using var module = GetCSharpObjectsComparisonTestModule(@operator); + + using var pyCSharpObject = new TestClass().ToPython(); + Assert.Throws(() => module.InvokeMethod("compare_with_csharp_object1", pyCSharpObject)); + Assert.Throws(() => module.InvokeMethod("compare_with_csharp_object2", pyCSharpObject)); + } + + [Test] + public void ThrowsOnNullComparisonOperators([Values("<", "<=", ">", ">=")] string @operator) + { + using var _ = Py.GIL(); + using var module = GetCSharpObjectsComparisonTestModule(@operator); + + Assert.Throws(() => module.InvokeMethod("compare_with_none1").As()); + Assert.Throws(() => module.InvokeMethod("compare_with_none2").As()); + + using var pyNull = ((TestClass)null).ToPython(); + Assert.Throws(() => module.InvokeMethod("compare_with_csharp_object1", pyNull)); + Assert.Throws(() => module.InvokeMethod("compare_with_csharp_object2", pyNull)); + } + + public class TestClass + { + } + } +} diff --git a/src/embed_tests/TestMethodBinder.cs b/src/embed_tests/TestMethodBinder.cs index 7f4c58d7e..0b3f6497c 100644 --- a/src/embed_tests/TestMethodBinder.cs +++ b/src/embed_tests/TestMethodBinder.cs @@ -815,6 +815,20 @@ public string VariableArgumentsMethod(params PyObject[] paramsParams) return "VariableArgumentsMethod(PyObject[])"; } + // ---- + + public string MethodWithEnumParam(SomeEnu enumValue, string symbol) + { + return $"MethodWithEnumParam With Enum"; + } + + public string MethodWithEnumParam(PyObject pyObject, string symbol) + { + return $"MethodWithEnumParam With PyObject"; + } + + // ---- + public string ConstructorMessage { get; set; } public OverloadsTestClass(params CSharpModel[] paramsParams) @@ -1117,6 +1131,26 @@ def get_instance(): Assert.AreEqual("OverloadsTestClass(PyObject[])", instance.GetAttr("ConstructorMessage").As()); } + [Test] + public void EnumHasPrecedenceOverPyObject() + { + using var _ = Py.GIL(); + + var module = PyModule.FromString("EnumHasPrecedenceOverPyObject", @$" +from clr import AddReference +AddReference(""System"") +from Python.EmbeddingTest import * + +class PythonModel(TestMethodBinder.CSharpModel): + pass + +def call_method(): + return TestMethodBinder.OverloadsTestClass().MethodWithEnumParam(TestMethodBinder.SomeEnu.A, ""Some string"") +"); + + var result = module.GetAttr("call_method").Invoke(); + Assert.AreEqual("MethodWithEnumParam With Enum", result.As()); + } // Used to test that we match this function with Py DateTime & Date Objects public static int GetMonth(DateTime test) diff --git a/src/perf_tests/Python.PerformanceTests.csproj b/src/perf_tests/Python.PerformanceTests.csproj index ee239ff12..aa3a04adb 100644 --- a/src/perf_tests/Python.PerformanceTests.csproj +++ b/src/perf_tests/Python.PerformanceTests.csproj @@ -13,7 +13,7 @@ runtime; build; native; contentfiles; analyzers; buildtransitive - + compile @@ -25,7 +25,7 @@ - + diff --git a/src/runtime/Converter.cs b/src/runtime/Converter.cs index 19fb1c883..fc6437bc1 100644 --- a/src/runtime/Converter.cs +++ b/src/runtime/Converter.cs @@ -18,6 +18,13 @@ namespace Python.Runtime [SuppressUnmanagedCodeSecurity] internal class Converter { + /// + /// We use a cache of the enum values references so that we treat them as singletons in Python. + /// We just try to mimic Python enums behavior, since Python enum values are singletons, + /// so the `is` identity comparison operator works for C# enums as well. + /// + + private static readonly Dictionary _enumCache = new(); private Converter() { } @@ -226,6 +233,16 @@ internal static NewReference ToPython(object? value, Type type) return resultlist.NewReferenceOrNull(); } + if (type.IsEnum) + { + if (!_enumCache.TryGetValue(value, out var cachedValue)) + { + _enumCache[value] = cachedValue = CLRObject.GetReference(value, type).MoveToPyObject(); + } + + return cachedValue.NewReferenceOrNull(); + } + // it the type is a python subclass of a managed type then return the // underlying python object rather than construct a new wrapper object. var pyderived = value as IPythonDerivedType; diff --git a/src/runtime/MethodBinder.cs b/src/runtime/MethodBinder.cs index 8c8bac65d..42fe0ba91 100644 --- a/src/runtime/MethodBinder.cs +++ b/src/runtime/MethodBinder.cs @@ -383,6 +383,17 @@ internal static int ArgPrecedence(Type t, bool isOperatorMethod) return 3000; } + if (t.IsGenericType && t.GetGenericTypeDefinition() == typeof(Nullable<>)) + { + // Nullable is a special case, we treat it as the underlying type + return ArgPrecedence(Nullable.GetUnderlyingType(t), isOperatorMethod); + } + + if (t.IsEnum) + { + return -2; + } + if (t.IsAssignableFrom(typeof(PyObject)) && !isOperatorMethod) { return -1; diff --git a/src/runtime/Properties/AssemblyInfo.cs b/src/runtime/Properties/AssemblyInfo.cs index c3e7c304f..6941d1ac1 100644 --- a/src/runtime/Properties/AssemblyInfo.cs +++ b/src/runtime/Properties/AssemblyInfo.cs @@ -4,5 +4,5 @@ [assembly: InternalsVisibleTo("Python.EmbeddingTest, PublicKey=00240000048000009400000006020000002400005253413100040000110000005ffd8f49fb44ab0641b3fd8d55e749f716e6dd901032295db641eb98ee46063cbe0d4a1d121ef0bc2af95f8a7438d7a80a3531316e6b75c2dae92fb05a99f03bf7e0c03980e1c3cfb74ba690aca2f3339ef329313bcc5dccced125a4ffdc4531dcef914602cd5878dc5fbb4d4c73ddfbc133f840231343e013762884d6143189")] [assembly: InternalsVisibleTo("Python.Test, PublicKey=00240000048000009400000006020000002400005253413100040000110000005ffd8f49fb44ab0641b3fd8d55e749f716e6dd901032295db641eb98ee46063cbe0d4a1d121ef0bc2af95f8a7438d7a80a3531316e6b75c2dae92fb05a99f03bf7e0c03980e1c3cfb74ba690aca2f3339ef329313bcc5dccced125a4ffdc4531dcef914602cd5878dc5fbb4d4c73ddfbc133f840231343e013762884d6143189")] -[assembly: AssemblyVersion("2.0.44")] -[assembly: AssemblyFileVersion("2.0.44")] +[assembly: AssemblyVersion("2.0.45")] +[assembly: AssemblyFileVersion("2.0.45")] diff --git a/src/runtime/Python.Runtime.csproj b/src/runtime/Python.Runtime.csproj index 9b870ed44..035bc6214 100644 --- a/src/runtime/Python.Runtime.csproj +++ b/src/runtime/Python.Runtime.csproj @@ -5,7 +5,7 @@ Python.Runtime Python.Runtime QuantConnect.pythonnet - 2.0.44 + 2.0.45 false LICENSE https://github.com/pythonnet/pythonnet diff --git a/src/runtime/Util/OpsHelper.cs b/src/runtime/Util/OpsHelper.cs index ab623f3de..89ce79e20 100644 --- a/src/runtime/Util/OpsHelper.cs +++ b/src/runtime/Util/OpsHelper.cs @@ -1,6 +1,7 @@ using System; using System.Linq.Expressions; using System.Reflection; +using System.Runtime.CompilerServices; using static Python.Runtime.OpsHelper; @@ -35,7 +36,7 @@ public static Expression EnumUnderlyingValue(Expression enumValue) } [AttributeUsage(AttributeTargets.Class, AllowMultiple = false)] - internal class OpsAttribute: Attribute { } + internal class OpsAttribute : Attribute { } [Ops] internal static class FlagEnumOps where T : Enum @@ -78,12 +79,505 @@ static Func UnaryOp(Func op) [Ops] internal static class EnumOps where T : Enum { + private static bool IsUnsigned = typeof(T).GetEnumUnderlyingType() == typeof(UInt64); + [ForbidPythonThreads] #pragma warning disable IDE1006 // Naming Styles - must match Python public static PyInt __int__(T value) #pragma warning restore IDE1006 // Naming Styles - => typeof(T).GetEnumUnderlyingType() == typeof(UInt64) + => IsUnsigned ? new PyInt(Convert.ToUInt64(value)) : new PyInt(Convert.ToInt64(value)); + + #region Arithmetic operators + + public static double op_Addition(T a, double b) + { + if (IsUnsigned) + { + return Convert.ToUInt64(a) + b; + } + return Convert.ToInt64(a) + b; + } + + public static double op_Addition(double a, T b) + { + return op_Addition(b, a); + } + + public static double op_Subtraction(T a, double b) + { + if (IsUnsigned) + { + return Convert.ToUInt64(a) - b; + } + return Convert.ToInt64(a) - b; + } + + public static double op_Subtraction(double a, T b) + { + if (IsUnsigned) + { + return a - Convert.ToUInt64(b); + } + return a - Convert.ToInt64(b); + } + + public static double op_Multiply(T a, double b) + { + if (IsUnsigned) + { + return Convert.ToUInt64(a) * b; + } + return Convert.ToInt64(a) * b; + } + + public static double op_Multiply(double a, T b) + { + return op_Multiply(b, a); + } + + public static double op_Division(T a, double b) + { + if (IsUnsigned) + { + return Convert.ToUInt64(a) / b; + } + return Convert.ToInt64(a) / b; + } + + public static double op_Division(double a, T b) + { + if (IsUnsigned) + { + return a / Convert.ToUInt64(b); + } + return a / Convert.ToInt64(b); + } + + #endregion + + #region Int comparison operators + + public static bool op_Equality(T a, long b) + { + if (IsUnsigned) + { + var uvalue = Convert.ToUInt64(a); + return b >= 0 && ((ulong)b) == uvalue; + } + return Convert.ToInt64(a) == b; + } + + public static bool op_Equality(T a, ulong b) + { + if (IsUnsigned) + { + var uvalue = Convert.ToUInt64(a); + return b == uvalue; + } + var ivalue = Convert.ToInt64(a); + return ivalue >= 0 && ((ulong)ivalue) == b; + } + + public static bool op_Equality(long a, T b) + { + return op_Equality(b, a); + } + + public static bool op_Equality(ulong a, T b) + { + return op_Equality(b, a); + } + + public static bool op_Inequality(T a, long b) + { + return !op_Equality(a, b); + } + + public static bool op_Inequality(T a, ulong b) + { + return !op_Equality(a, b); + } + + public static bool op_Inequality(long a, T b) + { + return !op_Equality(b, a); + } + + public static bool op_Inequality(ulong a, T b) + { + return !op_Equality(b, a); + } + + public static bool op_LessThan(T a, long b) + { + if (IsUnsigned) + { + var uvalue = Convert.ToUInt64(a); + return b >= 0 && ((ulong)b) > uvalue; + } + return Convert.ToInt64(a) < b; + } + + public static bool op_LessThan(T a, ulong b) + { + if (IsUnsigned) + { + var uvalue = Convert.ToUInt64(a); + return b > uvalue; + } + var ivalue = Convert.ToInt64(a); + return ivalue >= 0 && ((ulong)ivalue) < b; + } + + public static bool op_LessThan(long a, T b) + { + return op_GreaterThan(b, a); + } + + public static bool op_LessThan(ulong a, T b) + { + return op_GreaterThan(b, a); + } + + public static bool op_GreaterThan(T a, long b) + { + if (IsUnsigned) + { + var uvalue = Convert.ToUInt64(a); + return b >= 0 && ((ulong)b) < uvalue; + } + return Convert.ToInt64(a) > b; + } + + public static bool op_GreaterThan(T a, ulong b) + { + if (IsUnsigned) + { + var uvalue = Convert.ToUInt64(a); + return b < uvalue; + } + var ivalue = Convert.ToInt64(a); + return ivalue >= 0 && ((ulong)ivalue) > b; + } + + public static bool op_GreaterThan(long a, T b) + { + return op_LessThan(b, a); + } + + public static bool op_GreaterThan(ulong a, T b) + { + return op_LessThan(b, a); + } + + public static bool op_LessThanOrEqual(T a, long b) + { + if (IsUnsigned) + { + var uvalue = Convert.ToUInt64(a); + return b >= 0 && ((ulong)b) >= uvalue; + } + return Convert.ToInt64(a) <= b; + } + + public static bool op_LessThanOrEqual(T a, ulong b) + { + if (IsUnsigned) + { + var uvalue = Convert.ToUInt64(a); + return b >= uvalue; + } + var ivalue = Convert.ToInt64(a); + return ivalue >= 0 && ((ulong)ivalue) <= b; + } + + public static bool op_LessThanOrEqual(long a, T b) + { + return op_GreaterThanOrEqual(b, a); + } + + public static bool op_LessThanOrEqual(ulong a, T b) + { + return op_GreaterThanOrEqual(b, a); + } + + public static bool op_GreaterThanOrEqual(T a, long b) + { + if (IsUnsigned) + { + var uvalue = Convert.ToUInt64(a); + return b >= 0 && ((ulong)b) <= uvalue; + } + return Convert.ToInt64(a) >= b; + } + + public static bool op_GreaterThanOrEqual(T a, ulong b) + { + if (IsUnsigned) + { + var uvalue = Convert.ToUInt64(a); + return b <= uvalue; + } + var ivalue = Convert.ToInt64(a); + return ivalue >= 0 && ((ulong)ivalue) >= b; + } + + public static bool op_GreaterThanOrEqual(long a, T b) + { + return op_LessThanOrEqual(b, a); + } + + public static bool op_GreaterThanOrEqual(ulong a, T b) + { + return op_LessThanOrEqual(b, a); + } + + #endregion + + #region Double comparison operators + + public static bool op_Equality(T a, double b) + { + if (IsUnsigned) + { + return Convert.ToUInt64(a) == b; + } + return Convert.ToInt64(a) == b; + } + + public static bool op_Equality(double a, T b) + { + return op_Equality(b, a); + } + + public static bool op_Inequality(T a, double b) + { + return !op_Equality(a, b); + } + + public static bool op_Inequality(double a, T b) + { + return !op_Equality(b, a); + } + + public static bool op_LessThan(T a, double b) + { + if (IsUnsigned) + { + return Convert.ToUInt64(a) < b; + } + return Convert.ToInt64(a) < b; + } + + public static bool op_LessThan(double a, T b) + { + return op_GreaterThan(b, a); + } + + public static bool op_GreaterThan(T a, double b) + { + if (IsUnsigned) + { + return Convert.ToUInt64(a) > b; + } + return Convert.ToInt64(a) > b; + } + + public static bool op_GreaterThan(double a, T b) + { + return op_LessThan(b, a); + } + + public static bool op_LessThanOrEqual(T a, double b) + { + if (IsUnsigned) + { + return Convert.ToUInt64(a) <= b; + } + return Convert.ToInt64(a) <= b; + } + + public static bool op_LessThanOrEqual(double a, T b) + { + return op_GreaterThanOrEqual(b, a); + } + + public static bool op_GreaterThanOrEqual(T a, double b) + { + if (IsUnsigned) + { + return Convert.ToUInt64(a) >= b; + } + return Convert.ToInt64(a) >= b; + } + + public static bool op_GreaterThanOrEqual(double a, T b) + { + return op_LessThanOrEqual(b, a); + } + + #endregion + + #region String comparison operators + public static bool op_Equality(T a, string b) + { + return a.ToString().Equals(b, StringComparison.InvariantCultureIgnoreCase); + } + public static bool op_Equality(string a, T b) + { + return op_Equality(b, a); + } + + public static bool op_Inequality(T a, string b) + { + return !op_Equality(a, b); + } + + public static bool op_Inequality(string a, T b) + { + return !op_Equality(b, a); + } + + #endregion + + #region Enum comparison operators + + public static bool op_Equality(T a, Enum b) + { + if (b == null) + { + return false; + } + + if (b.GetType().GetEnumUnderlyingType() == typeof(UInt64)) + { + return op_Equality(a, Convert.ToUInt64(b)); + } + return op_Equality(a, Convert.ToInt64(b)); + } + + public static bool op_Equality(Enum a, T b) + { + return op_Equality(b, a); + } + + public static bool op_Inequality(T a, Enum b) + { + return !op_Equality(a, b); + } + + public static bool op_Inequality(Enum a, T b) + { + return !op_Equality(b, a); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static void ThrowOnNull(object obj, string @operator) + { + if (obj == null) + { + using (Py.GIL()) + { + Exceptions.RaiseTypeError($"'{@operator}' not supported between instances of '{typeof(T).Name}' and null/None"); + PythonException.ThrowLastAsClrException(); + } + } + } + + public static bool op_LessThan(T a, Enum b) + { + ThrowOnNull(b, "<"); + + if (b.GetType().GetEnumUnderlyingType() == typeof(UInt64)) + { + return op_LessThan(a, Convert.ToUInt64(b)); + } + return op_LessThan(a, Convert.ToInt64(b)); + } + + public static bool op_LessThan(Enum a, T b) + { + ThrowOnNull(a, "<"); + return op_GreaterThan(b, a); + } + + public static bool op_GreaterThan(T a, Enum b) + { + ThrowOnNull(b, ">"); + + if (b.GetType().GetEnumUnderlyingType() == typeof(UInt64)) + { + return op_GreaterThan(a, Convert.ToUInt64(b)); + } + return op_GreaterThan(a, Convert.ToInt64(b)); + } + + public static bool op_GreaterThan(Enum a, T b) + { + ThrowOnNull(a, ">"); + return op_LessThan(b, a); + } + + public static bool op_LessThanOrEqual(T a, Enum b) + { + ThrowOnNull(b, "<="); + + if (b.GetType().GetEnumUnderlyingType() == typeof(UInt64)) + { + return op_LessThanOrEqual(a, Convert.ToUInt64(b)); + } + return op_LessThanOrEqual(a, Convert.ToInt64(b)); + } + + public static bool op_LessThanOrEqual(Enum a, T b) + { + ThrowOnNull(a, "<="); + return op_GreaterThanOrEqual(b, a); + } + + public static bool op_GreaterThanOrEqual(T a, Enum b) + { + ThrowOnNull(b, ">="); + + if (b.GetType().GetEnumUnderlyingType() == typeof(UInt64)) + { + return op_GreaterThanOrEqual(a, Convert.ToUInt64(b)); + } + return op_GreaterThanOrEqual(a, Convert.ToInt64(b)); + } + + public static bool op_GreaterThanOrEqual(Enum a, T b) + { + ThrowOnNull(a, ">="); + return op_LessThanOrEqual(b, a); + } + + #endregion + + #region Object equality operators + + public static bool op_Equality(T a, object b) + { + return false; + } + + public static bool op_Equality(object a, T b) + { + return false; + } + + public static bool op_Inequality(T a, object b) + { + return true; + } + + public static bool op_Inequality(object a, T b) + { + return true; + } + + #endregion } }