{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- | A compact bitmap representation using serialisation-ready ByteStrings.
--
-- Adapted from @Cardano.Leios.BitMapPV@ in the @leios-wfa-ls-demo@ package.
--
-- NOTE: this module is meant to be imported qualified.
module Ouroboros.Consensus.Util.Bitmap
  ( Bitmap
  , fromIndices
  , toIndices
  , logicalUpperBound
  , rawSerialise
  , rawDeserialise
  ) where

import Cardano.Binary (FromCBOR (..), ToCBOR (..))
import qualified Codec.CBOR.Decoding as CBOR
import qualified Codec.CBOR.Encoding as CBOR
import Control.Monad (forM_, when)
import Data.Bits
  ( countTrailingZeros
  , popCount
  , unsafeShiftL
  , (.&.)
  , (.|.)
  )
import Data.ByteString (ByteString)
import qualified Data.ByteString as ByteString
import qualified Data.ByteString.Internal as ByteString
import Data.Word (Word8)
import Foreign.Marshal.Utils (fillBytes)
import Foreign.Storable (peekByteOff, pokeByteOff)

-- | A compact bitmap representation over an index type.
--
-- NOTE: the logical upper bound is stored explicitly so serialisation
-- round-trips exactly.
data Bitmap a
  = Bitmap
      -- | Logical upper bound
      !a
      -- | Payload
      !ByteString
  deriving Bitmap a -> Bitmap a -> Bool
(Bitmap a -> Bitmap a -> Bool)
-> (Bitmap a -> Bitmap a -> Bool) -> Eq (Bitmap a)
forall a. Eq a => Bitmap a -> Bitmap a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall a. Eq a => Bitmap a -> Bitmap a -> Bool
== :: Bitmap a -> Bitmap a -> Bool
$c/= :: forall a. Eq a => Bitmap a -> Bitmap a -> Bool
/= :: Bitmap a -> Bitmap a -> Bool
Eq

instance Show a => Show (Bitmap a) where
  show :: Bitmap a -> String
show (Bitmap a
maxIx ByteString
bs) =
    String
"Bitmap{maxIx="
      String -> ShowS
forall a. Semigroup a => a -> a -> a
<> a -> String
forall a. Show a => a -> String
show a
maxIx
      String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
",bytes="
      String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show (ByteString -> Int
ByteString.length ByteString
bs)
      String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
",set="
      String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show (ByteString -> Int
countSetBits ByteString
bs)
      String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"}"
   where
    countSetBits :: ByteString -> Int
countSetBits ByteString
arr =
      [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum
        [ Word8 -> Int
forall a. Bits a => a -> Int
popCount (HasCallStack => ByteString -> Int -> Word8
ByteString -> Int -> Word8
ByteString.index ByteString
arr Int
i)
        | Int
i <- [Int
0 .. ByteString -> Int
ByteString.length ByteString
arr Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
        ]

-- | Construct a 'Bitmap' from a list of indexes that should be set (flipped to
-- 1) and a maximum index (inclusive logical upper bound).
fromIndices :: Integral a => a -> [a] -> Bitmap a
fromIndices :: forall a. Integral a => a -> [a] -> Bitmap a
fromIndices a
maxIx [a]
flipped =
  a -> ByteString -> Bitmap a
forall a. a -> ByteString -> Bitmap a
Bitmap a
maxIx (ByteString -> Bitmap a) -> ByteString -> Bitmap a
forall a b. (a -> b) -> a -> b
$
    Int -> (Ptr Word8 -> IO ()) -> ByteString
ByteString.unsafeCreate Int
nBytes ((Ptr Word8 -> IO ()) -> ByteString)
-> (Ptr Word8 -> IO ()) -> ByteString
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr -> do
      Ptr Word8 -> Word8 -> Int -> IO ()
forall a. Ptr a -> Word8 -> Int -> IO ()
fillBytes Ptr Word8
ptr Word8
0 Int
nBytes
      [a] -> (a -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [a]
flipped ((a -> IO ()) -> IO ()) -> (a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \a
ix -> do
        let !i :: Int
i = a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
ix :: Int
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0 Bool -> Bool -> Bool
&& Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
maxI) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
          let !byteIx :: Int
byteIx = Int
i Int -> Int -> Int
forall a. Integral a => a -> a -> a
`quot` Int
8
          let !bitIx :: Int
bitIx = Int
i Int -> Int -> Int
forall a. Integral a => a -> a -> a
`rem` Int
8
          let !mask :: Word8
mask = Int -> Word8
forall {b}. Num b => Int -> b
bitMask Int
bitIx
          w <- Ptr Word8 -> Int -> IO Word8
forall b. Ptr b -> Int -> IO Word8
forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr Word8
ptr Int
byteIx :: IO Word8
          pokeByteOff ptr byteIx (w .|. mask)
 where
  !maxI :: Int
maxI = a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
maxIx :: Int
  !nBytes :: Int
nBytes = (Int
maxI Int -> Int -> Int
forall a. Integral a => a -> a -> a
`quot` Int
8) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1

  bitMask :: Int -> b
bitMask Int
k = Int -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Int
1 :: Int) Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
k)

-- | Retrieve all indexes that are set (flipped to 1) in the bitmap, in
-- ascending order.
toIndices :: Integral a => Bitmap a -> [a]
toIndices :: forall a. Integral a => Bitmap a -> [a]
toIndices (Bitmap a
maxIx ByteString
bitmap) =
  Int -> [a]
forall {a}. Num a => Int -> [a]
goBytes Int
0
 where
  !maxI :: Int
maxI = a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
maxIx :: Int
  !nBytes :: Int
nBytes = ByteString -> Int
ByteString.length ByteString
bitmap

  goBytes :: Int -> [a]
goBytes !Int
byteIx
    | Int
byteIx Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
nBytes = []
    | Bool
otherwise =
        let !w :: Word8
w = HasCallStack => ByteString -> Int -> Word8
ByteString -> Int -> Word8
ByteString.index ByteString
bitmap Int
byteIx
         in Int -> Word8 -> [a]
forall {a} {t}. (Num a, Num t, FiniteBits t) => Int -> t -> [a]
goBits (Int
byteIx Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
8) Word8
w [a] -> [a] -> [a]
forall a. Semigroup a => a -> a -> a
<> Int -> [a]
goBytes (Int
byteIx Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)

  goBits :: Int -> t -> [a]
goBits !Int
_ t
0 = []
  goBits !Int
base !t
w =
    let !bitIx :: Int
bitIx = t -> Int
forall b. FiniteBits b => b -> Int
countTrailingZeros t
w
        !i :: Int
i = Int
base Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
bitIx
        !w' :: t
w' = t
w t -> t -> t
forall a. Bits a => a -> a -> a
.&. (t
w t -> t -> t
forall a. Num a => a -> a -> a
- t
1)
     in if Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
maxI
          then Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i a -> [a] -> [a]
forall a. a -> [a] -> [a]
: Int -> t -> [a]
goBits Int
base t
w'
          else []

-- | Get the logical upper bound of a bitmap
logicalUpperBound :: Bitmap a -> a
logicalUpperBound :: forall a. Bitmap a -> a
logicalUpperBound (Bitmap a
a ByteString
_) = a
a

-- | Raw serialisation of the bitmap (just the underlying bytes, without the
-- logical upper bound).
rawSerialise :: Bitmap a -> ByteString
rawSerialise :: forall a. Bitmap a -> ByteString
rawSerialise (Bitmap a
_ ByteString
bs) = ByteString
bs

-- | Raw deserialisation of a bitmap from a logical upper bound and a ByteString
--
-- Returns 'Nothing' if the byte string length does not match the expected size
-- for the given upper bound.
rawDeserialise :: Integral a => a -> ByteString -> Maybe (Bitmap a)
rawDeserialise :: forall a. Integral a => a -> ByteString -> Maybe (Bitmap a)
rawDeserialise a
maxIx ByteString
bs
  | ByteString -> Int
ByteString.length ByteString
bs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
expectedBytes = Maybe (Bitmap a)
forall a. Maybe a
Nothing
  | Bool
otherwise = Bitmap a -> Maybe (Bitmap a)
forall a. a -> Maybe a
Just (a -> ByteString -> Bitmap a
forall a. a -> ByteString -> Bitmap a
Bitmap a
maxIx ByteString
bs)
 where
  expectedBytes :: Int
expectedBytes = (a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
maxIx Int -> Int -> Int
forall a. Integral a => a -> a -> a
`quot` Int
8) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1

instance ToCBOR a => ToCBOR (Bitmap a) where
  toCBOR :: Bitmap a -> Encoding
toCBOR (Bitmap a
maxIx ByteString
bs) =
    Word -> Encoding
CBOR.encodeListLen Word
2
      Encoding -> Encoding -> Encoding
forall a. Semigroup a => a -> a -> a
<> a -> Encoding
forall a. ToCBOR a => a -> Encoding
toCBOR a
maxIx
      Encoding -> Encoding -> Encoding
forall a. Semigroup a => a -> a -> a
<> ByteString -> Encoding
CBOR.encodeBytes ByteString
bs

instance (Integral a, FromCBOR a) => FromCBOR (Bitmap a) where
  fromCBOR :: forall s. Decoder s (Bitmap a)
fromCBOR = do
    Int -> Decoder s ()
forall s. Int -> Decoder s ()
CBOR.decodeListLenOf Int
2
    maxIx <- Decoder s a
forall s. Decoder s a
forall a s. FromCBOR a => Decoder s a
fromCBOR
    bs <- CBOR.decodeBytes
    case rawDeserialise maxIx bs of
      Maybe (Bitmap a)
Nothing ->
        String -> Decoder s (Bitmap a)
forall a. String -> Decoder s a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Bitmap: invalid bitmap data or size mismatch"
      Just Bitmap a
bitmap ->
        Bitmap a -> Decoder s (Bitmap a)
forall a. a -> Decoder s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bitmap a
bitmap