{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- | Mempool with a mocked ledger interface
module Test.Consensus.Mempool.Mocked
  ( InitialMempoolAndModelParams (..)

    -- * Mempool with a mocked LedgerDB interface
  , MockedMempool (getMempool)
  , openMockedMempool
  , setLedgerState

    -- * Mempool API functions
  , addTx
  , getTxs
  , removeTxsEvenIfValid
  ) where

import Control.Concurrent.Class.MonadSTM.Strict
  ( StrictTVar
  , atomically
  , newTVarIO
  , readTVar
  , writeTVar
  )
import Control.DeepSeq (NFData (rnf))
import Control.ResourceRegistry
import Control.Tracer (Tracer)
import qualified Data.List.NonEmpty as NE
import Ouroboros.Consensus.HeaderValidation as Header
import Ouroboros.Consensus.Ledger.Basics
import qualified Ouroboros.Consensus.Ledger.Basics as Ledger
import qualified Ouroboros.Consensus.Ledger.SupportsMempool as Ledger
import Ouroboros.Consensus.Ledger.Tables.Utils
  ( emptyLedgerTables
  , forgetLedgerTables
  , restrictValues'
  )
import Ouroboros.Consensus.Mempool (Mempool)
import qualified Ouroboros.Consensus.Mempool as Mempool
import Ouroboros.Consensus.Mempool.API
  ( AddTxOnBehalfOf
  , MempoolAddTxResult
  )
import Ouroboros.Consensus.Mempool.Impl.Common (MempoolLedgerDBView (MempoolLedgerDBView))
import Ouroboros.Consensus.Storage.LedgerDB.Forker

data MockedMempool m blk = MockedMempool
  { forall (m :: * -> *) blk.
MockedMempool m blk -> LedgerInterface m blk
getLedgerInterface :: !(Mempool.LedgerInterface m blk)
  , forall (m :: * -> *) blk.
MockedMempool m blk -> StrictTVar m (LedgerState blk ValuesMK)
getLedgerStateTVar :: !(StrictTVar m (LedgerState blk ValuesMK))
  , forall (m :: * -> *) blk. MockedMempool m blk -> Mempool m blk
getMempool :: !(Mempool m blk)
  }

instance NFData (MockedMempool m blk) where
  -- TODO: check we're OK with skipping the evaluation of the
  -- MockedMempool. The only data we could force here is the
  -- 'LedgerState' inside 'getLedgerStateTVar', but that would require adding a
  -- 'NFData' constraint and perform unsafe IO. Since we only require this
  -- instance to be able to use
  -- [env](<https://hackage.haskell.org/package/tasty-bench-0.3.3/docs/Test-Tasty-Bench.html#v:env),
  -- and we only care about initializing the mempool before running the
  -- benchmarks, maybe this definition is enough.
  rnf :: MockedMempool m blk -> ()
rnf MockedMempool{} = ()

data InitialMempoolAndModelParams blk = MempoolAndModelParams
  { forall blk.
InitialMempoolAndModelParams blk -> LedgerState blk ValuesMK
immpInitialState :: !(LedgerState blk ValuesMK)
  -- ^ Initial ledger state for the mocked Ledger DB interface.
  , forall blk. InitialMempoolAndModelParams blk -> LedgerConfig blk
immpLedgerConfig :: !(Ledger.LedgerConfig blk)
  -- ^ Ledger configuration, which is needed to open the mempool.
  }

openMockedMempool ::
  ( Ledger.LedgerSupportsMempool blk
  , Ledger.HasTxId (Ledger.GenTx blk)
  , Header.ValidateEnvelope blk
  ) =>
  Mempool.MempoolCapacityBytesOverride ->
  Tracer IO (Mempool.TraceEventMempool blk) ->
  InitialMempoolAndModelParams blk ->
  IO (MockedMempool IO blk)
openMockedMempool :: forall blk.
(LedgerSupportsMempool blk, HasTxId (GenTx blk),
 ValidateEnvelope blk) =>
MempoolCapacityBytesOverride
-> Tracer IO (TraceEventMempool blk)
-> InitialMempoolAndModelParams blk
-> IO (MockedMempool IO blk)
openMockedMempool MempoolCapacityBytesOverride
capacityOverride Tracer IO (TraceEventMempool blk)
tracer InitialMempoolAndModelParams blk
initialParams = do
  currentLedgerStateTVar <- LedgerState blk ValuesMK
-> IO (StrictTVar IO (LedgerState blk ValuesMK))
forall (m :: * -> *) a. MonadSTM m => a -> m (StrictTVar m a)
newTVarIO (InitialMempoolAndModelParams blk -> LedgerState blk ValuesMK
forall blk.
InitialMempoolAndModelParams blk -> LedgerState blk ValuesMK
immpInitialState InitialMempoolAndModelParams blk
initialParams)
  reg <- unsafeNewRegistry
  let ledgerItf =
        Mempool.LedgerInterface
          { getCurrentLedgerState :: ResourceRegistry IO -> STM IO (MempoolLedgerDBView IO blk)
Mempool.getCurrentLedgerState = \ResourceRegistry IO
_reg -> do
              st <- StrictTVar IO (LedgerState blk ValuesMK)
-> STM IO (LedgerState blk ValuesMK)
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StrictTVar IO (LedgerState blk ValuesMK)
currentLedgerStateTVar
              pure $
                MempoolLedgerDBView
                  (forgetLedgerTables st)
                  ( pure $
                      Right $
                        ReadOnlyForker
                          { roforkerClose = pure ()
                          , roforkerGetLedgerState = pure (forgetLedgerTables st)
                          , roforkerReadTables = \LedgerTables (LedgerState blk) KeysMK
keys ->
                              LedgerTables (LedgerState blk) ValuesMK
-> IO (LedgerTables (LedgerState blk) ValuesMK)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LedgerTables (LedgerState blk) ValuesMK
 -> IO (LedgerTables (LedgerState blk) ValuesMK))
-> LedgerTables (LedgerState blk) ValuesMK
-> IO (LedgerTables (LedgerState blk) ValuesMK)
forall a b. (a -> b) -> a -> b
$ LedgerState blk ValuesMK -> LedgerTables (LedgerState blk) ValuesMK
forall (mk :: MapKind).
(CanMapMK mk, CanMapKeysMK mk, ZeroableMK mk) =>
LedgerState blk mk -> LedgerTables (LedgerState blk) mk
forall (l :: LedgerStateKind) (mk :: MapKind).
(HasLedgerTables l, CanMapMK mk, CanMapKeysMK mk, ZeroableMK mk) =>
l mk -> LedgerTables l mk
projectLedgerTables LedgerState blk ValuesMK
st LedgerTables (LedgerState blk) ValuesMK
-> LedgerTables (LedgerState blk) KeysMK
-> LedgerTables (LedgerState blk) ValuesMK
forall (l :: LedgerStateKind) (l'' :: LedgerStateKind)
       (l' :: LedgerStateKind).
(SameUtxoTypes l l'', SameUtxoTypes l' l'', HasLedgerTables l,
 HasLedgerTables l', HasLedgerTables l'') =>
l ValuesMK -> l' KeysMK -> LedgerTables l'' ValuesMK
`restrictValues'` LedgerTables (LedgerState blk) KeysMK
keys
                          , roforkerReadStatistics = pure Nothing
                          , roforkerRangeReadTables = \RangeQueryPrevious (LedgerState blk)
_ -> LedgerTables (LedgerState blk) ValuesMK
-> IO (LedgerTables (LedgerState blk) ValuesMK)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure LedgerTables (LedgerState blk) ValuesMK
forall (mk :: MapKind) (l :: LedgerStateKind).
(ZeroableMK mk, LedgerTableConstraints l) =>
LedgerTables l mk
emptyLedgerTables
                          }
                  )
          }
  mempool <-
    Mempool.openMempoolWithoutSyncThread
      reg
      ledgerItf
      (immpLedgerConfig initialParams)
      capacityOverride
      tracer
  pure
    MockedMempool
      { getLedgerInterface = ledgerItf
      , getLedgerStateTVar = currentLedgerStateTVar
      , getMempool = mempool
      }

setLedgerState ::
  MockedMempool IO blk ->
  LedgerState blk ValuesMK ->
  IO ()
setLedgerState :: forall blk.
MockedMempool IO blk -> LedgerState blk ValuesMK -> IO ()
setLedgerState MockedMempool{StrictTVar IO (LedgerState blk ValuesMK)
getLedgerStateTVar :: forall (m :: * -> *) blk.
MockedMempool m blk -> StrictTVar m (LedgerState blk ValuesMK)
getLedgerStateTVar :: StrictTVar IO (LedgerState blk ValuesMK)
getLedgerStateTVar} LedgerState blk ValuesMK
newSt =
  STM IO () -> IO ()
forall a. HasCallStack => STM IO a -> IO a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM IO () -> IO ()) -> STM IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ StrictTVar IO (LedgerState blk ValuesMK)
-> LedgerState blk ValuesMK -> STM IO ()
forall (m :: * -> *) a.
MonadSTM m =>
StrictTVar m a -> a -> STM m ()
writeTVar StrictTVar IO (LedgerState blk ValuesMK)
getLedgerStateTVar LedgerState blk ValuesMK
newSt

addTx ::
  MockedMempool m blk ->
  AddTxOnBehalfOf ->
  Ledger.GenTx blk ->
  m (MempoolAddTxResult blk)
addTx :: forall (m :: * -> *) blk.
MockedMempool m blk
-> AddTxOnBehalfOf -> GenTx blk -> m (MempoolAddTxResult blk)
addTx = Mempool m blk
-> AddTxOnBehalfOf -> GenTx blk -> m (MempoolAddTxResult blk)
forall (m :: * -> *) blk.
Mempool m blk
-> AddTxOnBehalfOf -> GenTx blk -> m (MempoolAddTxResult blk)
Mempool.addTx (Mempool m blk
 -> AddTxOnBehalfOf -> GenTx blk -> m (MempoolAddTxResult blk))
-> (MockedMempool m blk -> Mempool m blk)
-> MockedMempool m blk
-> AddTxOnBehalfOf
-> GenTx blk
-> m (MempoolAddTxResult blk)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MockedMempool m blk -> Mempool m blk
forall (m :: * -> *) blk. MockedMempool m blk -> Mempool m blk
getMempool

removeTxsEvenIfValid ::
  MockedMempool m blk ->
  NE.NonEmpty (Ledger.GenTxId blk) ->
  m ()
removeTxsEvenIfValid :: forall (m :: * -> *) blk.
MockedMempool m blk -> NonEmpty (GenTxId blk) -> m ()
removeTxsEvenIfValid = Mempool m blk -> NonEmpty (GenTxId blk) -> m ()
forall (m :: * -> *) blk.
Mempool m blk -> NonEmpty (GenTxId blk) -> m ()
Mempool.removeTxsEvenIfValid (Mempool m blk -> NonEmpty (GenTxId blk) -> m ())
-> (MockedMempool m blk -> Mempool m blk)
-> MockedMempool m blk
-> NonEmpty (GenTxId blk)
-> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MockedMempool m blk -> Mempool m blk
forall (m :: * -> *) blk. MockedMempool m blk -> Mempool m blk
getMempool

getTxs ::
  forall blk.
  Ledger.LedgerSupportsMempool blk =>
  MockedMempool IO blk -> IO [Ledger.GenTx blk]
getTxs :: forall blk.
LedgerSupportsMempool blk =>
MockedMempool IO blk -> IO [GenTx blk]
getTxs MockedMempool IO blk
mockedMempool = do
  snapshotTxs <-
    (MempoolSnapshot blk
 -> [(Validated (GenTx blk), TicketNo, TxMeasure blk)])
-> IO (MempoolSnapshot blk)
-> IO [(Validated (GenTx blk), TicketNo, TxMeasure blk)]
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap MempoolSnapshot blk
-> [(Validated (GenTx blk), TicketNo, TxMeasure blk)]
forall blk.
MempoolSnapshot blk
-> [(Validated (GenTx blk), TicketNo, TxMeasure blk)]
Mempool.snapshotTxs (IO (MempoolSnapshot blk)
 -> IO [(Validated (GenTx blk), TicketNo, TxMeasure blk)])
-> IO (MempoolSnapshot blk)
-> IO [(Validated (GenTx blk), TicketNo, TxMeasure blk)]
forall a b. (a -> b) -> a -> b
$
      STM IO (MempoolSnapshot blk) -> IO (MempoolSnapshot blk)
forall a. HasCallStack => STM IO a -> IO a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM IO (MempoolSnapshot blk) -> IO (MempoolSnapshot blk))
-> STM IO (MempoolSnapshot blk) -> IO (MempoolSnapshot blk)
forall a b. (a -> b) -> a -> b
$
        Mempool IO blk -> STM IO (MempoolSnapshot blk)
forall (m :: * -> *) blk.
Mempool m blk -> STM m (MempoolSnapshot blk)
Mempool.getSnapshot (Mempool IO blk -> STM IO (MempoolSnapshot blk))
-> Mempool IO blk -> STM IO (MempoolSnapshot blk)
forall a b. (a -> b) -> a -> b
$
          MockedMempool IO blk -> Mempool IO blk
forall (m :: * -> *) blk. MockedMempool m blk -> Mempool m blk
getMempool MockedMempool IO blk
mockedMempool
  pure $ fmap prjTx snapshotTxs
 where
  prjTx :: (Validated (GenTx blk), b, c) -> GenTx blk
prjTx (Validated (GenTx blk)
a, b
_b, c
_c) = Validated (GenTx blk) -> GenTx blk
forall blk.
LedgerSupportsMempool blk =>
Validated (GenTx blk) -> GenTx blk
Ledger.txForgetValidated Validated (GenTx blk)
a :: Ledger.GenTx blk