{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}

-- | Data structure for tracking the weight of blocks due to Peras boosts.
module Ouroboros.Consensus.Peras.Weight
  ( -- * 'PerasWeightSnapshot' type
    PerasWeightSnapshot

    -- * Construction
  , emptyPerasWeightSnapshot
  , mkPerasWeightSnapshot

    -- * Conversion
  , perasWeightSnapshotToList

    -- * Insertion
  , addToPerasWeightSnapshot

    -- * Pruning
  , prunePerasWeightSnapshot

    -- * Query
  , isEmptyPerasWeightSnapshot
  , weightBoostOfPoint
  , weightBoostOfFragment
  , totalWeightOfFragment
  , takeVolatileSuffix
  ) where

import Data.Foldable as Foldable (foldl')
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import GHC.Generics (Generic)
import NoThunks.Class
import Ouroboros.Consensus.Block
import Ouroboros.Consensus.Config.SecurityParam
import Ouroboros.Consensus.Util.AnchoredSeq (takeLongestSuffix)
import Ouroboros.Network.AnchoredFragment (AnchoredFragment)
import qualified Ouroboros.Network.AnchoredFragment as AF

-- | Data structure for tracking the weight of blocks due to Peras boosts.
newtype PerasWeightSnapshot blk = PerasWeightSnapshot
  { forall blk. PerasWeightSnapshot blk -> Map (Point blk) PerasWeight
getPerasWeightSnapshot :: Map (Point blk) PerasWeight
  }
  deriving stock PerasWeightSnapshot blk -> PerasWeightSnapshot blk -> Bool
(PerasWeightSnapshot blk -> PerasWeightSnapshot blk -> Bool)
-> (PerasWeightSnapshot blk -> PerasWeightSnapshot blk -> Bool)
-> Eq (PerasWeightSnapshot blk)
forall blk.
StandardHash blk =>
PerasWeightSnapshot blk -> PerasWeightSnapshot blk -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall blk.
StandardHash blk =>
PerasWeightSnapshot blk -> PerasWeightSnapshot blk -> Bool
== :: PerasWeightSnapshot blk -> PerasWeightSnapshot blk -> Bool
$c/= :: forall blk.
StandardHash blk =>
PerasWeightSnapshot blk -> PerasWeightSnapshot blk -> Bool
/= :: PerasWeightSnapshot blk -> PerasWeightSnapshot blk -> Bool
Eq
  deriving (forall x.
 PerasWeightSnapshot blk -> Rep (PerasWeightSnapshot blk) x)
-> (forall x.
    Rep (PerasWeightSnapshot blk) x -> PerasWeightSnapshot blk)
-> Generic (PerasWeightSnapshot blk)
forall x.
Rep (PerasWeightSnapshot blk) x -> PerasWeightSnapshot blk
forall x.
PerasWeightSnapshot blk -> Rep (PerasWeightSnapshot blk) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall blk x.
Rep (PerasWeightSnapshot blk) x -> PerasWeightSnapshot blk
forall blk x.
PerasWeightSnapshot blk -> Rep (PerasWeightSnapshot blk) x
$cfrom :: forall blk x.
PerasWeightSnapshot blk -> Rep (PerasWeightSnapshot blk) x
from :: forall x.
PerasWeightSnapshot blk -> Rep (PerasWeightSnapshot blk) x
$cto :: forall blk x.
Rep (PerasWeightSnapshot blk) x -> PerasWeightSnapshot blk
to :: forall x.
Rep (PerasWeightSnapshot blk) x -> PerasWeightSnapshot blk
Generic
  deriving newtype Context -> PerasWeightSnapshot blk -> IO (Maybe ThunkInfo)
Proxy (PerasWeightSnapshot blk) -> String
(Context -> PerasWeightSnapshot blk -> IO (Maybe ThunkInfo))
-> (Context -> PerasWeightSnapshot blk -> IO (Maybe ThunkInfo))
-> (Proxy (PerasWeightSnapshot blk) -> String)
-> NoThunks (PerasWeightSnapshot blk)
forall blk.
StandardHash blk =>
Context -> PerasWeightSnapshot blk -> IO (Maybe ThunkInfo)
forall blk.
StandardHash blk =>
Proxy (PerasWeightSnapshot blk) -> String
forall a.
(Context -> a -> IO (Maybe ThunkInfo))
-> (Context -> a -> IO (Maybe ThunkInfo))
-> (Proxy a -> String)
-> NoThunks a
$cnoThunks :: forall blk.
StandardHash blk =>
Context -> PerasWeightSnapshot blk -> IO (Maybe ThunkInfo)
noThunks :: Context -> PerasWeightSnapshot blk -> IO (Maybe ThunkInfo)
$cwNoThunks :: forall blk.
StandardHash blk =>
Context -> PerasWeightSnapshot blk -> IO (Maybe ThunkInfo)
wNoThunks :: Context -> PerasWeightSnapshot blk -> IO (Maybe ThunkInfo)
$cshowTypeOf :: forall blk.
StandardHash blk =>
Proxy (PerasWeightSnapshot blk) -> String
showTypeOf :: Proxy (PerasWeightSnapshot blk) -> String
NoThunks

instance StandardHash blk => Show (PerasWeightSnapshot blk) where
  show :: PerasWeightSnapshot blk -> String
show = [(Point blk, PerasWeight)] -> String
forall a. Show a => a -> String
show ([(Point blk, PerasWeight)] -> String)
-> (PerasWeightSnapshot blk -> [(Point blk, PerasWeight)])
-> PerasWeightSnapshot blk
-> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PerasWeightSnapshot blk -> [(Point blk, PerasWeight)]
forall blk. PerasWeightSnapshot blk -> [(Point blk, PerasWeight)]
perasWeightSnapshotToList

-- | An empty 'PerasWeightSnapshot' not containing any boosted blocks.
emptyPerasWeightSnapshot :: PerasWeightSnapshot blk
emptyPerasWeightSnapshot :: forall blk. PerasWeightSnapshot blk
emptyPerasWeightSnapshot = Map (Point blk) PerasWeight -> PerasWeightSnapshot blk
forall blk. Map (Point blk) PerasWeight -> PerasWeightSnapshot blk
PerasWeightSnapshot Map (Point blk) PerasWeight
forall k a. Map k a
Map.empty

-- | Create a weight snapshot from a list of boosted points with an associated
-- weight. In case of duplicate points, their weights are combined.
--
-- >>> :{
-- weights :: [(Point Blk, PerasWeight)]
-- weights =
--   [ (BlockPoint 2 "foo", PerasWeight 2)
--   , (GenesisPoint,       PerasWeight 3)
--   , (BlockPoint 3 "bar", PerasWeight 2)
--   , (BlockPoint 2 "foo", PerasWeight 2)
--   ]
-- :}
--
-- >>> snap = mkPerasWeightSnapshot weights
-- >>> snap
-- [(Point Origin,PerasWeight 3),(Point (At (Block {blockPointSlot = SlotNo 2, blockPointHash = "foo"})),PerasWeight 4),(Point (At (Block {blockPointSlot = SlotNo 3, blockPointHash = "bar"})),PerasWeight 2)]
mkPerasWeightSnapshot ::
  StandardHash blk =>
  [(Point blk, PerasWeight)] ->
  PerasWeightSnapshot blk
mkPerasWeightSnapshot :: forall blk.
StandardHash blk =>
[(Point blk, PerasWeight)] -> PerasWeightSnapshot blk
mkPerasWeightSnapshot =
  (PerasWeightSnapshot blk
 -> (Point blk, PerasWeight) -> PerasWeightSnapshot blk)
-> PerasWeightSnapshot blk
-> [(Point blk, PerasWeight)]
-> PerasWeightSnapshot blk
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
Foldable.foldl'
    (\PerasWeightSnapshot blk
s (Point blk
pt, PerasWeight
weight) -> Point blk
-> PerasWeight
-> PerasWeightSnapshot blk
-> PerasWeightSnapshot blk
forall blk.
StandardHash blk =>
Point blk
-> PerasWeight
-> PerasWeightSnapshot blk
-> PerasWeightSnapshot blk
addToPerasWeightSnapshot Point blk
pt PerasWeight
weight PerasWeightSnapshot blk
s)
    PerasWeightSnapshot blk
forall blk. PerasWeightSnapshot blk
emptyPerasWeightSnapshot

-- | Return the list of boosted points with their associated weight, sorted
-- based on their point. Does not contain duplicate points.
--
-- >>> :{
-- weights :: [(Point Blk, PerasWeight)]
-- weights =
--   [ (BlockPoint 2 "foo", PerasWeight 2)
--   , (GenesisPoint,       PerasWeight 3)
--   , (BlockPoint 3 "bar", PerasWeight 2)
--   , (BlockPoint 2 "foo", PerasWeight 2)
--   ]
-- :}
--
-- >>> snap = mkPerasWeightSnapshot weights
-- >>> perasWeightSnapshotToList snap
-- [(Point Origin,PerasWeight 3),(Point (At (Block {blockPointSlot = SlotNo 2, blockPointHash = "foo"})),PerasWeight 4),(Point (At (Block {blockPointSlot = SlotNo 3, blockPointHash = "bar"})),PerasWeight 2)]
perasWeightSnapshotToList :: PerasWeightSnapshot blk -> [(Point blk, PerasWeight)]
perasWeightSnapshotToList :: forall blk. PerasWeightSnapshot blk -> [(Point blk, PerasWeight)]
perasWeightSnapshotToList = Map (Point blk) PerasWeight -> [(Point blk, PerasWeight)]
forall k a. Map k a -> [(k, a)]
Map.toAscList (Map (Point blk) PerasWeight -> [(Point blk, PerasWeight)])
-> (PerasWeightSnapshot blk -> Map (Point blk) PerasWeight)
-> PerasWeightSnapshot blk
-> [(Point blk, PerasWeight)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PerasWeightSnapshot blk -> Map (Point blk) PerasWeight
forall blk. PerasWeightSnapshot blk -> Map (Point blk) PerasWeight
getPerasWeightSnapshot

-- | Add weight for the given point to the 'PerasWeightSnapshot'. If the point
-- already has some weight, it is added on top.
--
-- >>> :{
-- weights :: [(Point Blk, PerasWeight)]
-- weights =
--   [ (BlockPoint 2 "foo", PerasWeight 2)
--   , (GenesisPoint,       PerasWeight 3)
--   ]
-- :}
--
-- >>> snap0 = mkPerasWeightSnapshot weights
-- >>> snap0
-- [(Point Origin,PerasWeight 3),(Point (At (Block {blockPointSlot = SlotNo 2, blockPointHash = "foo"})),PerasWeight 2)]
--
-- >>> snap1 = addToPerasWeightSnapshot (BlockPoint 3 "bar") (PerasWeight 2) snap0
-- >>> snap1
-- [(Point Origin,PerasWeight 3),(Point (At (Block {blockPointSlot = SlotNo 2, blockPointHash = "foo"})),PerasWeight 2),(Point (At (Block {blockPointSlot = SlotNo 3, blockPointHash = "bar"})),PerasWeight 2)]
--
-- >>> snap2 = addToPerasWeightSnapshot (BlockPoint 2 "foo") (PerasWeight 2) snap1
-- >>> snap2
-- [(Point Origin,PerasWeight 3),(Point (At (Block {blockPointSlot = SlotNo 2, blockPointHash = "foo"})),PerasWeight 4),(Point (At (Block {blockPointSlot = SlotNo 3, blockPointHash = "bar"})),PerasWeight 2)]
addToPerasWeightSnapshot ::
  StandardHash blk =>
  Point blk ->
  PerasWeight ->
  PerasWeightSnapshot blk ->
  PerasWeightSnapshot blk
addToPerasWeightSnapshot :: forall blk.
StandardHash blk =>
Point blk
-> PerasWeight
-> PerasWeightSnapshot blk
-> PerasWeightSnapshot blk
addToPerasWeightSnapshot Point blk
pt PerasWeight
weight =
  Map (Point blk) PerasWeight -> PerasWeightSnapshot blk
forall blk. Map (Point blk) PerasWeight -> PerasWeightSnapshot blk
PerasWeightSnapshot (Map (Point blk) PerasWeight -> PerasWeightSnapshot blk)
-> (PerasWeightSnapshot blk -> Map (Point blk) PerasWeight)
-> PerasWeightSnapshot blk
-> PerasWeightSnapshot blk
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PerasWeight -> PerasWeight -> PerasWeight)
-> Point blk
-> PerasWeight
-> Map (Point blk) PerasWeight
-> Map (Point blk) PerasWeight
forall k a. Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
Map.insertWith PerasWeight -> PerasWeight -> PerasWeight
forall a. Semigroup a => a -> a -> a
(<>) Point blk
pt PerasWeight
weight (Map (Point blk) PerasWeight -> Map (Point blk) PerasWeight)
-> (PerasWeightSnapshot blk -> Map (Point blk) PerasWeight)
-> PerasWeightSnapshot blk
-> Map (Point blk) PerasWeight
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PerasWeightSnapshot blk -> Map (Point blk) PerasWeight
forall blk. PerasWeightSnapshot blk -> Map (Point blk) PerasWeight
getPerasWeightSnapshot

-- | Prune the given 'PerasWeightSnapshot' by removing the weight of all blocks
-- strictly older than the given slot.
--
-- This function is used to get garbage-collect boosted blocks blocks which are
-- older than our immutable tip as we will never adopt a chain containing them.
--
-- >>> :{
-- weights :: [(Point Blk, PerasWeight)]
-- weights =
--   [ (BlockPoint 2 "foo", PerasWeight 2)
--   , (GenesisPoint,       PerasWeight 3)
--   , (BlockPoint 3 "bar", PerasWeight 2)
--   , (BlockPoint 2 "foo", PerasWeight 2)
--   ]
-- :}
--
-- >>> snap = mkPerasWeightSnapshot weights
--
-- >>> prunePerasWeightSnapshot (SlotNo 2) snap
-- [(Point (At (Block {blockPointSlot = SlotNo 2, blockPointHash = "foo"})),PerasWeight 4),(Point (At (Block {blockPointSlot = SlotNo 3, blockPointHash = "bar"})),PerasWeight 2)]
--
-- >>> prunePerasWeightSnapshot (SlotNo 3) snap
-- [(Point (At (Block {blockPointSlot = SlotNo 3, blockPointHash = "bar"})),PerasWeight 2)]
prunePerasWeightSnapshot ::
  SlotNo ->
  PerasWeightSnapshot blk ->
  PerasWeightSnapshot blk
prunePerasWeightSnapshot :: forall blk.
SlotNo -> PerasWeightSnapshot blk -> PerasWeightSnapshot blk
prunePerasWeightSnapshot SlotNo
slot =
  Map (Point blk) PerasWeight -> PerasWeightSnapshot blk
forall blk. Map (Point blk) PerasWeight -> PerasWeightSnapshot blk
PerasWeightSnapshot (Map (Point blk) PerasWeight -> PerasWeightSnapshot blk)
-> (PerasWeightSnapshot blk -> Map (Point blk) PerasWeight)
-> PerasWeightSnapshot blk
-> PerasWeightSnapshot blk
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Point blk -> Bool)
-> Map (Point blk) PerasWeight -> Map (Point blk) PerasWeight
forall k a. (k -> Bool) -> Map k a -> Map k a
Map.dropWhileAntitone Point blk -> Bool
forall blk. Point blk -> Bool
isTooOld (Map (Point blk) PerasWeight -> Map (Point blk) PerasWeight)
-> (PerasWeightSnapshot blk -> Map (Point blk) PerasWeight)
-> PerasWeightSnapshot blk
-> Map (Point blk) PerasWeight
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PerasWeightSnapshot blk -> Map (Point blk) PerasWeight
forall blk. PerasWeightSnapshot blk -> Map (Point blk) PerasWeight
getPerasWeightSnapshot
 where
  isTooOld :: Point blk -> Bool
  isTooOld :: forall blk. Point blk -> Bool
