{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RankNTypes #-}

-- | This module provides a watcher of the invariants that are specific to the
-- ChainSync jumping (CSJ) implementation. Those invariants are typically
-- documented in the codebase but are not checked in any way, yet they are
-- crucial for CSJ to work properly. This watcher monitors the ChainSync
-- handlers and throws a 'Violation' exception when an invariant stops holding.
-- It is intended for testing purposes.
module Test.Consensus.PeerSimulator.CSJInvariants (
    Violation
  , watcher
  ) where

import           Control.Monad (forM_, when)
import           Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import           Data.Typeable (Typeable)
import           Ouroboros.Consensus.Block (Point, StandardHash, castPoint)
import qualified Ouroboros.Consensus.MiniProtocol.ChainSync.Client.State as CSState
import           Ouroboros.Consensus.Util.IOLike (Exception, MonadSTM (STM),
                     MonadThrow (throwIO), StrictTVar, readTVar)
import           Ouroboros.Consensus.Util.STM (Watcher (..))

--------------------------------------------------------------------------------
-- Idealised view of the ChainSync client's state
--------------------------------------------------------------------------------

-- | Our idealised view of the ChainSync client's state with respect to
-- ChainSync jumping in particular.
type View peer blk = Map peer (State blk)

-- | Idealised version of 'ChainSyncJumpingState'.
data State blk
  = Dynamo
  | Objector
      -- | The point where the objector dissented with the dynamo when it was a
      -- jumper.
      !(Point blk)
  | Disengaged
  | Jumper !(JumperState blk)
  deriving (Int -> State blk -> ShowS
[State blk] -> ShowS
State blk -> String
(Int -> State blk -> ShowS)
-> (State blk -> String)
-> ([State blk] -> ShowS)
-> Show (State blk)
forall blk. StandardHash blk => Int -> State blk -> ShowS
forall blk. StandardHash blk => [State blk] -> ShowS
forall blk. StandardHash blk => State blk -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall blk. StandardHash blk => Int -> State blk -> ShowS
showsPrec :: Int -> State blk -> ShowS
$cshow :: forall blk. StandardHash blk => State blk -> String
show :: State blk -> String
$cshowList :: forall blk. StandardHash blk => [State blk] -> ShowS
showList :: [State blk] -> ShowS
Show, State blk -> State blk -> Bool
(State blk -> State blk -> Bool)
-> (State blk -> State blk -> Bool) -> Eq (State blk)
forall blk. StandardHash blk => State blk -> State blk -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall blk. StandardHash blk => State blk -> State blk -> Bool
== :: State blk -> State blk -> Bool
$c/= :: forall blk. StandardHash blk => State blk -> State blk -> Bool
/= :: State blk -> State blk -> Bool
Eq)

-- | Idealised version of 'ChainSyncJumpingJumperState'.
data JumperState blk
  = Happy
      -- | Latest accepted jump, if there is one
      !(Maybe (Point blk))
  | LookingForIntersection
      -- | Latest accepted jump
      !(Point blk)
      -- | Earliest rejected jump
      !(Point blk)
  | FoundIntersection
      -- | Latest accepted jump
      !(Point blk)
      -- | Earliest rejected jump
      !(Point blk)
  deriving (Int -> JumperState blk -> ShowS
[JumperState blk] -> ShowS
JumperState blk -> String
(Int -> JumperState blk -> ShowS)
-> (JumperState blk -> String)
-> ([JumperState blk] -> ShowS)
-> Show (JumperState blk)
forall blk. StandardHash blk => Int -> JumperState blk -> ShowS
forall blk. StandardHash blk => [JumperState blk] -> ShowS
forall blk. StandardHash blk => JumperState blk -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall blk. StandardHash blk => Int -> JumperState blk -> ShowS
showsPrec :: Int -> JumperState blk -> ShowS
$cshow :: forall blk. StandardHash blk => JumperState blk -> String
show :: JumperState blk -> String
$cshowList :: forall blk. StandardHash blk => [JumperState blk] -> ShowS
showList :: [JumperState blk] -> ShowS
Show, JumperState blk -> JumperState blk -> Bool
(JumperState blk -> JumperState blk -> Bool)
-> (JumperState blk -> JumperState blk -> Bool)
-> Eq (JumperState blk)
forall blk.
StandardHash blk =>
JumperState blk -> JumperState blk -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall blk.
StandardHash blk =>
JumperState blk -> JumperState blk -> Bool
== :: JumperState blk -> JumperState blk -> Bool
$c/= :: forall blk.
StandardHash blk =>
JumperState blk -> JumperState blk -> Bool
/= :: JumperState blk -> JumperState blk -> Bool
Eq)

--------------------------------------------------------------------------------
-- Invariants on views
--------------------------------------------------------------------------------

allInvariants :: [Invariant peer blk]
allInvariants :: forall peer blk. [Invariant peer blk]
allInvariants =
  [ Invariant peer blk
forall peer blk. Invariant peer blk
thereIsAlwaysOneDynamoUnlessDisengaged,
    Invariant peer blk
forall peer blk. Invariant peer blk
thereIsAlwaysAtMostOneObjector
  ]

thereIsAlwaysOneDynamoUnlessDisengaged :: Invariant peer blk
thereIsAlwaysOneDynamoUnlessDisengaged :: forall peer blk. Invariant peer blk
thereIsAlwaysOneDynamoUnlessDisengaged =
  Invariant
    { name :: String
name = String
"There is always one dynamo, unless all are disengaged",
      check :: View peer blk -> Bool
check = \View peer blk
view ->
        [State blk] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ((State blk -> Bool) -> [State blk] -> [State blk]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (State blk -> Bool) -> State blk -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. State blk -> Bool
forall blk. State blk -> Bool
isDisengaged) ([State blk] -> [State blk]) -> [State blk] -> [State blk]
forall a b. (a -> b) -> a -> b
$ View peer blk -> [State blk]
forall k a. Map k a -> [a]
Map.elems View peer blk
view)
          Bool -> Bool -> Bool
|| [State blk] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ((State blk -> Bool) -> [State blk] -> [State blk]
forall a. (a -> Bool) -> [a] -> [a]
filter State blk -> Bool
forall blk. State blk -> Bool
isDynamo ([State blk] -> [State blk]) -> [State blk] -> [State blk]
forall a b. (a -> b) -> a -> b
$ View peer blk -> [State blk]
forall k a. Map k a -> [a]
Map.elems View peer blk
view) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1
    }

