-- CFB.hs: OpenPGP (RFC4880) CFB mode
-- Copyright © 2013-2016  Clint Adams
-- Copyright © 2013  Daniel Kahn Gillmor
-- This software is released under the terms of the Expat license.
-- (See the LICENSE file).

{-# LANGUAGE CPP #-}

module Codec.Encryption.OpenPGP.CFB (
   decrypt
 , decryptNoNonce
 , decryptOpenPGPCfb
 , encryptNoNonce
) where

import Codec.Encryption.OpenPGP.BlockCipher (withSymmetricCipher)
import Codec.Encryption.OpenPGP.Internal.HOBlockCipher
import Codec.Encryption.OpenPGP.Types
#if !MIN_VERSION_base(4,8,0)
import Control.Applicative ((<$>), (<*>))
#endif
import Control.Monad (liftM)
import qualified Data.ByteString as B

decryptOpenPGPCfb :: SymmetricAlgorithm -> B.ByteString -> B.ByteString -> Either String B.ByteString
decryptOpenPGPCfb Plaintext ciphertext _ = return ciphertext
decryptOpenPGPCfb sa ciphertext keydata = withSymmetricCipher sa keydata $ \bc -> do
    nonce <- decrypt1 ciphertext bc
    cleartext <- decrypt2 ciphertext bc
    if nonceCheck bc nonce then return cleartext else Left "Session key quickcheck failed"
    where
        decrypt1 :: HOBlockCipher cipher => B.ByteString -> cipher -> Either String B.ByteString
        decrypt1 ct cipher = paddedCfbDecrypt cipher (B.replicate (blockSize cipher) 0) (B.take (blockSize cipher + 2) ct)
        decrypt2 :: HOBlockCipher cipher => B.ByteString -> cipher -> Either String B.ByteString
        decrypt2 ct cipher = let i = B.take (blockSize cipher) (B.drop 2 ct) in paddedCfbDecrypt cipher i (B.drop (blockSize cipher + 2) ct)

decrypt :: SymmetricAlgorithm -> B.ByteString -> B.ByteString -> Either String B.ByteString
decrypt Plaintext ciphertext _ = return ciphertext
decrypt sa ciphertext keydata = withSymmetricCipher sa keydata $ \bc -> do
    (nonce, cleartext) <- liftM (B.splitAt (blockSize bc + 2)) (decrypt' ciphertext bc)
    if nonceCheck bc nonce then return cleartext else Left "Session key quickcheck failed"
    where
        decrypt' :: HOBlockCipher cipher => B.ByteString -> cipher -> Either String B.ByteString
        decrypt' ct cipher = paddedCfbDecrypt cipher (B.replicate (blockSize cipher) 0) ct

decryptNoNonce :: SymmetricAlgorithm -> IV -> B.ByteString -> B.ByteString -> Either String B.ByteString
decryptNoNonce Plaintext _ ciphertext _ = return ciphertext
decryptNoNonce sa iv ciphertext keydata = withSymmetricCipher sa keydata (decrypt' ciphertext)
    where
        decrypt' :: HOBlockCipher cipher => B.ByteString -> cipher -> Either String B.ByteString
        decrypt' ct cipher = paddedCfbDecrypt cipher (unIV iv) ct

nonceCheck :: HOBlockCipher cipher => cipher -> B.ByteString -> Bool
nonceCheck bc = (==) <$> B.take 2 . B.drop (blockSize bc - 2) <*> B.drop (blockSize bc)

encryptNoNonce :: SymmetricAlgorithm -> S2K -> IV -> B.ByteString -> B.ByteString -> Either String B.ByteString
encryptNoNonce Plaintext _ _ payload _ = return payload
encryptNoNonce sa s2k iv payload keydata = withSymmetricCipher sa keydata (encrypt' payload)
    where
        encrypt' :: HOBlockCipher cipher => B.ByteString -> cipher -> Either String B.ByteString
        encrypt' ct cipher = paddedCfbEncrypt cipher (unIV iv) ct