{-# LANGUAGE GADTs #-}
{-# LANGUAGE StandaloneDeriving #-}

-- | Instantiate 'ObjectPoolReader' and 'ObjectPoolWriter' using Peras
-- certificates from the 'PerasCertDB' (or the 'ChainDB' which is wrapping the
-- 'PerasCertDB').
module Ouroboros.Consensus.MiniProtocol.ObjectDiffusion.ObjectPool.PerasCert
  ( makePerasCertPoolReaderFromCertDB
  , makePerasCertPoolWriterFromCertDB
  , makePerasCertPoolReaderFromChainDB
  , makePerasCertPoolWriterFromChainDB
  ) where

import Control.Monad (join)
import Data.Either (partitionEithers)
import Data.Functor (void)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Set (Set)
import qualified Data.Set as Set
import GHC.Exception (throw)
import Ouroboros.Consensus.Block
import Ouroboros.Consensus.BlockchainTime.WallClock.Types
  ( SystemTime (..)
  , WithArrivalTime (..)
  )
import Ouroboros.Consensus.MiniProtocol.ObjectDiffusion.ObjectPool.API
import Ouroboros.Consensus.Storage.ChainDB.API (ChainDB)
import qualified Ouroboros.Consensus.Storage.ChainDB.API as ChainDB
import Ouroboros.Consensus.Storage.PerasCertDB.API
  ( PerasCertDB
  , PerasCertTicketNo
  )
import qualified Ouroboros.Consensus.Storage.PerasCertDB.API as PerasCertDB
import Ouroboros.Consensus.Util.IOLike

-- | TODO: replace by `Data.Map.take` as soon as we move to GHC 9.8
takeAscMap :: Int -> Map k v -> Map k v
takeAscMap :: forall k v. Int -> Map k v -> Map k v
takeAscMap Int
n = [(k, v)] -> Map k v
forall k a. [(k, a)] -> Map k a
Map.fromDistinctAscList ([(k, v)] -> Map k v)
-> (Map k v -> [(k, v)]) -> Map k v -> Map k v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [(k, v)] -> [(k, v)]
forall a. Int -> [a] -> [a]
take Int
n ([(k, v)] -> [(k, v)])
-> (Map k v -> [(k, v)]) -> Map k v -> [(k, v)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map k v -> [(k, v)]
forall k a. Map k a -> [(k, a)]
Map.toAscList

-------------------------------------------------------------------------------
-- Readers
-------------------------------------------------------------------------------

-- | Internal helper: create a pool reader from a @getCertsAfter@ function.
makePerasCertPoolReader ::
  IOLike m =>
  ( PerasCertTicketNo ->
    STM m (Map PerasCertTicketNo (m (WithArrivalTime (ValidatedPerasCert blk))))
  ) ->
  ObjectPoolReader PerasRoundNo (PerasCert blk) PerasCertTicketNo m
makePerasCertPoolReader :: forall (m :: * -> *) blk.
IOLike m =>
(PerasCertTicketNo
 -> STM
      m
      (Map
         PerasCertTicketNo (m (WithArrivalTime (ValidatedPerasCert blk)))))
-> ObjectPoolReader
     PerasRoundNo (PerasCert blk) PerasCertTicketNo m
makePerasCertPoolReader PerasCertTicketNo
-> STM
     m
     (Map
        PerasCertTicketNo (m (WithArrivalTime (ValidatedPerasCert blk))))
getCertsAfterSTM =
  ObjectPoolReader
    { oprObjectId :: PerasCert blk -> PerasRoundNo
oprObjectId = PerasCert blk -> PerasRoundNo
forall cert. HasPerasCertRound cert => cert -> PerasRoundNo
getPerasCertRound
    , oprZeroTicketNo :: PerasCertTicketNo
oprZeroTicketNo = PerasCertTicketNo
PerasCertDB.zeroPerasCertTicketNo
    , oprObjectsAfter :: PerasCertTicketNo
-> Word64
-> STM m (Maybe (m (Map PerasCertTicketNo (PerasCert blk))))
oprObjectsAfter = \PerasCertTicketNo
lastKnown Word64
limit -> do
        certsAfterLastKnownNoLimit <- PerasCertTicketNo
-> STM
     m
     (Map
        PerasCertTicketNo (m (WithArrivalTime (ValidatedPerasCert blk))))
getCertsAfterSTM PerasCertTicketNo
lastKnown
        if Map.null certsAfterLastKnownNoLimit
          then pure Nothing
          else pure . Just $ do
            let certsAfterLastKnown = Int
-> Map
     PerasCertTicketNo (m (WithArrivalTime (ValidatedPerasCert blk)))
-> Map
     PerasCertTicketNo (m (WithArrivalTime (ValidatedPerasCert blk)))
forall k v. Int -> Map k v -> Map k v
takeAscMap (Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
limit) Map
  PerasCertTicketNo (m (WithArrivalTime (ValidatedPerasCert blk)))
certsAfterLastKnownNoLimit
            traverse
              (\m (WithArrivalTime (ValidatedPerasCert blk))
loadCertAction -> (ValidatedPerasCert blk -> PerasCert blk
forall blk. ValidatedPerasCert blk -> PerasCert blk
vpcCert (ValidatedPerasCert blk -> PerasCert blk)
-> (WithArrivalTime (ValidatedPerasCert blk)
    -> ValidatedPerasCert blk)
-> WithArrivalTime (ValidatedPerasCert blk)
-> PerasCert blk
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WithArrivalTime (ValidatedPerasCert blk) -> ValidatedPerasCert blk
forall a. WithArrivalTime a -> a
forgetArrivalTime) (WithArrivalTime (ValidatedPerasCert blk) -> PerasCert blk)
-> m (WithArrivalTime (ValidatedPerasCert blk))
-> m (PerasCert blk)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (WithArrivalTime (ValidatedPerasCert blk))
loadCertAction)
              certsAfterLastKnown
    }

makePerasCertPoolReaderFromCertDB ::
  IOLike m =>
  PerasCertDB m blk ->
  ObjectPoolReader PerasRoundNo (PerasCert blk) PerasCertTicketNo m
makePerasCertPoolReaderFromCertDB :: forall (m :: * -> *) blk.
IOLike m =>
PerasCertDB m blk
-> ObjectPoolReader
     PerasRoundNo (PerasCert blk) PerasCertTicketNo m
makePerasCertPoolReaderFromCertDB PerasCertDB m blk
perasCertDB =
  (PerasCertTicketNo
 -> STM
      m
      (Map
         PerasCertTicketNo (m (WithArrivalTime (ValidatedPerasCert blk)))))
-> ObjectPoolReader
     PerasRoundNo (PerasCert blk) PerasCertTicketNo m
forall (m :: * -> *) blk.
IOLike m =>
(PerasCertTicketNo
 -> STM
      m
      (Map
         PerasCertTicketNo (m (WithArrivalTime (ValidatedPerasCert blk)))))
-> ObjectPoolReader
     PerasRoundNo (PerasCert blk) PerasCertTicketNo m
makePerasCertPoolReader
    (PerasCertDB m blk
-> PerasCertTicketNo
-> STM
     m
     (Map
        PerasCertTicketNo (m (WithArrivalTime (ValidatedPerasCert blk))))
forall (m :: * -> *) blk.
PerasCertDB m blk
-> PerasCertTicketNo
-> STM
     m
     (Map
        PerasCertTicketNo (m (WithArrivalTime (ValidatedPerasCert blk))))
PerasCertDB.getCertsAfter PerasCertDB m blk
perasCertDB)

makePerasCertPoolReaderFromChainDB ::
  IOLike m =>
  ChainDB m blk ->
  ObjectPoolReader PerasRoundNo (PerasCert blk) PerasCertTicketNo m
makePerasCertPoolReaderFromChainDB :: forall (m :: * -> *) blk.
IOLike m =>
ChainDB m blk
-> ObjectPoolReader
     PerasRoundNo (PerasCert blk) PerasCertTicketNo m
makePerasCertPoolReaderFromChainDB ChainDB m blk
chainDB =
  (PerasCertTicketNo
 -> STM
      m
      (Map
         PerasCertTicketNo (m (WithArrivalTime (ValidatedPerasCert blk)))))
-> ObjectPoolReader
     PerasRoundNo (PerasCert blk) PerasCertTicketNo m
forall (m :: * -> *) blk.
IOLike m =>
(PerasCertTicketNo
 -> STM
      m
      (Map
         PerasCertTicketNo (m (WithArrivalTime (ValidatedPerasCert blk)))))
-> ObjectPoolReader
     PerasRoundNo (PerasCert blk) PerasCertTicketNo m
makePerasCertPoolReader
    (ChainDB m blk
-> PerasCertTicketNo
-> STM
     m
     (Map
        PerasCertTicketNo (m (WithArrivalTime (ValidatedPerasCert blk))))
forall (m :: * -> *) blk.
ChainDB m blk
-> PerasCertTicketNo
-> STM
     m
     (Map
        PerasCertTicketNo (m (WithArrivalTime (ValidatedPerasCert blk))))
ChainDB.getPerasCertsAfter ChainDB m blk
chainDB)

-------------------------------------------------------------------------------
-- Writers
-------------------------------------------------------------------------------

-- | Create a pool writer directly from a 'PerasCertDB'. This is mostly meant
-- for tests against the 'PerasCertDB' in isolation; for actual production use,
-- see 'makePerasCertPoolWriterFromChainDB' which creates a pool writer from the
-- 'ChainDB' with proper handling of chain selection side-effects.
makePerasCertPoolWriterFromCertDB ::
  (StandardHash blk, IOLike m) =>
  SystemTime m ->
  PerasCertDB m blk ->
  ObjectPoolWriter PerasRoundNo (PerasCert blk) m
makePerasCertPoolWriterFromCertDB :: forall blk (m :: * -> *).
(StandardHash blk, IOLike m) =>
SystemTime m
-> PerasCertDB m blk
-> ObjectPoolWriter PerasRoundNo (PerasCert blk) m
makePerasCertPoolWriterFromCertDB SystemTime m
systemTime PerasCertDB m blk
perasCertDB =
  ObjectPoolWriter
    { opwObjectId :: PerasCert blk -> PerasRoundNo
opwObjectId = PerasCert blk -> PerasRoundNo
forall cert. HasPerasCertRound cert => cert -> PerasRoundNo
getPerasCertRound
    , opwAddObjects :: [PerasCert blk] -> m ()
opwAddObjects = \[PerasCert blk]
certs ->
        SystemTime m
-> STM m (Set PerasRoundNo)
-> (PerasCert blk
    -> Either (PerasValidationErr blk) (ValidatedPerasCert blk))
-> (WithArrivalTime (ValidatedPerasCert blk) -> m ())
-> [PerasCert blk]
-> m ()
forall (m :: * -> *) blk.
MonadSTM m =>
SystemTime m
-> STM m (Set PerasRoundNo)
-> (PerasCert blk
    -> Either (PerasValidationErr blk) (ValidatedPerasCert blk))
-> (WithArrivalTime (ValidatedPerasCert blk) -> m ())
-> [PerasCert blk]
-> m ()
processCerts
          SystemTime m
systemTime
          (PerasCertDB m blk -> STM m (Set PerasRoundNo)
forall (m :: * -> *) blk.
PerasCertDB m blk -> STM m (Set PerasRoundNo)
PerasCertDB.getCertIds PerasCertDB m blk
perasCertDB)
          (PerasCfg blk
-> PerasCert blk
-> Either (PerasValidationErr blk) (ValidatedPerasCert blk)
forall blk.
BlockSupportsPeras blk =>
PerasCfg blk
-> PerasCert blk
-> Either (PerasValidationErr blk) (ValidatedPerasCert blk)
validatePerasCert PerasParams
PerasCfg blk
mkPerasParams) -- TODO replace when actual plumbing is in place
          (m AddPerasCertResult -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m AddPerasCertResult -> m ())
-> (WithArrivalTime (ValidatedPerasCert blk)
    -> m AddPerasCertResult)
-> WithArrivalTime (ValidatedPerasCert blk)
-> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m (m AddPerasCertResult) -> m AddPerasCertResult
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (m (m AddPerasCertResult) -> m AddPerasCertResult)
-> (WithArrivalTime (ValidatedPerasCert blk)
    -> m (m AddPerasCertResult))
-> WithArrivalTime (ValidatedPerasCert blk)
-> m AddPerasCertResult
forall b c a. (b -> c) -> (a -> b) -> a -> c
. STM m (m AddPerasCertResult) -> m (m AddPerasCertResult)
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m (m AddPerasCertResult) -> m (m AddPerasCertResult))
-> (WithArrivalTime (ValidatedPerasCert blk)
    -> STM m (m AddPerasCertResult))
-> WithArrivalTime (ValidatedPerasCert blk)
-> m (m AddPerasCertResult)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PerasCertDB m blk
-> WithArrivalTime (ValidatedPerasCert blk)
-> STM m (m AddPerasCertResult)
forall (m :: * -> *) blk.
PerasCertDB m blk
-> WithArrivalTime (ValidatedPerasCert blk)
-> STM m (m AddPerasCertResult)
PerasCertDB.addCert PerasCertDB m blk
perasCertDB)
          [PerasCert blk]
certs
    , opwHasObject :: STM m (PerasRoundNo -> Bool)
opwHasObject = do
        certIds <- PerasCertDB m blk -> STM m (Set PerasRoundNo)
forall (m :: * -> *) blk.
PerasCertDB m blk -> STM m (Set PerasRoundNo)
PerasCertDB.getCertIds PerasCertDB m blk
perasCertDB
        pure $ \PerasRoundNo
roundNo -> PerasRoundNo -> Set PerasRoundNo -> Bool
forall a. Ord a => a -> Set a -> Bool
Set.member PerasRoundNo
roundNo Set PerasRoundNo
certIds
    }

-- | Create a pool writer from the 'ChainDB'. This properly handles any needed
-- chain selection side-effects.
makePerasCertPoolWriterFromChainDB ::
  (StandardHash blk, IOLike m) =>
  SystemTime m ->
  ChainDB m blk ->
  ObjectPoolWriter PerasRoundNo (PerasCert blk) m
makePerasCertPoolWriterFromChainDB :: forall blk (m :: * -> *).
(StandardHash blk, IOLike m) =>
SystemTime m
-> ChainDB m blk -> ObjectPoolWriter PerasRoundNo (PerasCert blk) m
makePerasCertPoolWriterFromChainDB SystemTime m
systemTime ChainDB m blk
chainDB =
  ObjectPoolWriter
    { opwObjectId :: PerasCert blk -> PerasRoundNo
opwObjectId = PerasCert blk -> PerasRoundNo
forall cert. HasPerasCertRound cert => cert -> PerasRoundNo
getPerasCertRound
    , opwAddObjects :: [PerasCert blk] -> m ()
opwAddObjects = \[PerasCert blk]
certs ->
        SystemTime m
-> STM m (Set PerasRoundNo)
-> (PerasCert blk
    -> Either (PerasValidationErr blk) (ValidatedPerasCert blk))
-> (WithArrivalTime (ValidatedPerasCert blk) -> m ())
-> [PerasCert blk]
-> m ()
forall (m :: * -> *) blk.
MonadSTM m =>
SystemTime m
-> STM m (Set PerasRoundNo)
-> (PerasCert blk
    -> Either (PerasValidationErr blk) (ValidatedPerasCert blk))
-> (WithArrivalTime (ValidatedPerasCert blk) -> m ())
-> [PerasCert blk]
-> m ()
processCerts
          SystemTime m
systemTime
          (ChainDB m blk -> STM m (Set PerasRoundNo)
forall (m :: * -> *) blk. ChainDB m blk -> STM m (Set PerasRoundNo)
ChainDB.getPerasCertIds ChainDB m blk
chainDB)
          -- TODO replace when actual plumbing is in place
          (PerasCfg blk
-> PerasCert blk
-> Either (PerasValidationErr blk) (ValidatedPerasCert blk)
forall blk.
BlockSupportsPeras blk =>
PerasCfg blk
-> PerasCert blk
-> Either (PerasValidationErr blk) (ValidatedPerasCert blk)
validatePerasCert PerasParams
PerasCfg blk
mkPerasParams)
          -- We do not want to block the writer thread on waiting for ChainSel
          -- side-effects to complete, so we use the async version of adding
          -- certs to the ChainDB and ignore the returned promise.
          -- The async action is still launched and executed behind the scenes
          -- even though we drop the promise.
          (m (AddPerasCertPromise m) -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m (AddPerasCertPromise m) -> m ())
-> (WithArrivalTime (ValidatedPerasCert blk)
    -> m (AddPerasCertPromise m))
-> WithArrivalTime (ValidatedPerasCert blk)
-> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ChainDB m blk
-> WithArrivalTime (ValidatedPerasCert blk)
-> m (AddPerasCertPromise m)
forall (m :: * -> *) blk.
ChainDB m blk
-> WithArrivalTime (ValidatedPerasCert blk)
-> m (AddPerasCertPromise m)
ChainDB.addPerasCertAsync ChainDB m blk
chainDB)
          [PerasCert blk]
certs
    , opwHasObject :: STM m (PerasRoundNo -> Bool)
opwHasObject = do
        certIds <- ChainDB m blk -> STM m (Set PerasRoundNo)
forall (m :: * -> *) blk. ChainDB m blk -> STM m (Set PerasRoundNo)
ChainDB.getPerasCertIds ChainDB m blk
chainDB
        pure $ \PerasRoundNo
roundNo -> PerasRoundNo -> Set PerasRoundNo -> Bool
forall a. Ord a => a -> Set a -> Bool
Set.member PerasRoundNo
roundNo Set PerasRoundNo
certIds
    }