thereIsAlwaysAtMostOneObjector :: Invariant peer blk
thereIsAlwaysAtMostOneObjector :: forall peer blk. Invariant peer blk
thereIsAlwaysAtMostOneObjector =
  Invariant
    { name :: String
name = String
"There is always at most one objector",
      check :: View peer blk -> Bool
check = \View peer blk
view ->
        [State blk] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ((State blk -> Bool) -> [State blk] -> [State blk]
forall a. (a -> Bool) -> [a] -> [a]
filter State blk -> Bool
forall blk. State blk -> Bool
isObjector ([State blk] -> [State blk]) -> [State blk] -> [State blk]
forall a b. (a -> b) -> a -> b
$ View peer blk -> [State blk]
forall k a. Map k a -> [a]
Map.elems View peer blk
view) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1
    }

--------------------------------------------------------------------------------
-- Helpers for the invariants
--------------------------------------------------------------------------------

isDynamo :: State blk -> Bool
isDynamo :: forall blk. State blk -> Bool
isDynamo (Dynamo {}) = Bool
True
isDynamo State blk
_           = Bool
False

isObjector :: State blk -> Bool
isObjector :: forall blk. State blk -> Bool
isObjector (Objector {}) = Bool
True
isObjector State blk
_             = Bool
False

isDisengaged :: State blk -> Bool
isDisengaged :: forall blk. State blk -> Bool
isDisengaged (Disengaged {}) = Bool
True
isDisengaged State blk
_               = Bool
False

--------------------------------------------------------------------------------
-- Invariant enforcement implementation
--------------------------------------------------------------------------------