isTooOld Point blk
pt = Point blk -> WithOrigin SlotNo
forall {k} (block :: k). Point block -> WithOrigin SlotNo
pointSlot Point blk
pt WithOrigin SlotNo -> WithOrigin SlotNo -> Bool
forall a. Ord a => a -> a -> Bool
< SlotNo -> WithOrigin SlotNo
forall t. t -> WithOrigin t
NotOrigin SlotNo
slot

-- | Check whether the snapshot contains weights for any blocks.
--
-- >>> isEmptyPerasWeightSnapshot emptyPerasWeightSnapshot
-- True
--
-- >>> :{
-- weights :: [(Point Blk, PerasWeight)]
-- weights =
--   [ (BlockPoint 2 "foo", PerasWeight 2)
--   , (GenesisPoint,       PerasWeight 3)
--   , (BlockPoint 3 "bar", PerasWeight 2)
--   , (BlockPoint 2 "foo", PerasWeight 2)
--   ]
-- :}
--
-- >>> snap = mkPerasWeightSnapshot weights
--
-- >>> isEmptyPerasWeightSnapshot snap
-- False
isEmptyPerasWeightSnapshot :: PerasWeightSnapshot blk -> Bool
isEmptyPerasWeightSnapshot :: forall blk. PerasWeightSnapshot blk -> Bool
isEmptyPerasWeightSnapshot = Map (Point blk) PerasWeight -> Bool
forall k a. Map k a -> Bool
Map.null (Map (Point blk) PerasWeight -> Bool)
-> (PerasWeightSnapshot blk -> Map (Point blk) PerasWeight)
-> PerasWeightSnapshot blk
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PerasWeightSnapshot blk -> Map (Point blk) PerasWeight
forall blk. PerasWeightSnapshot blk -> Map (Point blk) PerasWeight
getPerasWeightSnapshot

