{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- | Mempool with a mocked ledger interface
module Test.Consensus.Mempool.Mocked (
    InitialMempoolAndModelParams (..)
    -- * Mempool with a mocked LedgerDB interface
  , MockedMempool (getMempool)
  , openMockedMempool
  , setLedgerState
    -- * Mempool API functions
  , addTx
  , getTxs
  , removeTxs
  ) where

import           Control.Concurrent.Class.MonadSTM.Strict (StrictTVar,
                     atomically, newTVarIO, readTVar, writeTVar)
import           Control.DeepSeq (NFData (rnf))
import           Control.Tracer (Tracer)
import           Ouroboros.Consensus.HeaderValidation as Header
import           Ouroboros.Consensus.Ledger.Basics (LedgerState)
import qualified Ouroboros.Consensus.Ledger.Basics as Ledger
import qualified Ouroboros.Consensus.Ledger.SupportsMempool as Ledger
import           Ouroboros.Consensus.Mempool (Mempool)
import qualified Ouroboros.Consensus.Mempool as Mempool
import           Ouroboros.Consensus.Mempool.API (AddTxOnBehalfOf,
                     MempoolAddTxResult)

data MockedMempool m blk = MockedMempool {
      forall (m :: * -> *) blk.
MockedMempool m blk -> LedgerInterface m blk
getLedgerInterface :: !(Mempool.LedgerInterface m blk)
    , forall (m :: * -> *) blk.
MockedMempool m blk -> StrictTVar m (LedgerState blk)
getLedgerStateTVar :: !(StrictTVar m (LedgerState blk))
    , forall (m :: * -> *) blk. MockedMempool m blk -> Mempool m blk
getMempool         :: !(Mempool m blk)
    }

instance NFData (MockedMempool m blk) where
  -- TODO: check we're OK with skipping the evaluation of the
  -- MockedMempool. The only data we could force here is the
  -- 'LedgerState' inside 'getLedgerStateTVar', but that would require adding a
  -- 'NFData' constraint and perform unsafe IO. Since we only require this
  -- instance to be able to use
  -- [env](<https://hackage.haskell.org/package/tasty-bench-0.3.3/docs/Test-Tasty-Bench.html#v:env),
  -- and we only care about initializing the mempool before running the
  -- benchmarks, maybe this definition is enough.
  rnf :: MockedMempool m blk -> ()
rnf MockedMempool {} = ()

data InitialMempoolAndModelParams blk = MempoolAndModelParams {
      -- | Initial ledger state for the mocked Ledger DB interface.
      forall blk. InitialMempoolAndModelParams blk -> LedgerState blk
immpInitialState :: !(Ledger.LedgerState blk)
      -- | Ledger configuration, which is needed to open the mempool.
    , forall blk. InitialMempoolAndModelParams blk -> LedgerConfig blk
immpLedgerConfig :: !(Ledger.LedgerConfig blk)
    }

openMockedMempool ::
     ( Ledger.LedgerSupportsMempool blk
     , Ledger.HasTxId (Ledger.GenTx blk)
     , Header.ValidateEnvelope blk
     )
  => Mempool.MempoolCapacityBytesOverride
  -> Tracer IO (Mempool.TraceEventMempool blk)
  -> InitialMempoolAndModelParams blk
  -> IO (MockedMempool IO blk)
openMockedMempool :: forall blk.
(LedgerSupportsMempool blk, HasTxId (GenTx blk),
 ValidateEnvelope blk) =>
MempoolCapacityBytesOverride
-> Tracer IO (TraceEventMempool blk)
-> InitialMempoolAndModelParams blk
-> IO (MockedMempool IO blk)
openMockedMempool MempoolCapacityBytesOverride
capacityOverride Tracer IO (TraceEventMempool blk)
tracer InitialMempoolAndModelParams blk
initialParams = do
    StrictTVar IO (LedgerState blk)
currentLedgerStateTVar <- LedgerState blk -> IO (StrictTVar IO (LedgerState blk))
forall (m :: * -> *) a. MonadSTM m => a -> m (StrictTVar m a)
newTVarIO (InitialMempoolAndModelParams blk -> LedgerState blk
forall blk. InitialMempoolAndModelParams blk -> LedgerState blk
immpInitialState InitialMempoolAndModelParams blk
initialParams)
    let ledgerItf :: LedgerInterface IO blk
ledgerItf = Mempool.LedgerInterface {
            getCurrentLedgerState :: STM IO (LedgerState blk)
Mempool.getCurrentLedgerState = StrictTVar IO (LedgerState blk) -> STM IO (LedgerState blk)
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StrictTVar IO (LedgerState blk)
currentLedgerStateTVar
        }
    Mempool IO blk
mempool <- LedgerInterface IO blk
-> LedgerConfig blk
-> MempoolCapacityBytesOverride
-> Tracer IO (TraceEventMempool blk)
-> IO (Mempool IO blk)
forall (m :: * -> *) blk.
(IOLike m, LedgerSupportsMempool blk, HasTxId (GenTx blk),
 ValidateEnvelope blk) =>
LedgerInterface m blk
-> LedgerConfig blk
-> MempoolCapacityBytesOverride
-> Tracer m (TraceEventMempool blk)
-> m (Mempool m blk)
Mempool.openMempoolWithoutSyncThread
                   LedgerInterface IO blk
ledgerItf
                   (InitialMempoolAndModelParams blk -> LedgerConfig blk
forall blk. InitialMempoolAndModelParams blk -> LedgerConfig blk
immpLedgerConfig InitialMempoolAndModelParams blk
initialParams)
                   MempoolCapacityBytesOverride
capacityOverride
                   Tracer IO (TraceEventMempool blk)
tracer
    MockedMempool IO blk -> IO (MockedMempool IO blk)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure MockedMempool {
        getLedgerInterface :: LedgerInterface IO blk
getLedgerInterface = LedgerInterface IO blk
ledgerItf
      , getLedgerStateTVar :: StrictTVar IO (LedgerState blk)
getLedgerStateTVar = StrictTVar IO (LedgerState blk)
currentLedgerStateTVar
      , getMempool :: Mempool IO blk
getMempool         = Mempool IO blk
mempool
    }

setLedgerState ::
     MockedMempool IO blk
  -> LedgerState blk
  -> IO ()
setLedgerState :: forall blk. MockedMempool IO blk -> LedgerState blk -> IO ()
setLedgerState MockedMempool {StrictTVar IO (LedgerState blk)
getLedgerStateTVar :: forall (m :: * -> *) blk.
MockedMempool m blk -> StrictTVar m (LedgerState blk)
getLedgerStateTVar :: StrictTVar IO (LedgerState blk)
getLedgerStateTVar} LedgerState blk
newSt =
  STM IO () -> IO ()
forall a. HasCallStack => STM IO a -> IO a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM IO () -> IO ()) -> STM IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ StrictTVar IO (LedgerState blk) -> LedgerState blk -> STM IO ()
forall (m :: * -> *) a.
MonadSTM m =>
StrictTVar m a -> a -> STM m ()
writeTVar StrictTVar IO (LedgerState blk)
getLedgerStateTVar LedgerState blk
newSt

addTx ::
     MockedMempool m blk
  -> AddTxOnBehalfOf
  -> Ledger.GenTx blk
  -> m (MempoolAddTxResult blk)
addTx :: forall (m :: * -> *) blk.
MockedMempool m blk
-> AddTxOnBehalfOf -> GenTx blk -> m (MempoolAddTxResult blk)
addTx = Mempool m blk
-> AddTxOnBehalfOf -> GenTx blk -> m (MempoolAddTxResult blk)
forall (m :: * -> *) blk.
Mempool m blk
-> AddTxOnBehalfOf -> GenTx blk -> m (MempoolAddTxResult blk)
Mempool.addTx (Mempool m blk
 -> AddTxOnBehalfOf -> GenTx blk -> m (MempoolAddTxResult blk))
-> (MockedMempool m blk -> Mempool m blk)
-> MockedMempool m blk
-> AddTxOnBehalfOf
-> GenTx blk
-> m (MempoolAddTxResult blk)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MockedMempool m blk -> Mempool m blk
forall (m :: * -> *) blk. MockedMempool m blk -> Mempool m blk
getMempool

removeTxs ::
     MockedMempool m blk
  -> [Ledger.GenTxId blk]
  -> m ()
removeTxs :: forall (m :: * -> *) blk.
MockedMempool m blk -> [GenTxId blk] -> m ()
removeTxs = Mempool m blk -> [GenTxId blk] -> m ()
forall (m :: * -> *) blk. Mempool m blk -> [GenTxId blk] -> m ()
Mempool.removeTxs (Mempool m blk -> [GenTxId blk] -> m ())
-> (MockedMempool m blk -> Mempool m blk)
-> MockedMempool m blk
-> [GenTxId blk]
-> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MockedMempool m blk -> Mempool m blk
forall (m :: * -> *) blk. MockedMempool m blk -> Mempool m blk
getMempool

getTxs :: forall blk.
     (Ledger.LedgerSupportsMempool blk)
  => MockedMempool IO blk -> IO [Ledger.GenTx blk]
getTxs :: forall blk.
LedgerSupportsMempool blk =>
MockedMempool IO blk -> IO [GenTx blk]
getTxs MockedMempool IO blk
mockedMempool = do
    [(Validated (GenTx blk), TicketNo, ByteSize32)]
snapshotTxs <- (MempoolSnapshot blk
 -> [(Validated (GenTx blk), TicketNo, ByteSize32)])
-> IO (MempoolSnapshot blk)
-> IO [(Validated (GenTx blk), TicketNo, ByteSize32)]
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap MempoolSnapshot blk
-> [(Validated (GenTx blk), TicketNo, ByteSize32)]
forall blk.
MempoolSnapshot blk
-> [(Validated (GenTx blk), TicketNo, ByteSize32)]
Mempool.snapshotTxs (IO (MempoolSnapshot blk)
 -> IO [(Validated (GenTx blk), TicketNo, ByteSize32)])
-> IO (MempoolSnapshot blk)
-> IO [(Validated (GenTx blk), TicketNo, ByteSize32)]
forall a b. (a -> b) -> a -> b
$ STM IO (MempoolSnapshot blk) -> IO (MempoolSnapshot blk)
forall a. HasCallStack => STM IO a -> IO a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically
                                            (STM IO (MempoolSnapshot blk) -> IO (MempoolSnapshot blk))
-> STM IO (MempoolSnapshot blk) -> IO (MempoolSnapshot blk)
forall a b. (a -> b) -> a -> b
$ Mempool IO blk -> STM IO (MempoolSnapshot blk)
forall (m :: * -> *) blk.
Mempool m blk -> STM m (MempoolSnapshot blk)
Mempool.getSnapshot
                                            (Mempool IO blk -> STM IO (MempoolSnapshot blk))
-> Mempool IO blk -> STM IO (MempoolSnapshot blk)
forall a b. (a -> b) -> a -> b
$ MockedMempool IO blk -> Mempool IO blk
forall (m :: * -> *) blk. MockedMempool m blk -> Mempool m blk
getMempool MockedMempool IO blk
mockedMempool
    [GenTx blk] -> IO [GenTx blk]
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([GenTx blk] -> IO [GenTx blk]) -> [GenTx blk] -> IO [GenTx blk]
forall a b. (a -> b) -> a -> b
$ ((Validated (GenTx blk), TicketNo, ByteSize32) -> GenTx blk)
-> [(Validated (GenTx blk), TicketNo, ByteSize32)] -> [GenTx blk]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Validated (GenTx blk), TicketNo, ByteSize32) -> GenTx blk
forall {b} {c}. (Validated (GenTx blk), b, c) -> GenTx blk
prjTx [(Validated (GenTx blk), TicketNo, ByteSize32)]
snapshotTxs
  where
    prjTx :: (Validated (GenTx blk), b, c) -> GenTx blk
prjTx (Validated (GenTx blk)
a, b
_b, c
_c) = Validated (GenTx blk) -> GenTx blk
forall blk.
LedgerSupportsMempool blk =>
Validated (GenTx blk) -> GenTx blk
Ledger.txForgetValidated Validated (GenTx blk)
a :: Ledger.GenTx blk