readAndView ::
  forall m peer blk.
  ( MonadSTM m
  ) =>
  StrictTVar m (Map peer (CSState.ChainSyncClientHandle m blk)) ->
  STM m (View peer blk)
readAndView :: forall (m :: * -> *) peer blk.
MonadSTM m =>
StrictTVar m (Map peer (ChainSyncClientHandle m blk))
-> STM m (View peer blk)
readAndView StrictTVar m (Map peer (ChainSyncClientHandle m blk))
handles =
  (ChainSyncClientHandle m blk -> STM m (State blk))
-> Map peer (ChainSyncClientHandle m blk)
-> STM m (Map peer (State blk))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Map peer a -> f (Map peer b)
traverse ((ChainSyncJumpingState m blk -> State blk)
-> STM m (ChainSyncJumpingState m blk) -> STM m (State blk)
forall a b. (a -> b) -> STM m a -> STM m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ChainSyncJumpingState m blk -> State blk
forall (m :: * -> *) blk. ChainSyncJumpingState m blk -> State blk
idealiseState (STM m (ChainSyncJumpingState m blk) -> STM m (State blk))
-> (ChainSyncClientHandle m blk
    -> STM m (ChainSyncJumpingState m blk))
-> ChainSyncClientHandle m blk
-> STM m (State blk)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StrictTVar m (ChainSyncJumpingState m blk)
-> STM m (ChainSyncJumpingState m blk)
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar (StrictTVar m (ChainSyncJumpingState m blk)
 -> STM m (ChainSyncJumpingState m blk))
-> (ChainSyncClientHandle m blk
    -> StrictTVar m (ChainSyncJumpingState m blk))
-> ChainSyncClientHandle m blk
-> STM m (ChainSyncJumpingState m blk)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ChainSyncClientHandle m blk
-> StrictTVar m (ChainSyncJumpingState m blk)
forall (m :: * -> *) blk.
ChainSyncClientHandle m blk
-> StrictTVar m (ChainSyncJumpingState m blk)
CSState.cschJumping) (Map peer (ChainSyncClientHandle m blk)
 -> STM m (Map peer (State blk)))
-> STM m (Map peer (ChainSyncClientHandle m blk))
-> STM m (Map peer (State blk))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< StrictTVar m (Map peer (ChainSyncClientHandle m blk))
-> STM m (Map peer (ChainSyncClientHandle m blk))
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StrictTVar m (Map peer (ChainSyncClientHandle m blk))
handles
  where
    -- Idealise the state of a ChainSync peer with respect to ChainSync jumping.
    -- In particular, we get rid of non-comparable information such as the TVars
    -- it may contain.
    idealiseState :: CSState.ChainSyncJumpingState m blk -> State blk
    idealiseState :: forall (m :: * -> *) blk. ChainSyncJumpingState m blk -> State blk
idealiseState (CSState.Dynamo {}) = State blk
forall blk. State blk
Dynamo
    idealiseState (CSState.Objector ObjectorInitState
_ JumpInfo blk
point Point (Header blk)
_) = Point blk -> State blk
forall blk. Point blk -> State blk
Objector (Point blk -> State blk) -> Point blk -> State blk
forall a b. (a -> b) -> a -> b
$ JumpInfo blk -> Point blk
forall blk. JumpInfo blk -> Point blk
idealiseJumpInfo JumpInfo blk
point
    idealiseState (CSState.Disengaged DisengagedInitState
_) = State blk
forall blk. State blk
Disengaged
    idealiseState (CSState.Jumper StrictTVar m (Maybe (JumpInfo blk))
_ ChainSyncJumpingJumperState blk
state) = JumperState blk -> State blk
forall blk. JumperState blk -> State blk
Jumper (JumperState blk -> State blk) -> JumperState blk -> State blk
forall a b. (a -> b) -> a -> b
$ ChainSyncJumpingJumperState blk -> JumperState blk
forall blk. ChainSyncJumpingJumperState blk -> JumperState blk
idealiseJumperState ChainSyncJumpingJumperState blk
state
    -- Idealise the jumper state by stripping away everything that is more of a
    -- technical necessity and not actually relevant for the invariants.
    idealiseJumperState :: CSState.ChainSyncJumpingJumperState blk -> JumperState blk
    idealiseJumperState :: forall blk. ChainSyncJumpingJumperState blk -> JumperState blk
