diff --git a/JSONAPI.EntityFramework/DbContextExtensions.cs b/JSONAPI.EntityFramework/DbContextExtensions.cs index 37724b77..b18d0e1c 100644 --- a/JSONAPI.EntityFramework/DbContextExtensions.cs +++ b/JSONAPI.EntityFramework/DbContextExtensions.cs @@ -21,28 +21,27 @@ public static class DbContextExtensions /// public static IEnumerable GetKeyNames(this DbContext dbContext, Type type) { - if (dbContext == null) throw new ArgumentNullException("dbContext"); - if (type == null) throw new ArgumentNullException("type"); + if (dbContext == null) throw new ArgumentNullException(nameof(dbContext)); + if (type == null) throw new ArgumentNullException(nameof(type)); - var originalType = type; - - while (type != null) + var baseEntityType = type; + while (baseEntityType.BaseType != typeof(Object)) { - var openMethod = typeof(DbContextExtensions).GetMethod("GetKeyNamesFromGeneric", BindingFlags.Public | BindingFlags.Static); - var method = openMethod.MakeGenericMethod(type); - - try - { - return (IEnumerable) method.Invoke(null, new object[] {dbContext}); - } - catch (TargetInvocationException) - { - } + baseEntityType = baseEntityType.BaseType; + } + + var openMethod = typeof(DbContextExtensions).GetMethod("GetKeyNamesFromGeneric", BindingFlags.Public | BindingFlags.Static); + var method = openMethod.MakeGenericMethod(baseEntityType); - type = type.BaseType; + try + { + return (IEnumerable) method.Invoke(null, new object[] {dbContext}); + } + catch (TargetInvocationException) + { } - throw new Exception(string.Format("Failed to identify the key names for {0} or any of its parent classes.", originalType.Name)); + throw new Exception(string.Format("Failed to identify the key names for {0} or any of its parent classes.", type.Name)); } /// @@ -59,7 +58,6 @@ public static IEnumerable GetKeyNamesFromGeneric(this DbContext dbCon try { objectSet = objectContext.CreateObjectSet(); - } catch (InvalidOperationException e) {