{-# LANGUAGE CPP #-}

-- | This module provides rate-limiting facilities built on top of the lazy bucket algorithm heavily inspired by
-- <http://ksdlck.com/post/17418037348/rate-limiting-at-webscale-lazy-leaky-buckets "Rate Limiting at Webscale: Lazy Leaky Buckets">
--
-- See also Wikipedia's <http://en.wikipedia.org/wiki/Token_bucket Token Bucket> article for general information about token bucket algorithms and their properties.
module Control.Concurrent.TokenBucket
    ( -- * The 'TokenBucket' type
      TokenBucket
    , newTokenBucket

      -- * Operations on 'TokenBucket'
      --
      -- | The following operations take two parameters, a burst-size and an average token rate.
      --
      -- === Average token rate
      --
      -- The average rate is expressed as inverse rate in terms of
      -- microseconds-per-token (i.e. one token every
      -- @n@ microseconds). This representation exposes the time
      -- granularity of the underlying implementation using integer
      -- arithmetic.
      --
      -- So in order to convert a token-rate @r@ expressed in
      -- tokens-per-second (i.e. @Hertz@) to microseconds-per-token the
      -- simple function below can be used:
      --
      -- @
      -- toInvRate :: Double -> Word64
      -- toInvRate r = round (1e6 / r)
      -- @
      --
      -- An inverse-rate @0@ denotes an infinite average rate, which
      -- will let token allocation always succeed (regardless of the
      -- burst-size parameter).
      --
      -- === Burst size
      --
      -- The burst-size parameter denotes the depth of the token
      -- bucket, and allows for temporarily exceeding the average
      -- token rate. The burst-size parameter should be at least as
      -- large as the maximum amount of tokens that need to be
      -- allocated at once, since an allocation-size smaller than the
      -- current burst-size will always fail unless an infinite token
      -- rate is used.

    , tokenBucketTryAlloc
    , tokenBucketTryAlloc1
    , tokenBucketWait
    ) where

import Control.Concurrent
import Control.Exception
import Control.Monad
import Data.IORef
#if !defined(USE_CBITS)
import Data.Time.Clock.POSIX (getPOSIXTime)
#endif
import Data.Word (Word64)

-- | Abstract type containing the token bucket state
newtype TokenBucket = TB (IORef TBData)

data TBData = TBData !Word64 !PosixTimeUsecs
              deriving Int -> TBData -> ShowS
[TBData] -> ShowS
TBData -> String
(Int -> TBData -> ShowS)
-> (TBData -> String) -> ([TBData] -> ShowS) -> Show TBData
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TBData] -> ShowS
$cshowList :: [TBData] -> ShowS
show :: TBData -> String
$cshow :: TBData -> String
showsPrec :: Int -> TBData -> ShowS
$cshowsPrec :: Int -> TBData -> ShowS
Show

type PosixTimeUsecs = Word64

-- getTBData :: TokenBucket -> IO TBData
-- getTBData (TB lbd) = readIORef lbd

#if defined(USE_CBITS)
foreign import ccall unsafe "hs_token_bucket_get_posix_time_usecs"
    getPosixTimeUsecs :: IO PosixTimeUsecs
#else
getPosixTimeUsecs :: IO PosixTimeUsecs
getPosixTimeUsecs = fmap (floor . (*1e6)) getPOSIXTime
#endif

-- | Create new 'TokenBucket' instance
newTokenBucket :: IO TokenBucket
newTokenBucket :: IO TokenBucket
newTokenBucket = do
    PosixTimeUsecs
now <- IO PosixTimeUsecs
getPosixTimeUsecs
    IORef TBData
lbd <- TBData -> IO (IORef TBData)
forall a. a -> IO (IORef a)
newIORef (TBData -> IO (IORef TBData)) -> TBData -> IO (IORef TBData)
forall a b. (a -> b) -> a -> b
$! PosixTimeUsecs -> PosixTimeUsecs -> TBData
TBData 0 PosixTimeUsecs
now
    TokenBucket -> IO TokenBucket
forall a. a -> IO a
evaluate (IORef TBData -> TokenBucket
TB IORef TBData
lbd)

-- | Attempt to allocate a given amount of tokens from the 'TokenBucket'
--
-- This operation either succeeds in allocating the requested amount
-- of tokens (and returns 'True'), or else, if allocation fails the
-- 'TokenBucket' remains in its previous allocation state.
tokenBucketTryAlloc :: TokenBucket
                    -> Word64  -- ^ burst-size (tokens)
                    -> Word64  -- ^ avg. inverse rate (usec/token)
                    -> Word64  -- ^ amount of tokens to allocate
                    -> IO Bool -- ^ 'True' if allocation succeeded