data PerasCertInboundException
  = forall blk. PerasCertValidationError [PerasValidationErr blk]

deriving instance Show PerasCertInboundException

instance Exception PerasCertInboundException

-- | Process a batch of inbound Peras certificates received from a peer.
--
-- Certificates whose round number is already present in the database (as
-- determined by @alreadyInDbSTM@) are silently skipped. The remaining
-- certificates are validated; if /any/ certificate in the batch fails
-- validation, the entire batch is rejected by throwing a
-- 'PerasCertInboundException' (which should make us disconnect from the distant
-- peer, see 'withPeer' bracket function from `ouroboros-network`). Otherwise,
-- each valid certificate is timestamped with the current wall-clock time and
-- added to the database via @addCert@.
processCerts ::
  MonadSTM m =>
  SystemTime m ->
  STM m (Set PerasRoundNo) ->
  (PerasCert blk -> Either (PerasValidationErr blk) (ValidatedPerasCert blk)) ->
  (WithArrivalTime (ValidatedPerasCert blk) -> m ()) ->
  [PerasCert blk] ->
  m ()
processCerts :: forall (m :: * -> *) blk.
MonadSTM m =>
SystemTime m
-> STM m (Set PerasRoundNo)
-> (PerasCert blk
    -> Either (PerasValidationErr blk) (ValidatedPerasCert blk))
