{-# LANGUAGE DataKinds #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE TypeApplications #-}

-- | Tests fairness aspects of the mempool.
--
-- See 'testTxSizeFairness' for more details on the tests we run in this module.
module Test.Consensus.Mempool.Fairness
  ( testTxSizeFairness
  , tests
  ) where

import Cardano.Ledger.BaseTypes (knownNonZeroBounded)
import qualified Cardano.Slotting.Time as Time
import Control.Arrow ((***))
import Control.Concurrent (threadDelay)
import qualified Control.Concurrent.Async as Async
import Control.Exception (assert)
import Control.Monad (forever, void)
import qualified Control.Tracer as Tracer
import Data.Foldable (asum)
import qualified Data.List as List
import Data.List.NonEmpty hiding (length)
import Data.Void (Void, vacuous)
import Ouroboros.Consensus.Config.SecurityParam as Consensus
import qualified Ouroboros.Consensus.HardFork.History as HardFork
import Ouroboros.Consensus.Ledger.SupportsMempool (ByteSize32 (..))
import qualified Ouroboros.Consensus.Ledger.SupportsMempool as Mempool
import Ouroboros.Consensus.Ledger.Tables.Utils
import Ouroboros.Consensus.Mempool (Mempool)
import qualified Ouroboros.Consensus.Mempool as Mempool
import qualified Ouroboros.Consensus.Mempool.Capacity as Mempool
import Ouroboros.Consensus.Util.IOLike (STM, atomically, retry)
import System.Random (randomIO)
import Test.Consensus.Mempool.Fairness.TestBlock
import Test.Tasty (TestTree, testGroup)
import Test.Tasty.HUnit (testCase, (@?), (@?=))
import Test.Util.TestBlock
  ( testBlockLedgerConfigFrom
  , testInitLedgerWithState
  )

tests :: TestTree
tests :: TestTree
tests =
  TestName -> [TestTree] -> TestTree
testGroup
    TestName
"Mempool fairness"
    [ TestName -> Assertion -> TestTree
testCase TestName
"There is no substantial bias in added transaction sizes" (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$
        TestParams -> Assertion
testTxSizeFairness
          TestParams
            { mempoolMaxCapacity :: ByteSize32
mempoolMaxCapacity = Word32 -> ByteSize32
ByteSize32 Word32
100
            , smallTxSize :: ByteSize32
smallTxSize = Word32 -> ByteSize32
ByteSize32 Word32
1
            , largeTxSize :: ByteSize32
largeTxSize = Word32 -> ByteSize32
ByteSize32 Word32
10
            , nrOftxsToCollect :: Int
nrOftxsToCollect = Int
1_000
            , toleranceThreshold :: Double
toleranceThreshold = Double
0.2 -- Somewhat arbitrarily chosen.
            }
    ]

type TestMempool = Mempool IO TestBlock

-- | Test if the mempool treats small and large transactions in the same way.
--
-- We run the following test:
--
-- - Given a mempool 'mp'.
-- - Concurrently:
--     - Run 'N' threads that add small transactions to 'mp'.
--     - Run 'N' threads that add large transactions to 'mp'.
--     - Remove transactions from 'mp' one by one, with a small delay between
--       removals. Collect the removed transactions.
--
-- We give the threads that add small transactions a head start to make sure
-- that the mempool fills up with small transactions. In this way the thread
-- that removes transactions one by one will remove small transactions first.
-- Then, if the mempool is fair, it will not always try to add a small
-- transaction as soon as it can, but it will eventually wait until enough
-- capacity has been freed (by the remover thread) so that a large transaction
-- can be added.
--
-- After collecting 'M' removed transactions, let 'diff' be the difference between
-- the number of small and large transactions that were added to 'mp', then we
-- check that:
--
-- > diff / M <= toleranceThreshold
--
-- See 'TestParams' for an explanation of the different parameters that
-- influence this test.
testTxSizeFairness :: TestParams -> IO ()
testTxSizeFairness :: TestParams -> Assertion
testTxSizeFairness TestParams{ByteSize32
mempoolMaxCapacity :: TestParams -> ByteSize32
mempoolMaxCapacity :: ByteSize32
mempoolMaxCapacity, ByteSize32
smallTxSize :: TestParams -> ByteSize32
smallTxSize :: ByteSize32
smallTxSize, ByteSize32
largeTxSize :: TestParams -> ByteSize32
largeTxSize :: ByteSize32
largeTxSize, Int
nrOftxsToCollect :: TestParams -> Int
nrOftxsToCollect :: Int
nrOftxsToCollect, Double
toleranceThreshold :: TestParams -> Double
toleranceThreshold :: Double
toleranceThreshold} = do
  ----------------------------------------------------------------------------
  --  Obtain a mempool.
  ----------------------------------------------------------------------------
  let
    ledgerItf :: Mempool.LedgerInterface IO TestBlock
    ledgerItf :: LedgerInterface IO (TestBlockWith Tx)
ledgerItf =
      Mempool.LedgerInterface
        { getCurrentLedgerState :: STM IO (LedgerState (TestBlockWith Tx) EmptyMK)
Mempool.getCurrentLedgerState =
            LedgerState (TestBlockWith Tx) EmptyMK
-> STM IO (LedgerState (TestBlockWith Tx) EmptyMK)
forall a. a -> STM IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LedgerState (TestBlockWith Tx) EmptyMK
 -> STM IO (LedgerState (TestBlockWith Tx) EmptyMK))
-> LedgerState (TestBlockWith Tx) EmptyMK
-> STM IO (LedgerState (TestBlockWith Tx) EmptyMK)
forall a b. (a -> b) -> a -> b
$
              PayloadDependentState Tx EmptyMK
-> LedgerState (TestBlockWith Tx) EmptyMK
forall ptype (mk :: MapKind).
PayloadDependentState ptype mk
-> LedgerState (TestBlockWith ptype) mk
testInitLedgerWithState PayloadDependentState Tx EmptyMK
forall (mk :: MapKind). PayloadDependentState Tx mk
NoPayLoadDependentState
        , getLedgerTablesAtFor :: Point (TestBlockWith Tx)
-> LedgerTables (LedgerState (TestBlockWith Tx)) KeysMK
-> IO
     (Maybe (LedgerTables (LedgerState (TestBlockWith Tx)) ValuesMK))
Mempool.getLedgerTablesAtFor = \Point (TestBlockWith Tx)
_ LedgerTables (LedgerState (TestBlockWith Tx)) KeysMK
_ ->
            Maybe (LedgerTables (LedgerState (TestBlockWith Tx)) ValuesMK)
-> IO
     (Maybe (LedgerTables (LedgerState (TestBlockWith Tx)) ValuesMK))
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (LedgerTables (LedgerState (TestBlockWith Tx)) ValuesMK)
 -> IO
      (Maybe (LedgerTables (LedgerState (TestBlockWith Tx)) ValuesMK)))
-> Maybe (LedgerTables (LedgerState (TestBlockWith Tx)) ValuesMK)
-> IO
     (Maybe (LedgerTables (LedgerState (TestBlockWith Tx)) ValuesMK))
forall a b. (a -> b) -> a -> b
$
              LedgerTables (LedgerState (TestBlockWith Tx)) ValuesMK
-> Maybe (LedgerTables (LedgerState (TestBlockWith Tx)) ValuesMK)
forall a. a -> Maybe a
Just LedgerTables (LedgerState (TestBlockWith Tx)) ValuesMK
forall (mk :: MapKind) (l :: LedgerStateKind).
(ZeroableMK mk, LedgerTableConstraints l) =>
LedgerTables l mk
emptyLedgerTables
        }

    eraParams :: EraParams
