{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell #-}

module Test.Consensus.Committee.WFALS.Conformance
  ( CommitteeSize
  , NumPersistent
  , NumNonPersistent
  , PoolId
  , Stake
  , StakeDistr
  , conformsToRustImplementation
  )
where

import Data.Aeson
  ( FromJSON (..)
  , eitherDecodeStrict
  , withObject
  , (.:)
  )
import Data.Array (Array)
import qualified Data.Array as Array
import qualified Data.FileEmbed as FileEmbed
import Data.Map.Strict (Map)
import Data.Word (Word64)
import System.FilePath ((</>))
import Test.Tasty (TestTree)
import Test.Tasty.HUnit (assertEqual, testCase)

type CommitteeSize = Word64
type NumPersistent = Word64
type NumNonPersistent = Word64
type PoolId = String
type Stake = Rational
type StakeDistr = Map PoolId Stake

-- | A single result of the Rust implementation
data RustResult = RustResult
  { RustResult -> CommitteeSize
targetCommitteeSize :: CommitteeSize
  , RustResult -> CommitteeSize
numPersistent :: NumPersistent
  , RustResult -> CommitteeSize
numNonPersistent :: NumNonPersistent
  }
  deriving (Int -> RustResult -> ShowS
[RustResult] -> ShowS
RustResult -> [Char]
(Int -> RustResult -> ShowS)
-> (RustResult -> [Char])
-> ([RustResult] -> ShowS)
-> Show RustResult
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> RustResult -> ShowS
showsPrec :: Int -> RustResult -> ShowS
$cshow :: RustResult -> [Char]
show :: RustResult -> [Char]
$cshowList :: [RustResult] -> ShowS
showList :: [RustResult] -> ShowS
Show, RustResult -> RustResult -> Bool
(RustResult -> RustResult -> Bool)
-> (RustResult -> RustResult -> Bool) -> Eq RustResult
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: RustResult -> RustResult -> Bool
== :: RustResult -> RustResult -> Bool
$c/= :: RustResult -> RustResult -> Bool
/= :: RustResult -> RustResult -> Bool
Eq)

instance FromJSON RustResult where
  parseJSON :: Value -> Parser RustResult
parseJSON = [Char]
-> (Object -> Parser RustResult) -> Value -> Parser RustResult
forall a. [Char] -> (Object -> Parser a) -> Value -> Parser a
withObject [Char]
"RustResult" ((Object -> Parser RustResult) -> Value -> Parser RustResult)
-> (Object -> Parser RustResult) -> Value -> Parser RustResult
forall a b. (a -> b) -> a -> b
$ \Object
obj ->
    CommitteeSize -> CommitteeSize -> CommitteeSize -> RustResult
RustResult
      (CommitteeSize -> CommitteeSize -> CommitteeSize -> RustResult)
-> Parser CommitteeSize
-> Parser (CommitteeSize -> CommitteeSize -> RustResult)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Object
obj Object -> Key -> Parser CommitteeSize
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"target"
      Parser (CommitteeSize -> CommitteeSize -> RustResult)
-> Parser CommitteeSize -> Parser (CommitteeSize -> RustResult)
forall a b. Parser (a -> b) -> Parser a -> Parser b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Object
obj Object -> Key -> Parser CommitteeSize
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"persistent"
      Parser (CommitteeSize -> RustResult)
-> Parser CommitteeSize -> Parser RustResult
forall a b. Parser (a -> b) -> Parser a -> Parser b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Object
obj Object -> Key -> Parser CommitteeSize
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"nonpersistent"

-- | Embedded Rust implementation results.
--
-- These are the results of the Rust implementation when applied to
-- 'exampleStakeDistr' for every valid target committee size, i.e., from 1 to
-- the maximum number of pools with strictly positive stake, which, in the case
-- of 'exampleStakeDistr', corresponds to 2052 pools.
rustResults :: Array Int RustResult
rustResults :: Array Int RustResult
rustResults =
  ([Char] -> Array Int RustResult)
-> ([RustResult] -> Array Int RustResult)
-> Either [Char] [RustResult]
-> Array Int RustResult
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either [Char] -> Array Int RustResult
forall a. HasCallStack => [Char] -> a
error [RustResult] -> Array Int RustResult
forall {e}. [e] -> Array Int e
toArray (Either [Char] [RustResult] -> Array Int RustResult)
-> Either [Char] [RustResult] -> Array Int RustResult
forall a b. (a -> b) -> a -> b
$
    ByteString -> Either [Char] [RustResult]
forall a. FromJSON a => ByteString -> Either [Char] a
eitherDecodeStrict (ByteString -> Either [Char] [RustResult])
-> ByteString -> Either [Char] [RustResult]
forall a b. (a -> b) -> a -> b
$
      $( FileEmbed.embedFile $
           "ouroboros-consensus"
             </> "test"
             </> "consensus-test"
             </> "data"
             </> "rust_results.json"
       )
 where
  toArray :: [e] -> Array Int e
toArray [e]
pairs =
    (Int, Int) -> [e] -> Array Int e
forall i e. Ix i => (i, i) -> [e] -> Array i e
Array.listArray (Int
0, [e] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [e]
pairs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) [e]
pairs

-- | Embedded stake distribution
exampleStakeDistr :: StakeDistr
exampleStakeDistr :: StakeDistr
exampleStakeDistr =
  ([Char] -> StakeDistr)
-> (StakeDistr -> StakeDistr)
-> Either [Char] StakeDistr
-> StakeDistr
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either [Char] -> StakeDistr
forall a. HasCallStack => [Char] -> a
error StakeDistr -> StakeDistr
forall a. a -> a
id (Either [Char] StakeDistr -> StakeDistr)
-> Either [Char] StakeDistr -> StakeDistr
forall a b. (a -> b) -> a -> b
$
    ByteString -> Either [Char] StakeDistr
forall a. FromJSON a => ByteString -> Either [Char] a
eitherDecodeStrict (ByteString -> Either [Char] StakeDistr)
-> ByteString -> Either [Char] StakeDistr
forall a b. (a -> b) -> a -> b
$
      $( FileEmbed.embedFile $
           "ouroboros-consensus"
             </> "test"
             </> "consensus-test"
             </> "data"
             </> "stake_distr.json"
       )

-- | Check that a weighted fait accompli committee selection implementation
-- conforms to the Rust implementation by comparing the number persistent and
-- non-persistent committee members it selects for a given target committee size.
conformsToRustImplementation ::
  String ->
  (Map PoolId Stake -> stakeDistr) ->
  ( stakeDistr ->
    CommitteeSize ->
    ( NumPersistent
    , NumNonPersistent
    )
  ) ->
  TestTree
conformsToRustImplementation :: forall stakeDistr.
[Char]
-> (StakeDistr -> stakeDistr)
-> (stakeDistr -> CommitteeSize -> (CommitteeSize, CommitteeSize))
-> TestTree
conformsToRustImplementation [Char]
name StakeDistr -> stakeDistr
mkStakeDistr stakeDistr -> CommitteeSize -> (CommitteeSize, CommitteeSize)
wfa = do
  [Char] -> Assertion -> TestTree
testCase [Char]
name ((Int, Int) -> Assertion
go (Array Int RustResult -> (Int, Int)
forall i e. Array i e -> (i, i)
Array.bounds Array Int RustResult
rustResults))
 where
  stakeDistr :: stakeDistr
stakeDistr = StakeDistr -> stakeDistr
mkStakeDistr StakeDistr
exampleStakeDistr

  go :: (Int, Int) -> Assertion
go (Int
currStep, Int
lastStep)
    | Int
currStep Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
lastStep =
        () -> Assertion
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    | Bool
otherwise = do
        RustResult -> Assertion
step (Array Int RustResult -> Int -> RustResult
forall i e. Ix i => Array i e -> i -> e
(Array.!) Array Int RustResult
rustResults Int
currStep)
        (Int, Int) -> Assertion
go (Int -> Int
forall a. Enum a => a -> a
succ Int
currStep, Int
lastStep)

  step :: RustResult -> Assertion
step RustResult{CommitteeSize
targetCommitteeSize :: RustResult -> CommitteeSize
targetCommitteeSize :: CommitteeSize
targetCommitteeSize, CommitteeSize
numPersistent :: RustResult -> CommitteeSize
numPersistent :: CommitteeSize
numPersistent, CommitteeSize
numNonPersistent :: RustResult -> CommitteeSize
numNonPersistent :: CommitteeSize
numNonPersistent} = do
    let (CommitteeSize
actualNumPersistent, CommitteeSize
actualNumNonPersistent) =
          stakeDistr -> CommitteeSize -> (CommitteeSize, CommitteeSize)
wfa stakeDistr
stakeDistr CommitteeSize
targetCommitteeSize
    [Char]
-> (CommitteeSize, CommitteeSize)
-> (CommitteeSize, CommitteeSize)
-> Assertion
forall a.
(Eq a, Show a, HasCallStack) =>
[Char] -> a -> a -> Assertion
assertEqual
      ( [[Char]] -> [Char]
unlines
          [ [Char]
"Target committee size: "
              [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> CommitteeSize -> [Char]
forall a. Show a => a -> [Char]
show CommitteeSize
targetCommitteeSize
          , [Char]
"Expected (persistent, non-persistent): "
              [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> (CommitteeSize, CommitteeSize) -> [Char]
forall a. Show a => a -> [Char]
show (CommitteeSize
numPersistent, CommitteeSize
numNonPersistent)
          , [Char]
"Actual (persistent, non-persistent): "
              [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> (CommitteeSize, CommitteeSize) -> [Char]
forall a. Show a => a -> [Char]
show (CommitteeSize
actualNumPersistent, CommitteeSize
actualNumNonPersistent)
          ]
      )
      (CommitteeSize
numPersistent, CommitteeSize
numNonPersistent)
      (CommitteeSize
actualNumPersistent, CommitteeSize
actualNumNonPersistent)