{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Test.Util.Corruption (
    Corruption (..)
  , applyCorruption
  , detectCorruption
  ) where

import           Codec.CBOR.Decoding (Decoder)
import           Codec.CBOR.Encoding (Encoding)
import           Codec.CBOR.Read (deserialiseFromBytes)
import           Codec.CBOR.Term (Term)
import           Codec.CBOR.Write (toLazyByteString)
import           Codec.Serialise (deserialise)
import qualified Data.ByteString.Lazy as Lazy
import           Test.QuickCheck


newtype Corruption = Corruption Word
  deriving stock   (Int -> Corruption -> ShowS
[Corruption] -> ShowS
Corruption -> String
(Int -> Corruption -> ShowS)
-> (Corruption -> String)
-> ([Corruption] -> ShowS)
-> Show Corruption
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Corruption -> ShowS
showsPrec :: Int -> Corruption -> ShowS
$cshow :: Corruption -> String
show :: Corruption -> String
$cshowList :: [Corruption] -> ShowS
showList :: [Corruption] -> ShowS
Show)
  deriving newtype (Gen Corruption
Gen Corruption
-> (Corruption -> [Corruption]) -> Arbitrary Corruption
Corruption -> [Corruption]
forall a. Gen a -> (a -> [a]) -> Arbitrary a
$carbitrary :: Gen Corruption
arbitrary :: Gen Corruption
$cshrink :: Corruption -> [Corruption]
shrink :: Corruption -> [Corruption]
Arbitrary)

-- | Increment (overflow if necessary) the byte at position @i@ in the
-- bytestring, where @i = n `mod` length bs@.
--
-- If the bytestring is empty, return it unmodified.
applyCorruption :: Corruption -> Lazy.ByteString -> Lazy.ByteString
applyCorruption :: Corruption -> ByteString -> ByteString
applyCorruption (Corruption Word
n) ByteString
bs
    | ByteString -> Bool
Lazy.null ByteString
bs
    = ByteString
bs
    | Bool
otherwise
    = ByteString
before ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> case ByteString -> Maybe (Word8, ByteString)
Lazy.uncons ByteString
atAfter of
      Maybe (Word8, ByteString)
Nothing       -> String -> ByteString
forall a. HasCallStack => String -> a
error String
"split bytestring after last byte"
      Just (Word8
hd, ByteString
tl) -> Word8 -> ByteString -> ByteString
Lazy.cons (Word8
hd Word8 -> Word8 -> Word8
forall a. Num a => a -> a -> a
+ Word8
1) ByteString
tl
  where
    offset :: Int64
offset = Word -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word
n Int64 -> Int64 -> Int64
forall a. Integral a => a -> a -> a
`mod` ByteString -> Int64
Lazy.length ByteString
bs
    (ByteString
before, ByteString
atAfter) = Int64 -> ByteString -> (ByteString, ByteString)
Lazy.splitAt Int64
offset ByteString
bs

-- | Serialise @a@, apply the given corruption, deserialise it, when that
-- fails, the corruption was detected. When deserialising the corrupted
-- bytestring succeeds, pass the deserialised value to the integrity checking
-- function. If that function returns 'False', the corruption was detected, if
-- it returns 'True', the corruption was not detected and the test fails.
detectCorruption ::
     Show a
  => (a -> Encoding)
  -> (forall s. Decoder s (Lazy.ByteString -> a))
  -> (a -> Bool)
     -- ^ Integrity check that should detect the corruption. Return 'False'
     -- when corrupt.
  -> a
  -> Corruption
  -> Property
detectCorruption :: forall a.
Show a =>
(a -> Encoding)
-> (forall s. Decoder s (ByteString -> a))
-> (a -> Bool)
-> a
-> Corruption
-> Property
detectCorruption a -> Encoding
enc forall s. Decoder s (ByteString -> a)
dec a -> Bool
isValid a
a Corruption
cor =
    case (forall s. Decoder s (ByteString -> a))
-> ByteString
-> Either DeserialiseFailure (ByteString, ByteString -> a)
forall a.
(forall s. Decoder s a)
-> ByteString -> Either DeserialiseFailure (ByteString, a)
deserialiseFromBytes Decoder s (ByteString -> a)
forall s. Decoder s (ByteString -> a)
dec ByteString
corruptBytes of
      Right (ByteString
leftover, ByteString -> a
mkA')
          | Bool -> Bool
not (ByteString -> Bool
Lazy.null ByteString
leftover)
          -> String -> Property -> Property
forall prop. Testable prop => String -> prop -> Property
label String
"corruption detected by decoder" (Property -> Property) -> Property -> Property
forall a b. (a -> b) -> a -> b
$ Bool -> Property
forall prop. Testable prop => prop -> Property
property Bool
True
          | Bool -> Bool
not (a -> Bool
isValid a
a')
          -> String -> Property -> Property
forall prop. Testable prop => String -> prop -> Property
label String
"corruption detected" (Property -> Property) -> Property -> Property
forall a b. (a -> b) -> a -> b
$ Bool -> Property
forall prop. Testable prop => prop -> Property
property Bool
True
          | Bool
otherwise
          -> String -> Property -> Property
forall prop. Testable prop => String -> prop -> Property
counterexample
               (String
"Corruption not detected: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> a -> String
forall a. Show a => a -> String
show a
a')
           (Property -> Property) -> Property -> Property
forall a b. (a -> b) -> a -> b
$ String -> Property -> Property
forall prop. Testable prop => String -> prop -> Property
counterexample
               (String
"Original bytes: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> ByteString -> String
forall a. Show a => a -> String
show ByteString
origBytes)
           (Property -> Property) -> Property -> Property
forall a b. (a -> b) -> a -> b
$ String -> Property -> Property
forall prop. Testable prop => String -> prop -> Property
counterexample
               (String
"Corrupt bytes: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> ByteString -> String
forall a. Show a => a -> String
show ByteString
corruptBytes)
           (Property -> Property) -> Property -> Property
forall a b. (a -> b) -> a -> b
$ String -> Property -> Property
forall prop. Testable prop => String -> prop -> Property
counterexample
               (String
"Original CBOR: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Term -> String
forall a. Show a => a -> String
show (ByteString -> Term
forall a. Serialise a => ByteString -> a
deserialise ByteString
origBytes :: Term))
           (Property -> Property) -> Property -> Property
forall a b. (a -> b) -> a -> b
$ String -> Bool -> Property
forall prop. Testable prop => String -> prop -> Property
counterexample
               (String
"Corrupt CBOR: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Term -> String
forall a. Show a => a -> String
show (ByteString -> Term
forall a. Serialise a => ByteString -> a
deserialise ByteString
corruptBytes :: Term))
             Bool
False
        where
          a' :: a
a' = ByteString -> a
mkA' ByteString
corruptBytes
      Left DeserialiseFailure
_ -> String -> Property -> Property
forall prop. Testable prop => String -> prop -> Property
label String
"corruption detected by decoder" (Property -> Property) -> Property -> Property
forall a b. (a -> b) -> a -> b
$ Bool -> Property
forall prop. Testable prop => prop -> Property
property Bool
True
  where
    origBytes :: ByteString
origBytes    = Encoding -> ByteString
toLazyByteString (a -> Encoding
enc a
a)
    corruptBytes :: ByteString
corruptBytes = Corruption -> ByteString -> ByteString
applyCorruption Corruption
cor ByteString
origBytes