diff --git a/quickfixj-core/src/main/java/quickfix/DefaultMessageFactory.java b/quickfixj-core/src/main/java/quickfix/DefaultMessageFactory.java index d55440bc0..6bc17a42e 100644 --- a/quickfixj-core/src/main/java/quickfix/DefaultMessageFactory.java +++ b/quickfixj-core/src/main/java/quickfix/DefaultMessageFactory.java @@ -118,8 +118,12 @@ public void addFactory(String beginString, String factoryClassName) throws Class // try using our own classloader factoryClass = (Class) Class.forName(factoryClassName); } catch (ClassNotFoundException e) { - // try using context classloader (i.e. allow caller to specify it) - Thread.currentThread().getContextClassLoader().loadClass(factoryClassName); + // try using context classloader (i.e. allow caller to specify it) + ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader(); + + if (contextClassLoader != null) { + factoryClass = (Class) contextClassLoader.loadClass(factoryClassName); + } } // if factory is found, add it if (factoryClass != null) { diff --git a/quickfixj-core/src/test/java/quickfix/DefaultMessageFactoryTest.java b/quickfixj-core/src/test/java/quickfix/DefaultMessageFactoryTest.java index 4803934bb..a6390464b 100644 --- a/quickfixj-core/src/test/java/quickfix/DefaultMessageFactoryTest.java +++ b/quickfixj-core/src/test/java/quickfix/DefaultMessageFactoryTest.java @@ -1,6 +1,10 @@ package quickfix; import static org.junit.Assert.*; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; import static quickfix.FixVersions.*; import static quickfix.field.ApplVerID.*; @@ -9,6 +13,7 @@ import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import quickfix.field.*; +import quickfix.fix44.MessageFactory; import quickfix.test.util.ExpectedTestFailure; /** @@ -81,6 +86,24 @@ protected void execute() throws Throwable { factory.create(BEGINSTRING_FIX40, MsgType.MARKET_DATA_SNAPSHOT_FULL_REFRESH, NoMDEntries.FIELD)); } + @Test + public void testContextClassLoaderFactory() throws ClassNotFoundException { + ClassLoader customLoader = mock(ClassLoader.class); + doReturn(MessageFactory.class).when(customLoader).loadClass("foo.DefaultMessageFactory"); + + ClassLoader previousClassLoader = Thread.currentThread().getContextClassLoader(); + Thread.currentThread().setContextClassLoader(customLoader); + + try { + factory.addFactory(BEGINSTRING_FIX44, "foo.DefaultMessageFactory"); + } finally { + Thread.currentThread().setContextClassLoader(previousClassLoader); + } + + verify(customLoader).loadClass("foo.DefaultMessageFactory"); + verifyNoMoreInteractions(customLoader); + } + private static void assertMessage(Class expectedMessageClass, String expectedMessageType, Message message) throws Exception { assertEquals(expectedMessageClass, message.getClass()); assertEquals(expectedMessageType, message.getHeader().getString(MsgType.FIELD));