-- |
-- Module      :  Cryptol.TypeCheck.Unify
-- Copyright   :  (c) 2013-2016 Galois, Inc.
-- License     :  BSD3
-- Maintainer  :  cryptol@galois.com
-- Stability   :  provisional
-- Portability :  portable

{-# LANGUAGE Safe #-}
{-# LANGUAGE PatternGuards, ViewPatterns #-}
{-# LANGUAGE DeriveFunctor #-}
module Cryptol.TypeCheck.Unify where

import Cryptol.TypeCheck.AST
import Cryptol.TypeCheck.Subst

import Control.Monad.Writer (Writer, writer, runWriter)
import Data.Ord(comparing)
import Data.List(sortBy)
import qualified Data.Set as Set

import Prelude ()
import Prelude.Compat

-- | The most general unifier is a substitution and a set of constraints
-- on bound variables.
type MGU = (Subst,[Prop])

type Result a = Writer [UnificationError] a

runResult :: Result a -> (a, [UnificationError])
runResult :: Result a -> (a, [UnificationError])
runResult = Result a -> (a, [UnificationError])
forall w a. Writer w a -> (a, w)
runWriter

data UnificationError
  = UniTypeMismatch Type Type
  | UniKindMismatch Kind Kind
  | UniTypeLenMismatch Int Int
  | UniRecursive TVar Type
  | UniNonPolyDepends TVar [TParam]
  | UniNonPoly TVar Type

uniError :: UnificationError -> Result MGU
uniError :: UnificationError -> Result MGU
uniError e :: UnificationError
e = (MGU, [UnificationError]) -> Result MGU
forall w (m :: * -> *) a. MonadWriter w m => (a, w) -> m a
writer (MGU
emptyMGU, [UnificationError
e])


emptyMGU :: MGU
emptyMGU :: MGU
emptyMGU = (Subst
emptySubst, [])

mgu :: Type -> Type -> Result MGU

mgu :: Type -> Type -> Result MGU
mgu (TUser c1 :: Name
c1 ts1 :: [Type]
ts1 _) (TUser c2 :: Name
c2 ts2 :: [Type]
ts2 _)
  | Name
c1 Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
c2 Bool -> Bool -> Bool
&& [Type]
ts1 [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
ts2  = MGU -> Result MGU
forall (m :: * -> *) a. Monad m => a -> m a
return MGU
emptyMGU

mgu (TVar x :: TVar
x) t :: Type
t        = TVar -> Type -> Result MGU
bindVar TVar
x Type
t
mgu t :: Type
t (TVar x :: TVar
x)        = TVar -> Type -> Result MGU
bindVar TVar
x Type
t

mgu (TUser _ _ t1 :: Type
t1) t2 :: Type
t2 = Type -> Type -> Result MGU
mgu Type
t1 Type
t2
mgu t1 :: Type
t1 (TUser _ _ t2 :: Type
t2) = Type -> Type -> Result MGU
mgu Type
t1 Type
t2

mgu (TCon (TC tc1 :: TC
tc1) ts1 :: [Type]
ts1) (TCon (TC tc2 :: TC
tc2) ts2 :: [Type]
ts2)
  | TC
tc1 TC -> TC -> Bool
forall a. Eq a => a -> a -> Bool
== TC
tc2 = [Type] -> [Type] -> Result MGU
mguMany [Type]
ts1 [Type]
ts2

mgu (TCon (TF f1 :: TFun
f1) ts1 :: [Type]
ts1) (TCon (TF f2 :: TFun
f2) ts2 :: [Type]
ts2)
  | TFun
f1 TFun -> TFun -> Bool
forall a. Eq a => a -> a -> Bool
== TFun
f2 Bool -> Bool -> Bool
&& [Type]
ts1 [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
ts2  = MGU -> Result MGU
forall (m :: * -> *) a. Monad m => a -> m a
return MGU
emptyMGU

mgu t1 :: Type
t1 t2 :: Type
t2
  | TCon (TF _) _ <- Type
t1, Bool
isNum, Kind
k1 Kind -> Kind -> Bool
forall a. Eq a => a -> a -> Bool
== Kind
k2 = MGU -> Result MGU
forall (m :: * -> *) a. Monad m => a -> m a
return (Subst
emptySubst, [Type
t1 Type -> Type -> Type
=#= Type
t2])
  | TCon (TF _) _ <- Type
t2, Bool
isNum, Kind
k1 Kind -> Kind -> Bool
forall a. Eq a => a -> a -> Bool
== Kind
k2 = MGU -> Result MGU
forall (m :: * -> *) a. Monad m => a -> m a
return (Subst
emptySubst, [Type
t1 Type -> Type -> Type
=#= Type
t2])
  where
  k1 :: Kind
k1 = Type -> Kind
forall t. HasKind t => t -> Kind
kindOf Type
t1
  k2 :: Kind
k2 = Type -> Kind
forall t. HasKind t => t -> Kind
kindOf Type
t2

  isNum :: Bool
isNum = Kind
k1 Kind -> Kind -> Bool
forall a. Eq a => a -> a -> Bool
== Kind
KNum

mgu (TRec fs1 :: [(Ident, Type)]
fs1) (TRec fs2 :: [(Ident, Type)]
fs2)
  | [Ident]
ns1 [Ident] -> [Ident] -> Bool
forall a. Eq a => a -> a -> Bool
== [Ident]
ns2 = [Type] -> [Type] -> Result MGU
mguMany [Type]
ts1 [Type]
ts2
  where
  (ns1 :: [Ident]
ns1,ts1 :: [Type]
ts1)  = [(Ident, Type)] -> ([Ident], [Type])
forall b. [(Ident, b)] -> ([Ident], [b])
sortFields [(Ident, Type)]
fs1
  (ns2 :: [Ident]
ns2,ts2 :: [Type]
ts2)  = [(Ident, Type)] -> ([Ident], [Type])
forall b. [(Ident, b)] -> ([Ident], [b])
sortFields [(Ident, Type)]
fs2
  sortFields :: [(Ident, b)] -> ([Ident], [b])
sortFields = [(Ident, b)] -> ([Ident], [b])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Ident, b)] -> ([Ident], [b]))
-> ([(Ident, b)] -> [(Ident, b)]) -> [(Ident, b)] -> ([Ident], [b])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Ident, b) -> (Ident, b) -> Ordering)
-> [(Ident, b)] -> [(Ident, b)]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (((Ident, b) -> Ident) -> (Ident, b) -> (Ident, b) -> Ordering
forall a b. Ord a => (b -> a) -> b -> b -> Ordering
comparing (Ident, b) -> Ident
forall a b. (a, b) -> a
fst)

