{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- | Mempool with a mocked ledger interface
module Test.Consensus.Mempool.Mocked
  ( InitialMempoolAndModelParams (..)

    -- * Mempool with a mocked LedgerDB interface
  , MockedMempool (getMempool)
  , openMockedMempool
  , setLedgerState

    -- * Mempool API functions
  , addTx
  , getTxs
  , removeTxsEvenIfValid
  ) where

import Control.Concurrent.Class.MonadSTM.Strict
  ( StrictTVar
  , atomically
  , newTVarIO
  , readTVar
  , readTVarIO
  , writeTVar
  )
import Control.DeepSeq (NFData (rnf))
import Control.Tracer (Tracer)
import qualified Data.List.NonEmpty as NE
import Ouroboros.Consensus.Block (castPoint)
import Ouroboros.Consensus.HeaderValidation as Header
import Ouroboros.Consensus.Ledger.Basics
import qualified Ouroboros.Consensus.Ledger.Basics as Ledger
import qualified Ouroboros.Consensus.Ledger.SupportsMempool as Ledger
import Ouroboros.Consensus.Ledger.Tables.Utils
  ( forgetLedgerTables
  , restrictValues'
  )
import Ouroboros.Consensus.Mempool (Mempool)
import qualified Ouroboros.Consensus.Mempool as Mempool
import Ouroboros.Consensus.Mempool.API
  ( AddTxOnBehalfOf
  , MempoolAddTxResult
  )

data MockedMempool m blk = MockedMempool
  { forall (m :: * -> *) blk.
MockedMempool m blk -> LedgerInterface m blk
getLedgerInterface :: !(Mempool.LedgerInterface m blk)
  , forall (m :: * -> *) blk.
MockedMempool m blk -> StrictTVar m (LedgerState blk ValuesMK)
getLedgerStateTVar :: !(StrictTVar m (LedgerState blk ValuesMK))
  , forall (m :: * -> *) blk. MockedMempool m blk -> Mempool m blk
getMempool :: !(Mempool m blk)
  }

instance NFData (MockedMempool m blk) where
  -- TODO: check we're OK with skipping the evaluation of the
  -- MockedMempool. The only data we could force here is the
  -- 'LedgerState' inside 'getLedgerStateTVar', but that would require adding a
  -- 'NFData' constraint and perform unsafe IO. Since we only require this
  -- instance to be able to use
  -- [env](<https://hackage.haskell.org/package/tasty-bench-0.3.3/docs/Test-Tasty-Bench.html#v:env),
  -- and we only care about initializing the mempool before running the
  -- benchmarks, maybe this definition is enough.
  rnf :: MockedMempool m blk -> ()
rnf MockedMempool{} = ()

data InitialMempoolAndModelParams blk = MempoolAndModelParams
  { forall blk.
InitialMempoolAndModelParams blk -> LedgerState blk ValuesMK
immpInitialState :: !(Ledger.LedgerState blk ValuesMK)
  -- ^ Initial ledger state for the mocked Ledger DB interface.
  , forall blk. InitialMempoolAndModelParams blk -> LedgerConfig blk
immpLedgerConfig :: !(Ledger.LedgerConfig blk)
  -- ^ Ledger configuration, which is needed to open the mempool.
  }

openMockedMempool ::
  ( Ledger.LedgerSupportsMempool blk
  , Ledger.HasTxId (Ledger.GenTx blk)
  , Header.ValidateEnvelope blk
  ) =>
  Mempool.MempoolCapacityBytesOverride ->
  Tracer IO (Mempool.TraceEventMempool blk) ->
  InitialMempoolAndModelParams blk ->
  IO (MockedMempool IO blk)
openMockedMempool :: forall blk.
(LedgerSupportsMempool blk, HasTxId (GenTx blk),
 ValidateEnvelope blk) =>
MempoolCapacityBytesOverride
-> Tracer IO (TraceEventMempool blk)
-> InitialMempoolAndModelParams blk
-> IO (MockedMempool IO blk)
openMockedMempool MempoolCapacityBytesOverride
capacityOverride Tracer IO (TraceEventMempool blk)
tracer InitialMempoolAndModelParams blk
initialParams = do
  currentLedgerStateTVar <- LedgerState blk ValuesMK
-> IO (StrictTVar IO (LedgerState blk ValuesMK))
forall (m :: * -> *) a. MonadSTM m => a -> m (StrictTVar m a)
newTVarIO (InitialMempoolAndModelParams blk -> LedgerState blk ValuesMK
forall blk.
InitialMempoolAndModelParams blk -> LedgerState blk ValuesMK
immpInitialState InitialMempoolAndModelParams blk
initialParams)
  let ledgerItf =
        Mempool.LedgerInterface
          { getCurrentLedgerState :: STM IO (LedgerState blk EmptyMK)
Mempool.getCurrentLedgerState = LedgerState blk ValuesMK -> LedgerState blk EmptyMK
forall (l :: LedgerStateKind) (mk :: MapKind).
HasLedgerTables l =>
l mk -> l EmptyMK
forgetLedgerTables (LedgerState blk ValuesMK -> LedgerState blk EmptyMK)
-> STM (LedgerState blk ValuesMK) -> STM (LedgerState blk EmptyMK)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StrictTVar IO (LedgerState blk ValuesMK)
-> STM IO (LedgerState blk ValuesMK)
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StrictTVar IO (LedgerState blk ValuesMK)
currentLedgerStateTVar
          , getLedgerTablesAtFor :: Point blk
-> LedgerTables (LedgerState blk) KeysMK
-> IO (Maybe (LedgerTables (LedgerState blk) ValuesMK))
Mempool.getLedgerTablesAtFor = \Point blk
pt LedgerTables (LedgerState blk) KeysMK
keys -> do
              st <- StrictTVar IO (LedgerState blk ValuesMK)
-> IO (LedgerState blk ValuesMK)
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> m a
readTVarIO StrictTVar IO (LedgerState blk ValuesMK)
currentLedgerStateTVar
              if castPoint (getTip st) == pt
                then pure $ Just $ restrictValues' st keys
                else pure Nothing
          }
  mempool <-
    Mempool.openMempoolWithoutSyncThread
      ledgerItf
      (immpLedgerConfig initialParams)
      capacityOverride
      tracer
  pure
    MockedMempool
      { getLedgerInterface = ledgerItf
      , getLedgerStateTVar = currentLedgerStateTVar
      , getMempool = mempool
      }

setLedgerState ::
  MockedMempool IO blk ->
  LedgerState blk ValuesMK ->
  IO ()
setLedgerState :: forall blk.
MockedMempool IO blk -> LedgerState blk ValuesMK -> IO ()
setLedgerState MockedMempool{StrictTVar IO (LedgerState blk ValuesMK)
getLedgerStateTVar :: forall (m :: * -> *) blk.
MockedMempool m blk -> StrictTVar m (LedgerState blk ValuesMK)
getLedgerStateTVar :: StrictTVar IO (LedgerState blk ValuesMK)
getLedgerStateTVar} LedgerState blk ValuesMK
newSt =
  STM IO () -> IO ()
forall a. HasCallStack => STM IO a -> IO a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM IO () -> IO ()) -> STM IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ StrictTVar IO (LedgerState blk ValuesMK)
-> LedgerState blk ValuesMK -> STM IO ()
forall (m :: * -> *) a.
MonadSTM m =>
StrictTVar m a -> a -> STM m ()
writeTVar StrictTVar IO (LedgerState blk ValuesMK)
getLedgerStateTVar LedgerState blk ValuesMK
newSt

addTx ::
  MockedMempool m blk ->
  AddTxOnBehalfOf ->
  Ledger.GenTx blk ->
  m (MempoolAddTxResult blk)
addTx :: forall (m :: * -> *) blk.
MockedMempool m blk
-> AddTxOnBehalfOf -> GenTx blk -> m (MempoolAddTxResult blk)
addTx = Mempool m blk
-> AddTxOnBehalfOf -> GenTx blk -> m (MempoolAddTxResult blk)
forall (m :: * -> *) blk.
Mempool m blk
-> AddTxOnBehalfOf -> GenTx blk -> m (MempoolAddTxResult blk)
Mempool.addTx (Mempool m blk
 -> AddTxOnBehalfOf -> GenTx blk -> m (MempoolAddTxResult blk))
-> (MockedMempool m blk -> Mempool m blk)
-> MockedMempool m blk
-> AddTxOnBehalfOf
-> GenTx blk
-> m (MempoolAddTxResult blk)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MockedMempool m blk -> Mempool m blk
forall (m :: * -> *) blk. MockedMempool m blk -> Mempool m blk
getMempool

removeTxsEvenIfValid ::
  MockedMempool m blk ->
  NE.NonEmpty (Ledger.GenTxId blk) ->
  m ()
removeTxsEvenIfValid :: forall (m :: * -> *) blk.
MockedMempool m blk -> NonEmpty (GenTxId blk) -> m ()
removeTxsEvenIfValid = Mempool m blk -> NonEmpty (GenTxId blk) -> m ()
forall (m :: * -> *) blk.
Mempool m blk -> NonEmpty (GenTxId blk) -> m ()
Mempool.removeTxsEvenIfValid (Mempool m blk -> NonEmpty (GenTxId blk) -> m ())
-> (MockedMempool m blk -> Mempool m blk)
-> MockedMempool m blk
-> NonEmpty (GenTxId blk)
-> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MockedMempool m blk -> Mempool m blk
forall (m :: * -> *) blk. MockedMempool m blk -> Mempool m blk
getMempool

getTxs ::
  forall blk.
  Ledger.LedgerSupportsMempool blk =>
  MockedMempool IO blk -> IO [Ledger.GenTx blk]
getTxs :: forall blk.
LedgerSupportsMempool blk =>
MockedMempool IO blk -> IO [GenTx blk]
getTxs MockedMempool IO blk
mockedMempool = do
  snapshotTxs <-
    (MempoolSnapshot blk
 -> [(Validated (GenTx blk), TicketNo, TxMeasure blk)])
-> IO (MempoolSnapshot blk)
-> IO [(Validated (GenTx blk), TicketNo, TxMeasure blk)]
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap MempoolSnapshot blk
-> [(Validated (GenTx blk), TicketNo, TxMeasure blk)]
forall blk.
MempoolSnapshot blk
-> [(Validated (GenTx blk), TicketNo, TxMeasure blk)]
Mempool.snapshotTxs (IO (MempoolSnapshot blk)
 -> IO [(Validated (GenTx blk), TicketNo, TxMeasure blk)])
-> IO (MempoolSnapshot blk)
-> IO [(Validated (GenTx blk), TicketNo, TxMeasure blk)]
forall a b. (a -> b) -> a -> b
$
      STM IO (MempoolSnapshot blk) -> IO (MempoolSnapshot blk)
forall a. HasCallStack => STM IO a -> IO a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM IO (MempoolSnapshot blk) -> IO (MempoolSnapshot blk))
-> STM IO (MempoolSnapshot blk) -> IO (MempoolSnapshot blk)
forall a b. (a -> b) -> a -> b
$
        Mempool IO blk -> STM IO (MempoolSnapshot blk)
forall (m :: * -> *) blk.
Mempool m blk -> STM m (MempoolSnapshot blk)
Mempool.getSnapshot (Mempool IO blk -> STM IO (MempoolSnapshot blk))
-> Mempool IO blk -> STM IO (MempoolSnapshot blk)
forall a b. (a -> b) -> a -> b
$
          MockedMempool IO blk -> Mempool IO blk
forall (m :: * -> *) blk. MockedMempool m blk -> Mempool m blk
getMempool MockedMempool IO blk
mockedMempool
  pure $ fmap prjTx snapshotTxs
 where
  prjTx :: (Validated (GenTx blk), b, c) -> GenTx blk
prjTx (Validated (GenTx blk)
a, b
_b, c
_c) = Validated (GenTx blk) -> GenTx blk
forall blk.
LedgerSupportsMempool blk =>
Validated (GenTx blk) -> GenTx blk
Ledger.txForgetValidated Validated (GenTx blk)
a :: Ledger.GenTx blk