diff --git a/Data/Pool.hs b/Data/Pool.hs index 6764e8b..c92ad40 100644 --- a/Data/Pool.hs +++ b/Data/Pool.hs @@ -1,9 +1,5 @@ {-# LANGUAGE CPP, NamedFieldPuns, RecordWildCards, ScopedTypeVariables, RankNTypes, DeriveDataTypeable #-} -#if MIN_VERSION_monad_control(0,3,0) -{-# LANGUAGE FlexibleContexts #-} -#endif - #if !MIN_VERSION_base(4,3,0) {-# LANGUAGE RankNTypes #-} #endif @@ -54,24 +50,7 @@ import Data.Typeable (Typeable) import GHC.Conc.Sync (labelThread) import qualified Control.Exception as E import qualified Data.Vector as V - -#if MIN_VERSION_monad_control(0,3,0) -import Control.Monad.Trans.Control (MonadBaseControl, control) -import Control.Monad.Base (liftBase) -#else -import Control.Monad.IO.Control (MonadControlIO, controlIO) -import Control.Monad.IO.Class (liftIO) -#define control controlIO -#define liftBase liftIO -#endif - -#if MIN_VERSION_base(4,3,0) -import Control.Exception (mask) -#else --- Don't do any async exception protection for older GHCs. -mask :: ((forall a. IO a -> IO a) -> IO b) -> IO b -mask f = f id -#endif +import UnliftIO (MonadUnliftIO, mask, withRunInIO) -- | A single resource pool entry. data Entry a = Entry { @@ -247,15 +226,9 @@ purgeLocalPool destroy LocalPool{..} = do -- destroy a pooled resource, as doing so will almost certainly cause -- a subsequent user (who expects the resource to be valid) to throw -- an exception. -withResource :: -#if MIN_VERSION_monad_control(0,3,0) - (MonadBaseControl IO m) -#else - (MonadControlIO m) -#endif - => Pool a -> (a -> m b) -> m b +withResource :: MonadUnliftIO m => Pool a -> (a -> m b) -> m b {-# SPECIALIZE withResource :: Pool a -> (a -> IO b) -> IO b #-} -withResource pool act = control $ \runInIO -> mask $ \restore -> do +withResource pool act = withRunInIO $ \runInIO -> mask $ \restore -> do (resource, local) <- takeResource pool ret <- restore (runInIO (act resource)) `onException` destroyResource pool local resource @@ -275,7 +248,7 @@ withResource pool act = control $ \runInIO -> mask $ \restore -> do takeResource :: Pool a -> IO (a, LocalPool a) takeResource pool@Pool{..} = do local@LocalPool{..} <- getLocalPool pool - resource <- liftBase . join . atomically $ do + resource <- join . atomically $ do ents <- readTVar entries case ents of (Entry{..}:es) -> writeTVar entries es >> return (return entry) @@ -295,14 +268,8 @@ takeResource pool@Pool{..} = do -- returns immediately with 'Nothing' (ie. the action function is /not/ called). -- Conversely, if a resource can be borrowed from the pool without blocking, the -- action is performed and it's result is returned, wrapped in a 'Just'. -tryWithResource :: forall m a b. -#if MIN_VERSION_monad_control(0,3,0) - (MonadBaseControl IO m) -#else - (MonadControlIO m) -#endif - => Pool a -> (a -> m b) -> m (Maybe b) -tryWithResource pool act = control $ \runInIO -> mask $ \restore -> do +tryWithResource :: forall m a b. MonadUnliftIO m => Pool a -> (a -> m b) -> m (Maybe b) +tryWithResource pool act = withRunInIO $ \runInIO -> mask $ \restore -> do res <- tryTakeResource pool case res of Just (resource, local) -> do @@ -321,7 +288,7 @@ tryWithResource pool act = control $ \runInIO -> mask $ \restore -> do tryTakeResource :: Pool a -> IO (Maybe (a, LocalPool a)) tryTakeResource pool@Pool{..} = do local@LocalPool{..} <- getLocalPool pool - resource <- liftBase . join . atomically $ do + resource <- join . atomically $ do ents <- readTVar entries case ents of (Entry{..}:es) -> writeTVar entries es >> return (return . Just $ entry) @@ -343,7 +310,7 @@ tryTakeResource pool@Pool{..} = do -- Internal, just to not repeat code for 'takeResource' and 'tryTakeResource' getLocalPool :: Pool a -> IO (LocalPool a) getLocalPool Pool{..} = do - i <- liftBase $ ((`mod` numStripes) . hash) <$> myThreadId + i <- ((`mod` numStripes) . hash) <$> myThreadId return $ localPools V.! i #if __GLASGOW_HASKELL__ >= 700 {-# INLINABLE getLocalPool #-} diff --git a/resource-pool.cabal b/resource-pool.cabal index 6a9bc09..5234e1d 100644 --- a/resource-pool.cabal +++ b/resource-pool.cabal @@ -32,11 +32,11 @@ library build-depends: base >= 4.4 && < 5, hashable, - monad-control >= 0.2.0.1, transformers, transformers-base >= 0.4, stm >= 2.3, time, + unliftio, vector >= 0.7 if flag(developer)