{-# LANGUAGE GADTs #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}

-- | Instantiate 'ObjectPoolReader' and 'ObjectPoolWriter' using Peras
-- votes from the 'PerasVoteDB' (or the 'ChainDB' which is wrapping the
-- 'PerasVoteDB').
module Ouroboros.Consensus.MiniProtocol.ObjectDiffusion.ObjectPool.PerasVote
  ( makePerasVotePoolReaderFromVoteDB
  , makePerasVotePoolWriterFromVoteDB
  ) where

import Control.Monad (join)
import Data.Either (partitionEithers)
import Data.Map (Map)
import qualified Data.Map as Map
import qualified Data.Set as Set
import GHC.Exception (throw)
import Ouroboros.Consensus.Block
import Ouroboros.Consensus.BlockchainTime.WallClock.Types
  ( SystemTime (..)
  , WithArrivalTime (..)
  )
import Ouroboros.Consensus.MiniProtocol.ObjectDiffusion.ObjectPool.API
import Ouroboros.Consensus.Storage.PerasVoteDB.API
  ( PerasVoteDB (..)
  , PerasVoteTicketNo
  , zeroPerasVoteTicketNo
  )
import Ouroboros.Consensus.Util.IOLike

-- | TODO: replace by `Data.Map.take` as soon as we move to GHC 9.8
takeAscMap :: Int -> Map k v -> Map k v
takeAscMap :: forall k v. Int -> Map k v -> Map k v
takeAscMap Int
n = [(k, v)] -> Map k v
forall k a. [(k, a)] -> Map k a
Map.fromDistinctAscList ([(k, v)] -> Map k v)
-> (Map k v -> [(k, v)]) -> Map k v -> Map k v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [(k, v)] -> [(k, v)]
forall a. Int -> [a] -> [a]
take Int
n ([(k, v)] -> [(k, v)])
-> (Map k v -> [(k, v)]) -> Map k v -> [(k, v)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map k v -> [(k, v)]
forall k a. Map k a -> [(k, a)]
Map.toAscList

makePerasVotePoolReaderFromVoteDB ::
  IOLike m =>
  PerasVoteDB m blk ->
  ObjectPoolReader (PerasVoteId blk) (PerasVote blk) PerasVoteTicketNo m
makePerasVotePoolReaderFromVoteDB :: forall (m :: * -> *) blk.
IOLike m =>
PerasVoteDB m blk
-> ObjectPoolReader
     (PerasVoteId blk) (PerasVote blk) PerasVoteTicketNo m
makePerasVotePoolReaderFromVoteDB PerasVoteDB{STM m (Set (PerasVoteId blk))
WithArrivalTime (ValidatedPerasVote blk)
-> STM m (m (AddPerasVoteResult blk))
PerasRoundNo -> STM m (m ())
PerasRoundNo -> STM m (Maybe (ValidatedPerasCert blk))
PerasVoteTicketNo
-> STM
     m
     (Map PerasVoteTicketNo (WithArrivalTime (ValidatedPerasVote blk)))
addVote :: WithArrivalTime (ValidatedPerasVote blk)
-> STM m (m (AddPerasVoteResult blk))
getVoteIds :: STM m (Set (PerasVoteId blk))
getVotesAfter :: PerasVoteTicketNo
-> STM
     m
     (Map PerasVoteTicketNo (WithArrivalTime (ValidatedPerasVote blk)))
getForgedCertForRound :: PerasRoundNo -> STM m (Maybe (ValidatedPerasCert blk))
garbageCollect :: PerasRoundNo -> STM m (m ())
garbageCollect :: forall (m :: * -> *) blk.
PerasVoteDB m blk -> PerasRoundNo -> STM m (m ())
getForgedCertForRound :: forall (m :: * -> *) blk.
PerasVoteDB m blk
-> PerasRoundNo -> STM m (Maybe (ValidatedPerasCert blk))
getVotesAfter :: forall (m :: * -> *) blk.
PerasVoteDB m blk
-> PerasVoteTicketNo
-> STM
     m
     (Map PerasVoteTicketNo (WithArrivalTime (ValidatedPerasVote blk)))
getVoteIds :: forall (m :: * -> *) blk.
PerasVoteDB m blk -> STM m (Set (PerasVoteId blk))
addVote :: forall (m :: * -> *) blk.
PerasVoteDB m blk
-> WithArrivalTime (ValidatedPerasVote blk)
-> STM m (m (AddPerasVoteResult blk))
..} =
  ObjectPoolReader
    { oprObjectId :: PerasVote blk -> PerasVoteId blk
oprObjectId = PerasVote blk -> PerasVoteId blk
forall vote blk. HasPerasVoteId vote blk => vote -> PerasVoteId blk
getPerasVoteId
    , oprZeroTicketNo :: PerasVoteTicketNo
oprZeroTicketNo = PerasVoteTicketNo
zeroPerasVoteTicketNo
    , oprObjectsAfter :: PerasVoteTicketNo
-> Word64
-> STM m (Maybe (m (Map PerasVoteTicketNo (PerasVote blk))))
oprObjectsAfter = \PerasVoteTicketNo
lastKnown Word64
limit -> do
        votesAfterLastKnown <- PerasVoteTicketNo
-> STM
     m
     (Map PerasVoteTicketNo (WithArrivalTime (ValidatedPerasVote blk)))
getVotesAfter PerasVoteTicketNo
lastKnown
        let loadVotesAfterLastKnown =
              Map PerasVoteTicketNo (PerasVote blk)
-> m (Map PerasVoteTicketNo (PerasVote blk))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map PerasVoteTicketNo (PerasVote blk)
 -> m (Map PerasVoteTicketNo (PerasVote blk)))
-> Map PerasVoteTicketNo (PerasVote blk)
-> m (Map PerasVoteTicketNo (PerasVote blk))
forall a b. (a -> b) -> a -> b
$
                (WithArrivalTime (ValidatedPerasVote blk) -> PerasVote blk)
-> Map PerasVoteTicketNo (WithArrivalTime (ValidatedPerasVote blk))
-> Map PerasVoteTicketNo (PerasVote blk)
forall a b.
(a -> b) -> Map PerasVoteTicketNo a -> Map PerasVoteTicketNo b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap
                  (ValidatedPerasVote blk -> PerasVote blk
forall blk. ValidatedPerasVote blk -> PerasVote blk
vpvVote (ValidatedPerasVote blk -> PerasVote blk)
-> (WithArrivalTime (ValidatedPerasVote blk)
    -> ValidatedPerasVote blk)
-> WithArrivalTime (ValidatedPerasVote blk)
-> PerasVote blk
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WithArrivalTime (ValidatedPerasVote blk) -> ValidatedPerasVote blk
forall a. WithArrivalTime a -> a
forgetArrivalTime)
                  (Int
-> Map PerasVoteTicketNo (WithArrivalTime (ValidatedPerasVote blk))
-> Map PerasVoteTicketNo (WithArrivalTime (ValidatedPerasVote blk))
forall k v. Int -> Map k v -> Map k v
takeAscMap (Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
limit) Map PerasVoteTicketNo (WithArrivalTime (ValidatedPerasVote blk))
votesAfterLastKnown)
        pure $
          if Map.null votesAfterLastKnown
            then Nothing
            else Just loadVotesAfterLastKnown
    }

makePerasVotePoolWriterFromVoteDB ::
  (StandardHash blk, IOLike m) =>
  StrictTVar m PerasVoteStakeDistr ->
  SystemTime m ->
  PerasVoteDB m blk ->
  ObjectPoolWriter (PerasVoteId blk) (PerasVote blk) m
makePerasVotePoolWriterFromVoteDB :: forall blk (m :: * -> *).
(StandardHash blk, IOLike m) =>
StrictTVar m PerasVoteStakeDistr
-> SystemTime m
-> PerasVoteDB m blk
-> ObjectPoolWriter (PerasVoteId blk) (PerasVote blk) m
makePerasVotePoolWriterFromVoteDB StrictTVar m PerasVoteStakeDistr
distrVar SystemTime m
systemTime PerasVoteDB{STM m (Set (PerasVoteId blk))
WithArrivalTime (ValidatedPerasVote blk)
-> STM m (m (AddPerasVoteResult blk))
PerasRoundNo -> STM m (m ())
PerasRoundNo -> STM m (Maybe (ValidatedPerasCert blk))
PerasVoteTicketNo
-> STM
     m
     (Map PerasVoteTicketNo (WithArrivalTime (ValidatedPerasVote blk)))
garbageCollect :: forall (m :: * -> *) blk.
PerasVoteDB m blk -> PerasRoundNo -> STM m (m ())
getForgedCertForRound :: forall (m :: * -> *) blk.
PerasVoteDB m blk
-> PerasRoundNo -> STM m (Maybe (ValidatedPerasCert blk))
getVotesAfter :: forall (m :: * -> *) blk.
PerasVoteDB m blk
-> PerasVoteTicketNo
-> STM
     m
     (Map PerasVoteTicketNo (WithArrivalTime (ValidatedPerasVote blk)))
getVoteIds :: forall (m :: * -> *) blk.
PerasVoteDB m blk -> STM m (Set (PerasVoteId blk))
addVote :: forall (m :: * -> *) blk.
PerasVoteDB m blk
-> WithArrivalTime (ValidatedPerasVote blk)
-> STM m (m (AddPerasVoteResult blk))
addVote :: WithArrivalTime (ValidatedPerasVote blk)
-> STM m (m (AddPerasVoteResult blk))
getVoteIds :: STM m (Set (PerasVoteId blk))
getVotesAfter :: PerasVoteTicketNo
-> STM
     m
     (Map PerasVoteTicketNo (WithArrivalTime (ValidatedPerasVote blk)))
getForgedCertForRound :: PerasRoundNo -> STM m (Maybe (ValidatedPerasCert blk))
garbageCollect :: PerasRoundNo -> STM m (m ())
..} =
  ObjectPoolWriter
    { opwObjectId :: PerasVote blk -> PerasVoteId blk
opwObjectId = PerasVote blk -> PerasVoteId blk
forall vote blk. HasPerasVoteId vote blk => vote -> PerasVoteId blk
getPerasVoteId
    , opwAddObjects :: [PerasVote blk] -> m ()
opwAddObjects = \[PerasVote blk]
votes -> do
        now <- SystemTime m -> m RelativeTime
forall (m :: * -> *). SystemTime m -> m RelativeTime
systemTimeCurrent SystemTime m
systemTime
        join $ atomically $ do
          distr <- readTVar distrVar
          case validatePerasVotes distr votes of
            -- All votes are valid => add them to the pool
            ([], [ValidatedPerasVote blk]
validatedVotes) ->
              ([m (AddPerasVoteResult blk)] -> m ())
-> STM m [m (AddPerasVoteResult blk)] -> STM m (m ())
forall a b. (a -> b) -> STM m a -> STM m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [m (AddPerasVoteResult blk)] -> m ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ (STM m [m (AddPerasVoteResult blk)] -> STM m (m ()))
-> STM m [m (AddPerasVoteResult blk)] -> STM m (m ())
forall a b. (a -> b) -> a -> b
$
                [STM m (m (AddPerasVoteResult blk))]
-> STM m [m (AddPerasVoteResult blk)]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence ([STM m (m (AddPerasVoteResult blk))]
 -> STM m [m (AddPerasVoteResult blk)])
-> [STM m (m (AddPerasVoteResult blk))]
-> STM m [m (AddPerasVoteResult blk)]
forall a b. (a -> b) -> a -> b
$
                  (ValidatedPerasVote blk -> STM m (m (AddPerasVoteResult blk)))
-> [ValidatedPerasVote blk] -> [STM m (m (AddPerasVoteResult blk))]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (WithArrivalTime (ValidatedPerasVote blk)
-> STM m (m (AddPerasVoteResult blk))
addVote (WithArrivalTime (ValidatedPerasVote blk)
 -> STM m (m (AddPerasVoteResult blk)))
-> (ValidatedPerasVote blk
    -> WithArrivalTime (ValidatedPerasVote blk))
-> ValidatedPerasVote blk
-> STM m (m (AddPerasVoteResult blk))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RelativeTime
-> ValidatedPerasVote blk
-> WithArrivalTime (ValidatedPerasVote blk)
forall a. RelativeTime -> a -> WithArrivalTime a
WithArrivalTime RelativeTime
now) ([ValidatedPerasVote blk] -> [STM m (m (AddPerasVoteResult blk))])
-> [ValidatedPerasVote blk] -> [STM m (m (AddPerasVoteResult blk))]
forall a b. (a -> b) -> a -> b
$
                    [ValidatedPerasVote blk]
validatedVotes
            -- Some votes are invalid => reject the whole batch
            ([PerasValidationErr blk]
errs, [ValidatedPerasVote blk]
_) ->
              PerasVoteInboundException -> STM m (m ())
forall a e. (HasCallStack, Exception e) => e -> a
throw ([PerasValidationErr blk] -> PerasVoteInboundException
forall blk. [PerasValidationErr blk] -> PerasVoteInboundException
PerasVoteValidationError [PerasValidationErr blk]
errs)
    , opwHasObject :: STM m (PerasVoteId blk -> Bool)
opwHasObject = ((PerasVoteId blk -> Set (PerasVoteId blk) -> Bool)
-> Set (PerasVoteId blk) -> PerasVoteId blk -> Bool
forall a b c. (a -> b -> c) -> b -> a -> c
flip PerasVoteId blk -> Set (PerasVoteId blk) -> Bool
forall a. Ord a => a -> Set a -> Bool
Set.member) (Set (PerasVoteId blk) -> PerasVoteId blk -> Bool)
-> STM m (Set (PerasVoteId blk)) -> STM m (PerasVoteId blk -> Bool)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STM m (Set (PerasVoteId blk))
getVoteIds
    }

data PerasVoteInboundException
  = forall blk. PerasVoteValidationError [PerasValidationErr blk]

deriving instance Show PerasVoteInboundException

instance Exception PerasVoteInboundException

-- | Validate a batch of Peras votes against the given stake distribution.
validatePerasVotes ::
  StandardHash blk =>
  PerasVoteStakeDistr ->
  [PerasVote blk] ->
  ([PerasValidationErr blk], [ValidatedPerasVote blk])
validatePerasVotes :: forall blk.
StandardHash blk =>
PerasVoteStakeDistr
-> [PerasVote blk]
-> ([PerasValidationErr blk], [ValidatedPerasVote blk])
validatePerasVotes PerasVoteStakeDistr
distr [PerasVote blk]
votes = do
  let perasParams :: PerasParams
perasParams = PerasParams
mkPerasParams
  -- TODO pass down 'BlockConfig' when all the plumbing is in place
  -- see https://github.com/tweag/cardano-peras/issues/73
  -- see https://github.com/tweag/cardano-peras/issues/120
  [Either (PerasValidationErr blk) (ValidatedPerasVote blk)]
-> ([PerasValidationErr blk], [ValidatedPerasVote blk])
forall a b. [Either a b] -> ([a], [b])
partitionEithers (PerasCfg blk
-> PerasVoteStakeDistr
-> PerasVote blk
-> Either (PerasValidationErr blk) (ValidatedPerasVote blk)
forall blk.
BlockSupportsPeras blk =>
PerasCfg blk
-> PerasVoteStakeDistr
-> PerasVote blk
-> Either (PerasValidationErr blk) (ValidatedPerasVote blk)
validatePerasVote PerasParams
PerasCfg blk
perasParams PerasVoteStakeDistr
distr (PerasVote blk
 -> Either (PerasValidationErr blk) (ValidatedPerasVote blk))
-> [PerasVote blk]
-> [Either (PerasValidationErr blk) (ValidatedPerasVote blk)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [PerasVote blk]
votes)