{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE CPP #-}
module Data.Conduit.Serialization.Binary
  ( conduitDecode
  , conduitEncode
  , conduitMsgEncode
  , conduitGet
  , conduitPut
  , conduitPutList
  , conduitPutLBS
  , conduitPutMany
  , sourcePut
  , sinkGet
  , ParseError(..)
  )
  where

import           Control.Exception
import           Control.Monad (unless)
import           Data.Binary
import           Data.Binary.Get
import           Data.Binary.Put
import           Data.ByteString      as BS
import qualified Data.ByteString.Lazy as LBS

import           Data.Conduit
import qualified Data.Conduit.List    as CL
import           Data.Foldable
import           Data.Typeable
import qualified Data.Vector          as V
import           Control.Monad.Catch (MonadThrow(..))


data ParseError = ParseError
      { ParseError -> ByteString
unconsumed :: ByteString
        -- ^ Data left unconsumed in single stream input value.

      , ParseError -> ByteOffset
offset     :: ByteOffset
        -- ^ Number of bytes consumed from single stream input value.

      , ParseError -> String
content    :: String      -- ^ Error content.
      } deriving (Int -> ParseError -> ShowS
[ParseError] -> ShowS
ParseError -> String
(Int -> ParseError -> ShowS)
-> (ParseError -> String)
-> ([ParseError] -> ShowS)
-> Show ParseError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ParseError] -> ShowS
$cshowList :: [ParseError] -> ShowS
show :: ParseError -> String
$cshow :: ParseError -> String
showsPrec :: Int -> ParseError -> ShowS
$cshowsPrec :: Int -> ParseError -> ShowS
Show, Typeable)

instance Exception ParseError

-- | Runs default 'Decoder' repeatedly on a input stream.
conduitDecode :: (Binary b, MonadThrow m) => ConduitT ByteString b m ()
conduitDecode :: ConduitT ByteString b m ()
conduitDecode = Get b -> ConduitT ByteString b m ()
forall (m :: * -> *) b.
MonadThrow m =>
Get b -> ConduitT ByteString b m ()
conduitGet Get b
forall t. Binary t => Get t
get

-- | Runs default encoder on a input stream.
--
-- This function produces a stream of bytes where for each input
-- value you will have a number of 'ByteString's, and no boundary
-- between different values.
conduitEncode :: (Binary b, MonadThrow m) => ConduitT b ByteString m ()
conduitEncode :: ConduitT b ByteString m ()
conduitEncode = (b -> Put) -> ConduitT b Put m ()
forall (m :: * -> *) a b. Monad m => (a -> b) -> ConduitT a b m ()
CL.map b -> Put
forall t. Binary t => t -> Put
put ConduitT b Put m ()
-> ConduitM Put ByteString m () -> ConduitT b ByteString m ()
forall (m :: * -> *) a b c r.
Monad m =>
ConduitM a b m () -> ConduitM b c m r -> ConduitM a c m r
.| ConduitM Put ByteString m ()
forall (m :: * -> *). Monad m => ConduitT Put ByteString m ()
conduitPut


-- | Runs default encoder on input stream.
--
-- This function produces a ByteString per each incomming packet,
-- it may be useful in datagram based protocols.
-- Function maintains following property
--
-- >   'conduitMsgEncode' xs == 'CL.map' 'Data.ByteString.encode' =$= 'CL.map' 'LBS.toStrict'
--
-- This invariant is maintaind by the cost of additional data copy,
-- so if you packets can be serialized to the large data chunks or
-- you interested in iterative packet serialization
-- concider using 'conduitPutList' or 'conduitPutMany'
--
conduitMsgEncode :: Monad m => (Binary b) => ConduitT b ByteString m ()
conduitMsgEncode :: ConduitT b ByteString m ()
conduitMsgEncode = (b -> Put) -> ConduitT b Put m ()
forall (m :: * -> *) a b. Monad m => (a -> b) -> ConduitT a b m ()
CL.map b -> Put
forall t. Binary t => t -> Put
put ConduitT b Put m ()
-> ConduitM Put ByteString m () -> ConduitT b ByteString m ()
forall (m :: * -> *) a b c r.
Monad m =>
ConduitM a b m () -> ConduitM b c m r -> ConduitM a c m r
.| ConduitM Put ByteString m ()
forall (m :: * -> *). Monad m => ConduitT Put ByteString m ()
conduitMsg

-- | Runs getter repeatedly on a input stream.
conduitGet :: MonadThrow m => Get b -> ConduitT ByteString b m ()
conduitGet :: Get b -> ConduitT ByteString b m ()
conduitGet g :: Get b
g = ConduitT ByteString b m ()
start
  where
    start :: ConduitT ByteString b m ()
start = do Maybe ByteString
mx <- ConduitT ByteString b m (Maybe ByteString)
forall (m :: * -> *) i. Monad m => Consumer i m (Maybe i)
await
               case Maybe ByteString
mx of
                  Nothing -> () -> ConduitT ByteString b m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
                  Just x :: ByteString
x -> Decoder b -> ConduitT ByteString b m ()
go (Get b -> Decoder b
forall a. Get a -> Decoder a
runGetIncremental Get b
g Decoder b -> ByteString -> Decoder b
forall a. Decoder a -> ByteString -> Decoder a
`pushChunk` ByteString
x)
    go :: Decoder b -> ConduitT ByteString b m ()
go (Done bs :: ByteString
bs _ v :: b
v) = do b -> ConduitT ByteString b m ()
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield b
v
                          if ByteString -> Bool
BS.null ByteString
bs
                            then ConduitT ByteString b m ()
