{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
module Ouroboros.Consensus.Util.IndexedMemPack
( IndexedMemPack (..)
, MemPack (..)
, indexedPackByteString
, indexedUnpackError
) where
import qualified Control.Monad as Monad
import Control.Monad.Trans.Fail (Fail, errorFail, failT)
import Data.Array.Byte (ByteArray (..))
import Data.ByteString
import Data.MemPack
import Data.MemPack.Buffer
import Data.MemPack.Error
import GHC.Stack
class IndexedMemPack idx a where
indexedPackedByteCount :: idx -> a -> Int
indexedPackM :: idx -> a -> Pack s ()
indexedUnpackM :: Buffer b => idx -> Unpack b a
indexedTypeName :: idx -> String
indexedPackByteString ::
forall a idx. (IndexedMemPack idx a, HasCallStack) => idx -> a -> ByteString
indexedPackByteString :: forall a idx.
(IndexedMemPack idx a, HasCallStack) =>
idx -> a -> ByteString
indexedPackByteString idx
idx = ByteArray -> ByteString
pinnedByteArrayToByteString (ByteArray -> ByteString) -> (a -> ByteArray) -> a -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> idx -> a -> ByteArray
forall a idx.
(IndexedMemPack idx a, HasCallStack) =>
Bool -> idx -> a -> ByteArray
indexedPackByteArray Bool
True idx
idx
{-# INLINE indexedPackByteString #-}
indexedPackByteArray ::
forall a idx.
(IndexedMemPack idx a, HasCallStack) =>
Bool ->
idx ->
a ->
ByteArray
indexedPackByteArray :: forall a idx.
(IndexedMemPack idx a, HasCallStack) =>
Bool -> idx -> a -> ByteArray
indexedPackByteArray Bool
isPinned idx
idx a
a =
HasCallStack =>
Bool -> String -> Int -> (forall s. Pack s ()) -> ByteArray
Bool -> String -> Int -> (forall s. Pack s ()) -> ByteArray
packWithByteArray
Bool
isPinned
(forall idx a. IndexedMemPack idx a => idx -> String
indexedTypeName @idx @a idx
idx)
(idx -> a -> Int
forall idx a. IndexedMemPack idx a => idx -> a -> Int
indexedPackedByteCount idx
idx a
a)
(idx -> a -> Pack s ()
forall s. idx -> a -> Pack s ()
forall idx a s. IndexedMemPack idx a => idx -> a -> Pack s ()
indexedPackM idx
idx a
a)
{-# INLINE indexedPackByteArray #-}
indexedUnpackError ::
forall idx a b. (Buffer b, IndexedMemPack idx a, HasCallStack) => idx -> b -> a
indexedUnpackError :: forall idx a b.
(Buffer b, IndexedMemPack idx a, HasCallStack) =>
idx -> b -> a
indexedUnpackError idx
idx = Fail SomeError a -> a
forall e a. (Show e, HasCallStack) => Fail e a -> a
errorFail (Fail SomeError a -> a) -> (b -> Fail SomeError a) -> b -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. idx -> b -> Fail SomeError a
forall idx a b.
(IndexedMemPack idx a, Buffer b, HasCallStack) =>
idx -> b -> Fail SomeError a
indexedUnpackFail idx
idx
{-# INLINEABLE indexedUnpackError #-}
indexedUnpackFail ::
forall idx a b. (IndexedMemPack idx a, Buffer b, HasCallStack) => idx -> b -> Fail SomeError a
indexedUnpackFail :: forall idx a b.
(IndexedMemPack idx a, Buffer b, HasCallStack) =>
idx -> b -> Fail SomeError a
indexedUnpackFail idx
idx b
b = do
let len :: Int
len = b -> Int
forall b. Buffer b => b -> Int
bufferByteCount b
b
(a, consumedBytes) <- idx -> b -> Fail SomeError (a, Int)
forall idx a b.
(IndexedMemPack idx a, Buffer b, HasCallStack) =>
idx -> b -> Fail SomeError (a, Int)
indexedUnpackLeftOver idx
idx b
b
Monad.when (consumedBytes /= len) $
unpackFailNotFullyConsumed (indexedTypeName @idx @a idx) consumedBytes len
pure a
{-# INLINEABLE indexedUnpackFail #-}
indexedUnpackLeftOver ::
forall idx a b.
(IndexedMemPack idx a, Buffer b, HasCallStack) => idx -> b -> Fail SomeError (a, Int)
indexedUnpackLeftOver :: forall idx a b.
(IndexedMemPack idx a, Buffer b, HasCallStack) =>
idx -> b -> Fail SomeError (a, Int)
indexedUnpackLeftOver idx
idx b
b = do
let len :: Int
len = b -> Int
forall b. Buffer b => b -> Int
bufferByteCount b
b
res@(_, consumedBytes) <- StateT Int (FailT SomeError Identity) a
-> Int -> FailT SomeError Identity (a, Int)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (Unpack b a -> b -> StateT Int (FailT SomeError Identity) a
forall b a.
Unpack b a -> b -> StateT Int (FailT SomeError Identity) a
runUnpack (idx -> Unpack b a
forall b. Buffer b => idx -> Unpack b a
forall idx a b.
(IndexedMemPack idx a, Buffer b) =>
idx -> Unpack b a
indexedUnpackM idx
idx) b
b) Int
0
Monad.when (consumedBytes > len) $ errorLeftOver (indexedTypeName @idx @a idx) consumedBytes len
pure res
{-# INLINEABLE indexedUnpackLeftOver #-}
errorLeftOver :: HasCallStack => String -> Int -> Int -> a
errorLeftOver :: forall a. HasCallStack => String -> Int -> Int -> a
errorLeftOver String
name Int
consumedBytes Int
len =
String -> a
forall a. HasCallStack => String -> a
error (String -> a) -> String -> a
forall a b. (a -> b) -> a -> b
$
String
"Potential buffer overflow. Some bug in 'unpackM' was detected while unpacking " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
name
String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
". Consumed " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
showBytes (Int
consumedBytes Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
len) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" more than allowed from a buffer of length "
String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
len
{-# NOINLINE errorLeftOver #-}
unpackFailNotFullyConsumed :: Applicative m => String -> Int -> Int -> FailT SomeError m a
unpackFailNotFullyConsumed :: forall (m :: * -> *) a.
Applicative m =>
String -> Int -> Int -> FailT SomeError m a
unpackFailNotFullyConsumed String
name Int
consumedBytes Int
len =
SomeError -> FailT SomeError m a
forall (m :: * -> *) e a. Applicative m => e -> FailT e m a
failT (SomeError -> FailT SomeError m a)
-> SomeError -> FailT SomeError m a
forall a b. (a -> b) -> a -> b
$
NotFullyConsumedError -> SomeError
forall e. Error e => e -> SomeError
toSomeError (NotFullyConsumedError -> SomeError)
-> NotFullyConsumedError -> SomeError
forall a b. (a -> b) -> a -> b
$
NotFullyConsumedError
{ notFullyConsumedRead :: Int
notFullyConsumedRead = Int
consumedBytes
, notFullyConsumedAvailable :: Int
notFullyConsumedAvailable = Int
len
, notFullyConsumedTypeName :: String
notFullyConsumedTypeName = String
name
}
{-# NOINLINE unpackFailNotFullyConsumed #-}