idealiseJumperState (CSState.Happy JumperInitState
_ Maybe (JumpInfo blk)
lastAccepted) = Maybe (Point blk) -> JumperState blk
forall blk. Maybe (Point blk) -> JumperState blk
Happy (Maybe (Point blk) -> JumperState blk)
-> Maybe (Point blk) -> JumperState blk
forall a b. (a -> b) -> a -> b
$ JumpInfo blk -> Point blk
forall blk. JumpInfo blk -> Point blk
idealiseJumpInfo (JumpInfo blk -> Point blk)
-> Maybe (JumpInfo blk) -> Maybe (Point blk)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe (JumpInfo blk)
lastAccepted
    idealiseJumperState (CSState.LookingForIntersection JumpInfo blk
lastAccepted JumpInfo blk
firstRejected) =
      Point blk -> Point blk -> JumperState blk
forall blk. Point blk -> Point blk -> JumperState blk
LookingForIntersection (JumpInfo blk -> Point blk
forall blk. JumpInfo blk -> Point blk
idealiseJumpInfo JumpInfo blk
lastAccepted) (JumpInfo blk -> Point blk
forall blk. JumpInfo blk -> Point blk
idealiseJumpInfo JumpInfo blk
firstRejected)
    idealiseJumperState (CSState.FoundIntersection ObjectorInitState
_ JumpInfo blk
lastAccepted Point (Header blk)
firstRejected) =
      Point blk -> Point blk -> JumperState blk
