diff --git a/JSONAPI.EntityFramework/DbContextExtensions.cs b/JSONAPI.EntityFramework/DbContextExtensions.cs
index 37724b77..b18d0e1c 100644
--- a/JSONAPI.EntityFramework/DbContextExtensions.cs
+++ b/JSONAPI.EntityFramework/DbContextExtensions.cs
@@ -21,28 +21,27 @@ public static class DbContextExtensions
///
public static IEnumerable GetKeyNames(this DbContext dbContext, Type type)
{
- if (dbContext == null) throw new ArgumentNullException("dbContext");
- if (type == null) throw new ArgumentNullException("type");
+ if (dbContext == null) throw new ArgumentNullException(nameof(dbContext));
+ if (type == null) throw new ArgumentNullException(nameof(type));
- var originalType = type;
-
- while (type != null)
+ var baseEntityType = type;
+ while (baseEntityType.BaseType != typeof(Object))
{
- var openMethod = typeof(DbContextExtensions).GetMethod("GetKeyNamesFromGeneric", BindingFlags.Public | BindingFlags.Static);
- var method = openMethod.MakeGenericMethod(type);
-
- try
- {
- return (IEnumerable) method.Invoke(null, new object[] {dbContext});
- }
- catch (TargetInvocationException)
- {
- }
+ baseEntityType = baseEntityType.BaseType;
+ }
+
+ var openMethod = typeof(DbContextExtensions).GetMethod("GetKeyNamesFromGeneric", BindingFlags.Public | BindingFlags.Static);
+ var method = openMethod.MakeGenericMethod(baseEntityType);
- type = type.BaseType;
+ try
+ {
+ return (IEnumerable) method.Invoke(null, new object[] {dbContext});
+ }
+ catch (TargetInvocationException)
+ {
}
- throw new Exception(string.Format("Failed to identify the key names for {0} or any of its parent classes.", originalType.Name));
+ throw new Exception(string.Format("Failed to identify the key names for {0} or any of its parent classes.", type.Name));
}
///
@@ -59,7 +58,6 @@ public static IEnumerable GetKeyNamesFromGeneric(this DbContext dbCon
try
{
objectSet = objectContext.CreateObjectSet();
-
}
catch (InvalidOperationException e)
{