tokenBucketTryAlloc :: TokenBucket
-> PosixTimeUsecs -> PosixTimeUsecs -> PosixTimeUsecs -> IO Bool
tokenBucketTryAlloc _ _  0 _ = Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True -- infinitive rate, no-op
tokenBucketTryAlloc _ burst :: PosixTimeUsecs
burst _ alloc :: PosixTimeUsecs
alloc | PosixTimeUsecs
alloc PosixTimeUsecs -> PosixTimeUsecs -> Bool
forall a. Ord a => a -> a -> Bool
> PosixTimeUsecs
burst = Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
tokenBucketTryAlloc (TB lbref :: IORef TBData
lbref) burst :: PosixTimeUsecs
burst invRate :: PosixTimeUsecs
invRate alloc :: PosixTimeUsecs
alloc = do
    PosixTimeUsecs
now <- IO PosixTimeUsecs
getPosixTimeUsecs
    IORef TBData -> (TBData -> (TBData, Bool)) -> IO Bool
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef TBData
lbref (PosixTimeUsecs -> TBData -> (TBData, Bool)
go PosixTimeUsecs
now)
  where
    go :: PosixTimeUsecs -> TBData -> (TBData, Bool)
go now :: PosixTimeUsecs
now (TBData lvl :: PosixTimeUsecs
lvl ts :: PosixTimeUsecs
ts)
      | PosixTimeUsecs
lvl'' PosixTimeUsecs -> PosixTimeUsecs -> Bool
forall a. Ord a => a -> a -> Bool
> PosixTimeUsecs
burst = (PosixTimeUsecs -> PosixTimeUsecs -> TBData
TBData PosixTimeUsecs
lvl'  PosixTimeUsecs
ts', Bool
False)
      | Bool
otherwise     = (PosixTimeUsecs -> PosixTimeUsecs -> TBData
TBData PosixTimeUsecs
lvl'' PosixTimeUsecs
ts', Bool
True)
      where
        lvl' :: PosixTimeUsecs
lvl' = PosixTimeUsecs
lvl PosixTimeUsecs -> PosixTimeUsecs -> PosixTimeUsecs
 PosixTimeUsecs
dl
        (dl :: PosixTimeUsecs
dl,dtRem :: PosixTimeUsecs
dtRem) = PosixTimeUsecs
dt PosixTimeUsecs
-> PosixTimeUsecs -> (PosixTimeUsecs, PosixTimeUsecs)
forall a. Integral a => a -> a -> (a, a)
`quotRem` PosixTimeUsecs
invRate
        dt :: PosixTimeUsecs
dt   = PosixTimeUsecs
now PosixTimeUsecs -> PosixTimeUsecs -> PosixTimeUsecs
 PosixTimeUsecs
ts
        ts' :: PosixTimeUsecs
ts'  = PosixTimeUsecs
now PosixTimeUsecs -> PosixTimeUsecs -> PosixTimeUsecs
 PosixTimeUsecs
dtRem

        lvl'' :: PosixTimeUsecs
lvl'' = PosixTimeUsecs
lvl' PosixTimeUsecs -> PosixTimeUsecs -> PosixTimeUsecs
 PosixTimeUsecs
alloc

-- | Try to allocate a single token from the token bucket.
--
-- Returns 0 if successful (i.e. a token was successfully allocated from
-- the token bucket).
--
-- On failure, i.e. if token bucket budget was exhausted, the minimum
-- non-zero amount of microseconds to wait till allocation /may/
-- succeed is returned.
--
-- This function does not block. See 'tokenBucketWait' for wrapper
-- around this function which blocks until a token could be allocated.
tokenBucketTryAlloc1 :: TokenBucket
                     -> Word64     -- ^ burst-size (tokens)
                     -> Word64     -- ^ avg. inverse rate (usec/token)
                     -> IO Word64  -- ^ retry-time (usecs)
tokenBucketTryAlloc1 :: TokenBucket
-> PosixTimeUsecs -> PosixTimeUsecs -> IO PosixTimeUsecs
tokenBucketTryAlloc1 _ _ 0 = PosixTimeUsecs -> IO PosixTimeUsecs
forall (m :: * -> *) a. Monad m => a -> m a
return 0 -- infinite rate, no-op
tokenBucketTryAlloc1 (TB lbref :: IORef TBData
lbref) burst :: PosixTimeUsecs
burst invRate :: PosixTimeUsecs
invRate = do
    PosixTimeUsecs
now <- IO PosixTimeUsecs
getPosixTimeUsecs
    IORef TBData
-> (TBData -> (TBData, PosixTimeUsecs)) -> IO PosixTimeUsecs
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef TBData
lbref (PosixTimeUsecs -> TBData -> (TBData, PosixTimeUsecs)
go PosixTimeUsecs
now)
  where
    go :: PosixTimeUsecs -> TBData -> (TBData, PosixTimeUsecs)
go now :: PosixTimeUsecs
now (TBData lvl :: PosixTimeUsecs
lvl ts :: PosixTimeUsecs
ts)
      | PosixTimeUsecs