-- | Get the weight boost for a point, or @'mempty' :: 'PerasWeight'@ otherwise.
--
-- >>> :{
-- weights :: [(Point Blk, PerasWeight)]
-- weights =
--   [ (BlockPoint 2 "foo", PerasWeight 2)
--   , (GenesisPoint,       PerasWeight 3)
--   , (BlockPoint 3 "bar", PerasWeight 2)
--   , (BlockPoint 2 "foo", PerasWeight 2)
--   ]
-- :}
--
-- >>> snap = mkPerasWeightSnapshot weights
--
-- >>> weightBoostOfPoint snap (BlockPoint 2 "foo")
-- PerasWeight 4
--
-- >>> weightBoostOfPoint snap (BlockPoint 2 "baz")
-- PerasWeight 0
weightBoostOfPoint ::
  forall blk.
  StandardHash blk =>
  PerasWeightSnapshot blk -> Point blk -> PerasWeight
weightBoostOfPoint :: forall blk.
StandardHash blk =>
PerasWeightSnapshot blk -> Point blk -> PerasWeight
weightBoostOfPoint (PerasWeightSnapshot Map (Point blk) PerasWeight
weightByPoint) Point blk
pt =
  PerasWeight
-> Point blk -> Map (Point blk) PerasWeight -> PerasWeight
forall k a. Ord k => a -> k -> Map k a -> a
Map.findWithDefault PerasWeight
forall a. Monoid a => a
mempty Point blk
pt Map (Point blk) PerasWeight
weightByPoint

