{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}

-- | This module is a derivative of "Data.MemPack" but we provide something that
-- will be used to " index " the serialization.
--
-- The idea is that we can use this in the Cardano block to avoid serializing a
-- tag next to the TxOut, as the Ledger layer establishes the property that
-- TxOuts are forwards deserializable, meaning we can read them in any later
-- era.
module Ouroboros.Consensus.Util.IndexedMemPack
  ( IndexedMemPack (..)
  , MemPack (..)
  , indexedPackByteString
  , indexedUnpackError
  ) where

import qualified Control.Monad as Monad
import Control.Monad.Trans.Fail (Fail, errorFail, failT)
import Data.Array.Byte (ByteArray (..))
import Data.ByteString
import Data.MemPack
import Data.MemPack.Buffer
import Data.MemPack.Error
import GHC.Stack

-- | See 'MemPack'.
class IndexedMemPack idx a where
  indexedPackedByteCount :: idx -> a -> Int
  indexedPackM :: idx -> a -> Pack s ()
  indexedUnpackM :: Buffer b => idx -> Unpack b a
  indexedTypeName :: idx -> String

indexedPackByteString ::
  forall a idx. (IndexedMemPack idx a, HasCallStack) => idx -> a -> ByteString
indexedPackByteString :: forall a idx.
(IndexedMemPack idx a, HasCallStack) =>
idx -> a -> ByteString
indexedPackByteString idx
idx = ByteArray -> ByteString
pinnedByteArrayToByteString (ByteArray -> ByteString) -> (a -> ByteArray) -> a -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> idx -> a -> ByteArray
forall a idx.
(IndexedMemPack idx a, HasCallStack) =>
Bool -> idx -> a -> ByteArray
indexedPackByteArray Bool
True idx
idx
{-# INLINE indexedPackByteString #-}

indexedPackByteArray ::
  forall a idx.
  (IndexedMemPack idx a, HasCallStack) =>
  Bool ->
  idx ->
  a ->
  ByteArray
indexedPackByteArray :: forall a idx.
(IndexedMemPack idx a, HasCallStack) =>
Bool -> idx -> a -> ByteArray
indexedPackByteArray Bool
isPinned idx
idx a
a =
  HasCallStack =>
Bool -> String -> Int -> (forall s. Pack s ()) -> ByteArray
Bool -> String -> Int -> (forall s. Pack s ()) -> ByteArray
packWithByteArray
    Bool
isPinned
    (forall idx a. IndexedMemPack idx a => idx -> String
indexedTypeName @idx @a idx
idx)
    (idx -> a -> Int
forall idx a. IndexedMemPack idx a => idx -> a -> Int
indexedPackedByteCount idx
idx a
a)
    (idx -> a -> Pack s ()
forall s. idx -> a -> Pack s ()
forall idx a s. IndexedMemPack idx a => idx -> a -> Pack s ()
indexedPackM idx
idx a
a)
{-# INLINE indexedPackByteArray #-}

indexedUnpackError ::
  forall idx a b. (Buffer b, IndexedMemPack idx a, HasCallStack) => idx -> b -> a
indexedUnpackError :: forall idx a b.
(Buffer b, IndexedMemPack idx a, HasCallStack) =>
idx -> b -> a
indexedUnpackError idx
idx = Fail SomeError a -> a
forall e a. (Show e, HasCallStack) => Fail e a -> a
errorFail (Fail SomeError a -> a) -> (b -> Fail SomeError a) -> b -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. idx -> b -> Fail SomeError a
forall idx a b.
(IndexedMemPack idx a, Buffer b, HasCallStack) =>
idx -> b -> Fail SomeError a
indexedUnpackFail idx
idx
{-# INLINEABLE indexedUnpackError #-}

indexedUnpackFail ::
  forall idx a b. (IndexedMemPack idx a, Buffer b, HasCallStack) => idx -> b -> Fail SomeError a
indexedUnpackFail :: forall idx a b.
(IndexedMemPack idx a, Buffer b, HasCallStack) =>
idx -> b -> Fail SomeError a
indexedUnpackFail idx
idx b
b = do
  let len :: Int
len = b -> Int
forall b. Buffer b => b -> Int
bufferByteCount b
b
  (a, consumedBytes) <- idx -> b -> Fail SomeError (a, Int)
forall idx a b.
(IndexedMemPack idx a, Buffer b, HasCallStack) =>
idx -> b -> Fail SomeError (a, Int)
indexedUnpackLeftOver idx
idx b
b
  Monad.when (consumedBytes /= len) $
    unpackFailNotFullyConsumed (indexedTypeName @idx @a idx) consumedBytes len
  pure a
{-# INLINEABLE indexedUnpackFail #-}

indexedUnpackLeftOver ::
  forall idx a b.
  (IndexedMemPack idx a, Buffer b, HasCallStack) => idx -> b -> Fail SomeError (a, Int)
indexedUnpackLeftOver :: forall idx a b.
(IndexedMemPack idx a, Buffer b, HasCallStack) =>
idx -> b -> Fail SomeError (a, Int)
indexedUnpackLeftOver idx
idx b
b = do
  let len :: Int
len = b -> Int
forall b. Buffer b => b -> Int
bufferByteCount b
b
  res@(_, consumedBytes) <- StateT Int (FailT SomeError Identity) a
-> Int -> FailT SomeError Identity (a, Int)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (Unpack b a -> b -> StateT Int (FailT SomeError Identity) a
forall b a.
Unpack b a -> b -> StateT Int (FailT SomeError Identity) a
runUnpack (idx -> Unpack b a
forall b. Buffer b => idx -> Unpack b a
forall idx a b.
(IndexedMemPack idx a, Buffer b) =>
idx -> Unpack b a
indexedUnpackM idx
idx) b
b) Int
0
  Monad.when (consumedBytes > len) $ errorLeftOver (indexedTypeName @idx @a idx) consumedBytes len
  pure res
{-# INLINEABLE indexedUnpackLeftOver #-}

errorLeftOver :: HasCallStack => String -> Int -> Int -> a
errorLeftOver :: forall a. HasCallStack => String -> Int -> Int -> a
errorLeftOver String
name Int
consumedBytes Int
len =
  String -> a
forall a. HasCallStack => String -> a
error (String -> a) -> String -> a
forall a b. (a -> b) -> a -> b
$
    String
"Potential buffer overflow. Some bug in 'unpackM' was detected while unpacking " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
name
      String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
". Consumed " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
showBytes (Int
consumedBytes Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
len) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" more than allowed from a buffer of length "
      String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
len
{-# NOINLINE errorLeftOver #-}

unpackFailNotFullyConsumed :: Applicative m => String -> Int -> Int -> FailT SomeError m a
unpackFailNotFullyConsumed :: forall (m :: * -> *) a.
Applicative m =>
String -> Int -> Int -> FailT SomeError m a
unpackFailNotFullyConsumed String
name Int
consumedBytes Int
len =
  SomeError -> FailT SomeError m a
forall (m :: * -> *) e a. Applicative m => e -> FailT e m a
failT (SomeError -> FailT SomeError m a)
-> SomeError -> FailT SomeError m a
forall a b. (a -> b) -> a -> b
$
    NotFullyConsumedError -> SomeError
forall e. Error e => e -> SomeError
toSomeError (NotFullyConsumedError -> SomeError)
-> NotFullyConsumedError -> SomeError
forall a b. (a -> b) -> a -> b
$
      NotFullyConsumedError
        { notFullyConsumedRead :: Int
notFullyConsumedRead = Int
consumedBytes
        , notFullyConsumedAvailable :: Int
notFullyConsumedAvailable = Int
len
        , notFullyConsumedTypeName :: String
notFullyConsumedTypeName = String
name
        }
{-# NOINLINE unpackFailNotFullyConsumed #-}