forall blk. Point blk -> Point blk -> JumperState blk
FoundIntersection (JumpInfo blk -> Point blk
forall blk. JumpInfo blk -> Point blk
idealiseJumpInfo JumpInfo blk
lastAccepted) (Point (Header blk) -> Point blk
forall {k1} {k2} (b :: k1) (b' :: k2).
Coercible (HeaderHash b) (HeaderHash b') =>
Point b -> Point b'
castPoint Point (Header blk)
firstRejected)
    -- Jumpers actually carry a lot of information regarding the jump. From our
    -- idealised point of view, we only care about the points where the jumpers
    -- agree or disagree with the dynamo.
    idealiseJumpInfo :: CSState.JumpInfo blk -> Point blk
    idealiseJumpInfo :: forall blk. JumpInfo blk -> Point blk
idealiseJumpInfo = JumpInfo blk -> Point blk
forall blk. JumpInfo blk -> Point blk
CSState.jMostRecentIntersection

-- | The type of an invariant. Basically a glorified pair of a name and a check
-- function.
data Invariant peer blk = Invariant
  { forall peer blk. Invariant peer blk -> String
name  :: !String,
    forall peer blk. Invariant peer blk -> View peer blk -> Bool
check :: !(View peer blk -> Bool)
  }

-- | An exception that is thrown when an invariant is violated. It carries the
-- name of the invariant and the view of the state that triggered the invariant
-- violation.
data Violation peer blk = Violation !String !(View peer blk)
  deriving (Violation peer blk -> Violation peer blk -> Bool
(Violation peer blk -> Violation peer blk -> Bool)
-> (Violation peer blk -> Violation peer blk -> Bool)
-> Eq (Violation peer blk)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall peer blk.
(StandardHash blk, Eq peer) =>
Violation peer blk -> Violation peer blk -> Bool
$c== :: forall peer blk.
(StandardHash blk, Eq peer) =>
Violation peer blk -> Violation peer blk -> Bool
== :: Violation peer blk -> Violation peer blk -> Bool
$c/= :: forall peer blk.
(StandardHash blk, Eq peer) =>
Violation peer blk -> Violation peer blk -> Bool
/= :: Violation peer blk -> Violation peer blk -> Bool
Eq, Int -> Violation peer blk -> ShowS
[Violation peer blk] -> ShowS
Violation peer blk -> String
(Int -> Violation peer blk -> ShowS)
-> (Violation peer blk -> String)
-> ([Violation peer blk] -> ShowS)
-> Show (Violation peer blk)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall peer blk.
(StandardHash blk, Show peer) =>
Int -> Violation peer blk -> ShowS
forall peer blk.
(StandardHash blk, Show peer) =>
[Violation peer blk] -> ShowS
forall peer blk.
(StandardHash blk, Show peer) =>
Violation peer blk -> String
$cshowsPrec :: forall peer blk.
(StandardHash blk, Show peer) =>
Int -> Violation peer blk -> ShowS
showsPrec :: Int -> Violation peer blk -> ShowS
$cshow :: forall peer blk.
(StandardHash blk, Show peer) =>
Violation peer blk -> String
show :: Violation peer blk -> String
$cshowList :: forall peer blk.
(StandardHash blk, Show peer) =>
[Violation peer blk] -> ShowS
showList :: [Violation peer blk] -> ShowS
Show)

instance
  ( Typeable blk,
    StandardHash blk,
    Eq peer,
    Show peer,
    Typeable peer
  ) =>
  Exception (Violation peer blk)

-- | The watcher of ChainSync jumping invariants. It receives the ChainSync
-- handles and monitors them for changes. When a change is detected, it runs all
-- the invariants and throws 'Violation' if any of the invariants is violated.
watcher ::
  ( MonadSTM m,
    MonadThrow m,
    Eq peer,
    Show peer,
    Typeable peer,
    Typeable blk,
    StandardHash blk
  ) =>
  StrictTVar m (Map peer (CSState.ChainSyncClientHandle m blk)) ->
  Watcher m (View peer blk) (View peer blk)
watcher :: forall (m :: * -> *) peer blk.
(MonadSTM m, MonadThrow m, Eq peer, Show peer, Typeable peer,
 Typeable blk, StandardHash blk) =>
StrictTVar m (Map peer (ChainSyncClientHandle m blk))
-> Watcher m (View peer blk) (View peer blk)
watcher StrictTVar m (Map peer (ChainSyncClientHandle m blk))
handles =
  Watcher
    { wFingerprint :: Map peer (State blk) -> Map peer (State blk)
wFingerprint = Map peer (State blk) -> Map peer (State blk)
forall a. a -> a
id,
      wInitial :: Maybe (Map peer (State blk))
wInitial = Maybe (Map peer (State blk))
forall a. Maybe a
Nothing,
      wReader :: STM m (Map peer (State blk))
wReader = StrictTVar m (Map peer (ChainSyncClientHandle m blk))
-> STM m (Map peer (State blk))
forall (m :: * -> *) peer blk.
MonadSTM m =>
StrictTVar m (Map peer (ChainSyncClientHandle m blk))
-> STM m (View peer blk)
readAndView StrictTVar m (Map peer (ChainSyncClientHandle m blk))
handles,
      wNotify :: Map peer (State blk) -> m ()
wNotify =
        [Invariant peer blk] -> (Invariant peer blk -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Invariant peer blk]
forall peer blk. [Invariant peer blk]
allInvariants ((Invariant peer blk -> m ()) -> m ())
-> (Map peer (State blk) -> Invariant peer blk -> m ())
-> Map peer (State blk)
-> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. \Map peer (State blk)
view Invariant {String
name :: forall peer blk. Invariant peer blk -> String
name :: String
name, Map peer (State blk) -> Bool
check :: forall peer blk. Invariant peer blk -> View peer blk -> Bool
check :: Map peer (State blk) -> Bool
check} ->
          Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Map peer (State blk) -> Bool
check Map peer (State blk)
view) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ Violation peer blk -> m ()
forall e a. Exception e => e -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO (Violation peer blk -> m ()) -> Violation peer blk -> m ()
forall a b. (a -> b) -> a -> b
$ String -> Map peer (State blk) -> Violation peer blk
forall peer blk. String -> View peer blk -> Violation peer blk
Violation String
name Map peer (State blk)
view
    }