{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}

-- | Intended for qualified import
--
-- > import Ouroboros.Consensus.Fragment.ValidatedDiff (ValidatedChainDiff (..))
-- > import qualified Ouroboros.Consensus.Fragment.ValidatedDiff as ValidatedDiff
module Ouroboros.Consensus.Fragment.ValidatedDiff (
    ValidatedChainDiff (ValidatedChainDiff)
  , getChainDiff
  , getLedger
  , new
  , rollbackExceedsSuffix
  , toValidatedFragment
    -- * Monadic
  , newM
  , toValidatedFragmentM
  ) where

import           Control.Monad.Except (throwError)
import           GHC.Stack (HasCallStack)
import           Ouroboros.Consensus.Block
import           Ouroboros.Consensus.Fragment.Diff (ChainDiff)
import qualified Ouroboros.Consensus.Fragment.Diff as Diff
import           Ouroboros.Consensus.Fragment.Validated (ValidatedFragment)
import qualified Ouroboros.Consensus.Fragment.Validated as VF
import           Ouroboros.Consensus.Ledger.Abstract
import           Ouroboros.Consensus.Util.Assert
import           Ouroboros.Consensus.Util.IOLike (MonadSTM (..))

-- | A 'ChainDiff' along with the ledger state after validation.
--
-- INVARIANT:
--
-- > getTip chainDiff == ledgerTipPoint ledger
--
-- The invariant is only checked on construction, maintaining it afterwards is
-- up to the user.
data ValidatedChainDiff b l = UnsafeValidatedChainDiff
    { forall b l. ValidatedChainDiff b l -> ChainDiff b
getChainDiff :: ChainDiff b
    , forall b l. ValidatedChainDiff b l -> l
getLedger    :: l
    }

-- | Allow for pattern matching on a 'ValidatedChainDiff' without exposing the
-- (unsafe) constructor. Use 'new' to construct a 'ValidatedChainDiff'.
pattern ValidatedChainDiff
  :: ChainDiff b -> l -> ValidatedChainDiff b l