eraParams =
      SecurityParam -> SlotLength -> EraParams
HardFork.defaultEraParams
        (NonZero Word64 -> SecurityParam
Consensus.SecurityParam (NonZero Word64 -> SecurityParam)
-> NonZero Word64 -> SecurityParam
forall a b. (a -> b) -> a -> b
$ forall (n :: Natural) a.
(KnownNat n, 1 <= n, WithinBounds n a, Num a) =>
NonZero a
knownNonZeroBounded @10)
        (Integer -> SlotLength
Time.slotLengthFromSec Integer
2)
  mempool <-
    LedgerInterface IO (TestBlockWith Tx)
-> LedgerConfig (TestBlockWith Tx)
-> MempoolCapacityBytesOverride
-> Tracer IO (TraceEventMempool (TestBlockWith Tx))
-> IO (Mempool IO (TestBlockWith Tx))
forall (m :: * -> *) blk.
(IOLike m, LedgerSupportsMempool blk, HasTxId (GenTx blk),
 ValidateEnvelope blk) =>
LedgerInterface m blk
-> LedgerConfig blk
-> MempoolCapacityBytesOverride
-> Tracer m (TraceEventMempool blk)
-> m (Mempool m blk)
Mempool.openMempoolWithoutSyncThread
      LedgerInterface IO (TestBlockWith Tx)