-- | Get the weight boost for a fragment, ie the sum of all
-- 'weightBoostOfPoint' for all points on the fragment (excluding the anchor).
--
-- Note that this quantity is relative to the anchor of the fragment, so it
-- should only be compared against other fragments with the same anchor.
--
-- >>> :{
-- weights :: [(Point Blk, PerasWeight)]
-- weights =
--   [ (BlockPoint 2 "foo", PerasWeight 2)
--   , (GenesisPoint,       PerasWeight 3)
--   , (BlockPoint 3 "bar", PerasWeight 2)
--   , (BlockPoint 2 "foo", PerasWeight 2)
--   ]
-- :}
--
-- >>> :{
-- snap = mkPerasWeightSnapshot weights
-- foo = HeaderFields (SlotNo 2) (BlockNo 1) "foo"
-- bar = HeaderFields (SlotNo 3) (BlockNo 2) "bar"
-- frag0 :: AnchoredFragment (HeaderFields Blk)
-- frag0 = Empty AnchorGenesis :> foo :> bar
-- :}
--
-- >>> weightBoostOfFragment snap frag0
-- PerasWeight 6
--
-- Only keeping the last block from @frag0@:
--
-- >>> frag1 = AF.anchorNewest 1 frag0
-- >>> weightBoostOfFragment snap frag1
-- PerasWeight 2
--
-- Dropping the head from @frag0@, and instead adding an unboosted point:
--
-- >>> frag2 = AF.dropNewest 1 frag0 :> HeaderFields (SlotNo 4) (BlockNo 2) "baz"
-- >>> weightBoostOfFragment snap frag2
-- PerasWeight 4
weightBoostOfFragment ::
  forall blk h.
  (StandardHash blk, HasHeader h, HeaderHash blk ~ HeaderHash h) =>
  PerasWeightSnapshot blk ->
  AnchoredFragment h ->
  PerasWeight