pattern $mValidatedChainDiff :: forall {r} {b} {l}.
ValidatedChainDiff b l
-> (ChainDiff b -> l -> r) -> ((# #) -> r) -> r
ValidatedChainDiff d l <- UnsafeValidatedChainDiff d l
{-# COMPLETE ValidatedChainDiff #-}

-- | Create a 'ValidatedChainDiff'.
--
-- PRECONDITION:
--
-- > getTip chainDiff == ledgerTipPoint ledger
new ::
     forall b l mk. (GetTip l, HasHeader b, HeaderHash l ~ HeaderHash b, HasCallStack)
  => ChainDiff b
  -> l mk
  -> ValidatedChainDiff b (l mk)
new :: forall b (l :: LedgerStateKind) (mk :: MapKind).
(GetTip l, HasHeader b, HeaderHash l ~ HeaderHash b,
 HasCallStack) =>
ChainDiff b -> l mk -> ValidatedChainDiff b (l mk)
new ChainDiff b
chainDiff l mk
ledger =
    Either String ()
-> ValidatedChainDiff b (l mk) -> ValidatedChainDiff b (l mk)
forall a. HasCallStack => Either String () -> a -> a
assertWithMsg (Point l -> ChainDiff b -> Either String ()
forall {k} (l :: k) b.
(HeaderHash b ~ HeaderHash l, HasHeader b) =>
Point l -> ChainDiff b -> Either String ()
pointInvariant (l mk -> Point l
forall (mk :: MapKind). l mk -> Point l
forall (l :: LedgerStateKind) (mk :: MapKind).
GetTip l =>
l mk -> Point l
getTip l mk
ledger) ChainDiff b
chainDiff) (ValidatedChainDiff b (l mk) -> ValidatedChainDiff b (l mk))
-> ValidatedChainDiff b (l mk) -> ValidatedChainDiff b (l mk)
forall a b. (a -> b) -> a -> b
$
    ChainDiff b -> l mk -> ValidatedChainDiff b (l mk)
forall b l. ChainDiff b -> l -> ValidatedChainDiff b l
UnsafeValidatedChainDiff ChainDiff b
chainDiff l mk
ledger

pointInvariant ::
     forall l b. (HeaderHash b ~ HeaderHash l, HasHeader b)
  => Point l
  -> ChainDiff b
  -> Either String ()
pointInvariant :: forall {k} (l :: k) b.
(HeaderHash b ~ HeaderHash l, HasHeader b) =>
Point l -> ChainDiff b -> Either String ()
pointInvariant Point l
ledgerTip0 ChainDiff b
chainDiff = Either String ()
precondition
  where
    chainDiffTip, ledgerTip :: Point b
    chainDiffTip :: Point b
chainDiffTip = Point b -> Point b
forall {k1} {k2} (b :: k1) (b' :: k2).
Coercible (HeaderHash b) (HeaderHash b') =>
Point b -> Point b'
castPoint (Point b -> Point b) -> Point b -> Point b
forall a b. (a -> b) -> a -> b
$ ChainDiff b -> Point b
forall b. HasHeader b => ChainDiff b -> Point b
Diff.getTip ChainDiff b
chainDiff
    ledgerTip :: Point b
ledgerTip    = Point l -> Point b
forall {k1} {k2} (b :: k1) (b' :: k2).
Coercible (HeaderHash b) (HeaderHash b') =>
Point b -> Point b'
castPoint Point l
ledgerTip0
    precondition :: Either String ()
precondition
      | Point b
chainDiffTip Point b -> Point b -> Bool
forall a. Eq a => a -> a -> Bool
== Point b
ledgerTip
      = () -> Either String ()
forall a. a -> Either String a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      | Bool
otherwise
      = String -> Either String ()
forall a. String -> Either String a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String -> Either String ()) -> String -> Either String ()
forall a b. (a -> b) -> a -> b
$
          String
"tip of ChainDiff doesn't match ledger: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<>
          Point b -> String
forall a. Show a => a -> String
show Point b
chainDiffTip String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" /= " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Point b -> String
forall a. Show a => a -> String
show Point b
ledgerTip

toValidatedFragment ::
     (GetTip l, HasHeader b, HeaderHash l ~ HeaderHash b, HasCallStack)
  => ValidatedChainDiff b (l mk)
  -> ValidatedFragment b (l mk)
toValidatedFragment :: forall (l :: LedgerStateKind) b (mk :: MapKind).
(GetTip l, HasHeader b, HeaderHash l ~ HeaderHash b,
 HasCallStack) =>
ValidatedChainDiff b (l mk) -> ValidatedFragment b (l mk)
toValidatedFragment (UnsafeValidatedChainDiff ChainDiff b
cs l mk
l) =
    AnchoredFragment b -> l mk -> ValidatedFragment b (l mk)
forall (l :: LedgerStateKind) b (mk :: MapKind).
(GetTip l, HasHeader b, HeaderHash b ~ HeaderHash l,
 HasCallStack) =>
AnchoredFragment b -> l mk -> ValidatedFragment b (l mk)
VF.ValidatedFragment (ChainDiff b -> AnchoredFragment b
forall b. ChainDiff b -> AnchoredFragment b
Diff.getSuffix ChainDiff b
cs) l mk
l

rollbackExceedsSuffix :: HasHeader b => ValidatedChainDiff b l -> Bool
rollbackExceedsSuffix :: forall b l. HasHeader b => ValidatedChainDiff b l -> Bool
rollbackExceedsSuffix = ChainDiff b -> Bool
forall b. HasHeader b => ChainDiff b -> Bool
Diff.rollbackExceedsSuffix (ChainDiff b -> Bool)
-> (ValidatedChainDiff b l -> ChainDiff b)
-> ValidatedChainDiff b l
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ValidatedChainDiff b l -> ChainDiff b
forall b l. ValidatedChainDiff b l -> ChainDiff b
getChainDiff

{-------------------------------------------------------------------------------
  Monadic
-------------------------------------------------------------------------------}

newM ::
     forall m b l. (
       MonadSTM m, GetTipSTM m l, HasHeader b, HeaderHash l ~ HeaderHash b
     , HasCallStack
     )
  => ChainDiff b
  -> l
  -> m (ValidatedChainDiff b l)
newM :: forall (m :: * -> *) b l.
(MonadSTM m, GetTipSTM m l, HasHeader b,
 HeaderHash l ~ HeaderHash b, HasCallStack) =>
ChainDiff b -> l -> m (ValidatedChainDiff b l)
newM ChainDiff b
chainDiff l
ledger = do
    ledgerTip <- l -> m (Point l)
forall (m :: * -> *) l.
(GetTipSTM m l, MonadSTM m) =>
l -> m (Point l)
getTipM l
ledger
    pure $ assertWithMsg  (pointInvariant ledgerTip chainDiff)
         $ UnsafeValidatedChainDiff chainDiff ledger

toValidatedFragmentM ::
     ( MonadSTM m, GetTipSTM m l, HasHeader b, HeaderHash l ~ HeaderHash b
     , HasCallStack
     )
  => ValidatedChainDiff b l
  -> m (ValidatedFragment b l)
toValidatedFragmentM :: forall (m :: * -> *) l b.
(MonadSTM m, GetTipSTM m l, HasHeader b,
 HeaderHash l ~ HeaderHash b, HasCallStack) =>
ValidatedChainDiff b l -> m (ValidatedFragment b l)
toValidatedFragmentM (UnsafeValidatedChainDiff ChainDiff b
cs l
l) =
    AnchoredFragment b -> l -> m (ValidatedFragment b l)
forall (m :: * -> *) l b.
(MonadSTM m, GetTipSTM m l, HasHeader b,
 HeaderHash b ~ HeaderHash l, HasCallStack) =>
AnchoredFragment b -> l -> m (ValidatedFragment b l)
VF.newM (ChainDiff b -> AnchoredFragment b
forall b. ChainDiff b -> AnchoredFragment b
Diff.getSuffix ChainDiff b
cs) l
l