{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}

-- | Intended for qualified import
--
-- > import Ouroboros.Consensus.Fragment.ValidatedDiff (ValidatedChainDiff (..))
-- > import qualified Ouroboros.Consensus.Fragment.ValidatedDiff as ValidatedDiff
module Ouroboros.Consensus.Fragment.ValidatedDiff
  ( ValidatedChainDiff (ValidatedChainDiff)
  , getChainDiff
  , getLedger

    -- * Monadic
  , newM
  ) where

import Control.Monad.Except (throwError)
import GHC.Stack (HasCallStack)
import Ouroboros.Consensus.Block
import Ouroboros.Consensus.Fragment.Diff (ChainDiff)
import qualified Ouroboros.Consensus.Fragment.Diff as Diff
import Ouroboros.Consensus.Ledger.Abstract
import Ouroboros.Consensus.Util.Assert
import Ouroboros.Consensus.Util.IOLike (MonadSTM (..))

-- | A 'ChainDiff' along with the ledger state after validation.
--
-- INVARIANT:
--
-- > getTip chainDiff == ledgerTipPoint ledger
--
-- The invariant is only checked on construction, maintaining it afterwards is
-- up to the user.
data ValidatedChainDiff b l = UnsafeValidatedChainDiff
  { forall b l. ValidatedChainDiff b l -> ChainDiff b
getChainDiff :: ChainDiff b
  , forall b l. ValidatedChainDiff b l -> l
getLedger :: l
  }

-- | Allow for pattern matching on a 'ValidatedChainDiff' without exposing the
-- (unsafe) constructor. Use 'new' to construct a 'ValidatedChainDiff'.
pattern ValidatedChainDiff ::
  ChainDiff b -> l -> ValidatedChainDiff b l
pattern $mValidatedChainDiff :: forall {r} {b} {l}.
ValidatedChainDiff b l
-> (ChainDiff b -> l -> r) -> ((# #) -> r) -> r
ValidatedChainDiff d l <- UnsafeValidatedChainDiff d l

{-# COMPLETE ValidatedChainDiff #-}

pointInvariant ::
  forall l b.
  (HeaderHash b ~ HeaderHash l, HasHeader b) =>
  Point l ->
  ChainDiff b ->
  Either String ()
pointInvariant :: forall {k} (l :: k) b.
(HeaderHash b ~ HeaderHash l, HasHeader b) =>
Point l -> ChainDiff b -> Either String ()
pointInvariant Point l
ledgerTip0 ChainDiff b
chainDiff = Either String ()
precondition
 where
  chainDiffTip, ledgerTip :: Point b
  chainDiffTip :: Point b
chainDiffTip = Point b -> Point b
forall {k1} {k2} (b :: k1) (b' :: k2).
Coercible (HeaderHash b) (HeaderHash b') =>
Point b -> Point b'
castPoint (Point b -> Point b) -> Point b -> Point b
forall a b. (a -> b) -> a -> b
$ ChainDiff b -> Point b
forall b. HasHeader b => ChainDiff b -> Point b
Diff.getTip ChainDiff b
chainDiff
  ledgerTip :: Point b
ledgerTip = Point l -> Point b
forall {k1} {k2} (b :: k1) (b' :: k2).
Coercible (HeaderHash b) (HeaderHash b') =>
Point b -> Point b'
castPoint Point l
ledgerTip0
  precondition :: Either String ()
precondition
    | Point b
chainDiffTip Point b -> Point b -> Bool
forall a. Eq a => a -> a -> Bool
== Point b
ledgerTip =
        () -> Either String ()
forall a. a -> Either String a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    | Bool
otherwise =
        String -> Either String ()
forall a. String -> Either String a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String -> Either String ()) -> String -> Either String ()
forall a b. (a -> b) -> a -> b
$
          String
"tip of ChainDiff doesn't match ledger: "
            String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Point b -> String
forall a. Show a => a -> String
show Point b
chainDiffTip
            String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" /= "
            String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Point b -> String
forall a. Show a => a -> String
show Point b
ledgerTip

{-------------------------------------------------------------------------------
  Monadic
-------------------------------------------------------------------------------}

-- | Create a 'ValidatedChainDiff'.
--
-- PRECONDITION:
--
-- > getTip chainDiff == ledgerTipPoint ledger
newM ::
  forall m b l.
  ( MonadSTM m
  , GetTipSTM m l
  , HasHeader b
  , HeaderHash l ~ HeaderHash b
  , HasCallStack
  ) =>
  ChainDiff b ->
  l ->
  m (ValidatedChainDiff b l)
newM :: forall (m :: * -> *) b l.
(MonadSTM m, GetTipSTM m l, HasHeader b,
 HeaderHash l ~ HeaderHash b, HasCallStack) =>
ChainDiff b -> l -> m (ValidatedChainDiff b l)
newM ChainDiff b
chainDiff l
ledger = do
  ledgerTip <- l -> m (Point l)
forall (m :: * -> *) l.
(GetTipSTM m l, MonadSTM m) =>
l -> m (Point l)
getTipM l
ledger
  pure $
    assertWithMsg (pointInvariant ledgerTip chainDiff) $
      UnsafeValidatedChainDiff chainDiff ledger