{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Test.Consensus.Mempool.Mocked
( InitialMempoolAndModelParams (..)
, MockedMempool (getMempool)
, openMockedMempool
, setLedgerState
, addTx
, getTxs
, removeTxsEvenIfValid
) where
import Control.Concurrent.Class.MonadSTM.Strict
( StrictTVar
, atomically
, newTVarIO
, readTVar
, readTVarIO
, writeTVar
)
import Control.DeepSeq (NFData (rnf))
import Control.Tracer (Tracer)
import qualified Data.List.NonEmpty as NE
import Ouroboros.Consensus.Block (castPoint)
import Ouroboros.Consensus.HeaderValidation as Header
import Ouroboros.Consensus.Ledger.Basics
import qualified Ouroboros.Consensus.Ledger.Basics as Ledger
import qualified Ouroboros.Consensus.Ledger.SupportsMempool as Ledger
import Ouroboros.Consensus.Ledger.Tables.Utils
( forgetLedgerTables
, restrictValues'
)
import Ouroboros.Consensus.Mempool (Mempool)
import qualified Ouroboros.Consensus.Mempool as Mempool
import Ouroboros.Consensus.Mempool.API
( AddTxOnBehalfOf
, MempoolAddTxResult
)
data MockedMempool m blk = MockedMempool
{ forall (m :: * -> *) blk.
MockedMempool m blk -> LedgerInterface m blk
getLedgerInterface :: !(Mempool.LedgerInterface m blk)
, forall (m :: * -> *) blk.
MockedMempool m blk -> StrictTVar m (LedgerState blk ValuesMK)
getLedgerStateTVar :: !(StrictTVar m (LedgerState blk ValuesMK))
, forall (m :: * -> *) blk. MockedMempool m blk -> Mempool m blk
getMempool :: !(Mempool m blk)
}
instance NFData (MockedMempool m blk) where
rnf :: MockedMempool m blk -> ()
rnf MockedMempool{} = ()
data InitialMempoolAndModelParams blk = MempoolAndModelParams
{ forall blk.
InitialMempoolAndModelParams blk -> LedgerState blk ValuesMK
immpInitialState :: !(Ledger.LedgerState blk ValuesMK)
, forall blk. InitialMempoolAndModelParams blk -> LedgerConfig blk
immpLedgerConfig :: !(Ledger.LedgerConfig blk)
}
openMockedMempool ::
( Ledger.LedgerSupportsMempool blk
, Ledger.HasTxId (Ledger.GenTx blk)
, Header.ValidateEnvelope blk
) =>
Mempool.MempoolCapacityBytesOverride ->
Tracer IO (Mempool.TraceEventMempool blk) ->
InitialMempoolAndModelParams blk ->
IO (MockedMempool IO blk)
openMockedMempool :: forall blk.
(LedgerSupportsMempool blk, HasTxId (GenTx blk),
ValidateEnvelope blk) =>
MempoolCapacityBytesOverride
-> Tracer IO (TraceEventMempool blk)
-> InitialMempoolAndModelParams blk
-> IO (MockedMempool IO blk)
openMockedMempool MempoolCapacityBytesOverride
capacityOverride Tracer IO (TraceEventMempool blk)
tracer InitialMempoolAndModelParams blk
initialParams = do
currentLedgerStateTVar <- LedgerState blk ValuesMK
-> IO (StrictTVar IO (LedgerState blk ValuesMK))
forall (m :: * -> *) a. MonadSTM m => a -> m (StrictTVar m a)
newTVarIO (InitialMempoolAndModelParams blk -> LedgerState blk ValuesMK
forall blk.
InitialMempoolAndModelParams blk -> LedgerState blk ValuesMK
immpInitialState InitialMempoolAndModelParams blk
initialParams)
let ledgerItf =
Mempool.LedgerInterface
{ getCurrentLedgerState :: STM IO (LedgerState blk EmptyMK)
Mempool.getCurrentLedgerState = LedgerState blk ValuesMK -> LedgerState blk EmptyMK
forall (l :: LedgerStateKind) (mk :: MapKind).
HasLedgerTables l =>
l mk -> l EmptyMK
forgetLedgerTables (LedgerState blk ValuesMK -> LedgerState blk EmptyMK)
-> STM (LedgerState blk ValuesMK) -> STM (LedgerState blk EmptyMK)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StrictTVar IO (LedgerState blk ValuesMK)
-> STM IO (LedgerState blk ValuesMK)
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StrictTVar IO (LedgerState blk ValuesMK)
currentLedgerStateTVar
, getLedgerTablesAtFor :: Point blk
-> LedgerTables (LedgerState blk) KeysMK
-> IO (Maybe (LedgerTables (LedgerState blk) ValuesMK))
Mempool.getLedgerTablesAtFor = \Point blk
pt LedgerTables (LedgerState blk) KeysMK
keys -> do
st <- StrictTVar IO (LedgerState blk ValuesMK)
-> IO (LedgerState blk ValuesMK)
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> m a
readTVarIO StrictTVar IO (LedgerState blk ValuesMK)
currentLedgerStateTVar
if castPoint (getTip st) == pt
then pure $ Just $ restrictValues' st keys
else pure Nothing
}
mempool <-
Mempool.openMempoolWithoutSyncThread
ledgerItf
(immpLedgerConfig initialParams)
capacityOverride
tracer
pure
MockedMempool
{ getLedgerInterface = ledgerItf
, getLedgerStateTVar = currentLedgerStateTVar
, getMempool = mempool
}
setLedgerState ::
MockedMempool IO blk ->
LedgerState blk ValuesMK ->
IO ()
setLedgerState :: forall blk.
MockedMempool IO blk -> LedgerState blk ValuesMK -> IO ()
setLedgerState MockedMempool{StrictTVar IO (LedgerState blk ValuesMK)
getLedgerStateTVar :: forall (m :: * -> *) blk.
MockedMempool m blk -> StrictTVar m (LedgerState blk ValuesMK)
getLedgerStateTVar :: StrictTVar IO (LedgerState blk ValuesMK)
getLedgerStateTVar} LedgerState blk ValuesMK
newSt =
STM IO () -> IO ()
forall a. HasCallStack => STM IO a -> IO a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM IO () -> IO ()) -> STM IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ StrictTVar IO (LedgerState blk ValuesMK)
-> LedgerState blk ValuesMK -> STM IO ()
forall (m :: * -> *) a.
MonadSTM m =>
StrictTVar m a -> a -> STM m ()
writeTVar StrictTVar IO (LedgerState blk ValuesMK)
getLedgerStateTVar LedgerState blk ValuesMK
newSt
addTx ::
MockedMempool m blk ->
AddTxOnBehalfOf ->
Ledger.GenTx blk ->
m (MempoolAddTxResult blk)
addTx :: forall (m :: * -> *) blk.
MockedMempool m blk
-> AddTxOnBehalfOf -> GenTx blk -> m (MempoolAddTxResult blk)
addTx = Mempool m blk
-> AddTxOnBehalfOf -> GenTx blk -> m (MempoolAddTxResult blk)
forall (m :: * -> *) blk.
Mempool m blk
-> AddTxOnBehalfOf -> GenTx blk -> m (MempoolAddTxResult blk)
Mempool.addTx (Mempool m blk
-> AddTxOnBehalfOf -> GenTx blk -> m (MempoolAddTxResult blk))
-> (MockedMempool m blk -> Mempool m blk)
-> MockedMempool m blk
-> AddTxOnBehalfOf
-> GenTx blk
-> m (MempoolAddTxResult blk)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MockedMempool m blk -> Mempool m blk
forall (m :: * -> *) blk. MockedMempool m blk -> Mempool m blk
getMempool
removeTxsEvenIfValid ::
MockedMempool m blk ->
NE.NonEmpty (Ledger.GenTxId blk) ->
m ()
removeTxsEvenIfValid :: forall (m :: * -> *) blk.
MockedMempool m blk -> NonEmpty (GenTxId blk) -> m ()
removeTxsEvenIfValid = Mempool m blk -> NonEmpty (GenTxId blk) -> m ()
forall (m :: * -> *) blk.
Mempool m blk -> NonEmpty (GenTxId blk) -> m ()
Mempool.removeTxsEvenIfValid (Mempool m blk -> NonEmpty (GenTxId blk) -> m ())
-> (MockedMempool m blk -> Mempool m blk)
-> MockedMempool m blk
-> NonEmpty (GenTxId blk)
-> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MockedMempool m blk -> Mempool m blk
forall (m :: * -> *) blk. MockedMempool m blk -> Mempool m blk
getMempool
getTxs ::
forall blk.
Ledger.LedgerSupportsMempool blk =>
MockedMempool IO blk -> IO [Ledger.GenTx blk]
getTxs :: forall blk.
LedgerSupportsMempool blk =>
MockedMempool IO blk -> IO [GenTx blk]
getTxs MockedMempool IO blk
mockedMempool = do
snapshotTxs <-
(MempoolSnapshot blk
-> [(Validated (GenTx blk), TicketNo, TxMeasure blk)])
-> IO (MempoolSnapshot blk)
-> IO [(Validated (GenTx blk), TicketNo, TxMeasure blk)]
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap MempoolSnapshot blk
-> [(Validated (GenTx blk), TicketNo, TxMeasure blk)]
forall blk.
MempoolSnapshot blk
-> [(Validated (GenTx blk), TicketNo, TxMeasure blk)]
Mempool.snapshotTxs (IO (MempoolSnapshot blk)
-> IO [(Validated (GenTx blk), TicketNo, TxMeasure blk)])
-> IO (MempoolSnapshot blk)
-> IO [(Validated (GenTx blk), TicketNo, TxMeasure blk)]
forall a b. (a -> b) -> a -> b
$
STM IO (MempoolSnapshot blk) -> IO (MempoolSnapshot blk)
forall a. HasCallStack => STM IO a -> IO a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM IO (MempoolSnapshot blk) -> IO (MempoolSnapshot blk))
-> STM IO (MempoolSnapshot blk) -> IO (MempoolSnapshot blk)
forall a b. (a -> b) -> a -> b
$
Mempool IO blk -> STM IO (MempoolSnapshot blk)
forall (m :: * -> *) blk.
Mempool m blk -> STM m (MempoolSnapshot blk)
Mempool.getSnapshot (Mempool IO blk -> STM IO (MempoolSnapshot blk))
-> Mempool IO blk -> STM IO (MempoolSnapshot blk)
forall a b. (a -> b) -> a -> b
$
MockedMempool IO blk -> Mempool IO blk
forall (m :: * -> *) blk. MockedMempool m blk -> Mempool m blk
getMempool MockedMempool IO blk
mockedMempool
pure $ fmap prjTx snapshotTxs
where
prjTx :: (Validated (GenTx blk), b, c) -> GenTx blk
prjTx (Validated (GenTx blk)
a, b
_b, c
_c) = Validated (GenTx blk) -> GenTx blk
forall blk.
LedgerSupportsMempool blk =>
Validated (GenTx blk) -> GenTx blk
Ledger.txForgetValidated Validated (GenTx blk)
a :: Ledger.GenTx blk