ledgerItf
      (EraParams -> TestBlockLedgerConfig
testBlockLedgerConfigFrom EraParams
eraParams)
      (ByteSize32 -> MempoolCapacityBytesOverride
Mempool.mkCapacityBytesOverride ByteSize32
mempoolMaxCapacity)
      Tracer IO (TraceEventMempool (TestBlockWith Tx))
forall (m :: * -> *) a. Applicative m => Tracer m a
Tracer.nullTracer
  ----------------------------------------------------------------------------
  --  Add and collect transactions
  ----------------------------------------------------------------------------
  let waitForSmallAddersToFillMempool = Int -> Assertion
threadDelay Int
1_000
  txs <-
    runConcurrently
      [ adders mempool smallTxSize
      , waitForSmallAddersToFillMempool >> adders mempool largeTxSize
      , waitForSmallAddersToFillMempool >> remover mempool nrOftxsToCollect
      ]

  ----------------------------------------------------------------------------
  --  Count the small and large transactions
  ----------------------------------------------------------------------------
  let
    nrSmall :: Double
    nrLarge :: Double
    (nrSmall, nrLarge) =
      (fromIntegral . length *** fromIntegral . length) $
        List.partition (<= smallTxSize) $
          fmap txSize txs
  length txs @?= nrOftxsToCollect
  theRatioOfTheDifferenceBetween nrSmall nrLarge `isBelow` toleranceThreshold
 where
  theRatioOfTheDifferenceBetween :: c -> c -> (c, c, c)
theRatioOfTheDifferenceBetween c
x c
y = (c -> c
forall a. Num a => a -> a
abs (c
x c -> c -> c
forall a. Num a => a -> a -> a
- c
y) c -> c -> c
forall a. Fractional a => a -> a -> a
/ (c
x c -> c -> c
forall a. Num a => a -> a -> a
+ c
y), c
x, c
y)

  -- At the end of the tests the proportion of small and large
  -- transactions that were added should be rouhgly the same. We tolerate
  -- the given theshold.
  isBelow :: (a, a, a) -> a -> Assertion
isBelow (a
ratioDiff, a
nrSmall, a
nrLarge) a
threshold =
    a
ratioDiff a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
threshold
      Bool -> TestName -> Assertion