weightBoostOfFragment :: forall blk h.
(StandardHash blk, HasHeader h, HeaderHash blk ~ HeaderHash h) =>
PerasWeightSnapshot blk -> AnchoredFragment h -> PerasWeight
weightBoostOfFragment PerasWeightSnapshot blk
weightSnap AnchoredFragment h
frag
  | Map (Point blk) PerasWeight -> Bool
forall k a. Map k a -> Bool
Map.null (Map (Point blk) PerasWeight -> Bool)
-> Map (Point blk) PerasWeight -> Bool
forall a b. (a -> b) -> a -> b
$ PerasWeightSnapshot blk -> Map (Point blk) PerasWeight
forall blk. PerasWeightSnapshot blk -> Map (Point blk) PerasWeight
getPerasWeightSnapshot PerasWeightSnapshot blk
weightSnap =
      PerasWeight
forall a. Monoid a => a
mempty
  | Bool
otherwise =
      -- TODO: think about whether this could be done in sublinear complexity
      -- see https://github.com/IntersectMBO/ouroboros-consensus/pull/1613
      (h -> PerasWeight) -> [h] -> PerasWeight
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap
        (PerasWeightSnapshot blk -> Point blk -> PerasWeight
forall blk.
StandardHash blk =>
PerasWeightSnapshot blk -> Point blk -> PerasWeight
weightBoostOfPoint PerasWeightSnapshot blk
weightSnap (Point blk -> PerasWeight) -> (h -> Point blk) -> h -> PerasWeight
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Point h -> Point blk
forall {k1} {k2} (b :: k1) (b' :: k2).
Coercible (HeaderHash b) (HeaderHash b') =>
Point b -> Point b'
castPoint (Point h -> Point blk) -> (h -> Point h) -> h -> Point blk
forall b c a. (b -> c) -> (a -> b) -> a -> c
. h -> Point h
forall block. HasHeader block => block -> Point block
blockPoint)
        (AnchoredFragment h -> [h]
forall v a b. AnchoredSeq v a b -> [b]
AF.toOldestFirst AnchoredFragment h
frag)

-- | Get the total weight for a fragment, ie the length plus the weight boost
-- ('weightBoostOfFragment') of the fragment.
--
-- Note that this quantity is relative to the anchor of the fragment, so it
-- should only be compared against other fragments with the same anchor.
--
-- >>> :{
-- weights :: [(Point Blk, PerasWeight)]
-- weights =
--   [ (BlockPoint 2 "foo", PerasWeight 2)
--   , (GenesisPoint,       PerasWeight 3)
--   , (BlockPoint 3 "bar", PerasWeight 2)
--   , (BlockPoint 2 "foo", PerasWeight 2)
--   ]
-- :}
--
-- >>> :{
-- snap = mkPerasWeightSnapshot weights
-- foo = HeaderFields (SlotNo 2) (BlockNo 1) "foo"
-- bar = HeaderFields (SlotNo 3) (BlockNo 2) "bar"
-- frag0 :: AnchoredFragment (HeaderFields Blk)
-- frag0 = Empty AnchorGenesis :> foo :> bar
-- :}
--
-- >>> totalWeightOfFragment snap frag0
-- PerasWeight 8
--
-- Only keeping the last block from @frag0@:
--
-- >>> frag1 = AF.anchorNewest 1 frag0
-- >>> totalWeightOfFragment snap frag1
-- PerasWeight 3
--
-- Dropping the head from @frag0@, and instead adding an unboosted point:
--
-- >>> frag2 = AF.dropNewest 1 frag0 :> HeaderFields (SlotNo 4) (BlockNo 2) "baz"
-- >>> totalWeightOfFragment snap frag2
-- PerasWeight 6
totalWeightOfFragment ::
  forall blk h.
  (StandardHash blk, HasHeader h, HeaderHash blk ~ HeaderHash h) =>
  PerasWeightSnapshot blk ->
  AnchoredFragment h ->
  PerasWeight
totalWeightOfFragment :: forall blk h.
(StandardHash blk, HasHeader h, HeaderHash blk ~ HeaderHash h) =>
PerasWeightSnapshot blk -> AnchoredFragment h -> PerasWeight
totalWeightOfFragment PerasWeightSnapshot blk
weightSnap AnchoredFragment h
frag =
  PerasWeight
weightLength PerasWeight -> PerasWeight -> PerasWeight
forall a. Semigroup a => a -> a -> a
<> PerasWeight
weightBoost
 where
  weightLength :: PerasWeight
weightLength = Word64 -> PerasWeight
PerasWeight (Word64 -> PerasWeight) -> Word64 -> PerasWeight
forall a b. (a -> b) -> a -> b
$ Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word64) -> Int -> Word64
forall a b. (a -> b) -> a -> b
$ AnchoredFragment h -> Int
forall v a b. Anchorable v a b => AnchoredSeq v a b -> Int
AF.length AnchoredFragment h
frag
  weightBoost :: PerasWeight
weightBoost = PerasWeightSnapshot blk -> AnchoredFragment h -> PerasWeight
forall blk h.
(StandardHash blk, HasHeader h, HeaderHash blk ~ HeaderHash h) =>
PerasWeightSnapshot blk -> AnchoredFragment h -> PerasWeight
weightBoostOfFragment PerasWeightSnapshot blk
weightSnap AnchoredFragment h
frag

-- | Take the longest suffix of the given fragment with total weight
-- ('totalWeightOfFragment') at most @k@. This is the volatile suffix of blocks
-- which are subject to rollback.
--
-- If the total weight of the input fragment is at least @k@, then the anchor of
-- the output fragment is the most recent point on the input fragment that is
-- buried under at least weight @k@ (also counting the weight boost of that
-- point).
--
-- See 'mkPerasWeightSnapshot' for context.
--
-- >>> :{
-- weights :: [(Point Blk, PerasWeight)]
-- weights =
--   [ (BlockPoint 2 "foo", PerasWeight 2)
--   , (GenesisPoint,       PerasWeight 3)
--   , (BlockPoint 3 "bar", PerasWeight 2)
--   , (BlockPoint 2 "foo", PerasWeight 2)
--   ]
-- snap = mkPerasWeightSnapshot weights
-- foo = HeaderFields (SlotNo 2) (BlockNo 1) "foo"
-- bar = HeaderFields (SlotNo 3) (BlockNo 2) "bar"
-- frag :: AnchoredFragment (HeaderFields Blk)
-- frag = Empty AnchorGenesis :> foo :> bar
-- :}
--
-- >>> k1 = SecurityParam $ knownNonZeroBounded @1
-- >>> k3 = SecurityParam $ knownNonZeroBounded @3
-- >>> k6 = SecurityParam $ knownNonZeroBounded @6
-- >>> k9 = SecurityParam $ knownNonZeroBounded @9
--
-- >>> AF.toOldestFirst $ takeVolatileSuffix snap k1 frag
-- []
--
-- >>> AF.toOldestFirst $ takeVolatileSuffix snap k3 frag
-- [HeaderFields {headerFieldSlot = SlotNo 3, headerFieldBlockNo = BlockNo 2, headerFieldHash = "bar"}]
--
-- >>> AF.toOldestFirst $ takeVolatileSuffix snap k6 frag
-- [HeaderFields {headerFieldSlot = SlotNo 3, headerFieldBlockNo = BlockNo 2, headerFieldHash = "bar"}]
--
-- >>> AF.toOldestFirst $ takeVolatileSuffix snap k9 frag
-- [HeaderFields {headerFieldSlot = SlotNo 2, headerFieldBlockNo = BlockNo 1, headerFieldHash = "foo"},HeaderFields {headerFieldSlot = SlotNo 3, headerFieldBlockNo = BlockNo 2, headerFieldHash = "bar"}]
takeVolatileSuffix ::
  forall blk h.
  (StandardHash blk, HasHeader h, HeaderHash blk ~ HeaderHash h) =>
  PerasWeightSnapshot blk ->
  -- | The security parameter @k@ is interpreted as a weight.
  SecurityParam ->
  AnchoredFragment h ->
  AnchoredFragment h
takeVolatileSuffix :: forall blk h.
(StandardHash blk, HasHeader h, HeaderHash blk ~ HeaderHash h) =>
PerasWeightSnapshot blk
-> SecurityParam -> AnchoredFragment h -> AnchoredFragment h
takeVolatileSuffix PerasWeightSnapshot blk
snap SecurityParam
secParam
  | Map (Point blk) PerasWeight -> Bool
forall k a. Map k a -> Bool
Map.null (Map (Point blk) PerasWeight -> Bool)
-> Map (Point blk) PerasWeight -> Bool
forall a b. (a -> b) -> a -> b
$ PerasWeightSnapshot blk -> Map (Point blk) PerasWeight
forall blk. PerasWeightSnapshot blk -> Map (Point blk) PerasWeight
getPerasWeightSnapshot PerasWeightSnapshot blk
snap =
      -- Optimize the case where Peras is disabled.
      Word64
-> AnchoredSeq (WithOrigin SlotNo) (Anchor h) h
-> AnchoredSeq (WithOrigin SlotNo) (Anchor h) h
forall v a b.
Anchorable v a b =>
Word64 -> AnchoredSeq v a b -> AnchoredSeq v a b
AF.anchorNewest (PerasWeight -> Word64
unPerasWeight PerasWeight
k)
  | Bool
otherwise =
      (AnchoredSeq (WithOrigin SlotNo) (Anchor h) h -> PerasWeight)
-> (PerasWeight -> Bool)
-> AnchoredSeq (WithOrigin SlotNo) (Anchor h) h
-> AnchoredSeq (WithOrigin SlotNo) (Anchor h) h
forall s v a b.
(Monoid s, Anchorable v a b) =>
(AnchoredSeq v a b -> s)
-> (s -> Bool) -> AnchoredSeq v a b -> AnchoredSeq v a b
takeLongestSuffix (PerasWeightSnapshot blk
-> AnchoredSeq (WithOrigin SlotNo) (Anchor h) h -> PerasWeight
forall blk h.
(StandardHash blk, HasHeader h, HeaderHash blk ~ HeaderHash h) =>
PerasWeightSnapshot blk -> AnchoredFragment h -> PerasWeight
totalWeightOfFragment PerasWeightSnapshot blk
snap) (PerasWeight -> PerasWeight -> Bool
forall a. Ord a => a -> a -> Bool
<= PerasWeight
k)
 where
  k :: PerasWeight
  k :: PerasWeight
k = SecurityParam -> PerasWeight
maxRollbackWeight SecurityParam
secParam

-- $setup
-- >>> import Cardano.Ledger.BaseTypes
-- >>> import Ouroboros.Consensus.Block
-- >>> import Ouroboros.Consensus.Config.SecurityParam
-- >>> import Ouroboros.Network.AnchoredFragment (AnchoredFragment, AnchoredSeq(..), Anchor(..))
-- >>> import qualified Ouroboros.Network.AnchoredFragment as AF
-- >>> :set -XDataKinds -XTypeApplications -XTypeFamilies
-- >>> data Blk = Blk
-- >>> type instance HeaderHash Blk = String
-- >>> instance StandardHash Blk