{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE ScopedTypeVariables #-}

{-# OPTIONS_GHC -Wno-orphans #-}

module Ouroboros.Consensus.Util.IOLike (
    IOLike (..)
    -- * Re-exports
    -- *** MonadThrow
  , Exception (..)
  , ExitCase (..)
  , MonadCatch (..)
  , MonadMask (..)
  , MonadThrow (..)
  , SomeException
    -- *** Variables with NoThunks invariants
  , module Ouroboros.Consensus.Util.MonadSTM.NormalForm
  , module Ouroboros.Consensus.Util.NormalForm.StrictMVar
  , module Ouroboros.Consensus.Util.NormalForm.StrictTVar
    -- *** MonadFork, TODO: Should we hide this in favour of MonadAsync?
  , MonadFork (..)
  , MonadThread (..)
  , labelThisThread
    -- *** MonadAsync
  , ExceptionInLinkedThread (..)
  , MonadAsync (..)
  , link
  , linkTo
    -- *** MonadST
  , MonadST (..)
  , PrimMonad (..)
    -- *** MonadTime
  , DiffTime
  , MonadMonotonicTime (..)
  , Time (..)
  , addTime
  , diffTime
    -- *** MonadDelay
  , MonadDelay (..)
    -- *** MonadEventlog
  , MonadEventlog (..)
    -- *** MonadEvaluate
  , MonadEvaluate (..)
    -- *** NoThunks
  , NoThunks (..)
  ) where

import           Cardano.Crypto.KES (KESAlgorithm, SignKeyKES)
import qualified Cardano.Crypto.KES as KES
import           Control.Applicative (Alternative)
import           Control.Concurrent.Class.MonadMVar (MonadInspectMVar (..))
import qualified Control.Concurrent.Class.MonadMVar.Strict as Strict
import qualified Control.Concurrent.Class.MonadSTM.Strict as StrictSTM
import           Control.Monad.Base (MonadBase)
import           Control.Monad.Class.MonadAsync
import           Control.Monad.Class.MonadEventlog
import           Control.Monad.Class.MonadFork
import           Control.Monad.Class.MonadST
import           Control.Monad.Class.MonadThrow
import           Control.Monad.Class.MonadTime.SI
import           Control.Monad.Class.MonadTimer.SI
import           Control.Monad.Primitive
import           Data.Functor (void)
import           Data.Proxy (Proxy (..))
import           NoThunks.Class (NoThunks (..))
import           Ouroboros.Consensus.Util.MonadSTM.NormalForm
import           Ouroboros.Consensus.Util.NormalForm.StrictMVar
import           Ouroboros.Consensus.Util.NormalForm.StrictTVar
import           Ouroboros.Consensus.Util.Orphans ()


{-------------------------------------------------------------------------------
  IOLike
-------------------------------------------------------------------------------}

class ( MonadAsync              m
      , MonadLabelledSTM        m
      , MonadTraceSTM           m
      , MonadMVar               m
      , MonadEventlog           m
      , MonadFork               m
      , MonadST                 m
      , MonadDelay              m
      , MonadThread             m
      , MonadThrow              m
      , MonadCatch              m
      , MonadMask               m
      , MonadMonotonicTime      m
      , MonadEvaluate           m
      , MonadTraceSTM           m
      , Alternative        (STM m)
      , MonadCatch         (STM m)
      , PrimMonad               m
      , MonadLabelledSTM        m
      , MonadBase               m m
      , forall a. NoThunks (m a)
      , forall a. NoThunks a => NoThunks (StrictSTM.StrictTVar m a)
      , forall a. NoThunks a => NoThunks (StrictSVar m a)
      , forall a. NoThunks a => NoThunks (Strict.StrictMVar m a)
      , forall a. NoThunks a => NoThunks (StrictTVar m a)
      , forall a. NoThunks a => NoThunks (StrictMVar m a)
      , forall a. NoThunks a => NoThunks (StrictSTM.StrictTMVar m a)
      ) => IOLike m where
  -- | Securely forget a KES signing key.
  --
  -- No-op for the IOSim, but 'KES.forgetSignKeyKES' for IO.
  forgetSignKeyKES :: KESAlgorithm v => SignKeyKES v -> m ()

instance IOLike IO where
  forgetSignKeyKES :: forall v. KESAlgorithm v => SignKeyKES v -> IO ()
forgetSignKeyKES = SignKeyKES v -> IO ()
forall v (m :: * -> *).
(KESAlgorithm v, MonadST m, MonadThrow m) =>
SignKeyKES v -> m ()
KES.forgetSignKeyKES

-- | Generalization of 'link' that links an async to an arbitrary thread.
--
-- Non standard (not in 'async' library)
--
linkTo :: (MonadAsync m, MonadFork m, MonadMask m)
       => ThreadId m -> Async m a -> m ()
linkTo :: forall (m :: * -> *) a.
(MonadAsync m, MonadFork m, MonadMask m) =>
ThreadId m -> Async m a -> m ()
linkTo ThreadId m
tid = ThreadId m -> (SomeException -> Bool) -> Async m a -> m ()
forall (m :: * -> *) a.
(MonadAsync m, MonadFork m, MonadMask m) =>
ThreadId m -> (SomeException -> Bool) -> Async m a -> m ()
linkToOnly ThreadId m
tid (Bool -> Bool
not (Bool -> Bool) -> (SomeException -> Bool) -> SomeException -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SomeException -> Bool
isCancel)

-- | Generalization of 'linkOnly' that links an async to an arbitrary thread.
--
-- Non standard (not in 'async' library).
--
linkToOnly :: forall m a. (MonadAsync m, MonadFork m, MonadMask m)
           => ThreadId m -> (SomeException -> Bool) -> Async m a -> m ()
linkToOnly :: forall (m :: * -> *) a.
(MonadAsync m, MonadFork m, MonadMask m) =>
ThreadId m -> (SomeException -> Bool) -> Async m a -> m ()
linkToOnly ThreadId m
tid SomeException -> Bool
shouldThrow Async m a
a = do
    m (ThreadId m) -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m (ThreadId m) -> m ()) -> m (ThreadId m) -> m ()
forall a b. (a -> b) -> a -> b
$ String -> m () -> m (ThreadId m)
forall (m :: * -> *) a.
(MonadFork m, MonadMask m) =>
String -> m a -> m (ThreadId m)
forkRepeat (String
"linkToOnly " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> ThreadId m -> String
forall a. Show a => a -> String
show ThreadId m
linkedThreadId) (m () -> m (ThreadId m)) -> m () -> m (ThreadId m)
forall a b. (a -> b) -> a -> b
$ do
      r <- Async m a -> m (Either SomeException a)
forall a. Async m a -> m (Either SomeException a)
forall (m :: * -> *) a.
MonadAsync m =>
Async m a -> m (Either SomeException a)
waitCatch Async m a
a
      case r of
        Left SomeException
e | SomeException -> Bool
shouldThrow SomeException
e -> ThreadId m -> ExceptionInLinkedThread -> m ()
forall e. Exception e => ThreadId m -> e -> m ()
forall (m :: * -> *) e.
(MonadFork m, Exception e) =>
ThreadId m -> e -> m ()
throwTo ThreadId m
tid (SomeException -> ExceptionInLinkedThread
exceptionInLinkedThread SomeException
e)
        Either SomeException a
_otherwise             -> () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  where
    linkedThreadId :: ThreadId m
    linkedThreadId :: ThreadId m
linkedThreadId = Async m a -> ThreadId m
forall a. Async m a -> ThreadId m
forall (m :: * -> *) a. MonadAsync m => Async m a -> ThreadId m
asyncThreadId Async m a
a

    exceptionInLinkedThread :: SomeException -> ExceptionInLinkedThread
    exceptionInLinkedThread :: SomeException -> ExceptionInLinkedThread
exceptionInLinkedThread =
        String -> SomeException -> ExceptionInLinkedThread
ExceptionInLinkedThread (ThreadId m -> String
forall a. Show a => a -> String
show ThreadId m
linkedThreadId)

isCancel :: SomeException -> Bool
isCancel :: SomeException -> Bool
isCancel SomeException
e
  | Just AsyncCancelled
AsyncCancelled <- SomeException -> Maybe AsyncCancelled
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
e = Bool
True
  | Bool
otherwise = Bool
False

forkRepeat :: (MonadFork m, MonadMask m) => String -> m a -> m (ThreadId m)
forkRepeat :: forall (m :: * -> *) a.
(MonadFork m, MonadMask m) =>
String -> m a -> m (ThreadId m)
forkRepeat String
label m a
action =
  ((forall a. m a -> m a) -> m (ThreadId m)) -> m (ThreadId m)
forall b. ((forall a. m a -> m a) -> m b) -> m b
forall (m :: * -> *) b.
MonadMask m =>
((forall a. m a -> m a) -> m b) -> m b
mask (((forall a. m a -> m a) -> m (ThreadId m)) -> m (ThreadId m))
-> ((forall a. m a -> m a) -> m (ThreadId m)) -> m (ThreadId m)
forall a b. (a -> b) -> a -> b
$ \forall a. m a -> m a
restore ->
    let go :: m ()
go = do r <- m a -> m (Either SomeException a)
forall (m :: * -> *) a.
MonadCatch m =>
m a -> m (Either SomeException a)
tryAll (m a -> m a
forall a. m a -> m a
restore m a
action)
                case r of
                  Left SomeException
_ -> m ()
go
                  Either SomeException a
_      -> () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    in m () -> m (ThreadId m)
forall (m :: * -> *). MonadFork m => m () -> m (ThreadId m)
forkIO (String -> m ()
forall (m :: * -> *). MonadThread m => String -> m ()
labelThisThread String
label m () -> m () -> m ()
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> m ()
go)

tryAll :: MonadCatch m => m a -> m (Either SomeException a)
tryAll :: forall (m :: * -> *) a.
MonadCatch m =>
m a -> m (Either SomeException a)
tryAll = m a -> m (Either SomeException a)
forall e a. Exception e => m a -> m (Either e a)
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> m (Either e a)
try


{-------------------------------------------------------------------------------
  NoThunks instance
-------------------------------------------------------------------------------}

instance NoThunks a => NoThunks (StrictSTM.StrictTVar IO a) where
  showTypeOf :: Proxy (StrictTVar IO a) -> String
showTypeOf Proxy (StrictTVar IO a)
_ = String
"StrictTVar IO"
  wNoThunks :: Context -> StrictTVar IO a -> IO (Maybe ThunkInfo)
wNoThunks Context
ctxt StrictTVar IO a
tv = do
      -- We can't use @atomically $ readTVar ..@ here, as that will lead to a
      -- "Control.Concurrent.STM.atomically was nested" exception.
      a <- StrictTVar IO a -> IO a
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> m a
StrictSTM.readTVarIO StrictTVar IO a
tv
      noThunks ctxt a

instance NoThunks a => NoThunks (Strict.StrictMVar IO a) where
  showTypeOf :: Proxy (StrictMVar IO a) -> String
showTypeOf Proxy (StrictMVar IO a)
_ = String
"StrictMVar IO"
  wNoThunks :: Context -> StrictMVar IO a -> IO (Maybe ThunkInfo)
wNoThunks Context
ctxt StrictMVar IO a
mvar = do
      aMay <- Proxy IO -> MVar IO a -> InspectMVarMonad IO (Maybe a)
forall (m :: * -> *) (proxy :: (* -> *) -> *) a.
MonadInspectMVar m =>
proxy m -> MVar m a -> InspectMVarMonad m (Maybe a)
forall (proxy :: (* -> *) -> *) a.
proxy IO -> MVar IO a -> InspectMVarMonad IO (Maybe a)
inspectMVar (Proxy IO
forall {k} (t :: k). Proxy t
Proxy :: Proxy IO) (StrictMVar IO a -> MVar IO a
forall (m :: * -> *) a. StrictMVar m a -> LazyMVar m a
Strict.toLazyMVar StrictMVar IO a
mvar)
      noThunks ctxt aMay


instance NoThunks a => NoThunks (StrictSTM.StrictTMVar IO a) where
  showTypeOf :: Proxy (StrictTMVar IO a) -> String
showTypeOf Proxy (StrictTMVar IO a)
_ = String
"StrictTMVar IO"
  wNoThunks :: Context -> StrictTMVar IO a -> IO (Maybe ThunkInfo)
wNoThunks Context
ctxt StrictTMVar IO a
t  = do
      a <- Proxy IO -> TMVar IO a -> InspectMonad IO (Maybe a)
forall (m :: * -> *) (proxy :: (* -> *) -> *) a.
MonadInspectSTM m =>
proxy m -> TMVar m a -> InspectMonad m (Maybe a)
forall (proxy :: (* -> *) -> *) a.
proxy IO -> TMVar IO a -> InspectMonad IO (Maybe a)
inspectTMVar (Proxy IO
forall {k} (t :: k). Proxy t
Proxy :: Proxy IO) (TMVar IO a -> InspectMonad IO (Maybe a))
-> TMVar IO a -> InspectMonad IO (Maybe a)
forall a b. (a -> b) -> a -> b
$ StrictTMVar IO a -> TMVar IO a
forall (m :: * -> *) a. StrictTMVar m a -> LazyTMVar m a
toLazyTMVar StrictTMVar IO a
t
      noThunks ctxt a