{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE UndecidableInstances #-}

-- | This module implements a “leaky bucket”. One defines a bucket with a
-- capacity and a leaking rate; a race (in the sense of Async) starts against
-- the bucket which leaks at the given rate. The user is provided with a
-- function to refill the bucket by a certain amount. If the bucket ever goes
-- empty, both threads are cancelled.
--
-- This can be used for instance to enforce a minimal rate of a peer: they race
-- against the bucket and refill the bucket by a certain amount whenever they do
-- a “good” action.
--
-- NOTE: Even though the imagery is the same, this is different from what is
-- usually called a \“token bucket\” or \“leaky bucket\” in the litterature
-- where it is mostly used for rate limiting.
--
-- REVIEW: Could be used as leaky bucket used for rate limiting algorithms. All
-- the infrastructure is here (put 'onEmpty' to @pure ()@ and you're good to go)
-- but it has not been tested with that purpose in mind.
--
-- $leakyBucketDesign
module Ouroboros.Consensus.Util.LeakyBucket (
    Config (..)
  , Handlers (..)
  , State (..)
  , atomicallyWithMonotonicTime
  , diffTimeToSecondsRational
  , dummyConfig
  , evalAgainstBucket
  , execAgainstBucket
  , execAgainstBucket'
  , fill'
  , microsecondsPerSecond
  , picosecondsPerSecond
  , runAgainstBucket
  , secondsRationalToDiffTime
  , setPaused'
  , updateConfig'
  ) where

import           Control.Exception (assert)
import           Control.Monad (forever, void, when)
import qualified Control.Monad.Class.MonadSTM.Internal as TVar
import           Control.Monad.Class.MonadTimer (MonadTimer, registerDelay)
import           Control.Monad.Class.MonadTimer.SI (diffTimeToMicrosecondsAsInt)
import           Data.Ratio ((%))
import           Data.Time.Clock (diffTimeToPicoseconds)
import           GHC.Generics (Generic)
import           Ouroboros.Consensus.Util.IOLike hiding (killThread)
import           Ouroboros.Consensus.Util.STM (blockUntilChanged)
import           Prelude hiding (init)

-- | Configuration of a leaky bucket.
data Config m = Config
  { -- | Initial and maximal capacity of the bucket, in number of tokens.
    forall (m :: * -> *). Config m -> Rational
capacity       :: !Rational,
    -- | Tokens per second leaking off the bucket.
    forall (m :: * -> *). Config m -> Rational
rate           :: !Rational,
    -- | Whether to fill to capacity on overflow or to do nothing.
    forall (m :: * -> *). Config m -> Bool
fillOnOverflow :: !Bool,
    -- | A monadic action to trigger when the bucket is empty.
    forall (m :: * -> *). Config m -> m ()
onEmpty        :: !(m ())
  }
  deriving ((forall x. Config m -> Rep (Config m) x)
-> (forall x. Rep (Config m) x -> Config m) -> Generic (Config m)
forall x. Rep (Config m) x -> Config m
forall x. Config m -> Rep (Config m) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (m :: * -> *) x. Rep (Config m) x -> Config m
forall (m :: * -> *) x. Config m -> Rep (Config m) x
$cfrom :: forall (m :: * -> *) x. Config m -> Rep (Config m) x
from :: forall x. Config m -> Rep (Config m) x
$cto :: forall (m :: * -> *) x. Rep (Config m) x -> Config m
to :: forall x. Rep (Config m) x -> Config m
Generic)

deriving instance NoThunks (m ()) => NoThunks (Config m)

-- | A configuration for a bucket that does nothing.
dummyConfig :: (Applicative m) => Config m
dummyConfig :: forall (m :: * -> *). Applicative m => Config m
dummyConfig =
  Config
    { capacity :: Rational
capacity = Rational
0,
      rate :: Rational
rate = Rational
0,
      fillOnOverflow :: Bool
fillOnOverflow = Bool
True,
      onEmpty :: m ()
onEmpty = () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    }

-- | State of a leaky bucket, giving the level and the associated time.
data State m = State
  { forall (m :: * -> *). State m -> Rational
level            :: !Rational,
    forall (m :: * -> *). State m -> Time
time             :: !Time,
    forall (m :: * -> *). State m -> Bool
paused           :: !Bool,
    forall (m :: * -> *). State m -> Int
configGeneration :: !Int,
    forall (m :: * -> *). State m -> Config m
config           :: !(Config m)
  }
  deriving ((forall x. State m -> Rep (State m) x)
-> (forall x. Rep (State m) x -> State m) -> Generic (State m)
forall x. Rep (State m) x -> State m
forall x. State m -> Rep (State m) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (m :: * -> *) x. Rep (State m) x -> State m
forall (m :: * -> *) x. State m -> Rep (State m) x
$cfrom :: forall (m :: * -> *) x. State m -> Rep (State m) x
from :: forall x. State m -> Rep (State m) x
$cto :: forall (m :: * -> *) x. Rep (State m) x -> State m
to :: forall x. Rep (State m) x -> State m
Generic)

deriving instance (NoThunks (m ())) => NoThunks (State m)

-- | A bucket is simply a TVar of a state. The state carries a 'Config' and an
-- integer (a “generation”) to detect changes in the configuration.
type Bucket m = StrictTVar m (State m)

-- | Whether filling the bucket overflew.
data FillResult = Overflew | DidNotOverflow

-- | The handlers to a bucket: contains the API to interact with a running
-- bucket. All the endpoints are STM but require the current time; the easy way
-- to provide this being 'atomicallyWithMonotonicTime'.
data Handlers m = Handlers
  { -- | Refill the bucket by the given amount and returns whether the bucket
    -- overflew. The bucket may silently get filled to full capacity or not get
    -- filled depending on 'fillOnOverflow'.
    forall (m :: * -> *).
Handlers m -> Rational -> Time -> STM m FillResult
fill ::
      !( Rational ->
         Time ->
         STM m FillResult
       ),
    -- | Pause or resume the bucket. Pausing stops the bucket from leaking until
    -- it is resumed. It is still possible to fill it during that time. @setPaused
    -- True@ and @setPaused False@ are idempotent.
    forall (m :: * -> *). Handlers m -> Bool -> Time -> STM m ()
setPaused ::
      !( Bool ->
         Time ->
         STM m ()
       ),
    -- | Dynamically update the level and configuration of the bucket. Updating
    -- the level matters if the capacity changes, in particular. If updating
    -- leave the bucket empty, the action is triggered immediately.
    forall (m :: * -> *).
Handlers m
-> ((Rational, Config m) -> (Rational, Config m))
-> Time
-> STM m ()
updateConfig ::
      !( ((Rational, Config m) -> (Rational, Config m)) ->
         Time ->
         STM m ()
       )
  }

-- | Variant of 'fill' already wrapped in 'atomicallyWithMonotonicTime'.
fill' ::
  ( MonadMonotonicTime m,
    MonadSTM m
  ) =>
  Handlers m ->
  Rational ->
  m FillResult
fill' :: forall (m :: * -> *).
(MonadMonotonicTime m, MonadSTM m) =>
Handlers m -> Rational -> m FillResult
fill' Handlers m
h Rational
r = (Time -> STM m FillResult) -> m FillResult
forall (m :: * -> *) b.
(MonadMonotonicTime m, MonadSTM m) =>
(Time -> STM m b) -> m b
atomicallyWithMonotonicTime ((Time -> STM m FillResult) -> m FillResult)
-> (Time -> STM m FillResult) -> m FillResult
forall a b. (a -> b) -> a -> b
$ Handlers m -> Rational -> Time -> STM m FillResult
forall (m :: * -> *).
Handlers m -> Rational -> Time -> STM m FillResult
fill Handlers m
h Rational
r

-- | Variant of 'setPaused' already wrapped in 'atomicallyWithMonotonicTime'.
setPaused' ::
  ( MonadMonotonicTime m,
    MonadSTM m
  ) =>
  Handlers m ->
  Bool ->
  m ()
setPaused' :: forall (m :: * -> *).
(MonadMonotonicTime m, MonadSTM m) =>
Handlers m -> Bool -> m ()
setPaused' Handlers m
h Bool
p = (Time -> STM m ()) -> m ()
forall (m :: * -> *) b.
(MonadMonotonicTime m, MonadSTM m) =>
(Time -> STM m b) -> m b
atomicallyWithMonotonicTime ((Time -> STM m ()) -> m ()) -> (Time -> STM m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ Handlers m -> Bool -> Time -> STM m ()
forall (m :: * -> *). Handlers m -> Bool -> Time -> STM m ()
setPaused Handlers m
h Bool
p

-- | Variant of 'updateConfig' already wrapped in 'atomicallyWithMonotonicTime'.
updateConfig' ::
  ( MonadMonotonicTime m,
    MonadSTM m
  ) =>
  Handlers m ->
  ((Rational, Config m) -> (Rational, Config m)) ->
  m ()
updateConfig' :: forall (m :: * -> *).
(MonadMonotonicTime m, MonadSTM m) =>
Handlers m
-> ((Rational, Config m) -> (Rational, Config m)) -> m ()
updateConfig' Handlers m
h (Rational, Config m) -> (Rational, Config m)
f = (Time -> STM m ()) -> m ()
forall (m :: * -> *) b.
(MonadMonotonicTime m, MonadSTM m) =>
(Time -> STM m b) -> m b
atomicallyWithMonotonicTime ((Time -> STM m ()) -> m ()) -> (Time -> STM m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ Handlers m
-> ((Rational, Config m) -> (Rational, Config m))
-> Time
-> STM m ()
forall (m :: * -> *).
Handlers m
-> ((Rational, Config m) -> (Rational, Config m))
-> Time
-> STM m ()
updateConfig Handlers m
h (Rational, Config m) -> (Rational, Config m)
f

-- | Create a bucket with the given configuration, then run the action against
-- that bucket. Returns when the action terminates or the bucket empties. In the
-- first case, return the value returned by the action. In the second case,
-- return @Nothing@.
execAgainstBucket ::
  ( MonadDelay m,
    MonadAsync m,
    MonadFork m,
    MonadMask m,
    MonadTimer m,
    NoThunks (m ())
  ) =>
  Config m ->
  (Handlers m -> m a) ->
  m a
execAgainstBucket :: forall (m :: * -> *) a.
(MonadDelay m, MonadAsync m, MonadFork m, MonadMask m,
 MonadTimer m, NoThunks (m ())) =>
Config m -> (Handlers m -> m a) -> m a
execAgainstBucket Config m
config Handlers m -> m a
action = (State m, a) -> a
forall a b. (a, b) -> b
snd ((State m, a) -> a) -> m (State m, a) -> m a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Config m -> (Handlers m -> m a) -> m (State m, a)
forall (m :: * -> *) a.
(MonadDelay m, MonadAsync m, MonadFork m, MonadMask m,
 MonadTimer m, NoThunks (m ())) =>
Config m -> (Handlers m -> m a) -> m (State m, a)
runAgainstBucket Config m
config Handlers m -> m a
action

-- | Variant of 'execAgainstBucket' that uses a dummy configuration. This only
-- makes sense for actions that use 'updateConfig'.
execAgainstBucket' ::
  ( MonadDelay m,
    MonadAsync m,
    MonadFork m,
    MonadMask m,
    MonadTimer m,
    NoThunks (m ())
  ) =>
  (Handlers m -> m a) ->
  m a
execAgainstBucket' :: forall (m :: * -> *) a.
(MonadDelay m, MonadAsync m, MonadFork m, MonadMask m,
 MonadTimer m, NoThunks (m ())) =>
(Handlers m -> m a) -> m a
execAgainstBucket' Handlers m -> m a
action =
  Config m -> (Handlers m -> m a) -> m a
forall (m :: * -> *) a.
(MonadDelay m, MonadAsync m, MonadFork m, MonadMask m,
 MonadTimer m, NoThunks (m ())) =>
Config m -> (Handlers m -> m a) -> m a
execAgainstBucket Config m
forall (m :: * -> *). Applicative m => Config m
dummyConfig Handlers m -> m a
action

-- | Same as 'execAgainstBucket' but returns the 'State' of the bucket when the
-- action terminates. Exposed for testing purposes.
evalAgainstBucket ::
  (MonadDelay m, MonadAsync m, MonadFork m, MonadMask m, MonadTimer m, NoThunks (m ())
  ) =>
  Config m ->
  (Handlers m -> m a) ->
  m (State m)
evalAgainstBucket :: forall (m :: * -> *) a.
(MonadDelay m, MonadAsync m, MonadFork m, MonadMask m,
 MonadTimer m, NoThunks (m ())) =>
Config m -> (Handlers m -> m a) -> m (State m)
evalAgainstBucket Config m
config Handlers m -> m a
action = (State m, a) -> State m
forall a b. (a, b) -> a
fst ((State m, a) -> State m) -> m (State m, a) -> m (State m)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Config m -> (Handlers m -> m a) -> m (State m, a)
forall (m :: * -> *) a.
(MonadDelay m, MonadAsync m, MonadFork m, MonadMask m,
 MonadTimer m, NoThunks (m ())) =>
Config m -> (Handlers m -> m a) -> m (State m, a)
runAgainstBucket Config m
config Handlers m -> m a
action

-- | Same as 'execAgainstBucket' but also returns the 'State' of the bucket when
-- the action terminates. Exposed for testing purposes.
runAgainstBucket ::
  forall m a.
  ( MonadDelay m,
    MonadAsync m,
    MonadFork m,
    MonadMask m,
    MonadTimer m,
    NoThunks (m ())
  ) =>
  Config m ->
  (Handlers m -> m a) ->
  m (State m, a)
runAgainstBucket :: forall (m :: * -> *) a.
(MonadDelay m, MonadAsync m, MonadFork m, MonadMask m,
 MonadTimer m, NoThunks (m ())) =>
Config m -> (Handlers m -> m a) -> m (State m, a)
runAgainstBucket Config m
config Handlers m -> m a
action = do
  StrictTMVar m Int
leakingPeriodVersionTMVar <- STM m (StrictTMVar m Int) -> m (StrictTMVar m Int)
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically STM m (StrictTMVar m Int)
forall (m :: * -> *) a. MonadSTM m => STM m (StrictTMVar m a)
newEmptyTMVar -- see note [Leaky bucket design].
  ThreadId m
tid <- m (ThreadId m)
forall (m :: * -> *). MonadThread m => m (ThreadId m)
myThreadId
  Bucket m
bucket <- Config m -> m (Bucket m)
forall (m :: * -> *).
(MonadMonotonicTime m, MonadSTM m, NoThunks (m ())) =>
Config m -> m (Bucket m)
init Config m
config
  m () -> (Async m () -> m (State m, a)) -> m (State m, a)
forall a b. m a -> (Async m a -> m b) -> m b
forall (m :: * -> *) a b.
MonadAsync m =>
m a -> (Async m a -> m b) -> m b
withAsync (STM m Int -> ThreadId m -> Bucket m -> m ()
forall (m :: * -> *).
(MonadDelay m, MonadCatch m, MonadFork m, MonadAsync m,
 MonadTimer m) =>
STM m Int -> ThreadId m -> Bucket m -> m ()
leak (StrictTMVar m Int -> STM m Int
forall (m :: * -> *) a. MonadSTM m => StrictTMVar m a -> STM m a
readTMVar StrictTMVar m Int
leakingPeriodVersionTMVar) ThreadId m
tid Bucket m
bucket) ((Async m () -> m (State m, a)) -> m (State m, a))
-> (Async m () -> m (State m, a)) -> m (State m, a)
forall a b. (a -> b) -> a -> b
$ \Async m ()
_ -> do
    (Time -> STM m ()) -> m ()
forall (m :: * -> *) b.
(MonadMonotonicTime m, MonadSTM m) =>
(Time -> STM m b) -> m b
atomicallyWithMonotonicTime ((Time -> STM m ()) -> m ()) -> (Time -> STM m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ Maybe Int -> StrictTMVar m Int -> Bucket m -> Time -> STM m ()
maybeStartThread Maybe Int
forall a. Maybe a
Nothing StrictTMVar m Int
leakingPeriodVersionTMVar Bucket m
bucket
    a
result <-
      Handlers m -> m a
action (Handlers m -> m a) -> Handlers m -> m a
forall a b. (a -> b) -> a -> b
$
        Handlers
          { fill :: Rational -> Time -> STM m FillResult
fill = \Rational
r Time
t -> ((State m, FillResult) -> FillResult
forall a b. (a, b) -> b
snd ((State m, FillResult) -> FillResult)
-> STM m (State m, FillResult) -> STM m FillResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>) (STM m (State m, FillResult) -> STM m FillResult)
-> STM m (State m, FillResult) -> STM m FillResult
forall a b. (a -> b) -> a -> b
$ Bucket m -> Rational -> Time -> STM m (State m, FillResult)
forall (m :: * -> *).
MonadSTM m =>
Bucket m -> Rational -> Time -> STM m (State m, FillResult)
snapshotFill Bucket m
bucket Rational
r Time
t,
            setPaused :: Bool -> Time -> STM m ()
setPaused = Bucket m -> Bool -> Time -> STM m ()
setPaused Bucket m
bucket,
            updateConfig :: ((Rational, Config m) -> (Rational, Config m)) -> Time -> STM m ()
updateConfig = StrictTMVar m Int
-> Bucket m
-> ((Rational, Config m) -> (Rational, Config m))
-> Time
-> STM m ()
updateConfig StrictTMVar m Int
leakingPeriodVersionTMVar Bucket m
bucket
          }
    State m
state <- (Time -> STM m (State m)) -> m (State m)
forall (m :: * -> *) b.
(MonadMonotonicTime m, MonadSTM m) =>
(Time -> STM m b) -> m b
atomicallyWithMonotonicTime ((Time -> STM m (State m)) -> m (State m))
-> (Time -> STM m (State m)) -> m (State m)
forall a b. (a -> b) -> a -> b
$ Bucket m -> Time -> STM m (State m)
forall (m :: * -> *).
MonadSTM m =>
Bucket m -> Time -> STM m (State m)
snapshot Bucket m
bucket
    (State m, a) -> m (State m, a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (State m
state, a
result)
  where
    -- Start the thread (that is, write to its 'leakingPeriodVersionTMVar') if it is useful.
    -- Takes a potential old value of the 'leakingPeriodVersionTMVar' as first argument,
    -- which will be increased to help differentiate between restarts.
    maybeStartThread :: Maybe Int -> StrictTMVar m Int -> Bucket m -> Time -> STM m ()
    maybeStartThread :: Maybe Int -> StrictTMVar m Int -> Bucket m -> Time -> STM m ()
maybeStartThread Maybe Int
mLeakingPeriodVersion StrictTMVar m Int
leakingPeriodVersionTMVar Bucket m
bucket Time
time = do
      State {config :: forall (m :: * -> *). State m -> Config m
config = Config {Rational
rate :: forall (m :: * -> *). Config m -> Rational
rate :: Rational
rate}} <- Bucket m -> Time -> STM m (State m)
forall (m :: * -> *).
MonadSTM m =>
Bucket m -> Time -> STM m (State m)
snapshot Bucket m
bucket Time
time
      Bool -> STM m () -> STM m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Rational
rate Rational -> Rational -> Bool
forall a. Ord a => a -> a -> Bool
> Rational
0) (STM m () -> STM m ()) -> STM m () -> STM m ()
forall a b. (a -> b) -> a -> b
$ STM m Bool -> STM m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (STM m Bool -> STM m ()) -> STM m Bool -> STM m ()
forall a b. (a -> b) -> a -> b
$ StrictTMVar m Int -> Int -> STM m Bool
forall (m :: * -> *) a.
MonadSTM m =>
StrictTMVar m a -> a -> STM m Bool
tryPutTMVar StrictTMVar m Int
leakingPeriodVersionTMVar (Int -> STM m Bool) -> Int -> STM m Bool
forall a b. (a -> b) -> a -> b
$ Int -> (Int -> Int) -> Maybe Int -> Int
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Int
0 (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Maybe Int
mLeakingPeriodVersion

    setPaused :: Bucket m -> Bool -> Time -> STM m ()
    setPaused :: Bucket m -> Bool -> Time -> STM m ()
setPaused Bucket m
bucket Bool
paused Time
time = do
      State m
newState <- Bucket m -> Time -> STM m (State m)
forall (m :: * -> *).
MonadSTM m =>
Bucket m -> Time -> STM m (State m)
snapshot Bucket m
bucket Time
time
      Bucket m -> State m -> STM m ()
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
StrictTVar m a -> a -> STM m ()
writeTVar Bucket m
bucket State m
newState {paused}

    updateConfig ::
      StrictTMVar m Int ->
      Bucket m ->
      ((Rational, Config m) -> (Rational, Config m)) ->
      Time ->
      STM m ()
    updateConfig :: StrictTMVar m Int
-> Bucket m
-> ((Rational, Config m) -> (Rational, Config m))
-> Time
-> STM m ()
updateConfig StrictTMVar m Int
leakingPeriodVersionTMVar Bucket m
bucket (Rational, Config m) -> (Rational, Config m)
f Time
time = do
      State
        { level :: forall (m :: * -> *). State m -> Rational
level = Rational
oldLevel,
          Bool
paused :: forall (m :: * -> *). State m -> Bool
paused :: Bool
paused,
          configGeneration :: forall (m :: * -> *). State m -> Int
configGeneration = Int
oldConfigGeneration,
          config :: forall (m :: * -> *). State m -> Config m
config = Config m
oldConfig
        } <-
        Bucket m -> Time -> STM m (State m)
forall (m :: * -> *).
MonadSTM m =>
Bucket m -> Time -> STM m (State m)
snapshot Bucket m
bucket Time
time
      let (Rational
newLevel, Config m
newConfig) = (Rational, Config m) -> (Rational, Config m)
f (Rational
oldLevel, Config m
oldConfig)
          Config {capacity :: forall (m :: * -> *). Config m -> Rational
capacity = Rational
newCapacity} = Config m
newConfig
          newLevel' :: Rational
newLevel' = (Rational, Rational) -> Rational -> Rational
forall a. Ord a => (a, a) -> a -> a
clamp (Rational
0, Rational
newCapacity) Rational
newLevel
      Bucket m -> State m -> STM m ()
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
StrictTVar m a -> a -> STM m ()
writeTVar Bucket m
bucket (State m -> STM m ()) -> State m -> STM m ()
forall a b. (a -> b) -> a -> b
$
        State
          { level :: Rational
level = Rational
newLevel',
            Time
time :: Time
time :: Time
time,
            Bool
paused :: Bool
paused :: Bool
paused,
            configGeneration :: Int
configGeneration = Int
oldConfigGeneration Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1,
            config :: Config m
config = Config m
newConfig
          }
      -- Ensure that 'leakingPeriodVersionTMVar' is empty, then maybe start the thread.
      Maybe Int
mLeakingPeriodVersion <- StrictTMVar m Int -> STM m (Maybe Int)
forall (m :: * -> *) a.
MonadSTM m =>
StrictTMVar m a -> STM m (Maybe a)
tryTakeTMVar StrictTMVar m Int
leakingPeriodVersionTMVar
      Maybe Int -> StrictTMVar m Int -> Bucket m -> Time -> STM m ()
maybeStartThread Maybe Int
mLeakingPeriodVersion StrictTMVar m Int
leakingPeriodVersionTMVar Bucket m
bucket Time
time

-- | Initialise a bucket given a configuration. The bucket starts full at the
-- time where one calls 'init'.
init ::
  (MonadMonotonicTime m, MonadSTM m, NoThunks (m ())) =>
  Config m ->
  m (Bucket m)
init :: forall (m :: * -> *).
(MonadMonotonicTime m, MonadSTM m, NoThunks (m ())) =>
Config m -> m (Bucket m)
init config :: Config m
config@Config {Rational
capacity :: forall (m :: * -> *). Config m -> Rational
capacity :: Rational
capacity} = do
  Time
time <- m Time
forall (m :: * -> *). MonadMonotonicTime m => m Time
getMonotonicTime
  State m -> m (Bucket m)
forall (m :: * -> *) a.
(HasCallStack, MonadSTM m, NoThunks a) =>
a -> m (StrictTVar m a)
newTVarIO (State m -> m (Bucket m)) -> State m -> m (Bucket m)
forall a b. (a -> b) -> a -> b
$
    State
      { Time
time :: Time
time :: Time
time,
        level :: Rational
level = Rational
capacity,
        paused :: Bool
paused = Bool
False,
        configGeneration :: Int
configGeneration = Int
0,
        config :: Config m
config = Config m
config
      }

-- $leakyBucketDesign
--
-- Note [Leaky bucket design]
-- ~~~~~~~~~~~~~~~~~~~~~~~~~~
--
-- The leaky bucket works by running the given action against a thread that
-- makes the bucket leak. Since it would be inefficient to actually
-- remove tokens one by one from the bucket, the 'leak' thread instead looks at
-- the current state of the bucket, computes how much time it would take for the
-- bucket to empty, and then wait that amount of time. Once the wait is over, it
-- recurses, looks at the new state of the bucket, etc. If tokens were given to
-- the bucket via the action, the bucket is not empty and the loop continues.
--
-- This description assumes that two things hold:
--
--  - the bucket must be leaking (ie. rate is strictly positive),
--  - the action can only increase the waiting time (eg. by giving tokens).
--
-- Neither of those properties hold in the general case. Indeed, it is possible
-- for the bucket to have a zero rate or even a negative one (for a more
-- traditional rate limiting bucket, for instance). Conversely, it is possible
-- for the action to lower the waiting time by changing the bucket configuration
-- to one where the rate is higher.
--
-- We fix both those issues with one mechanism, the @leakingPeriodVersionSTM@.
-- It is a computation returning an integer that identifies a version of the
-- configuration that controls the leaking period. If the computation blocks,
-- it means that no configuration has been determined yet.
-- The leak thread first waits until @leakingPeriodVersionSTM@ yields a
-- value, and only then proceeds as described above.
-- Additionally, while waiting for the bucket to empty, the thread monitors
-- for changes to the version of the leaking period, indicating either that the
-- thread should pause running if the @leakingPeriodVersionSTM@ starts blocking
-- again or that the configuration changed as that it might have to wait less
-- long.
--

-- | Neverending computation that runs 'onEmpty' whenever the bucket becomes
-- empty. See note [Leaky bucket design].
leak ::
  ( MonadDelay m,
    MonadCatch m,
    MonadFork m,
    MonadAsync m,
    MonadTimer m
  ) =>
  -- | A computation indicating the version of the configuration affecting the
  -- leaking period. Whenever the configuration changes, the returned integer
  -- must be incremented. While no configuration is available, the computation
  -- should block. Blocking is allowed at any time, and it will cause the
  -- leaking to pause.
  STM m Int ->
  -- | The 'ThreadId' of the action's thread, which is used to throw exceptions
  -- at it.
  ThreadId m ->
  Bucket m ->
  m ()
leak :: forall (m :: * -> *).
(MonadDelay m, MonadCatch m, MonadFork m, MonadAsync m,
 MonadTimer m) =>
STM m Int -> ThreadId m -> Bucket m -> m ()
leak STM m Int
leakingPeriodVersionSTM ThreadId m
actionThreadId Bucket m
bucket = m () -> m ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
      -- Block until we are allowed to run.
      Int
leakingPeriodVersion <- STM m Int -> m Int
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically STM m Int
leakingPeriodVersionSTM
      -- NOTE: It is tempting to group this @atomically@ and
      -- @atomicallyWithMonotonicTime@ into one; however, because the former is
      -- blocking, the latter could get a _very_ inaccurate time, which we
      -- cannot afford.
      State {Rational
level :: forall (m :: * -> *). State m -> Rational
level :: Rational
level, configGeneration :: forall (m :: * -> *). State m -> Int
configGeneration = Int
oldConfigGeneration, config :: forall (m :: * -> *). State m -> Config m
config = Config {Rational
rate :: forall (m :: * -> *). Config m -> Rational
rate :: Rational
rate, m ()
onEmpty :: forall (m :: * -> *). Config m -> m ()
onEmpty :: m ()
onEmpty}} <-
        (Time -> STM m (State m)) -> m (State m)
forall (m :: * -> *) b.
(MonadMonotonicTime m, MonadSTM m) =>
(Time -> STM m b) -> m b
atomicallyWithMonotonicTime ((Time -> STM m (State m)) -> m (State m))
-> (Time -> STM m (State m)) -> m (State m)
forall a b. (a -> b) -> a -> b
$ Bucket m -> Time -> STM m (State m)
forall (m :: * -> *).
MonadSTM m =>
Bucket m -> Time -> STM m (State m)
snapshot Bucket m
bucket
      let timeToWait :: DiffTime
timeToWait = Rational -> DiffTime
secondsRationalToDiffTime (Rational
level Rational -> Rational -> Rational
forall a. Fractional a => a -> a -> a
/ Rational
rate)
          timeToWaitMicroseconds :: Int
timeToWaitMicroseconds = DiffTime -> Int
diffTimeToMicrosecondsAsInt DiffTime
timeToWait
      -- NOTE: It is possible that @timeToWait <= 1µs@ while @level > 0@ when
      -- @level@ is extremely small.
      if Rational
level Rational -> Rational -> Bool
forall a. Ord a => a -> a -> Bool
<= Rational
0 Bool -> Bool -> Bool
|| Int
timeToWaitMicroseconds Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0
        then do
          (SomeException -> m ()) -> m () -> m ()
forall e a. Exception e => (e -> m a) -> m a -> m a
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
(e -> m a) -> m a -> m a
handle (\(SomeException
e :: SomeException) -> ThreadId m -> SomeException -> m ()
forall e. Exception e => ThreadId m -> e -> m ()
forall (m :: * -> *) e.
(MonadFork m, Exception e) =>
ThreadId m -> e -> m ()
throwTo ThreadId m
actionThreadId SomeException
e) m ()
onEmpty
          -- We have run the action on empty, there is nothing left to do,
          -- unless someone changes the configuration.
          m (State m, Int) -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m (State m, Int) -> m ()) -> m (State m, Int) -> m ()
forall a b. (a -> b) -> a -> b
$ STM m (State m, Int) -> m (State m, Int)
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m (State m, Int) -> m (State m, Int))
-> STM m (State m, Int) -> m (State m, Int)
forall a b. (a -> b) -> a -> b
$ (State m -> Int) -> Int -> STM m (State m) -> STM m (State m, Int)
forall (m :: * -> *) a b.
(MonadSTM m, Eq b) =>
(a -> b) -> b -> STM m a -> STM m (a, b)
blockUntilChanged State m -> Int
forall (m :: * -> *). State m -> Int
configGeneration Int
oldConfigGeneration (STM m (State m) -> STM m (State m, Int))
-> STM m (State m) -> STM m (State m, Int)
forall a b. (a -> b) -> a -> b
$ Bucket m -> STM m (State m)
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar Bucket m
bucket
        else
          -- Wait for the bucket to empty, or for the thread to be stopped or
          -- restarted. Beware not to call 'registerDelay' with argument 0, that
          -- is ensure that @timeToWaitMicroseconds > 0@.
          Bool -> m () -> m ()
forall a. HasCallStack => Bool -> a -> a
assert (Int
timeToWaitMicroseconds Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
            TVar m Bool
varTimeout <- Int -> m (TVar m Bool)
forall (m :: * -> *). MonadTimer m => Int -> m (TVar m Bool)
registerDelay Int
timeToWaitMicroseconds
            STM m () -> m ()
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m () -> m ()) -> STM m () -> m ()
forall a b. (a -> b) -> a -> b
$
              (Bool -> STM m ()
forall (m :: * -> *). MonadSTM m => Bool -> STM m ()
check (Bool -> STM m ()) -> STM m Bool -> STM m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TVar m Bool -> STM m Bool
forall a. TVar m a -> STM m a
forall (m :: * -> *) a. MonadSTM m => TVar m a -> STM m a
TVar.readTVar TVar m Bool
varTimeout)
                STM m () -> STM m () -> STM m ()
forall a. STM m a -> STM m a -> STM m a
forall (m :: * -> *) a. MonadSTM m => STM m a -> STM m a -> STM m a
`orElse`
              (STM m (Int, Int) -> STM m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (STM m (Int, Int) -> STM m ()) -> STM m (Int, Int) -> STM m ()
forall a b. (a -> b) -> a -> b
$ (Int -> Int) -> Int -> STM m Int -> STM m (Int, Int)
forall (m :: * -> *) a b.
(MonadSTM m, Eq b) =>
(a -> b) -> b -> STM m a -> STM m (a, b)
blockUntilChanged Int -> Int
forall a. a -> a
id Int
leakingPeriodVersion STM m Int
leakingPeriodVersionSTM)

-- | Take a snapshot of the bucket, that is compute its state at the current
-- time.
snapshot ::
  ( MonadSTM m
  ) =>
  Bucket m ->
  Time ->
  STM m (State m)
snapshot :: forall (m :: * -> *).
MonadSTM m =>
Bucket m -> Time -> STM m (State m)
snapshot Bucket m
bucket Time
newTime = (State m, FillResult) -> State m
forall a b. (a, b) -> a
fst ((State m, FillResult) -> State m)
-> STM m (State m, FillResult) -> STM m (State m)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Bucket m -> Rational -> Time -> STM m (State m, FillResult)
forall (m :: * -> *).
MonadSTM m =>
Bucket m -> Rational -> Time -> STM m (State m, FillResult)
snapshotFill Bucket m
bucket Rational
0 Time
newTime

-- | Same as 'snapshot' but also adds the given quantity to the resulting
-- level and returns whether this action overflew the bucket.
--
-- REVIEW: What to do when 'toAdd' is negative?
snapshotFill ::
  ( MonadSTM m
  ) =>
  Bucket m ->
  Rational ->
  Time ->
  STM m (State m, FillResult)
snapshotFill :: forall (m :: * -> *).
MonadSTM m =>
Bucket m -> Rational -> Time -> STM m (State m, FillResult)
snapshotFill Bucket m
bucket Rational
toAdd Time
newTime = do
  State {Rational
level :: forall (m :: * -> *). State m -> Rational
level :: Rational
level, Time
time :: forall (m :: * -> *). State m -> Time
time :: Time
time, Bool
paused :: forall (m :: * -> *). State m -> Bool
paused :: Bool
paused, Int
configGeneration :: forall (m :: * -> *). State m -> Int
configGeneration :: Int
configGeneration, config :: forall (m :: * -> *). State m -> Config m
config = Config m
config} <- Bucket m -> STM m (State m)
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar Bucket m
bucket
  let Config {Rational
rate :: forall (m :: * -> *). Config m -> Rational
rate :: Rational
rate, Rational
capacity :: forall (m :: * -> *). Config m -> Rational
capacity :: Rational
capacity, Bool
fillOnOverflow :: forall (m :: * -> *). Config m -> Bool
fillOnOverflow :: Bool
fillOnOverflow} = Config m
config
      elapsed :: DiffTime
elapsed = Time -> Time -> DiffTime
diffTime Time
newTime Time
time
      leaked :: Rational
leaked = if Bool
paused then Rational
0 else (DiffTime -> Rational
diffTimeToSecondsRational DiffTime
elapsed Rational -> Rational -> Rational
forall a. Num a => a -> a -> a
* Rational
rate)
      levelLeaked :: Rational
levelLeaked = (Rational, Rational) -> Rational -> Rational
forall a. Ord a => (a, a) -> a -> a
clamp (Rational
0, Rational
capacity) (Rational
level Rational -> Rational -> Rational
forall a. Num a => a -> a -> a
- Rational
leaked)
      levelFilled :: Rational
levelFilled = (Rational, Rational) -> Rational -> Rational
forall a. Ord a => (a, a) -> a -> a
clamp (Rational
0, Rational
capacity) (Rational
levelLeaked Rational -> Rational -> Rational
forall a. Num a => a -> a -> a
+ Rational
toAdd)
      overflew :: Bool
overflew = Rational
levelLeaked Rational -> Rational -> Rational
forall a. Num a => a -> a -> a
+ Rational
toAdd Rational -> Rational -> Bool
forall a. Ord a => a -> a -> Bool
> Rational
capacity
      newLevel :: Rational
newLevel = if Bool -> Bool
not Bool
overflew Bool -> Bool -> Bool
|| Bool
fillOnOverflow then Rational
levelFilled else Rational
levelLeaked
      !newState :: State m
newState = State {time :: Time
time = Time
newTime, level :: Rational
level = Rational
newLevel, Bool
paused :: Bool
paused :: Bool
paused, Int
configGeneration :: Int
configGeneration :: Int
configGeneration, Config m
config :: Config m
config :: Config m
config}
  Bucket m -> State m -> STM m ()
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
StrictTVar m a -> a -> STM m ()
writeTVar Bucket m
bucket State m
newState
  (State m, FillResult) -> STM m (State m, FillResult)
forall a. a -> STM m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (State m
newState, if Bool
overflew then FillResult
Overflew else FillResult
DidNotOverflow)

-- | Convert a 'DiffTime' to a 'Rational' number of seconds. This is similar to
-- 'diffTimeToSeconds' but with picoseconds precision.
diffTimeToSecondsRational :: DiffTime -> Rational
diffTimeToSecondsRational :: DiffTime -> Rational
diffTimeToSecondsRational = (Integer -> Integer -> Rational
forall a. Integral a => a -> a -> Ratio a
% Integer
picosecondsPerSecond) (Integer -> Rational)
-> (DiffTime -> Integer) -> DiffTime -> Rational
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DiffTime -> Integer
diffTimeToPicoseconds

-- | Alias of 'realToFrac' to make code more readable and typing more explicit.
secondsRationalToDiffTime :: Rational -> DiffTime
secondsRationalToDiffTime :: Rational -> DiffTime
secondsRationalToDiffTime = Rational -> DiffTime
forall a b. (Real a, Fractional b) => a -> b
realToFrac

-- | Helper around 'getMonotonicTime' and 'atomically'.
atomicallyWithMonotonicTime ::
  ( MonadMonotonicTime m,
    MonadSTM m
  ) =>
  (Time -> STM m b) ->
  m b
atomicallyWithMonotonicTime :: forall (m :: * -> *) b.
(MonadMonotonicTime m, MonadSTM m) =>
(Time -> STM m b) -> m b
atomicallyWithMonotonicTime Time -> STM m b
f =
  STM m b -> m b
forall a. HasCallStack => STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m b -> m b) -> (Time -> STM m b) -> Time -> m b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Time -> STM m b
f (Time -> m b) -> m Time -> m b
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m Time
forall (m :: * -> *). MonadMonotonicTime m => m Time
getMonotonicTime

-- NOTE: Needed for GHC 8
clamp :: Ord a => (a, a) -> a -> a
clamp :: forall a. Ord a => (a, a) -> a -> a
clamp (a
low, a
high) a
x = a -> a -> a
forall a. Ord a => a -> a -> a
min a
high (a -> a -> a
forall a. Ord a => a -> a -> a
max a
low a
x)

-- | Number of microseconds in a second (@10^6@).
microsecondsPerSecond :: Integer
microsecondsPerSecond :: Integer
microsecondsPerSecond = Integer
1_000_000

-- | Number of picoseconds in a second (@10^12@).
picosecondsPerSecond :: Integer
picosecondsPerSecond :: Integer
picosecondsPerSecond = Integer
1_000_000_000_000