{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE ScopedTypeVariables #-}

{-# OPTIONS_GHC -Wno-orphans #-}

module Ouroboros.Consensus.Util.IOLike (
    IOLike (..)
    -- * Re-exports
    -- *** MonadThrow
  , Exception (..)
  , ExitCase (..)
  , MonadCatch (..)
  , MonadMask (..)
  , MonadThrow (..)
  , SomeException
    -- *** Variables with NoThunks invariants
  , module Ouroboros.Consensus.Util.MonadSTM.NormalForm
  , module Ouroboros.Consensus.Util.NormalForm.StrictMVar
  , module Ouroboros.Consensus.Util.NormalForm.StrictTVar
    -- *** MonadFork, TODO: Should we hide this in favour of MonadAsync?
  , MonadFork (..)
  , MonadThread (..)
  , labelThisThread
    -- *** MonadAsync
  , ExceptionInLinkedThread (..)
  , MonadAsync (..)
  , link
  , linkTo
    -- *** MonadST
  , MonadST (..)
  , PrimMonad (..)
    -- *** MonadTime
  , DiffTime
  , MonadMonotonicTime (..)
  , Time (..)
  , addTime
  , diffTime
    -- *** MonadDelay
  , MonadDelay (..)
    -- *** MonadEventlog
  , MonadEventlog (..)
    -- *** MonadEvaluate
  , MonadEvaluate (..)
    -- *** NoThunks
  , NoThunks (..)
  ) where

import           Cardano.Crypto.KES (KESAlgorithm, SignKeyKES)
import qualified Cardano.Crypto.KES as KES
import           Control.Applicative (Alternative)
import           Control.Concurrent.Class.MonadMVar (MonadInspectMVar (..))
import qualified Control.Concurrent.Class.MonadMVar.Strict as Strict
import qualified Control.Concurrent.Class.MonadSTM.Strict as StrictSTM
import           Control.Monad.Class.MonadAsync
import           Control.Monad.Class.MonadEventlog
import           Control.Monad.Class.MonadFork
import           Control.Monad.Class.MonadST
import           Control.Monad.Class.MonadThrow
import           Control.Monad.Class.MonadTime.SI
import           Control.Monad.Class.MonadTimer.SI
import           Control.Monad.Primitive
import           Data.Functor (void)
import           Data.Proxy (Proxy (..))
import           NoThunks.Class (NoThunks (..))
import           Ouroboros.Consensus.Util.MonadSTM.NormalForm
import           Ouroboros.Consensus.Util.NormalForm.StrictMVar
import           Ouroboros.Consensus.Util.NormalForm.StrictTVar
import           Ouroboros.Consensus.Util.Orphans ()


{-------------------------------------------------------------------------------
  IOLike
-------------------------------------------------------------------------------}

class ( MonadAsync              m
      , MonadLabelledSTM        m
      , MonadTraceSTM           m
      , MonadMVar               m
      , MonadEventlog           m
      , MonadFork               m
      , MonadST                 m
      , MonadDelay              m
      , MonadThread             m
      , MonadThrow              m
      , MonadCatch              m
      , MonadMask               m
      , MonadMonotonicTime      m
      , MonadEvaluate           m
      , Alternative        (STM m)
      , MonadCatch         (STM m)
      , PrimMonad               m
      , forall a. NoThunks (m a)
      , forall a. NoThunks a => NoThunks (StrictSTM.StrictTVar m a)
      , forall a. NoThunks a => NoThunks (StrictSVar m a)
      , forall a. NoThunks a => NoThunks (Strict.StrictMVar m a)
      , forall a. NoThunks a => NoThunks (StrictTVar m a)
      , forall a. NoThunks a => NoThunks (StrictMVar m a)
      ) => IOLike m where
  -- | Securely forget a KES signing key.
  --
  -- No-op for the IOSim, but 'KES.forgetSignKeyKES' for IO.
  forgetSignKeyKES :: KESAlgorithm v => SignKeyKES v -> m ()

instance IOLike IO where
  forgetSignKeyKES :: forall v. KESAlgorithm v => SignKeyKES v -> IO ()
forgetSignKeyKES = SignKeyKES v -> IO ()
forall v. KESAlgorithm v => SignKeyKES v -> IO ()
KES.forgetSignKeyKES

-- | Generalization of 'link' that links an async to an arbitrary thread.
--
-- Non standard (not in 'async' library)
--
linkTo :: (MonadAsync m, MonadFork m, MonadMask m)
       => ThreadId m -> Async m a -> m ()
linkTo :: forall (m :: * -> *) a.
(MonadAsync m, MonadFork m, MonadMask m) =>
ThreadId m -> Async m a -> m ()
linkTo ThreadId m
tid = ThreadId m -> (SomeException -> Bool) -> Async m a -> m ()
forall (m :: * -> *) a.
(MonadAsync m, MonadFork m, MonadMask m) =>
ThreadId m -> (SomeException -> Bool) -> Async m a -> m ()
linkToOnly ThreadId m
tid (Bool -> Bool
not (Bool -> Bool) -> (SomeException -> Bool) -> SomeException -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SomeException -> Bool
isCancel)

-- | Generalization of 'linkOnly' that links an async to an arbitrary thread.
--
-- Non standard (not in 'async' library).
--
linkToOnly :: forall m a. (MonadAsync m, MonadFork m, MonadMask m)
           => ThreadId m -> (SomeException -> Bool) -> Async m a -> m ()
linkToOnly :: forall (m :: * -> *) a.
(MonadAsync m, MonadFork m, MonadMask m) =>
ThreadId m -> (SomeException -> Bool) -> Async m a -> m ()
linkToOnly ThreadId m
tid SomeException -> Bool
shouldThrow Async m a
a = do
    m (ThreadId m) -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m (ThreadId m) -> m ()) -> m (ThreadId m) -> m ()
forall a b. (a -> b) -> a -> b
$ String -> m () -> m (ThreadId m)
forall (m :: * -> *) a.
(MonadFork m, MonadMask m) =>
String -> m a -> m (ThreadId m)
forkRepeat (String
"linkToOnly " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> ThreadId m -> String
forall a. Show a => a -> String
show ThreadId m
linkedThreadId) (m () -> m (ThreadId m)) -> m () -> m (ThreadId m)
forall a b. (a -> b) -> a -> b
$ do
      Either SomeException a
r <- Async m a -> m (Either SomeException a)
forall a. Async m a -> m (Either SomeException a)
forall (m :: * -> *) a.
MonadAsync m =>
Async m a -> m (Either SomeException a)
waitCatch Async m a
a
      case Either SomeException a
r of
        Left SomeException
e | SomeException -> Bool
shouldThrow SomeException
e -> ThreadId m -> ExceptionInLinkedThread -> m ()
forall e. Exception e => ThreadId m -> e -> m ()
forall (m :: * -> *) e.
(MonadFork m, Exception e) =>
ThreadId m -> e -> m ()
throwTo ThreadId m
tid (SomeException -> ExceptionInLinkedThread
exceptionInLinkedThread SomeException
e)
        Either SomeException a
_otherwise             -> () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  where
    linkedThreadId :: ThreadId m
    linkedThreadId :: ThreadId m
linkedThreadId = Async m a -> ThreadId m
forall a. Async m a -> ThreadId m
forall (m :: * -> *) a. MonadAsync m => Async m a -> ThreadId m
asyncThreadId Async m a
a

    exceptionInLinkedThread :: SomeException -> ExceptionInLinkedThread
    exceptionInLinkedThread :: SomeException -> ExceptionInLinkedThread
exceptionInLinkedThread =
        String -> SomeException -> ExceptionInLinkedThread
ExceptionInLinkedThread (ThreadId m -> String
forall a. Show a => a -> String
show ThreadId m
linkedThreadId)

isCancel :: SomeException -> Bool
isCancel :: SomeException -> Bool
isCancel SomeException
e
  | Just AsyncCancelled
AsyncCancelled <- SomeException -> Maybe AsyncCancelled
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
e = Bool
True
  | Bool
otherwise = Bool
False

forkRepeat :: (MonadFork m, MonadMask m) => String -> m a -> m (ThreadId m)
forkRepeat :: forall (m :: * -> *) a.
(MonadFork m, MonadMask m) =>
String -> m a -> m (ThreadId m)
forkRepeat String
label m a
action =
  ((forall a. m a -> m a) -> m (ThreadId m)) -> m (ThreadId m)
forall b. ((forall a. m a -> m a) -> m b) -> m b
forall (m :: * -> *) b.
MonadMask m =>
((forall a. m a -> m a) -> m b) -> m b
mask (((forall a. m a -> m a) -> m (ThreadId m)) -> m (ThreadId m))
-> ((forall a. m a -> m a) -> m (ThreadId m)) -> m (ThreadId m)
forall a b. (a -> b) -> a -> b
$ \forall a. m a -> m a
restore ->
    let go :: m ()
go = do Either SomeException a
r <- m a -> m (Either SomeException a)
forall (m :: * -> *) a.
MonadCatch m =>
m a -> m (Either SomeException a)
tryAll (m a -> m a
forall a. m a -> m a
restore m a
action)
                case Either SomeException a
r of
                  Left SomeException
_ -> m ()
go
                  Either SomeException a
_      -> () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    in m () -> m (ThreadId m)
forall (m :: * -> *). MonadFork m => m () -> m (ThreadId m)
forkIO (String -> m ()
forall (m :: * -> *). MonadThread m => String -> m ()
labelThisThread String
label m () -> m () -> m ()
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> m ()
go)

tryAll :: MonadCatch m => m a -> m (Either SomeException a)
tryAll :: forall (m :: * -> *) a.
MonadCatch m =>
m a -> m (Either SomeException a)
tryAll = m a -> m (Either SomeException a)
forall e a. Exception e => m a -> m (Either e a)
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> m (Either e a)
try


{-------------------------------------------------------------------------------
  NoThunks instance
-------------------------------------------------------------------------------}

instance NoThunks a => NoThunks (StrictSTM.StrictTVar IO a) where
  showTypeOf :: Proxy (StrictTVar IO a) -> String
showTypeOf Proxy (StrictTVar IO a)
_ = String
"StrictTVar IO"
  wNoThunks :: Context -> StrictTVar IO a -> IO (Maybe ThunkInfo)
wNoThunks Context
ctxt StrictTVar IO a
tv = do
      -- We can't use @atomically $ readTVar ..@ here, as that will lead to a
      -- "Control.Concurrent.STM.atomically was nested" exception.
      a
a <- StrictTVar IO a -> IO a
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> m a
StrictSTM.readTVarIO StrictTVar IO a
tv
      Context -> a -> IO (Maybe ThunkInfo)
forall a. NoThunks a => Context -> a -> IO (Maybe ThunkInfo)
noThunks Context
ctxt a
a

instance NoThunks a => NoThunks (Strict.StrictMVar IO a) where
  showTypeOf :: Proxy (StrictMVar IO a) -> String
showTypeOf Proxy (StrictMVar IO a)
_ = String
"StrictMVar IO"
  wNoThunks :: Context -> StrictMVar IO a -> IO (Maybe ThunkInfo)
wNoThunks Context
ctxt StrictMVar IO a
mvar = do
      Maybe a
aMay <- Proxy IO -> MVar IO a -> InspectMVarMonad IO (Maybe a)
forall (m :: * -> *) (proxy :: (* -> *) -> *) a.
MonadInspectMVar m =>
proxy m -> MVar m a -> InspectMVarMonad m (Maybe a)
forall (proxy :: (* -> *) -> *) a.
proxy IO -> MVar IO a -> InspectMVarMonad IO (Maybe a)
inspectMVar (Proxy IO
forall {k} (t :: k). Proxy t
Proxy :: Proxy IO) (StrictMVar IO a -> MVar IO a
forall (m :: * -> *) a. StrictMVar m a -> LazyMVar m a
Strict.toLazyMVar StrictMVar IO a
mvar)
      Context -> Maybe a -> IO (Maybe ThunkInfo)
forall a. NoThunks a => Context -> a -> IO (Maybe ThunkInfo)
noThunks Context
ctxt Maybe a
aMay