{-# LANGUAGE GADTs #-}
{-# LANGUAGE StandaloneDeriving #-}

-- | Instantiate 'ObjectPoolReader' and 'ObjectPoolWriter' using Peras
-- certificates from the 'PerasCertDB' (or the 'ChainDB' which is wrapping the
-- 'PerasCertDB').
module Ouroboros.Consensus.MiniProtocol.ObjectDiffusion.ObjectPool.PerasCert
  ( makePerasCertPoolReaderFromCertDB
  , makePerasCertPoolWriterFromCertDB
  , makePerasCertPoolReaderFromChainDB
  , makePerasCertPoolWriterFromChainDB
  ) where

import Data.Either (partitionEithers)
import Data.Map (Map)
import qualified Data.Map as Map
import GHC.Exception (throw)
import Ouroboros.Consensus.Block
import Ouroboros.Consensus.BlockchainTime.WallClock.Types
  ( SystemTime (..)
  , WithArrivalTime (..)
  )
import Ouroboros.Consensus.MiniProtocol.ObjectDiffusion.ObjectPool.API
import Ouroboros.Consensus.Storage.ChainDB.API (ChainDB)
import qualified Ouroboros.Consensus.Storage.ChainDB.API as ChainDB
import Ouroboros.Consensus.Storage.PerasCertDB.API
  ( PerasCertDB
  , PerasCertSnapshot
  , PerasCertTicketNo
  )
import qualified Ouroboros.Consensus.Storage.PerasCertDB.API as PerasCertDB
import Ouroboros.Consensus.Util.IOLike

-- | TODO: replace by `Data.Map.take` as soon as we move to GHC 9.8
takeAscMap :: Int -> Map k v -> Map k v
takeAscMap :: forall k v. Int -> Map k v -> Map k v
takeAscMap Int
n = [(k, v)] -> Map k v
forall k a. [(k, a)] -> Map k a
Map.fromDistinctAscList ([(k, v)] -> Map k v)
-> (Map k v -> [(k, v)]) -> Map k v -> Map k v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [(k, v)] -> [(k, v)]
forall a. Int -> [a] -> [a]
take Int
n ([(k, v)] -> [(k, v)])
-> (Map k v -> [(k, v)]) -> Map k v -> [(k, v)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map k v -> [(k, v)]
forall k a. Map k a -> [(k, a)]
Map.toAscList

makePerasCertPoolReaderFromSnapshot ::
  IOLike m =>
  STM m (PerasCertSnapshot blk) ->
  ObjectPoolReader PerasRoundNo (PerasCert blk) PerasCertTicketNo m
makePerasCertPoolReaderFromSnapshot :: forall (m :: * -> *) blk.
IOLike m =>
STM m (PerasCertSnapshot blk)
-> ObjectPoolReader
     PerasRoundNo (PerasCert blk) PerasCertTicketNo m
makePerasCertPoolReaderFromSnapshot STM m (PerasCertSnapshot blk)
getCertSnapshot =
  ObjectPoolReader
    { oprObjectId :: PerasCert blk -> PerasRoundNo
oprObjectId = PerasCert blk -> PerasRoundNo
forall cert. HasPerasCertRound cert => cert -> PerasRoundNo
getPerasCertRound
    , oprZeroTicketNo :: PerasCertTicketNo
oprZeroTicketNo = PerasCertTicketNo
PerasCertDB.zeroPerasCertTicketNo
    , oprObjectsAfter :: PerasCertTicketNo
-> Word64
-> STM m (Maybe (m (Map PerasCertTicketNo (PerasCert blk))))
oprObjectsAfter = \PerasCertTicketNo
lastKnown Word64
limit -> do
        certSnapshot <- STM m (PerasCertSnapshot blk)
getCertSnapshot
        let certsAfterLastKnown =
              PerasCertSnapshot blk
-> PerasCertTicketNo
-> Map PerasCertTicketNo (WithArrivalTime (ValidatedPerasCert blk))
forall blk.
PerasCertSnapshot blk
-> PerasCertTicketNo
-> Map PerasCertTicketNo (WithArrivalTime (ValidatedPerasCert blk))
PerasCertDB.getCertsAfter PerasCertSnapshot blk
certSnapshot PerasCertTicketNo
lastKnown
        let loadCertsAfterLastKnown =
              Map PerasCertTicketNo (PerasCert blk)
-> m (Map PerasCertTicketNo (PerasCert blk))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map PerasCertTicketNo (PerasCert blk)
 -> m (Map PerasCertTicketNo (PerasCert blk)))
-> Map PerasCertTicketNo (PerasCert blk)
-> m (Map PerasCertTicketNo (PerasCert blk))
forall a b. (a -> b) -> a -> b
$
                (WithArrivalTime (ValidatedPerasCert blk) -> PerasCert blk)
-> Map PerasCertTicketNo (WithArrivalTime (ValidatedPerasCert blk))
-> Map PerasCertTicketNo (PerasCert blk)
forall a b.
(a -> b) -> Map PerasCertTicketNo a -> Map PerasCertTicketNo b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap
                  (ValidatedPerasCert blk -> PerasCert blk
forall blk. ValidatedPerasCert blk -> PerasCert blk
vpcCert (ValidatedPerasCert blk -> PerasCert blk)
-> (WithArrivalTime (ValidatedPerasCert blk)
    -> ValidatedPerasCert blk)
-> WithArrivalTime (ValidatedPerasCert blk)
-> PerasCert blk
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WithArrivalTime (ValidatedPerasCert blk) -> ValidatedPerasCert blk
forall a. WithArrivalTime a -> a
forgetArrivalTime)
                  (Int
-> Map PerasCertTicketNo (WithArrivalTime (ValidatedPerasCert blk))
-> Map PerasCertTicketNo (WithArrivalTime (ValidatedPerasCert blk))
forall k v. Int -> Map k v -> Map k v
takeAscMap (Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
limit) Map PerasCertTicketNo (WithArrivalTime (ValidatedPerasCert blk))
certsAfterLastKnown)
        pure $
          if Map.null certsAfterLastKnown
            then Nothing
            else Just loadCertsAfterLastKnown
    }

makePerasCertPoolReaderFromCertDB ::
  IOLike m =>
  PerasCertDB m blk ->
  ObjectPoolReader PerasRoundNo (PerasCert blk) PerasCertTicketNo m
makePerasCertPoolReaderFromCertDB :: forall (m :: * -> *) blk.
IOLike m =>
PerasCertDB m blk
-> ObjectPoolReader
     PerasRoundNo (PerasCert blk) PerasCertTicketNo m
makePerasCertPoolReaderFromCertDB PerasCertDB m blk
perasCertDB =
  STM m (PerasCertSnapshot blk)
-> ObjectPoolReader
     PerasRoundNo (PerasCert blk) PerasCertTicketNo m
forall (m :: * -> *) blk.
IOLike m =>
STM m (PerasCertSnapshot blk)
-> ObjectPoolReader
     PerasRoundNo (PerasCert blk) PerasCertTicketNo m
makePerasCertPoolReaderFromSnapshot (PerasCertDB m blk -> STM m (PerasCertSnapshot blk)
forall (m :: * -> *) blk.
PerasCertDB m blk -> STM m (PerasCertSnapshot blk)
PerasCertDB.getCertSnapshot PerasCertDB m blk
perasCertDB)

makePerasCertPoolWriterFromCertDB ::
  (StandardHash blk, IOLike m) =>
  SystemTime m ->
  PerasCertDB m blk ->
  ObjectPoolWriter PerasRoundNo (PerasCert blk) m
makePerasCertPoolWriterFromCertDB :: forall blk (m :: * -> *).
(StandardHash blk, IOLike m) =>
SystemTime m
-> PerasCertDB m blk
-> ObjectPoolWriter PerasRoundNo (PerasCert blk) m
makePerasCertPoolWriterFromCertDB SystemTime m
systemTime PerasCertDB m blk
perasCertDB =
  ObjectPoolWriter
    { opwObjectId :: PerasCert blk -> PerasRoundNo
opwObjectId = PerasCert blk -> PerasRoundNo
forall cert. HasPerasCertRound cert => cert -> PerasRoundNo
getPerasCertRound
    , opwAddObjects :: [PerasCert blk] -> m ()
opwAddObjects = \[PerasCert blk]
certs -> do
        SystemTime m
-> (WithArrivalTime (ValidatedPerasCert blk)
    -> m AddPerasCertResult)
-> [PerasCert blk]
-> m ()
forall blk (m :: * -> *) a.
(StandardHash blk, MonadSTM m) =>
SystemTime m
-> (WithArrivalTime (ValidatedPerasCert blk) -> m a)
-> [PerasCert blk]
-> m ()
addPerasCerts SystemTime m
systemTime (PerasCertDB m blk
-> WithArrivalTime (ValidatedPerasCert blk) -> m AddPerasCertResult
forall (m :: * -> *) blk.
PerasCertDB m blk
-> WithArrivalTime (ValidatedPerasCert blk) -> m AddPerasCertResult
PerasCertDB.addCert PerasCertDB m blk
perasCertDB) [PerasCert blk]
certs
    , opwHasObject :: STM m (PerasRoundNo -> Bool)
opwHasObject = do
        certSnapshot <- PerasCertDB m blk -> STM m (PerasCertSnapshot blk)
forall (m :: * -> *) blk.
PerasCertDB m blk -> STM m (PerasCertSnapshot blk)
PerasCertDB.getCertSnapshot PerasCertDB m blk
perasCertDB
        pure $ PerasCertDB.containsCert certSnapshot
    }

makePerasCertPoolReaderFromChainDB ::
  IOLike m =>
  ChainDB m blk ->
  ObjectPoolReader PerasRoundNo (PerasCert blk) PerasCertTicketNo m
makePerasCertPoolReaderFromChainDB :: forall (m :: * -> *) blk.
IOLike m =>
ChainDB m blk
-> ObjectPoolReader
     PerasRoundNo (PerasCert blk) PerasCertTicketNo m
makePerasCertPoolReaderFromChainDB ChainDB m blk
chainDB =
  STM m (PerasCertSnapshot blk)
-> ObjectPoolReader
     PerasRoundNo (PerasCert blk) PerasCertTicketNo m
forall (m :: * -> *) blk.
IOLike m =>
STM m (PerasCertSnapshot blk)
-> ObjectPoolReader
     PerasRoundNo (PerasCert blk) PerasCertTicketNo m
makePerasCertPoolReaderFromSnapshot (ChainDB m blk -> STM m (PerasCertSnapshot blk)
forall (m :: * -> *) blk.
ChainDB m blk -> STM m (PerasCertSnapshot blk)
ChainDB.getPerasCertSnapshot ChainDB m blk
chainDB)

makePerasCertPoolWriterFromChainDB ::
  (StandardHash blk, IOLike m) =>
  SystemTime m ->
  ChainDB m blk ->
  ObjectPoolWriter PerasRoundNo (PerasCert blk) m
makePerasCertPoolWriterFromChainDB :: forall blk (m :: * -> *).
(StandardHash blk, IOLike m) =>
SystemTime m
-> ChainDB m blk -> ObjectPoolWriter PerasRoundNo (PerasCert blk) m
makePerasCertPoolWriterFromChainDB SystemTime m
systemTime ChainDB m blk
chainDB =
  ObjectPoolWriter
    { opwObjectId :: PerasCert blk -> PerasRoundNo
opwObjectId = PerasCert blk -> PerasRoundNo
forall cert. HasPerasCertRound cert => cert -> PerasRoundNo
getPerasCertRound
    , opwAddObjects :: [PerasCert blk] -> m ()
opwAddObjects = \[PerasCert blk]
certs -> do
        SystemTime m
-> (WithArrivalTime (ValidatedPerasCert blk)
    -> m (AddPerasCertPromise m))
-> [PerasCert blk]
-> m ()
forall blk (m :: * -> *) a.
(StandardHash blk, MonadSTM m) =>
SystemTime m
-> (WithArrivalTime (ValidatedPerasCert blk) -> m a)
-> [PerasCert blk]
-> m ()
addPerasCerts SystemTime m
systemTime (ChainDB m blk
-> WithArrivalTime (ValidatedPerasCert blk)
-> m (AddPerasCertPromise m)
forall (m :: * -> *) blk.
ChainDB m blk
-> WithArrivalTime (ValidatedPerasCert blk)
-> m (AddPerasCertPromise m)
ChainDB.addPerasCertAsync ChainDB m blk
chainDB) [PerasCert blk]
certs
    , opwHasObject :: STM m (PerasRoundNo -> Bool)
opwHasObject = do
        certSnapshot <- ChainDB m blk -> STM m (PerasCertSnapshot blk)
forall (m :: * -> *) blk.
ChainDB m blk -> STM m (PerasCertSnapshot blk)
ChainDB.getPerasCertSnapshot ChainDB m blk
chainDB
        pure $ PerasCertDB.containsCert certSnapshot
    }

data PerasCertInboundException
  = forall blk. PerasCertValidationError [PerasValidationErr blk]

deriving instance Show PerasCertInboundException

instance Exception PerasCertInboundException

-- | Add a batch of Peras certs to a pool after validating them.
addPerasCerts ::
  (StandardHash blk, MonadSTM m) =>
  SystemTime m ->
  (WithArrivalTime (ValidatedPerasCert blk) -> m a) ->
  [PerasCert blk] ->
  m ()
addPerasCerts :: forall blk (m :: * -> *) a.
(StandardHash blk, MonadSTM m) =>
SystemTime m
-> (WithArrivalTime (ValidatedPerasCert blk) -> m a)
-> [PerasCert blk]
-> m ()
addPerasCerts SystemTime m
systemTime WithArrivalTime (ValidatedPerasCert blk) -> m a
addCert [PerasCert blk]
certs = do
  now <- SystemTime m -> m RelativeTime
forall (m :: * -> *). SystemTime m -> m RelativeTime
systemTimeCurrent SystemTime m
systemTime
  case validatePerasCerts certs of
    -- All certs are valid => add them to the pool
    ([], [ValidatedPerasCert blk]
validatedCerts) ->
      (ValidatedPerasCert blk -> m a) -> [ValidatedPerasCert blk] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_
        (WithArrivalTime (ValidatedPerasCert blk) -> m a
addCert (WithArrivalTime (ValidatedPerasCert blk) -> m a)
-> (ValidatedPerasCert blk
    -> WithArrivalTime (ValidatedPerasCert blk))
-> ValidatedPerasCert blk
-> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RelativeTime
-> ValidatedPerasCert blk
-> WithArrivalTime (ValidatedPerasCert blk)
forall a. RelativeTime -> a -> WithArrivalTime a
WithArrivalTime RelativeTime
now)
        [ValidatedPerasCert blk]
validatedCerts
    -- Some certs are invalid => reject the whole batch
    ([PerasValidationErr blk]
errs, [ValidatedPerasCert blk]
_) ->
      PerasCertInboundException -> m ()
forall a e. (HasCallStack, Exception e) => e -> a
throw ([PerasValidationErr blk] -> PerasCertInboundException
forall blk. [PerasValidationErr blk] -> PerasCertInboundException
PerasCertValidationError [PerasValidationErr blk]
errs)

-- | Validate a batch of Peras certs.
validatePerasCerts ::
  StandardHash blk =>
  [PerasCert blk] ->
  ([PerasValidationErr blk], [ValidatedPerasCert blk])
validatePerasCerts :: forall blk.
StandardHash blk =>
[PerasCert blk]
-> ([PerasValidationErr blk], [ValidatedPerasCert blk])
validatePerasCerts [PerasCert blk]
certs = do
  let perasParams :: PerasParams
perasParams = PerasParams
mkPerasParams
  -- TODO pass down 'BlockConfig' when all the plumbing is in place
  -- see https://github.com/tweag/cardano-peras/issues/73
  -- see https://github.com/tweag/cardano-peras/issues/120
  [Either (PerasValidationErr blk) (ValidatedPerasCert blk)]
-> ([PerasValidationErr blk], [ValidatedPerasCert blk])
forall a b. [Either a b] -> ([a], [b])
partitionEithers (PerasCfg blk
-> PerasCert blk
-> Either (PerasValidationErr blk) (ValidatedPerasCert blk)
forall blk.
BlockSupportsPeras blk =>
PerasCfg blk
-> PerasCert blk
-> Either (PerasValidationErr blk) (ValidatedPerasCert blk)
validatePerasCert PerasParams
PerasCfg blk
perasParams (PerasCert blk
 -> Either (PerasValidationErr blk) (ValidatedPerasCert blk))
-> [PerasCert blk]
-> [Either (PerasValidationErr blk) (ValidatedPerasCert blk)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [PerasCert blk]
certs)