mgu t1 :: Type
t1 t2 :: Type
t2
  | Bool -> Bool
not (Kind
k1 Kind -> Kind -> Bool
forall a. Eq a => a -> a -> Bool
== Kind
k2)  = UnificationError -> Result MGU
uniError (UnificationError -> Result MGU) -> UnificationError -> Result MGU
forall a b. (a -> b) -> a -> b
$ Kind -> Kind -> UnificationError
UniKindMismatch Kind
k1 Kind
k2
  | Bool
otherwise       = UnificationError -> Result MGU
uniError (UnificationError -> Result MGU) -> UnificationError -> Result MGU
forall a b. (a -> b) -> a -> b
$ Type -> Type -> UnificationError
UniTypeMismatch Type
t1 Type
t2
  where
  k1 :: Kind
k1 = Type -> Kind
forall t. HasKind t => t -> Kind
kindOf Type
t1
  k2 :: Kind
k2 = Type -> Kind
forall t. HasKind t => t -> Kind
kindOf Type
t2


mguMany :: [Type] -> [Type] -> Result MGU
mguMany :: [Type] -> [Type] -> Result MGU
mguMany [] [] = MGU -> Result MGU
forall (m :: * -> *) a. Monad m => a -> m a
return MGU
emptyMGU
mguMany (t1 :: Type
t1 : ts1 :: [Type]
ts1) (t2 :: Type
t2 : ts2 :: [Type]
ts2) =
  do (su1 :: Subst
su1,ps1 :: [Type]
ps1) <- Type -> Type -> Result MGU
mgu Type
t1 Type
t2
     (su2 :: Subst
su2,ps2 :: [Type]
ps2) <- [Type] -> [Type] -> Result MGU
mguMany (Subst -> [Type] -> [Type]
forall t. TVars t => Subst -> t -> t
apSubst Subst
su1 [Type]
ts1) (Subst -> [Type] -> [Type]
forall t. TVars t => Subst -> t -> t
apSubst Subst
su1 [Type]
ts2)
     MGU -> Result MGU
forall (m :: * -> *) a. Monad m => a -> m a
return (Subst
su2 Subst -> Subst -> Subst
@@ Subst
su1, [Type]
ps1 [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
ps2)
mguMany t1 :: [Type]
t1 t2 :: [Type]
t2 = UnificationError -> Result MGU
uniError (UnificationError -> Result MGU) -> UnificationError -> Result MGU
forall a b. (a -> b) -> a -> b
$ Int -> Int -> UnificationError
UniTypeLenMismatch ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
t1) ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
t2)


bindVar :: TVar -> Type -> Result MGU

bindVar :: TVar -> Type -> Result MGU
bindVar x :: TVar
x (Type -> Type
tNoUser -> TVar y :: TVar
y)
  | TVar
x TVar -> TVar -> Bool
forall a. Eq a => a -> a -> Bool
== TVar
y                      = MGU -> Result MGU
forall (m :: * -> *) a. Monad m => a -> m a
return MGU
emptyMGU

bindVar v :: TVar
v@(TVBound {}) (Type -> Type
tNoUser -> TVar v1 :: TVar
v1@(TVFree {})) = TVar -> Type -> Result MGU
bindVar TVar
v1 (TVar -> Type
TVar TVar
v)

bindVar v :: TVar
v@(TVBound {}) t :: Type
t
  | Kind
k Kind -> Kind -> Bool
forall a. Eq a => a -> a -> Bool
== Type -> Kind
forall t. HasKind t => t -> Kind
kindOf Type
t = if Kind
k Kind -> Kind -> Bool
forall a. Eq a => a -> a -> Bool
== Kind
KNum
                       then MGU -> Result MGU
