{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE UndecidableInstances #-}

module Test.Ouroboros.Storage.PerasCertDB.Model
  ( Model (..)
  , initModel
  , openDB
  , closeDB
  , addCert
  , getWeightSnapshot
  , getLatestCertSeen
  , garbageCollect
  , hasRoundNo
  ) where

import Data.Set (Set)
import qualified Data.Set as Set
import GHC.Generics (Generic)
import Ouroboros.Consensus.Block
import Ouroboros.Consensus.BlockchainTime.WallClock.Types (WithArrivalTime, forgetArrivalTime)
import Ouroboros.Consensus.Peras.Weight
  ( PerasWeightSnapshot
  , mkPerasWeightSnapshot
  )
import Ouroboros.Consensus.Util (safeMaximumOn)

data Model blk = Model
  { forall blk.
Model blk -> Set (WithArrivalTime (ValidatedPerasCert blk))
certs :: Set (WithArrivalTime (ValidatedPerasCert blk))
  , forall blk.
Model blk -> Maybe (WithArrivalTime (ValidatedPerasCert blk))
latestCertSeen :: Maybe (WithArrivalTime (ValidatedPerasCert blk))
  , forall blk. Model blk -> Bool
open :: Bool
  }
  deriving (forall x. Model blk -> Rep (Model blk) x)
-> (forall x. Rep (Model blk) x -> Model blk)
-> Generic (Model blk)
forall x. Rep (Model blk) x -> Model blk
forall x. Model blk -> Rep (Model blk) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall blk x. Rep (Model blk) x -> Model blk
forall blk x. Model blk -> Rep (Model blk) x
$cfrom :: forall blk x. Model blk -> Rep (Model blk) x
from :: forall x. Model blk -> Rep (Model blk) x
$cto :: forall blk x. Rep (Model blk) x -> Model blk
to :: forall x. Rep (Model blk) x -> Model blk
Generic

deriving instance StandardHash blk => Show (Model blk)

initModel :: Model blk
initModel :: forall blk. Model blk
initModel = Model{open :: Bool
open = Bool
False, certs :: Set (WithArrivalTime (ValidatedPerasCert blk))
certs = Set (WithArrivalTime (ValidatedPerasCert blk))
forall a. Set a
Set.empty, latestCertSeen :: Maybe (WithArrivalTime (ValidatedPerasCert blk))
latestCertSeen = Maybe (WithArrivalTime (ValidatedPerasCert blk))
forall a. Maybe a
Nothing}

openDB :: Model blk -> Model blk
openDB :: forall blk. Model blk -> Model blk
openDB Model blk
model = Model blk
model{open = True}

closeDB :: Model blk -> Model blk
closeDB :: forall blk. Model blk -> Model blk
closeDB Model blk
_ = Model{open :: Bool
open = Bool
False, certs :: Set (WithArrivalTime (ValidatedPerasCert blk))
certs = Set (WithArrivalTime (ValidatedPerasCert blk))
forall a. Set a
Set.empty, latestCertSeen :: Maybe (WithArrivalTime (ValidatedPerasCert blk))
latestCertSeen = Maybe (WithArrivalTime (ValidatedPerasCert blk))
forall a. Maybe a
Nothing}

addCert ::
  StandardHash blk =>
  Model blk -> WithArrivalTime (ValidatedPerasCert blk) -> Model blk
addCert :: forall blk.
StandardHash blk =>
Model blk -> WithArrivalTime (ValidatedPerasCert blk) -> Model blk
addCert model :: Model blk
model@Model{Set (WithArrivalTime (ValidatedPerasCert blk))
certs :: forall blk.
Model blk -> Set (WithArrivalTime (ValidatedPerasCert blk))
certs :: Set (WithArrivalTime (ValidatedPerasCert blk))
certs} WithArrivalTime (ValidatedPerasCert blk)
cert
  | Set (WithArrivalTime (ValidatedPerasCert blk))
certs Set (WithArrivalTime (ValidatedPerasCert blk))
-> WithArrivalTime (ValidatedPerasCert blk) -> Bool
forall blk.
Set (WithArrivalTime (ValidatedPerasCert blk))
-> WithArrivalTime (ValidatedPerasCert blk) -> Bool
`hasRoundNo` WithArrivalTime (ValidatedPerasCert blk)
cert = Model blk
model
  | Bool
otherwise = Model blk
model{certs = certs', latestCertSeen = safeMaximumOn roundNo (Set.toList certs')}
 where
  certs' :: Set (WithArrivalTime (ValidatedPerasCert blk))
certs' = WithArrivalTime (ValidatedPerasCert blk)
-> Set (WithArrivalTime (ValidatedPerasCert blk))
-> Set (WithArrivalTime (ValidatedPerasCert blk))
forall a. Ord a => a -> Set a -> Set a
Set.insert WithArrivalTime (ValidatedPerasCert blk)
cert Set (WithArrivalTime (ValidatedPerasCert blk))
certs
  roundNo :: WithArrivalTime (ValidatedPerasCert blk) -> PerasRoundNo
roundNo = ValidatedPerasCert blk -> PerasRoundNo
forall cert. HasPerasCertRound cert => cert -> PerasRoundNo
getPerasCertRound (ValidatedPerasCert blk -> PerasRoundNo)
-> (WithArrivalTime (ValidatedPerasCert blk)
    -> ValidatedPerasCert blk)