forall t.
(AssertionPredicable t, ?callStack::CallStack) =>
t -> TestName -> Assertion
@? ( TestName
"The difference between the number of large and small transactions added "
             TestName -> TestName -> TestName
forall a. Semigroup a => a -> a -> a
<> TestName
"exeeds the threshold ("
             TestName -> TestName -> TestName
forall a. Semigroup a => a -> a -> a
<> a -> TestName
forall a. Show a => a -> TestName
show a
threshold
             TestName -> TestName -> TestName
forall a. Semigroup a => a -> a -> a
<> TestName
")\n"
             TestName -> TestName -> TestName
forall a. Semigroup a => a -> a -> a
<> TestName
"Total number of small transactions that were added: "
             TestName -> TestName -> TestName
forall a. Semigroup a => a -> a -> a
<> a -> TestName
forall a. Show a => a -> TestName
show a
nrSmall
             TestName -> TestName -> TestName
forall a. Semigroup a => a -> a -> a
<> TestName
"\n"
             TestName -> TestName -> TestName
forall a. Semigroup a => a -> a -> a
<> TestName
"Total number of large transactions that were added: "
             TestName -> TestName -> TestName
forall a. Semigroup a => a -> a -> a
<> a -> TestName
forall a. Show a => a -> TestName
show a
nrLarge
             TestName -> TestName -> TestName
forall a. Semigroup a => a -> a -> a
<> TestName
"\n"
             TestName -> TestName -> TestName
forall a. Semigroup a => a -> a -> a
<> TestName
"Difference / Total: "
             TestName -> TestName -> TestName
forall a. Semigroup a => a -> a -> a
<> a -> TestName
forall a. Show a => a -> TestName
show a
ratioDiff
         )

runConcurrently :: [IO a] -> IO a
runConcurrently :: forall a. [IO a] -> IO a
runConcurrently = Concurrently a -> IO a
forall a. Concurrently a -> IO a
Async.runConcurrently (Concurrently a -> IO a)
-> ([IO a] -> Concurrently a) -> [IO a] -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Concurrently a] -> Concurrently a
forall (t :: * -> *) (f :: * -> *) a.
(Foldable t, Alternative f) =>
t (f a) -> f a
asum ([Concurrently a] -> Concurrently a)
-> ([IO a] -> [Concurrently a]) -> [IO a] -> Concurrently a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IO a -> Concurrently a) -> [IO a] -> [Concurrently a]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap IO a -> Concurrently a
forall a. IO a -> Concurrently a
Async.Concurrently

-- | Test parameters.
--
-- When choosing the parameters bear in mind that:
--
-- - The smaller the difference between 'smallTxSize' and 'largeTxSize', the
--   harder it will be detect potential differences in way the mempool handles
--   small and large transactions.
--
-- - The larger the capacity, the higher the chance large transactions can be
--   added before the mempool is saturated.
data TestParams = TestParams
  { TestParams -> ByteSize32
mempoolMaxCapacity :: ByteSize32
  , TestParams -> ByteSize32
smallTxSize :: ByteSize32
  -- ^ Size of what we consider to be a small transaction.
  , TestParams -> ByteSize32
largeTxSize :: ByteSize32
  -- ^ Size of what we consider to be a large transaction.
  , TestParams -> Int
nrOftxsToCollect :: Int
  -- ^ How many added transactions we count.
  , TestParams -> Double
toleranceThreshold :: Double
  -- ^ We tolerate a certain ratio between the difference of small and large
  -- transactions added, and the total transactions that were added. For
  -- instance, given a threshold of 0.2, if we measure the sizes of 100 added
  -- transactions, the difference between the number small and large
  -- transactions we counted should not be larger than 20.
  }

-- | Add transactions with the specified size to the mempool.
--
-- We launch a fixed number of adder threads.
--
-- This process does not finish. Hence the 'a' type parameter.
adders ::
  -- | Mempool to which transactions will be added
  TestMempool ->
  -- | Transaction size
  ByteSize32 ->
  IO a
adders :: forall a. Mempool IO (TestBlockWith Tx) -> ByteSize32 -> IO a
adders Mempool IO (TestBlockWith Tx)
mempool ByteSize32
fixedTxSize = IO Void -> IO a
forall (f :: * -> *) a. Functor f => f Void -> f a
vacuous (IO Void -> IO a) -> IO Void -> IO a
forall a b. (a -> b) -> a -> b
$ [IO Void] -> IO Void
forall a. [IO a] -> IO a
runConcurrently ([IO Void] -> IO Void) -> [IO Void] -> IO Void
forall a b. (a -> b) -> a -> b
$ (Int -> IO Void) -> [Int] -> [IO Void]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Int -> IO Void
adder [Int
0 .. Int
2]
 where
  adder :: Int -> IO Void
  adder :: Int -> IO Void
