{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}

-- | Intended for qualified import
--
-- > import Ouroboros.Consensus.Fragment.ValidatedDiff (ValidatedChainDiff (..))
-- > import qualified Ouroboros.Consensus.Fragment.ValidatedDiff as ValidatedDiff
module Ouroboros.Consensus.Fragment.ValidatedDiff
  ( ValidatedChainDiff (ValidatedChainDiff)
  , getChainDiff
  , getLedger
  , new
  , rollbackExceedsSuffix
  , toValidatedFragment

    -- * Monadic
  , newM
  , toValidatedFragmentM
  ) where

import Control.Monad.Except (throwError)
import GHC.Stack (HasCallStack)
import Ouroboros.Consensus.Block
import Ouroboros.Consensus.Fragment.Diff (ChainDiff)
import qualified Ouroboros.Consensus.Fragment.Diff as Diff
import Ouroboros.Consensus.Fragment.Validated (ValidatedFragment)
import qualified Ouroboros.Consensus.Fragment.Validated as VF
import Ouroboros.Consensus.Ledger.Abstract
import Ouroboros.Consensus.Util.Assert
import Ouroboros.Consensus.Util.IOLike (MonadSTM (..))

-- | A 'ChainDiff' along with the ledger state after validation.
--
-- INVARIANT:
--
-- > getTip chainDiff == ledgerTipPoint ledger
--
-- The invariant is only checked on construction, maintaining it afterwards is
-- up to the user.
data ValidatedChainDiff b l = UnsafeValidatedChainDiff
  { forall b l. ValidatedChainDiff b l -> ChainDiff b
getChainDiff :: ChainDiff b
  , forall b l. ValidatedChainDiff b l -> l
getLedger :: l
  }

-- | Allow for pattern matching on a 'ValidatedChainDiff' without exposing the
-- (unsafe) constructor. Use 'new' to construct a 'ValidatedChainDiff'.
pattern ValidatedChainDiff ::
  ChainDiff b -> l -> ValidatedChainDiff b l