-> WithArrivalTime (ValidatedPerasCert blk)
-> PerasRoundNo
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WithArrivalTime (ValidatedPerasCert blk) -> ValidatedPerasCert blk
forall a. WithArrivalTime a -> a
forgetArrivalTime

hasRoundNo ::
  Set (WithArrivalTime (ValidatedPerasCert blk)) ->
  WithArrivalTime (ValidatedPerasCert blk) ->
  Bool
hasRoundNo :: forall blk.
Set (WithArrivalTime (ValidatedPerasCert blk))
-> WithArrivalTime (ValidatedPerasCert blk) -> Bool
hasRoundNo Set (WithArrivalTime (ValidatedPerasCert blk))
certs WithArrivalTime (ValidatedPerasCert blk)
cert =
  (WithArrivalTime (ValidatedPerasCert blk) -> PerasRoundNo
forall cert. HasPerasCertRound cert => cert -> PerasRoundNo
getPerasCertRound WithArrivalTime (ValidatedPerasCert blk)
cert) PerasRoundNo -> Set PerasRoundNo -> Bool
forall a. Ord a => a -> Set a -> Bool
`Set.member` ((WithArrivalTime (ValidatedPerasCert blk) -> PerasRoundNo)
-> Set (WithArrivalTime (ValidatedPerasCert blk))
-> Set PerasRoundNo
forall b a. Ord b => (a -> b) -> Set a -> Set b
Set.map WithArrivalTime (ValidatedPerasCert blk) -> PerasRoundNo
forall cert. HasPerasCertRound cert => cert -> PerasRoundNo
getPerasCertRound Set (WithArrivalTime (ValidatedPerasCert blk))
certs)

getWeightSnapshot ::
  StandardHash blk =>
  Model blk -> PerasWeightSnapshot blk
getWeightSnapshot :: forall blk.
StandardHash blk =>
Model blk -> PerasWeightSnapshot blk
getWeightSnapshot Model{Set (WithArrivalTime (ValidatedPerasCert blk))
certs :: forall blk.
Model blk -> Set (WithArrivalTime (ValidatedPerasCert blk))
certs :: Set (WithArrivalTime (ValidatedPerasCert blk))
certs} =
  [(Point blk, PerasWeight)] -> PerasWeightSnapshot blk
forall blk.
StandardHash blk =>
[(Point blk, PerasWeight)] -> PerasWeightSnapshot blk
mkPerasWeightSnapshot
    [ (WithArrivalTime (ValidatedPerasCert blk) -> Point blk
forall cert blk.
HasPerasCertBoostedBlock cert blk =>
cert -> Point blk
getPerasCertBoostedBlock WithArrivalTime (ValidatedPerasCert blk)
cert, WithArrivalTime (ValidatedPerasCert blk) -> PerasWeight
forall cert. HasPerasCertBoost cert => cert -> PerasWeight
getPerasCertBoost WithArrivalTime (ValidatedPerasCert blk)
cert)
    | WithArrivalTime (ValidatedPerasCert blk)
cert <- Set (WithArrivalTime (ValidatedPerasCert blk))
-> [WithArrivalTime (ValidatedPerasCert blk)]
forall a. Set a -> [a]
Set.toList Set (WithArrivalTime (ValidatedPerasCert blk))
certs
    ]

getLatestCertSeen ::
  Model blk -> Maybe (WithArrivalTime (ValidatedPerasCert blk))
getLatestCertSeen :: forall blk.
Model blk -> Maybe (WithArrivalTime (ValidatedPerasCert blk))
getLatestCertSeen Model{Maybe (WithArrivalTime (ValidatedPerasCert blk))
latestCertSeen :: forall blk.
Model blk -> Maybe (WithArrivalTime (ValidatedPerasCert blk))
latestCertSeen :: Maybe (WithArrivalTime (ValidatedPerasCert blk))
latestCertSeen} =
  Maybe (WithArrivalTime (ValidatedPerasCert blk))
latestCertSeen

garbageCollect :: SlotNo -> Model blk -> Model blk
garbageCollect :: forall blk. SlotNo -> Model blk -> Model blk
garbageCollect SlotNo
slot model :: Model blk
model@Model{Set (WithArrivalTime (ValidatedPerasCert blk))
certs :: forall blk.
Model blk -> Set (WithArrivalTime (ValidatedPerasCert blk))
certs :: Set (WithArrivalTime (ValidatedPerasCert blk))
certs} =
  Model blk
model{certs = Set.filter keepCert certs}
 where
  keepCert :: cert -> Bool
keepCert cert
cert = Point block -> WithOrigin SlotNo
forall {k} (block :: k). Point block -> WithOrigin SlotNo
pointSlot (cert -> Point block
forall cert blk.
HasPerasCertBoostedBlock cert blk =>
cert -> Point blk
getPerasCertBoostedBlock cert
cert) WithOrigin SlotNo -> WithOrigin SlotNo -> Bool
forall a. Ord a => a -> a -> Bool
>= SlotNo -> WithOrigin SlotNo
forall t. t -> WithOrigin t
NotOrigin SlotNo
slot