{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}

module Test.Consensus.Committee.WFALS.Conformance
  ( conformsToRustImplementation
  )
where

import Data.Aeson
  ( FromJSON (..)
  , eitherDecodeStrict
  , withObject
  , (.:)
  )
import Data.Array (Array)
import qualified Data.Array as Array
import qualified Data.FileEmbed as FileEmbed
import Data.Map.Strict (Map)
import System.FilePath ((</>))
import Test.QuickCheck (Property)
import Test.QuickCheck.Gen (Gen, choose)
import Test.QuickCheck.Property (counterexample, forAll, (===))

type CommitteeSize = Int
type NumPersistent = Int
type NumNonPersistent = Int
type PoolId = String
type Stake = Rational
type StakeDistr = Map PoolId Stake

-- | A single result of the Rust implementation
data RustResult = RustResult
  { RustResult -> NumPersistent
targetCommitteeSize :: CommitteeSize
  , RustResult -> NumPersistent
numPersistent :: NumPersistent
  , RustResult -> NumPersistent
numNonPersistent :: NumNonPersistent
  }
  deriving (NumPersistent -> RustResult -> ShowS
[RustResult] -> ShowS
RustResult -> String
(NumPersistent -> RustResult -> ShowS)
-> (RustResult -> String)
-> ([RustResult] -> ShowS)
-> Show RustResult
forall a.
(NumPersistent -> a -> ShowS)
-> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: NumPersistent -> RustResult -> ShowS
showsPrec :: NumPersistent -> RustResult -> ShowS
$cshow :: RustResult -> String
show :: RustResult -> String
$cshowList :: [RustResult] -> ShowS
showList :: [RustResult] -> ShowS
Show, RustResult -> RustResult -> Bool
(RustResult -> RustResult -> Bool)
-> (RustResult -> RustResult -> Bool) -> Eq RustResult
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: RustResult -> RustResult -> Bool
== :: RustResult -> RustResult -> Bool
$c/= :: RustResult -> RustResult -> Bool
/= :: RustResult -> RustResult -> Bool
Eq)

instance FromJSON RustResult where
  parseJSON :: Value -> Parser RustResult
parseJSON = String
-> (Object -> Parser RustResult) -> Value -> Parser RustResult
forall a. String -> (Object -> Parser a) -> Value -> Parser a
withObject String
"RustResult" ((Object -> Parser RustResult) -> Value -> Parser RustResult)
-> (Object -> Parser RustResult) -> Value -> Parser RustResult
forall a b. (a -> b) -> a -> b
$ \Object
obj ->
    NumPersistent -> NumPersistent -> NumPersistent -> RustResult
RustResult
      (NumPersistent -> NumPersistent -> NumPersistent -> RustResult)
-> Parser NumPersistent
-> Parser (NumPersistent -> NumPersistent -> RustResult)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Object
obj Object -> Key -> Parser NumPersistent
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"target"
      Parser (NumPersistent -> NumPersistent -> RustResult)
-> Parser NumPersistent -> Parser (NumPersistent -> RustResult)
forall a b. Parser (a -> b) -> Parser a -> Parser b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Object
obj Object -> Key -> Parser NumPersistent
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"persistent"
      Parser (NumPersistent -> RustResult)
-> Parser NumPersistent -> Parser RustResult
forall a b. Parser (a -> b) -> Parser a -> Parser b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Object
obj Object -> Key -> Parser NumPersistent
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"nonpersistent"

-- | Embedded Rust implementation results
rustResults :: Array Int RustResult
rustResults :: Array NumPersistent RustResult
rustResults =
  (String -> Array NumPersistent RustResult)
-> ([RustResult] -> Array NumPersistent RustResult)
-> Either String [RustResult]
-> Array NumPersistent RustResult
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either String -> Array NumPersistent RustResult
forall a. HasCallStack => String -> a
error [RustResult] -> Array NumPersistent RustResult
forall {e}. [e] -> Array NumPersistent e
toArray (Either String [RustResult] -> Array NumPersistent RustResult)
-> Either String [RustResult] -> Array NumPersistent RustResult
forall a b. (a -> b) -> a -> b
$
    ByteString -> Either String [RustResult]
forall a. FromJSON a => ByteString -> Either String a
eitherDecodeStrict (ByteString -> Either String [RustResult])
-> ByteString -> Either String [RustResult]
forall a b. (a -> b) -> a -> b
$
      $( FileEmbed.embedFile $
           "ouroboros-consensus"
             </> "test"
             </> "consensus-test"
             </> "data"
             </> "rust_results.json"
       )
 where
  toArray :: [e] -> Array NumPersistent e
toArray [e]
pairs =
    (NumPersistent, NumPersistent) -> [e] -> Array NumPersistent e