-> (WithArrivalTime (ValidatedPerasCert blk) -> m ())
-> [PerasCert blk]
-> m ()
processCerts SystemTime m
systemTime STM m (Set PerasRoundNo)
alreadyInDbSTM PerasCert blk
-> Either (PerasValidationErr blk) (ValidatedPerasCert blk)
validateCert WithArrivalTime (ValidatedPerasCert blk) -> m ()
addCert [PerasCert blk]
certs = do
  alreadyInDb <- STM m (Set PerasRoundNo) -> m (Set PerasRoundNo)
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically STM m (Set PerasRoundNo)
alreadyInDbSTM
  let certsNotAlreadyInDb = (PerasCert blk -> Bool) -> [PerasCert blk] -> [PerasCert blk]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (PerasCert blk -> Bool) -> PerasCert blk -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PerasRoundNo -> Set PerasRoundNo -> Bool
forall a. Ord a => a -> Set a -> Bool
`Set.member` Set PerasRoundNo
alreadyInDb) (PerasRoundNo -> Bool)
-> (PerasCert blk -> PerasRoundNo) -> PerasCert blk -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PerasCert blk -> PerasRoundNo
forall cert. HasPerasCertRound cert => cert -> PerasRoundNo
getPerasCertRound) [PerasCert blk]
certs
  now <- systemTimeCurrent systemTime
  case partitionEithers (validateCert <$> certsNotAlreadyInDb) of
    -- All certs are valid => add them to the pool
    ([], [ValidatedPerasCert blk]
validatedCerts) ->
      (ValidatedPerasCert blk -> m ())