adder Int
_i = Assertion -> IO Void
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (Assertion -> IO Void) -> Assertion -> IO Void
forall a b. (a -> b) -> a -> b
$ do
    -- We pick a random number for the transaction id.
    thisTxId <- IO Int
forall a (m :: * -> *). (Random a, MonadIO m) => m a
randomIO
    void $ Mempool.addTxs mempool [mkGenTx thisTxId fixedTxSize]

-- | Remove the given number of transactions and return them.
remover ::
  -- | Mempool to remove transactions from.
  TestMempool ->
  -- | Number of transactions to remove.
  Int ->
  IO [Tx]
remover :: Mempool IO (TestBlockWith Tx) -> Int -> IO [Tx]
remover Mempool IO (TestBlockWith Tx)
mempool Int
total = do
  let result :: IO [Tx]
result = [Tx] -> Int -> IO [Tx]
forall {t}. (Eq t, Num t) => [Tx] -> t -> IO [Tx]
loop [] Int
total
  removedTxs <- IO [Tx]
result
  assert (length removedTxs == total) result
 where
  -- Remove transactions one by one till we reach 'n'.
  loop :: [Tx] -> t -> IO [Tx]
loop [Tx]
txs t
0 = [Tx] -> IO [Tx]
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Tx]
txs -- Note that the transactions will come out in reverse
  -- order wrt the order in which they were added to the
  -- mempool. That is ok since we only care about the
  -- transaction sizes.
  loop [Tx]
txs t
n = do
    -- We wait 1ms to give the other threads the possibility to add
    -- transactions.
    Int -> Assertion
threadDelay Int
1000
    gtx <- STM IO (GenTx (TestBlockWith Tx)) -> IO (GenTx (TestBlockWith Tx))
forall a. (?callStack::CallStack) => STM IO a -> IO a
forall (m :: * -> *) a.
(MonadSTM m, ?callStack::CallStack) =>
STM m a -> m a
atomically (STM IO (GenTx (TestBlockWith Tx))
 -> IO (GenTx (TestBlockWith Tx)))
-> STM IO (GenTx (TestBlockWith Tx))
-> IO (GenTx (TestBlockWith Tx))
forall a b. (a -> b) -> a -> b
$ STM (GenTx (TestBlockWith Tx))
STM IO (GenTx (TestBlockWith Tx))
getATxFromTheMempool
    Mempool.removeTxsEvenIfValid mempool (Mempool.txId gtx :| [])
    loop (unGenTx gtx : txs) (n - 1)
   where
    getATxFromTheMempool :: STM (GenTx (TestBlockWith Tx))
getATxFromTheMempool =
      Mempool IO (TestBlockWith Tx) -> STM IO [GenTx (TestBlockWith Tx)]
getTxsInSnapshot Mempool IO (TestBlockWith Tx)
mempool STM [GenTx (TestBlockWith Tx)]
-> ([GenTx (TestBlockWith Tx)] -> STM (GenTx (TestBlockWith Tx)))
-> STM (GenTx (TestBlockWith Tx))
forall a b. STM a -> (a -> STM b) -> STM b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        [] -> STM (GenTx (TestBlockWith Tx))
STM IO (GenTx (TestBlockWith Tx))
forall a. STM IO a
forall (m :: * -> *) a. MonadSTM m => STM m a
retry
        GenTx (TestBlockWith Tx)
x : [GenTx (TestBlockWith Tx)]
_ -> GenTx (TestBlockWith Tx) -> STM (GenTx (TestBlockWith Tx))
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure GenTx (TestBlockWith Tx)
x

