module Network.TLS.Crypto.ECDH
(
ECDHParams(..)
, ECDHPublic
, ECDHPrivate(..)
, ecdhPublic
, ecdhPrivate
, ecdhParams
, ecdhGenerateKeyPair
, ecdhGetShared
, ecdhUnwrap
, ecdhUnwrapPublic
) where
import Network.TLS.Util.Serialization (lengthBytes)
import Network.TLS.Extension.EC
import qualified Crypto.PubKey.ECC.DH as ECDH
import qualified Crypto.PubKey.ECC.Types as ECDH
import qualified Crypto.PubKey.ECC.Prim as ECC (isPointValid)
import Network.TLS.RNG
import Data.Word (Word16)
data ECDHPublic = ECDHPublic ECDH.PublicPoint Int
deriving (Show,Eq)
newtype ECDHPrivate = ECDHPrivate ECDH.PrivateNumber deriving (Show,Eq)
data ECDHParams = ECDHParams ECDH.Curve ECDH.CurveName deriving (Show,Eq)
type ECDHKey = ECDH.SharedKey
ecdhPublic :: Integer -> Integer -> Int -> ECDHPublic
ecdhPublic x y siz = ECDHPublic (ECDH.Point x y) siz
ecdhPrivate :: Integer -> ECDHPrivate
ecdhPrivate = ECDHPrivate
ecdhParams :: Word16 -> ECDHParams
ecdhParams w16 = ECDHParams curve name
where
Just name = toCurveName w16
curve = ECDH.getCurveByName name
ecdhGenerateKeyPair :: MonadRandom r => ECDHParams -> r (ECDHPrivate, ECDHPublic)
ecdhGenerateKeyPair (ECDHParams curve _) = do
priv <- ECDH.generatePrivate curve
let siz = pointSize curve
point = ECDH.calculatePublic curve priv
pub = ECDHPublic point siz
return (ECDHPrivate priv, pub)
ecdhGetShared :: ECDHParams -> ECDHPrivate -> ECDHPublic -> Maybe ECDHKey
ecdhGetShared (ECDHParams curve _) (ECDHPrivate priv) (ECDHPublic point _)
| ECC.isPointValid curve point = Just $ ECDH.getShared curve priv point
| otherwise = Nothing
ecdhUnwrap :: ECDHParams -> ECDHPublic -> (Word16,Integer,Integer,Int)
ecdhUnwrap (ECDHParams _ name) point = (w16,x,y,siz)
where
w16 = case fromCurveName name of
Just w -> w
Nothing -> error "ecdhUnwrap"
(x,y,siz) = ecdhUnwrapPublic point
ecdhUnwrapPublic :: ECDHPublic -> (Integer,Integer,Int)
ecdhUnwrapPublic (ECDHPublic (ECDH.Point x y) siz) = (x,y,siz)
ecdhUnwrapPublic _ = error "ecdhUnwrapPublic"
pointSize :: ECDH.Curve -> Int
pointSize (ECDH.CurveFP curve) = lengthBytes $ ECDH.ecc_p curve
pointSize _ = error "pointSize"