start
                            else Decoder b -> ConduitT ByteString b m ()
go (Get b -> Decoder b
forall a. Get a -> Decoder a
runGetIncremental Get b
g Decoder b -> ByteString -> Decoder b
forall a. Decoder a -> ByteString -> Decoder a
`pushChunk` ByteString
bs)
    go (Fail u :: ByteString
u o :: ByteOffset
o e :: String
e)  = ParseError -> ConduitT ByteString b m ()
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (ByteString -> ByteOffset -> String -> ParseError
ParseError ByteString
u ByteOffset
o String
e)
    go (Partial n :: Maybe ByteString -> Decoder b
n)   = ConduitT ByteString b m (Maybe ByteString)
forall (m :: * -> *) i. Monad m => Consumer i m (Maybe i)
await ConduitT ByteString b m (Maybe ByteString)
-> (Maybe ByteString -> ConduitT ByteString b m ())
-> ConduitT ByteString b m ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Decoder b -> ConduitT ByteString b m ()
go (Decoder b -> ConduitT ByteString b m ())
-> (Maybe ByteString -> Decoder b)
-> Maybe ByteString
-> ConduitT ByteString b m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe ByteString -> Decoder b
n)

-- \o/
#define conduitPutGeneric(name,yi) \
name = conduit \
  where \
    conduit = do {mx <- await;\
                 case mx of;\
                    Nothing -> return ();\
                    Just x  -> do { yi ; conduit}}

-- | Runs putter repeatedly on a input stream, returns an output stream.
conduitPut :: Monad m => ConduitT Put ByteString m ()
conduitPutGeneric(conduitPut, (traverse_ yield (LBS.toChunks $ runPut x)))

-- | Runs a putter repeatedly on a input stream, returns a packets.
conduitMsg :: Monad m => ConduitT Put ByteString m ()
conduitPutGeneric(conduitMsg, (yield (LBS.toStrict $ runPut x)))

-- | Runs putter repeatedly on a input stream.
-- Returns a lazy butestring so it's possible to use vectorized
-- IO on the result either by calling' LBS.toChunks' or by
-- calling 'Network.Socket.ByteString.Lazy.send'.
conduitPutLBS :: Monad m => ConduitT Put LBS.ByteString m ()
conduitPutGeneric(conduitPutLBS, yield (runPut x))

-- | Vectorized variant of 'conduitPut' returning list contains
-- all chunks from one element representation
conduitPutList :: Monad m => ConduitT Put [ByteString] m ()
conduitPutGeneric(conduitPutList, yield (LBS.toChunks (runPut x)))

-- | Vectorized variant of 'conduitPut'.
conduitPutMany :: Monad m => ConduitT Put (V.Vector ByteString) m ()
conduitPutGeneric(conduitPutMany, yield (V.fromList (LBS.toChunks (runPut x))))

-- | Create stream of strict bytestrings from 'Put' value.
sourcePut :: Monad m => Put -> ConduitT z ByteString m ()
sourcePut :: Put -> ConduitT z ByteString m ()
sourcePut = [ByteString] -> ConduitT z ByteString m ()
forall (m :: * -> *) a i. Monad m => [a] -> ConduitT i a m ()
CL.sourceList ([ByteString] -> ConduitT z ByteString m ())
-> (Put -> [ByteString]) -> Put -> ConduitT z ByteString m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [ByteString]
LBS.toChunks (ByteString -> [ByteString])
-> (Put -> ByteString) -> Put -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Put -> ByteString
runPut

-- | Decode message from input stream.
sinkGet :: MonadThrow m => Get b -> ConduitT ByteString z m b
sinkGet :: Get b -> ConduitT ByteString z m b
sinkGet f :: Get b
f = Decoder b -> ConduitT ByteString z m b
forall (m :: * -> *) b o.
MonadThrow m =>
Decoder b -> ConduitT ByteString o m b
sink (Get b -> Decoder b
forall a. Get a -> Decoder a
runGetIncremental Get b
f)
  where
      sink :: Decoder b -> ConduitT ByteString o m b
sink (Done bs :: ByteString
bs _ v :: b
v)  = do
        Bool -> ConduitT ByteString o m () -> ConduitT ByteString o m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ByteString -> Bool
BS.null ByteString
bs) (ConduitT ByteString o m () -> ConduitT ByteString o m ())
-> ConduitT ByteString o m () -> ConduitT ByteString o m ()
forall a b. (a -> b) -> a -> b
$
          ByteString -> ConduitT ByteString o m ()
forall i o (m :: * -> *). i -> ConduitT i o m ()
leftover ByteString
bs
        b -> ConduitT ByteString o m b
forall (m :: * -> *) a. Monad m => a -> m a
return b
v
      sink (Fail u :: ByteString
u o :: ByteOffset
o e :: String
e)   = ParseError -> ConduitT ByteString o m b
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (ByteString -> ByteOffset -> String -> ParseError
ParseError ByteString
u ByteOffset
o String
e)
      sink (Partial next :: Maybe ByteString -> Decoder b
next) = ConduitT ByteString o m (Maybe ByteString)
forall (m :: * -> *) i. Monad m => Consumer i m (Maybe i)
await ConduitT ByteString o m (Maybe ByteString)
-> (Maybe ByteString -> ConduitT ByteString o m b)
-> ConduitT ByteString o m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Decoder b -> ConduitT ByteString o m b
sink (Decoder b -> ConduitT ByteString o m b)
-> (Maybe ByteString -> Decoder b)
-> Maybe ByteString
-> ConduitT ByteString o m b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe ByteString -> Decoder b
next