lvl'' PosixTimeUsecs -> PosixTimeUsecs -> Bool
forall a. Ord a => a -> a -> Bool
> PosixTimeUsecs
burst = (PosixTimeUsecs -> PosixTimeUsecs -> TBData
TBData PosixTimeUsecs
lvl'  PosixTimeUsecs
ts', PosixTimeUsecs
invRatePosixTimeUsecs -> PosixTimeUsecs -> PosixTimeUsecs
forall a. Num a => a -> a -> a
-PosixTimeUsecs
dtRem)
      | Bool
otherwise     = (PosixTimeUsecs -> PosixTimeUsecs -> TBData
TBData PosixTimeUsecs
lvl'' PosixTimeUsecs
ts', 0)
      where
        lvl' :: PosixTimeUsecs
lvl' = PosixTimeUsecs
lvl PosixTimeUsecs -> PosixTimeUsecs -> PosixTimeUsecs
 PosixTimeUsecs
dl
        (dl :: PosixTimeUsecs
dl,dtRem :: PosixTimeUsecs
dtRem) = PosixTimeUsecs
dt PosixTimeUsecs
-> PosixTimeUsecs -> (PosixTimeUsecs, PosixTimeUsecs)
forall a. Integral a => a -> a -> (a, a)
`quotRem` PosixTimeUsecs
invRate
        dt :: PosixTimeUsecs
dt   = PosixTimeUsecs
now PosixTimeUsecs -> PosixTimeUsecs -> PosixTimeUsecs
 PosixTimeUsecs
ts
        ts' :: PosixTimeUsecs
ts'  = PosixTimeUsecs
now PosixTimeUsecs -> PosixTimeUsecs -> PosixTimeUsecs
 PosixTimeUsecs
dtRem
        lvl'' :: PosixTimeUsecs
lvl'' = PosixTimeUsecs
lvl' PosixTimeUsecs -> PosixTimeUsecs -> PosixTimeUsecs
 1

-- | Blocking wrapper around 'tokenBucketTryAlloc1'. Uses 'threadDelay' when blocking.
--
-- This is effectively implemented as
--
-- @
-- 'tokenBucketWait' tb burst invRate = do
--   delay <- 'tokenBucketTryAlloc1' tb burst invRate
--   unless (delay == 0) $ do
--     threadDelay (fromIntegral delay)
--     'tokenBucketWait' tb burst invRate
-- @
tokenBucketWait :: TokenBucket
                -> Word64  -- ^ burst-size (tokens)
                -> Word64  -- ^ avg. inverse rate (usec/token)
                -> IO ()
tokenBucketWait :: TokenBucket -> PosixTimeUsecs -> PosixTimeUsecs -> IO ()
tokenBucketWait tb :: TokenBucket
tb burst :: PosixTimeUsecs
burst invRate :: PosixTimeUsecs
invRate = do
    PosixTimeUsecs
delay <- TokenBucket
-> PosixTimeUsecs -> PosixTimeUsecs -> IO PosixTimeUsecs
tokenBucketTryAlloc1 TokenBucket
tb PosixTimeUsecs
burst PosixTimeUsecs
invRate
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (PosixTimeUsecs
delay PosixTimeUsecs -> PosixTimeUsecs -> Bool
forall a. Eq a => a -> a -> Bool
== 0) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        Int -> IO ()
threadDelay (PosixTimeUsecs -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral PosixTimeUsecs
delay)
        TokenBucket -> PosixTimeUsecs -> PosixTimeUsecs -> IO ()
tokenBucketWait TokenBucket
tb PosixTimeUsecs
burst PosixTimeUsecs
invRate

-- saturated arithmetic helpers
(∸), (∔) :: Word64 -> Word64 -> Word64
x :: PosixTimeUsecs
x ∸ :: PosixTimeUsecs -> PosixTimeUsecs -> PosixTimeUsecs
 y :: PosixTimeUsecs
y = if PosixTimeUsecs
xPosixTimeUsecs -> PosixTimeUsecs -> Bool
forall a. Ord a => a -> a -> Bool
>PosixTimeUsecs
y then PosixTimeUsecs
xPosixTimeUsecs -> PosixTimeUsecs -> PosixTimeUsecs
forall a. Num a => a -> a -> a
-PosixTimeUsecs
y else 0
{-# INLINE () #-}
x :: PosixTimeUsecs
x ∔ :: PosixTimeUsecs -> PosixTimeUsecs -> PosixTimeUsecs
 y :: PosixTimeUsecs
y = let s :: PosixTimeUsecs
s=PosixTimeUsecs
xPosixTimeUsecs -> PosixTimeUsecs -> PosixTimeUsecs
forall a. Num a => a -> a -> a
+PosixTimeUsecs
y in if PosixTimeUsecs
x PosixTimeUsecs -> PosixTimeUsecs -> Bool
forall a. Ord a => a -> a -> Bool
<= PosixTimeUsecs
s then PosixTimeUsecs
s else PosixTimeUsecs
forall a. Bounded a => a
maxBound
{-# INLINE () #-}