pattern $mValidatedChainDiff :: forall {r} {b} {l}.
ValidatedChainDiff b l
-> (ChainDiff b -> l -> r) -> ((# #) -> r) -> r
ValidatedChainDiff d l <- UnsafeValidatedChainDiff d l

{-# COMPLETE ValidatedChainDiff #-}

-- | Create a 'ValidatedChainDiff'.
--
-- PRECONDITION:
--
-- > getTip chainDiff == ledgerTipPoint ledger
new ::
  forall b l mk.
  (GetTip l, HasHeader b, HeaderHash l ~ HeaderHash b, HasCallStack) =>
  ChainDiff b ->
  l mk ->
  ValidatedChainDiff b (l mk)
new :: forall b (l :: LedgerStateKind) (mk :: MapKind).
(GetTip l, HasHeader b, HeaderHash l ~ HeaderHash b,
 HasCallStack) =>
ChainDiff b -> l mk -> ValidatedChainDiff b (l mk)
new ChainDiff b
chainDiff l mk
ledger =
  Either String ()
-> ValidatedChainDiff b (l mk) -> ValidatedChainDiff b (l mk)
forall a. HasCallStack => Either String () -> a -> a
assertWithMsg (Point l -> ChainDiff b -> Either String ()
forall {k} (l :: k) b.
(HeaderHash b ~ HeaderHash l, HasHeader b) =>
Point l -> ChainDiff b -> Either String ()
pointInvariant (l mk -> Point l
forall (mk :: MapKind). l mk -> Point l
forall (l :: LedgerStateKind) (mk :: MapKind).
GetTip l =>
l mk -> Point l
getTip l mk
ledger) ChainDiff b
chainDiff) (ValidatedChainDiff b (l mk) -> ValidatedChainDiff b (l mk))
-> ValidatedChainDiff b (l mk) -> ValidatedChainDiff b (l mk)
forall a b. (a -> b) -> a -> b
$
    ChainDiff b -> l mk -> ValidatedChainDiff b (l mk)
forall b l. ChainDiff b -> l -> ValidatedChainDiff b l
UnsafeValidatedChainDiff ChainDiff b
chainDiff l mk
ledger

pointInvariant ::
  forall l b.
  (HeaderHash b ~ HeaderHash l, HasHeader b) =>
  Point l ->
  ChainDiff b ->
  Either String ()
pointInvariant :: forall {k} (l :: k) b.
(HeaderHash b ~ HeaderHash l, HasHeader b) =>
Point l -> ChainDiff b -> Either String ()
pointInvariant Point l
ledgerTip0 ChainDiff b
chainDiff = Either String ()
precondition
 where
  chainDiffTip, ledgerTip :: Point b
  chainDiffTip :: Point b
chainDiffTip = Point b -> Point b
forall {k1} {k2} (b :: k1) (b' :: k2).
Coercible (HeaderHash b) (HeaderHash b') =>
Point b -> Point b'
castPoint (Point b -> Point b) -> Point b -> Point b
forall a b. (a -> b) -> a -> b
$ ChainDiff b -> Point b
forall b. HasHeader b => ChainDiff b -> Point b
Diff.getTip ChainDiff b
chainDiff
  ledgerTip :: Point b
ledgerTip = Point l -> Point b
forall {k1} {k2} (b :: k1) (b' :: k2).
Coercible (HeaderHash b) (HeaderHash b') =>
Point b -> Point b'
castPoint Point l
ledgerTip0
  precondition :: Either String ()
precondition
    | Point b
chainDiffTip Point b -> Point b -> Bool
forall a. Eq a => a -> a -> Bool
== Point b
ledgerTip =
        () -> Either String ()
forall a. a -> Either String a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    | Bool
otherwise =
        String -> Either String ()
forall a. String -> Either String a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String -> Either String ()) -> String -> Either String ()
forall a b. (a -> b) -> a -> b
$
          String
"tip of ChainDiff doesn't match ledger: "
            String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Point b -> String
forall a. Show a => a -> String
show Point b
chainDiffTip
            String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" /= "
            String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Point b -> String
forall a. Show a => a -> String
show Point b
ledgerTip

toValidatedFragment ::
  (GetTip l, HasHeader b, HeaderHash l ~ HeaderHash b, HasCallStack) =>
  ValidatedChainDiff b (l mk) ->
  ValidatedFragment b (l mk)
toValidatedFragment :: forall (l :: LedgerStateKind) b (mk :: MapKind).
(GetTip l, HasHeader b, HeaderHash l ~ HeaderHash b,
 HasCallStack) =>
ValidatedChainDiff b (l mk) -> ValidatedFragment b (l mk)
toValidatedFragment (UnsafeValidatedChainDiff ChainDiff b
cs l mk
l) =
  AnchoredFragment b -> l mk -> ValidatedFragment b (l mk)
forall (l :: LedgerStateKind) b (mk :: MapKind).
(GetTip l, HasHeader b, HeaderHash b ~ HeaderHash l,
 HasCallStack) =>
AnchoredFragment b -> l mk -> ValidatedFragment b (l mk)
VF.ValidatedFragment (ChainDiff b -> AnchoredFragment b
forall b. ChainDiff b -> AnchoredFragment b
Diff.getSuffix ChainDiff b
cs) l mk
l

rollbackExceedsSuffix :: HasHeader b => ValidatedChainDiff b l -> Bool
rollbackExceedsSuffix :: forall b l. HasHeader b => ValidatedChainDiff b l -> Bool
rollbackExceedsSuffix = ChainDiff b -> Bool
forall b. HasHeader b => ChainDiff b -> Bool
Diff.rollbackExceedsSuffix (ChainDiff b -> Bool)
-> (ValidatedChainDiff b l -> ChainDiff b)
-> ValidatedChainDiff b l
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ValidatedChainDiff b l -> ChainDiff b
forall b l. ValidatedChainDiff b l -> ChainDiff b
getChainDiff

{-------------------------------------------------------------------------------
  Monadic
-------------------------------------------------------------------------------}

newM ::
  forall m b l.
  ( MonadSTM m
  , GetTipSTM m l
  , HasHeader b
  , HeaderHash l ~ HeaderHash b
  , HasCallStack
  ) =>
  ChainDiff b ->
  l ->
  m (ValidatedChainDiff b l)
newM :: forall (m :: * -> *) b l.
(MonadSTM m, GetTipSTM m l, HasHeader b,
 HeaderHash l ~ HeaderHash b, HasCallStack) =>
ChainDiff b -> l -> m (ValidatedChainDiff b l)
newM ChainDiff b
chainDiff l
ledger = do
  ledgerTip <- l -> m (Point l)
forall (m :: * -> *) l.
(GetTipSTM m l, MonadSTM m) =>
l -> m (Point l)
getTipM l
ledger
  pure $
    assertWithMsg (pointInvariant ledgerTip chainDiff) $
      UnsafeValidatedChainDiff chainDiff ledger

toValidatedFragmentM ::
  ( MonadSTM m
  , GetTipSTM m l
  , HasHeader b
  , HeaderHash l ~ HeaderHash b
  , HasCallStack
  ) =>
  ValidatedChainDiff b l ->
  m (ValidatedFragment b l)
toValidatedFragmentM :: forall (m :: * -> *) l b.
(MonadSTM m, GetTipSTM m l, HasHeader b,
 HeaderHash l ~ HeaderHash b, HasCallStack) =>
ValidatedChainDiff b l -> m (ValidatedFragment b l)
toValidatedFragmentM (UnsafeValidatedChainDiff ChainDiff b
cs l
l) =
  AnchoredFragment b -> l -> m (ValidatedFragment b l)
forall (m :: * -> *) l b.
(MonadSTM m, GetTipSTM m l, HasHeader b,
 HeaderHash b ~ HeaderHash l, HasCallStack) =>
AnchoredFragment b -> l -> m (ValidatedFragment b l)
VF.newM (ChainDiff b -> AnchoredFragment b
forall b. ChainDiff b -> AnchoredFragment b
Diff.getSuffix ChainDiff b
cs) l
l