-- TODO: consider moving this to O.C.Mempool.API
getTxsInSnapshot :: Mempool IO TestBlock -> STM IO [Mempool.GenTx TestBlock]
getTxsInSnapshot :: Mempool IO (TestBlockWith Tx) -> STM IO [GenTx (TestBlockWith Tx)]
getTxsInSnapshot Mempool IO (TestBlockWith Tx)
mempool =
  (MempoolSnapshot (TestBlockWith Tx) -> [GenTx (TestBlockWith Tx)])
-> STM IO (MempoolSnapshot (TestBlockWith Tx))
-> STM IO [GenTx (TestBlockWith Tx)]
forall a b. (a -> b) -> STM IO a -> STM IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap MempoolSnapshot (TestBlockWith Tx) -> [GenTx (TestBlockWith Tx)]
txsInSnapshot (STM IO (MempoolSnapshot (TestBlockWith Tx))
 -> STM IO [GenTx (TestBlockWith Tx)])
-> STM IO (MempoolSnapshot (TestBlockWith Tx))
-> STM IO [GenTx (TestBlockWith Tx)]
forall a b. (a -> b) -> a -> b
$
    Mempool IO (TestBlockWith Tx)
-> STM IO (MempoolSnapshot (TestBlockWith Tx))
forall (m :: * -> *) blk.
Mempool m blk -> STM m (MempoolSnapshot blk)
Mempool.getSnapshot Mempool IO (TestBlockWith Tx)
mempool
 where
  txsInSnapshot :: MempoolSnapshot (TestBlockWith Tx) -> [GenTx (TestBlockWith Tx)]
txsInSnapshot =
    ((Validated (GenTx (TestBlockWith Tx)), TicketNo,
  IgnoringOverflow ByteSize32)
 -> GenTx (TestBlockWith Tx))
-> [(Validated (GenTx (TestBlockWith Tx)), TicketNo,
     IgnoringOverflow ByteSize32)]
-> [GenTx (TestBlockWith Tx)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Validated (GenTx (TestBlockWith Tx)), TicketNo,
 IgnoringOverflow ByteSize32)
-> GenTx (TestBlockWith Tx)
forall {b} {c}.
(Validated (GenTx (TestBlockWith Tx)), b, c)
-> GenTx (TestBlockWith Tx)
prjTx
      ([(Validated (GenTx (TestBlockWith Tx)), TicketNo,
   IgnoringOverflow ByteSize32)]
 -> [GenTx (TestBlockWith Tx)])
-> (MempoolSnapshot (TestBlockWith Tx)
    -> [(Validated (GenTx (TestBlockWith Tx)), TicketNo,
         IgnoringOverflow ByteSize32)])
-> MempoolSnapshot (TestBlockWith Tx)
-> [GenTx (TestBlockWith Tx)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MempoolSnapshot (TestBlockWith Tx)
-> [(Validated (GenTx (TestBlockWith Tx)), TicketNo,
     IgnoringOverflow ByteSize32)]
MempoolSnapshot (TestBlockWith Tx)
-> [(Validated (GenTx (TestBlockWith Tx)), TicketNo,
     TxMeasure (TestBlockWith Tx))]
forall blk.
MempoolSnapshot blk
-> [(Validated (GenTx blk), TicketNo, TxMeasure blk)]
Mempool.snapshotTxs

  prjTx :: (Validated (GenTx (TestBlockWith Tx)), b, c)
-> GenTx (TestBlockWith Tx)
prjTx (Validated (GenTx (TestBlockWith Tx))
a, b
_b, c
_c) = Validated (GenTx (TestBlockWith Tx)) -> GenTx (TestBlockWith Tx)
forall blk.
LedgerSupportsMempool blk =>
Validated (GenTx blk) -> GenTx blk
Mempool.txForgetValidated Validated (GenTx (TestBlockWith Tx))
a :: Mempool.GenTx TestBlock