-- Internal.hs: private utility functions and such
-- Copyright © 2012-2016  Clint Adams
-- This software is released under the terms of the Expat license.
-- (See the LICENSE file).

{-# LANGUAGE OverloadedStrings #-}

module Codec.Encryption.OpenPGP.Internal (
   countBits
 , PktStreamContext(..)
 , issuer
 , emptyPSC
 , pubkeyToMPIs
 , multiplicativeInverse
 , sigType
 , sigPKA
 , sigHA
 , sigCT
 , curveoidBSToCurve
 , curveToCurveoidBS
 , point2BS
) where

import Crypto.Number.Serialize (i2osp, os2ip)
import qualified Crypto.PubKey.DSA as DSA
import qualified Crypto.PubKey.ECC.ECDSA as ECDSA
import qualified Crypto.PubKey.ECC.Types as ECCT
import qualified Crypto.PubKey.RSA as RSA

import Data.Bits (testBit)
import qualified Data.ByteString as B
import Data.ByteString.Lazy (ByteString)
import qualified Data.ByteString.Lazy as BL
import Data.List (find)
import Data.Maybe (fromJust)
import Data.Word (Word8, Word16)

import Codec.Encryption.OpenPGP.Types
import Codec.Encryption.OpenPGP.Ontology (isIssuerSSP, isSigCreationTime)

countBits :: ByteString -> Word16
countBits bs
    | BL.null bs = 0
    | otherwise = fromIntegral (BL.length bs * 8) - fromIntegral (go (BL.head bs) 7)
    where
        go :: Word8 -> Int -> Word8
        go _ 0 = 7
        go n b = if testBit n b then 7 - fromIntegral b else go n (b-1)

data PktStreamContext = PktStreamContext { lastLD :: Pkt
                      , lastUIDorUAt :: Pkt
                      , lastSig :: Pkt
                      , lastPrimaryKey :: Pkt
                      , lastSubkey :: Pkt
                      }

emptyPSC :: PktStreamContext
emptyPSC = PktStreamContext (OtherPacketPkt 0 "lastLD placeholder") (OtherPacketPkt 0 "lastUIDorUAt placeholder") (OtherPacketPkt 0 "lastSig placeholder") (OtherPacketPkt 0 "lastPrimaryKey placeholder") (OtherPacketPkt 0 "lastSubkey placeholder")

issuer :: Pkt -> Maybe EightOctetKeyId
issuer (SignaturePkt (SigV4 _ _ _ _ usubs _ _)) = fmap (\(SigSubPacket _ (Issuer i)) -> i) (find isIssuerSSP usubs)
issuer _ = Nothing

pubkeyToMPIs :: PKey -> [MPI]
pubkeyToMPIs (RSAPubKey (RSA_PublicKey k)) = [MPI (RSA.public_n k), MPI (RSA.public_e k)]
pubkeyToMPIs (DSAPubKey (DSA_PublicKey k)) = [
                               pkParams DSA.params_p
                             , pkParams DSA.params_q
                             , pkParams DSA.params_g
                             , MPI . DSA.public_y $ k
                             ]
  where pkParams f = MPI . f . DSA.public_params $ k

pubkeyToMPIs (ElGamalPubKey p g y) = [MPI p, MPI g, MPI y]
pubkeyToMPIs (ECDHPubKey (ECDSA_PublicKey (ECDSA.PublicKey _ q)) _ _) = [MPI (os2ip (point2BS q))]
pubkeyToMPIs (ECDSAPubKey (ECDSA_PublicKey (ECDSA.PublicKey _ q))) = [MPI (os2ip (point2BS q))]

multiplicativeInverse :: Integral a => a -> a -> a
multiplicativeInverse _ 1 = 1
multiplicativeInverse q p = (n * q + 1) `div` p
    where n = p - multiplicativeInverse p (q `mod` p)

sigType :: SignaturePayload -> Maybe SigType
sigType (SigV3 st _ _ _ _ _ _) = Just st
sigType (SigV4 st _ _ _ _ _ _) = Just st
sigType _ = Nothing -- this includes v2 sigs, which don't seem to be specified in the RFCs but exist in the wild

sigPKA :: SignaturePayload -> Maybe PubKeyAlgorithm
sigPKA (SigV3 _ _ _ pka _ _ _) = Just pka
sigPKA (SigV4 _ pka _ _ _ _ _) = Just pka
sigPKA _ = Nothing -- this includes v2 sigs, which don't seem to be specified in the RFCs but exist in the wild

sigHA :: SignaturePayload -> Maybe HashAlgorithm
sigHA (SigV3 _ _ _ _ ha _ _) = Just ha
sigHA (SigV4 _ _ ha _ _ _ _) = Just ha
sigHA _ = Nothing -- this includes v2 sigs, which don't seem to be specified in the RFCs but exist in the wild

sigCT :: SignaturePayload -> Maybe ThirtyTwoBitTimeStamp
sigCT (SigV3 _ ct _ _ _ _ _) = Just ct
sigCT (SigV4 _ _ _ hsubs _ _ _) = fmap (\(SigSubPacket _ (SigCreationTime i)) -> i) (find isSigCreationTime hsubs)
sigCT _ = Nothing

curveoidBSToCurve :: B.ByteString -> Either String ECCT.Curve
curveoidBSToCurve oidbs
    | B.pack [0x2A,0x86,0x48,0xCE,0x3D,0x03,0x01,0x07] == oidbs = Right $ ECCT.getCurveByName ECCT.SEC_p256r1
    | B.pack [0x2B,0x81,0x04,0x00,0x22] == oidbs = Right $ ECCT.getCurveByName ECCT.SEC_p384r1
    | B.pack [0x2B,0x81,0x04,0x00,0x23] == oidbs = Right $ ECCT.getCurveByName ECCT.SEC_p521r1
    | otherwise = Left "unknown curve OID"

-- [0x2B 0x06 0x01 0x04 0x01 0xDA 0x47 0x0F 0x01] -- ed25519

curveToCurveoidBS :: ECCT.Curve -> Either String B.ByteString
curveToCurveoidBS curve
    | curve == ECCT.getCurveByName ECCT.SEC_p256r1 = Right $ B.pack [0x2A,0x86,0x48,0xCE,0x3D,0x03,0x01,0x07]
    | curve == ECCT.getCurveByName ECCT.SEC_p384r1 = Right $ B.pack [0x2B,0x81,0x04,0x00,0x22]
    | curve == ECCT.getCurveByName ECCT.SEC_p521r1 = Right $ B.pack [0x2B,0x81,0x04,0x00,0x23]
    | otherwise = Left "unknown curve"

point2BS :: ECCT.PublicPoint -> B.ByteString
point2BS (ECCT.Point x y) = B.concat [B.singleton 0x04, i2osp x, i2osp y] -- FIXME: check for length equality?
point2BS ECCT.PointO = error "FIXME: point at infinity"