forall (m :: * -> *) a. Monad m => a -> m a
return (Subst
emptySubst, [TVar -> Type
TVar TVar
v Type -> Type -> Type
=#= Type
t])
                       else UnificationError -> Result MGU
uniError (UnificationError -> Result MGU) -> UnificationError -> Result MGU
forall a b. (a -> b) -> a -> b
$ TVar -> Type -> UnificationError
UniNonPoly TVar
v Type
t
  | Bool
otherwise     = UnificationError -> Result MGU
uniError (UnificationError -> Result MGU) -> UnificationError -> Result MGU
forall a b. (a -> b) -> a -> b
$ Kind -> Kind -> UnificationError
UniKindMismatch Kind
k (Type -> Kind
forall t. HasKind t => t -> Kind
kindOf Type
t)
  where k :: Kind
k = TVar -> Kind
forall t. HasKind t => t -> Kind
kindOf TVar
v

bindVar x :: TVar
x@(TVFree _ _ xscope :: Set TParam
xscope _) (TVar y :: TVar
y@(TVFree _ _ yscope :: Set TParam
yscope _))
  | Set TParam
xscope Set TParam -> Set TParam -> Bool
forall a. Ord a => Set a -> Set a -> Bool
`Set.isProperSubsetOf` Set TParam
yscope = MGU -> Result MGU
forall (m :: * -> *) a. Monad m => a -> m a
return (TVar -> Type -> Subst
singleSubst TVar
y (TVar -> Type
TVar TVar
x), [])

bindVar x :: TVar
x@(TVFree _ k :: Kind
k inScope :: Set TParam
inScope _d :: TVarInfo
_d) t :: Type
t
  | Bool -> Bool
not (Kind
k Kind -> Kind -> Bool
forall a. Eq a => a -> a -> Bool
== Type -> Kind
forall t. HasKind t => t -> Kind
kindOf Type
t)     = UnificationError -> Result MGU
uniError (UnificationError -> Result MGU) -> UnificationError -> Result MGU
forall a b. (a -> b) -> a -> b
$ Kind -> Kind -> UnificationError
UniKindMismatch Kind
k (Type -> Kind
forall t. HasKind t => t -> Kind
kindOf Type
t)
  | Bool
recTy Bool -> Bool -> Bool
&& Kind
k Kind -> Kind -> Bool
forall a. Eq a => a -> a -> Bool
== Kind
KType     = UnificationError -> Result MGU
uniError (UnificationError -> Result MGU) -> UnificationError -> Result MGU
forall a b. (a -> b) -> a -> b
$ TVar -> Type -> UnificationError
UniRecursive TVar
x Type
t
  | Bool -> Bool
not (Set TParam -> Bool
forall a. Set a -> Bool
Set.null Set TParam
escaped)  = UnificationError -> Result MGU
uniError (UnificationError -> Result MGU) -> UnificationError -> Result MGU
forall a b. (a -> b) -> a -> b
$ TVar -> [TParam] -> UnificationError
UniNonPolyDepends TVar
x ([TParam] -> UnificationError) -> [TParam] -> UnificationError
forall a b. (a -> b) -> a -> b
$ Set TParam -> [TParam]
forall a. Set a -> [a]
Set.toList Set TParam
escaped
  | Bool
recTy                   = MGU -> Result MGU
forall (m :: * -> *) a. Monad m => a -> m a
return (Subst
emptySubst, [TVar -> Type
TVar TVar
x Type -> Type -> Type
=#= Type
t])
  | Bool
otherwise               = MGU -> Result MGU
forall (m :: * -> *) a. Monad m => a -> m a
return (TVar -> Type -> Subst
singleSubst TVar
x Type
t, [])
    where
    escaped :: Set TParam
escaped = Type -> Set TParam
forall t. FVS t => t -> Set TParam
freeParams Type
t Set TParam -> Set TParam -> Set TParam
forall a. Ord a => Set a -> Set a -> Set a
`Set.difference` Set TParam
inScope
    recTy :: Bool
recTy   = TVar
x TVar -> Set TVar -> Bool
forall a. Ord a => a -> Set a -> Bool
`Set.member` Type -> Set TVar
forall t. FVS t => t -> Set TVar
fvs Type
t


freeParams :: FVS t => t -> Set.Set TParam
freeParams :: t -> Set TParam
freeParams x :: t
x = [Set TParam] -> Set TParam
forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
Set.unions ((TVar -> Set TParam) -> [TVar] -> [Set TParam]
forall a b. (a -> b) -> [a] -> [b]
map TVar -> Set TParam
params (Set TVar -> [TVar]
forall a. Set a -> [a]
Set.toList (t -> Set TVar
forall t. FVS t => t -> Set TVar
fvs t
x)))
  where
    params :: TVar -> Set TParam
params (TVFree _ _ tps :: Set TParam
tps _) = Set TParam
tps
    params (TVBound tp :: TParam
tp) = TParam -> Set TParam
forall a. a -> Set a
Set.singleton TParam
tp