forall i e. Ix i => (i, i) -> [e] -> Array i e
Array.listArray (NumPersistent
0, [e] -> NumPersistent
forall a. [a] -> NumPersistent
forall (t :: * -> *) a. Foldable t => t a -> NumPersistent
length [e]
pairs NumPersistent -> NumPersistent -> NumPersistent
forall a. Num a => a -> a -> a
- NumPersistent
1) [e]
pairs

-- | Embedded stake distribution
stakeDistr :: StakeDistr
stakeDistr :: StakeDistr
stakeDistr =
  (String -> StakeDistr)
-> (StakeDistr -> StakeDistr)
-> Either String StakeDistr
-> StakeDistr
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either String -> StakeDistr
forall a. HasCallStack => String -> a
error StakeDistr -> StakeDistr
forall a. a -> a
id (Either String StakeDistr -> StakeDistr)
-> Either String StakeDistr -> StakeDistr
forall a b. (a -> b) -> a -> b
$
    ByteString -> Either String StakeDistr
forall a. FromJSON a => ByteString -> Either String a
eitherDecodeStrict (ByteString -> Either String StakeDistr)
-> ByteString -> Either String StakeDistr
forall a b. (a -> b) -> a -> b
$
      $( FileEmbed.embedFile $
           "ouroboros-consensus"
             </> "test"
             </> "consensus-test"
             </> "data"
             </> "stake_distr.json"
       )

-- | Sample a value from an array
sampleArray :: Array Int a -> Gen a
sampleArray :: forall a. Array NumPersistent a -> Gen a
sampleArray Array NumPersistent a
array = do
  i <- (NumPersistent, NumPersistent) -> Gen NumPersistent
forall a. Random a => (a, a) -> Gen a
choose (Array NumPersistent a -> (NumPersistent, NumPersistent)
forall i e. Array i e -> (i, i)
Array.bounds Array NumPersistent a
array)
  pure $ (Array.!) array i

-- | Check that a weighted fait accompli committee selection implementation
-- conforms to the Rust implementation by comparing the number persistent and
-- non-persistent committee members it selects for a given target committee size.
conformsToRustImplementation ::
  ( Map PoolId Stake ->
    CommitteeSize ->
    ( NumPersistent
    , NumNonPersistent
    )
  ) ->
  Property
conformsToRustImplementation :: (StakeDistr -> NumPersistent -> (NumPersistent, NumPersistent))
-> Property
conformsToRustImplementation StakeDistr -> NumPersistent -> (NumPersistent, NumPersistent)
wfals = do
  Gen RustResult -> (RustResult -> Property) -> Property
forall a prop.
(Show a, Testable prop) =>
Gen a -> (a -> prop) -> Property
forAll (Array NumPersistent RustResult -> Gen RustResult
forall a. Array NumPersistent a -> Gen a
sampleArray Array NumPersistent RustResult
rustResults) ((RustResult -> Property) -> Property)
-> (RustResult -> Property) -> Property
forall a b. (a -> b) -> a -> b
$
    \RustResult
       { NumPersistent
targetCommitteeSize :: RustResult -> NumPersistent
targetCommitteeSize :: NumPersistent
targetCommitteeSize
       , NumPersistent
numPersistent :: RustResult -> NumPersistent
numPersistent :: NumPersistent
numPersistent
       , NumPersistent
numNonPersistent :: RustResult -> NumPersistent
numNonPersistent :: NumPersistent
numNonPersistent
       } -> do
        let (NumPersistent
actualNumPersistent, NumPersistent
actualNumNonPersistent) =
              StakeDistr -> NumPersistent -> (NumPersistent, NumPersistent)
wfals StakeDistr
stakeDistr NumPersistent
targetCommitteeSize
        String -> Property -> Property
forall prop. Testable prop => String -> prop -> Property
counterexample
          ( [String] -> String
unlines
              [ String
"Target committee size: "
                  String -> ShowS
forall a. Semigroup a => a -> a -> a
<> NumPersistent -> String
forall a. Show a => a -> String
show NumPersistent
targetCommitteeSize
              , String
"Expected (persistent, non-persistent): "
                  String -> ShowS
forall a. Semigroup a => a -> a -> a
<> (NumPersistent, NumPersistent) -> String
forall a. Show a => a -> String
show (NumPersistent
numPersistent, NumPersistent
numNonPersistent)
              , String
"Actual (persistent, non-persistent): "
                  String -> ShowS
forall a. Semigroup a => a -> a -> a
<> (NumPersistent, NumPersistent) -> String
forall a. Show a => a -> String
show (NumPersistent
actualNumPersistent, NumPersistent
actualNumNonPersistent)
              ]
          )
          (Property -> Property) -> Property -> Property
forall a b. (a -> b) -> a -> b
$ (NumPersistent
actualNumPersistent, NumPersistent
actualNumNonPersistent)
            (NumPersistent, NumPersistent)
-> (NumPersistent, NumPersistent) -> Property
forall a. (Eq a, Show a) => a -> a -> Property
=== (NumPersistent
numPersistent, NumPersistent
numNonPersistent)