{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE UndecidableInstances #-}

module Test.Ouroboros.Storage.PerasCertDB.Model
  ( Model (..)
  , initModel
  , openDB
  , closeDB
  , addCert
  , getWeightSnapshot
  , garbageCollect
  , hasRoundNo
  ) where

import Data.Set (Set)
import qualified Data.Set as Set
import GHC.Generics (Generic)
import Ouroboros.Consensus.Block
import Ouroboros.Consensus.Peras.Weight
  ( PerasWeightSnapshot
  , mkPerasWeightSnapshot
  )

data Model blk = Model
  { forall blk. Model blk -> Set (ValidatedPerasCert blk)
certs :: Set (ValidatedPerasCert blk)
  , forall blk. Model blk -> Bool
open :: Bool
  }
  deriving (forall x. Model blk -> Rep (Model blk) x)
-> (forall x. Rep (Model blk) x -> Model blk)
-> Generic (Model blk)
forall x. Rep (Model blk) x -> Model blk
forall x. Model blk -> Rep (Model blk) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall blk x. Rep (Model blk) x -> Model blk
forall blk x. Model blk -> Rep (Model blk) x
$cfrom :: forall blk x. Model blk -> Rep (Model blk) x
from :: forall x. Model blk -> Rep (Model blk) x
$cto :: forall blk x. Rep (Model blk) x -> Model blk
to :: forall x. Rep (Model blk) x -> Model blk
Generic

deriving instance StandardHash blk => Show (Model blk)

initModel :: Model blk
initModel :: forall blk. Model blk
initModel = Model{open :: Bool
open = Bool
False, certs :: Set (ValidatedPerasCert blk)
certs = Set (ValidatedPerasCert blk)
forall a. Set a
Set.empty}

openDB :: Model blk -> Model blk
openDB :: forall blk. Model blk -> Model blk
openDB Model blk
model = Model blk
model{open = True}

closeDB :: Model blk -> Model blk
closeDB :: forall blk. Model blk -> Model blk
closeDB Model blk
_ = Model{open :: Bool
open = Bool
False, certs :: Set (ValidatedPerasCert blk)
certs = Set (ValidatedPerasCert blk)
forall a. Set a
Set.empty}

addCert ::
  StandardHash blk =>
  Model blk -> ValidatedPerasCert blk -> Model blk
addCert :: forall blk.
StandardHash blk =>
Model blk -> ValidatedPerasCert blk -> Model blk
addCert model :: Model blk
model@Model{Set (ValidatedPerasCert blk)
certs :: forall blk. Model blk -> Set (ValidatedPerasCert blk)
certs :: Set (ValidatedPerasCert blk)
certs} ValidatedPerasCert blk
cert
  | Set (ValidatedPerasCert blk)
certs Set (ValidatedPerasCert blk) -> ValidatedPerasCert blk -> Bool
forall blk.
StandardHash blk =>
Set (ValidatedPerasCert blk) -> ValidatedPerasCert blk -> Bool
`hasRoundNo` ValidatedPerasCert blk
cert = Model blk
model
  | Bool
otherwise = Model blk
model{certs = Set.insert cert certs}

hasRoundNo ::
  StandardHash blk =>
  Set (ValidatedPerasCert blk) ->
  ValidatedPerasCert blk ->
  Bool
hasRoundNo :: forall blk.
StandardHash blk =>
Set (ValidatedPerasCert blk) -> ValidatedPerasCert blk -> Bool
hasRoundNo Set (ValidatedPerasCert blk)
certs ValidatedPerasCert blk
cert =
  (ValidatedPerasCert blk -> PerasRoundNo
forall (cert :: * -> *) blk.
HasPerasCert cert blk =>
cert blk -> PerasRoundNo
getPerasCertRound ValidatedPerasCert blk
cert) PerasRoundNo -> Set PerasRoundNo -> Bool
forall a. Ord a => a -> Set a -> Bool
`Set.member` ((ValidatedPerasCert blk -> PerasRoundNo)
-> Set (ValidatedPerasCert blk) -> Set PerasRoundNo
forall b a. Ord b => (a -> b) -> Set a -> Set b
Set.map ValidatedPerasCert blk -> PerasRoundNo
forall (cert :: * -> *) blk.
HasPerasCert cert blk =>
cert blk -> PerasRoundNo
getPerasCertRound Set (ValidatedPerasCert blk)
certs)

getWeightSnapshot ::
  StandardHash blk =>
  Model blk -> PerasWeightSnapshot blk
getWeightSnapshot :: forall blk.
StandardHash blk =>
Model blk -> PerasWeightSnapshot blk
getWeightSnapshot Model{Set (ValidatedPerasCert blk)
certs :: forall blk. Model blk -> Set (ValidatedPerasCert blk)
certs :: Set (ValidatedPerasCert blk)
certs} =
  [(Point blk, PerasWeight)] -> PerasWeightSnapshot blk
forall blk.
StandardHash blk =>
[(Point blk, PerasWeight)] -> PerasWeightSnapshot blk
mkPerasWeightSnapshot
    [ (ValidatedPerasCert blk -> Point blk
forall (cert :: * -> *) blk.
HasPerasCert cert blk =>
cert blk -> Point blk
getPerasCertBoostedBlock ValidatedPerasCert blk
cert, ValidatedPerasCert blk -> PerasWeight
forall blk. ValidatedPerasCert blk -> PerasWeight
getPerasCertBoost ValidatedPerasCert blk
cert)
    | ValidatedPerasCert blk
cert <- Set (ValidatedPerasCert blk) -> [ValidatedPerasCert blk]
forall a. Set a -> [a]
Set.toList Set (ValidatedPerasCert blk)
certs
    ]

garbageCollect :: StandardHash blk => SlotNo -> Model blk -> Model blk
garbageCollect :: forall blk. StandardHash blk => SlotNo -> Model blk -> Model blk
garbageCollect SlotNo
slot model :: Model blk
model@Model{Set (ValidatedPerasCert blk)
certs :: forall blk. Model blk -> Set (ValidatedPerasCert blk)
certs :: Set (ValidatedPerasCert blk)
certs} =
  Model blk
model{certs = Set.filter keepCert certs}
 where
  keepCert :: cert block -> Bool
keepCert cert block
cert = Point block -> WithOrigin SlotNo
forall {k} (block :: k). Point block -> WithOrigin SlotNo
pointSlot (cert block -> Point block
forall (cert :: * -> *) blk.
HasPerasCert cert blk =>
cert blk -> Point blk
getPerasCertBoostedBlock cert block
cert) WithOrigin SlotNo -> WithOrigin SlotNo -> Bool
forall a. Ord a => a -> a -> Bool
>= SlotNo -> WithOrigin SlotNo
forall t. t -> WithOrigin t
NotOrigin SlotNo
slot