{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}

-- | Intended for qualified import
--
-- > import Ouroboros.Consensus.Fragment.ValidatedDiff (ValidatedChainDiff (..))
-- > import qualified Ouroboros.Consensus.Fragment.ValidatedDiff as ValidatedDiff
module Ouroboros.Consensus.Fragment.ValidatedDiff (
    ValidatedChainDiff (ValidatedChainDiff)
  , getChainDiff
  , getLedger
  , new
  , rollbackExceedsSuffix
  , toValidatedFragment
  ) where

import           Control.Monad.Except (throwError)
import           GHC.Stack (HasCallStack)
import           Ouroboros.Consensus.Block
import           Ouroboros.Consensus.Fragment.Diff (ChainDiff)
import qualified Ouroboros.Consensus.Fragment.Diff as Diff
import           Ouroboros.Consensus.Fragment.Validated (ValidatedFragment)
import qualified Ouroboros.Consensus.Fragment.Validated as VF
import           Ouroboros.Consensus.Ledger.Abstract
import           Ouroboros.Consensus.Util.Assert

-- | A 'ChainDiff' along with the ledger state after validation.
--
-- INVARIANT:
--
-- > getTip chainDiff == ledgerTipPoint ledger
data ValidatedChainDiff b l = UnsafeValidatedChainDiff
    { forall b l. ValidatedChainDiff b l -> ChainDiff b
getChainDiff :: ChainDiff b
    , forall b l. ValidatedChainDiff b l -> l
getLedger    :: l
    }

-- | Allow for pattern matching on a 'ValidatedChainDiff' without exposing the
-- (unsafe) constructor. Use 'new' to construct a 'ValidatedChainDiff'.
pattern ValidatedChainDiff
  :: ChainDiff b -> l -> ValidatedChainDiff b l
pattern $mValidatedChainDiff :: forall {r} {b} {l}.
ValidatedChainDiff b l
-> (ChainDiff b -> l -> r) -> ((# #) -> r) -> r
ValidatedChainDiff d l <- UnsafeValidatedChainDiff d l
{-# COMPLETE ValidatedChainDiff #-}

-- | Create a 'ValidatedChainDiff'.
--
-- PRECONDITION:
--
-- > getTip chainDiff == ledgerTipPoint ledger
new ::
     forall b l.
     (GetTip l, HasHeader b, HeaderHash l ~ HeaderHash b, HasCallStack)
  => ChainDiff b
  -> l
  -> ValidatedChainDiff b l
new :: forall b l.
(GetTip l, HasHeader b, HeaderHash l ~ HeaderHash b,
 HasCallStack) =>
ChainDiff b -> l -> ValidatedChainDiff b l
new ChainDiff b
chainDiff l
ledger =
    Either String ()
-> ValidatedChainDiff b l -> ValidatedChainDiff b l
forall a. HasCallStack => Either String () -> a -> a
assertWithMsg Either String ()
precondition (ValidatedChainDiff b l -> ValidatedChainDiff b l)
-> ValidatedChainDiff b l -> ValidatedChainDiff b l
forall a b. (a -> b) -> a -> b
$
    ChainDiff b -> l -> ValidatedChainDiff b l
forall b l. ChainDiff b -> l -> ValidatedChainDiff b l
UnsafeValidatedChainDiff ChainDiff b
chainDiff l
ledger
  where
    chainDiffTip, ledgerTip :: Point b
    chainDiffTip :: Point b
chainDiffTip = ChainDiff b -> Point b
forall b. HasHeader b => ChainDiff b -> Point b
Diff.getTip ChainDiff b
chainDiff
    ledgerTip :: Point b
ledgerTip    = Point l -> Point b
forall {k1} {k2} (b :: k1) (b' :: k2).
Coercible (HeaderHash b) (HeaderHash b') =>
Point b -> Point b'
castPoint (Point l -> Point b) -> Point l -> Point b
forall a b. (a -> b) -> a -> b
$ l -> Point l
forall l. GetTip l => l -> Point l
getTip l
ledger
    precondition :: Either String ()
precondition
      | Point b
chainDiffTip Point b -> Point b -> Bool
forall a. Eq a => a -> a -> Bool
== Point b
ledgerTip
      = () -> Either String ()
forall a. a -> Either String a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      | Bool
otherwise
      = String -> Either String ()
forall a. String -> Either String a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String -> Either String ()) -> String -> Either String ()
forall a b. (a -> b) -> a -> b
$
          String
"tip of ChainDiff doesn't match ledger: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<>
          Point b -> String
forall a. Show a => a -> String
show Point b
chainDiffTip String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" /= " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Point b -> String
forall a. Show a => a -> String
show Point b
ledgerTip

toValidatedFragment ::
     (GetTip l, HasHeader b, HeaderHash l ~ HeaderHash b, HasCallStack)
  => ValidatedChainDiff b l
  -> ValidatedFragment b l
toValidatedFragment :: forall l b.
(GetTip l, HasHeader b, HeaderHash l ~ HeaderHash b,
 HasCallStack) =>
ValidatedChainDiff b l -> ValidatedFragment b l
toValidatedFragment (UnsafeValidatedChainDiff ChainDiff b
cs l
l) =
    AnchoredFragment b -> l -> ValidatedFragment b l
forall l b.
(GetTip l, HasHeader b, HeaderHash b ~ HeaderHash l,
 HasCallStack) =>
AnchoredFragment b -> l -> ValidatedFragment b l
VF.ValidatedFragment (ChainDiff b -> AnchoredFragment b
forall b. ChainDiff b -> AnchoredFragment b
Diff.getSuffix ChainDiff b
cs) l
l

rollbackExceedsSuffix :: HasHeader b => ValidatedChainDiff b l -> Bool
rollbackExceedsSuffix :: forall b l. HasHeader b => ValidatedChainDiff b l -> Bool
rollbackExceedsSuffix = ChainDiff b -> Bool
forall b. HasHeader b => ChainDiff b -> Bool
Diff.rollbackExceedsSuffix (ChainDiff b -> Bool)
-> (ValidatedChainDiff b l -> ChainDiff b)
-> ValidatedChainDiff b l
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ValidatedChainDiff b l -> ChainDiff b
forall b l. ValidatedChainDiff b l -> ChainDiff b
getChainDiff