-> [ValidatedPerasCert blk] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_
        (WithArrivalTime (ValidatedPerasCert blk) -> m ()
addCert (WithArrivalTime (ValidatedPerasCert blk) -> m ())
-> (ValidatedPerasCert blk
    -> WithArrivalTime (ValidatedPerasCert blk))
-> ValidatedPerasCert blk
-> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RelativeTime
-> ValidatedPerasCert blk
-> WithArrivalTime (ValidatedPerasCert blk)
forall a. RelativeTime -> a -> WithArrivalTime a
WithArrivalTime RelativeTime
now)
        [ValidatedPerasCert blk]
validatedCerts
    -- Some certs are invalid => reject the whole batch
    --
    -- N.B. it has been requested in PR review
    -- https://github.com/IntersectMBO/ouroboros-consensus/pull/1768#discussion_r2747873186
    -- to gather all validation errors and report them together in the exception
    -- rather than just report the first error encountered.
    -- This assumes that cert validation is cheap, which may not be true in
    -- practice depending on the actual crypto/committee selection scheme.
    -- Hence we may revisit this to lazily abort validation upon the first error
    -- encountered.
    ([PerasValidationErr blk]
errs, [ValidatedPerasCert blk]
_) ->
      PerasCertInboundException -> m ()
forall a e. (HasCallStack, Exception e) => e -> a
throw ([PerasValidationErr blk] -> PerasCertInboundException
forall blk. [PerasValidationErr blk] -> PerasCertInboundException
PerasCertValidationError [PerasValidationErr blk]
errs)