{-# LANGUAGE Safe #-}
{-# LANGUAGE PatternGuards, ViewPatterns #-}
{-# LANGUAGE DeriveFunctor #-}
module Cryptol.TypeCheck.Unify where
import Cryptol.TypeCheck.AST
import Cryptol.TypeCheck.Subst
import Control.Monad.Writer (Writer, writer, runWriter)
import Data.Ord(comparing)
import Data.List(sortBy)
import qualified Data.Set as Set
import Prelude ()
import Prelude.Compat
type MGU = (Subst,[Prop])
type Result a = Writer [UnificationError] a
runResult :: Result a -> (a, [UnificationError])
runResult :: Result a -> (a, [UnificationError])
runResult = Result a -> (a, [UnificationError])
forall w a. Writer w a -> (a, w)
runWriter
data UnificationError
= UniTypeMismatch Type Type
| UniKindMismatch Kind Kind
| UniTypeLenMismatch Int Int
| UniRecursive TVar Type
| UniNonPolyDepends TVar [TParam]
| UniNonPoly TVar Type
uniError :: UnificationError -> Result MGU
uniError :: UnificationError -> Result MGU
uniError e :: UnificationError
e = (MGU, [UnificationError]) -> Result MGU
forall w (m :: * -> *) a. MonadWriter w m => (a, w) -> m a
writer (MGU
emptyMGU, [UnificationError
e])
emptyMGU :: MGU
emptyMGU :: MGU
emptyMGU = (Subst
emptySubst, [])
mgu :: Type -> Type -> Result MGU
mgu :: Type -> Type -> Result MGU
mgu (TUser c1 :: Name
c1 ts1 :: [Type]
ts1 _) (TUser c2 :: Name
c2 ts2 :: [Type]
ts2 _)
| Name
c1 Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
c2 Bool -> Bool -> Bool
&& [Type]
ts1 [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
ts2 = MGU -> Result MGU
forall (m :: * -> *) a. Monad m => a -> m a
return MGU
emptyMGU
mgu (TVar x :: TVar
x) t :: Type
t = TVar -> Type -> Result MGU
bindVar TVar
x Type
t
mgu t :: Type
t (TVar x :: TVar
x) = TVar -> Type -> Result MGU
bindVar TVar
x Type
t
mgu (TUser _ _ t1 :: Type
t1) t2 :: Type
t2 = Type -> Type -> Result MGU
mgu Type
t1 Type
t2
mgu t1 :: Type
t1 (TUser _ _ t2 :: Type
t2) = Type -> Type -> Result MGU
mgu Type
t1 Type
t2
mgu (TCon (TC tc1 :: TC
tc1) ts1 :: [Type]
ts1) (TCon (TC tc2 :: TC
tc2) ts2 :: [Type]
ts2)
| TC
tc1 TC -> TC -> Bool
forall a. Eq a => a -> a -> Bool
== TC
tc2 = [Type] -> [Type] -> Result MGU
mguMany [Type]
ts1 [Type]
ts2
mgu (TCon (TF f1 :: TFun
f1) ts1 :: [Type]
ts1) (TCon (TF f2 :: TFun
f2) ts2 :: [Type]
ts2)
| TFun
f1 TFun -> TFun -> Bool
forall a. Eq a => a -> a -> Bool
== TFun
f2 Bool -> Bool -> Bool
&& [Type]
ts1 [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
ts2 = MGU -> Result MGU
forall (m :: * -> *) a. Monad m => a -> m a
return MGU
emptyMGU
mgu t1 :: Type
t1 t2 :: Type
t2
| TCon (TF _) _ <- Type
t1, Bool
isNum, Kind
k1 Kind -> Kind -> Bool
forall a. Eq a => a -> a -> Bool
== Kind
k2 = MGU -> Result MGU
forall (m :: * -> *) a. Monad m => a -> m a
return (Subst
emptySubst, [Type
t1 Type -> Type -> Type
=#= Type
t2])
| TCon (TF _) _ <- Type
t2, Bool
isNum, Kind
k1 Kind -> Kind -> Bool
forall a. Eq a => a -> a -> Bool
== Kind
k2 = MGU -> Result MGU
forall (m :: * -> *) a. Monad m => a -> m a
return (Subst
emptySubst, [Type
t1 Type -> Type -> Type
=#= Type
t2])
where
k1 :: Kind
k1 = Type -> Kind
forall t. HasKind t => t -> Kind
kindOf Type
t1
k2 :: Kind
k2 = Type -> Kind
forall t. HasKind t => t -> Kind
kindOf Type
t2
isNum :: Bool
isNum = Kind
k1 Kind -> Kind -> Bool
forall a. Eq a => a -> a -> Bool
== Kind
KNum
mgu (TRec fs1 :: [(Ident, Type)]
fs1) (TRec fs2 :: [(Ident, Type)]
fs2)
| [Ident]
ns1 [Ident] -> [Ident] -> Bool
forall a. Eq a => a -> a -> Bool
== [Ident]
ns2 = [Type] -> [Type] -> Result MGU
mguMany [Type]
ts1 [Type]
ts2
where
(ns1 :: [Ident]
ns1,ts1 :: [Type]
ts1) = [(Ident, Type)] -> ([Ident], [Type])
forall b. [(Ident, b)] -> ([Ident], [b])
sortFields [(Ident, Type)]
fs1
(ns2 :: [Ident]
ns2,ts2 :: [Type]
ts2) = [(Ident, Type)] -> ([Ident], [Type])
forall b. [(Ident, b)] -> ([Ident], [b])
sortFields [(Ident, Type)]
fs2
sortFields :: [(Ident, b)] -> ([Ident], [b])
sortFields = [(Ident, b)] -> ([Ident], [b])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Ident, b)] -> ([Ident], [b]))
-> ([(Ident, b)] -> [(Ident, b)]) -> [(Ident, b)] -> ([Ident], [b])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Ident, b) -> (Ident, b) -> Ordering)
-> [(Ident, b)] -> [(Ident, b)]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (((Ident, b) -> Ident) -> (Ident, b) -> (Ident, b) -> Ordering
forall a b. Ord a => (b -> a) -> b -> b -> Ordering
comparing (Ident, b) -> Ident
forall a b. (a, b) -> a
fst)
mgu t1 :: Type
t1 t2 :: Type
t2
| Bool -> Bool
not (Kind
k1 Kind -> Kind -> Bool
forall a. Eq a => a -> a -> Bool
== Kind
k2) = UnificationError -> Result MGU
uniError (UnificationError -> Result MGU) -> UnificationError -> Result MGU
forall a b. (a -> b) -> a -> b
$ Kind -> Kind -> UnificationError
UniKindMismatch Kind
k1 Kind
k2
| Bool
otherwise = UnificationError -> Result MGU
uniError (UnificationError -> Result MGU) -> UnificationError -> Result MGU
forall a b. (a -> b) -> a -> b
$ Type -> Type -> UnificationError
UniTypeMismatch Type
t1 Type
t2
where
k1 :: Kind
k1 = Type -> Kind
forall t. HasKind t => t -> Kind
kindOf Type
t1
k2 :: Kind
k2 = Type -> Kind
forall t. HasKind t => t -> Kind
kindOf Type
t2
mguMany :: [Type] -> [Type] -> Result MGU
mguMany :: [Type] -> [Type] -> Result MGU
mguMany [] [] = MGU -> Result MGU
forall (m :: * -> *) a. Monad m => a -> m a
return MGU
emptyMGU
mguMany (t1 :: Type
t1 : ts1 :: [Type]
ts1) (t2 :: Type
t2 : ts2 :: [Type]
ts2) =
do (su1 :: Subst
su1,ps1 :: [Type]
ps1) <- Type -> Type -> Result MGU
mgu Type
t1 Type
t2
(su2 :: Subst
su2,ps2 :: [Type]
ps2) <- [Type] -> [Type] -> Result MGU
mguMany (Subst -> [Type] -> [Type]
forall t. TVars t => Subst -> t -> t
apSubst Subst
su1 [Type]
ts1) (Subst -> [Type] -> [Type]
forall t. TVars t => Subst -> t -> t
apSubst Subst
su1 [Type]
ts2)
MGU -> Result MGU
forall (m :: * -> *) a. Monad m => a -> m a
return (Subst
su2 Subst -> Subst -> Subst
@@ Subst
su1, [Type]
ps1 [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
ps2)
mguMany t1 :: [Type]
t1 t2 :: [Type]
t2 = UnificationError -> Result MGU
uniError (UnificationError -> Result MGU) -> UnificationError -> Result MGU
forall a b. (a -> b) -> a -> b
$ Int -> Int -> UnificationError
UniTypeLenMismatch ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
t1) ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
t2)
bindVar :: TVar -> Type -> Result MGU
bindVar :: TVar -> Type -> Result MGU
bindVar x :: TVar
x (Type -> Type
tNoUser -> TVar y :: TVar
y)
| TVar
x TVar -> TVar -> Bool
forall a. Eq a => a -> a -> Bool
== TVar
y = MGU -> Result MGU
forall (m :: * -> *) a. Monad m => a -> m a
return MGU
emptyMGU
bindVar v :: TVar
v@(TVBound {}) (Type -> Type
tNoUser -> TVar v1 :: TVar
v1@(TVFree {})) = TVar -> Type -> Result MGU
bindVar TVar
v1 (TVar -> Type
TVar TVar
v)
bindVar v :: TVar
v@(TVBound {}) t :: Type
t
| Kind
k Kind -> Kind -> Bool
forall a. Eq a => a -> a -> Bool
== Type -> Kind
forall t. HasKind t => t -> Kind
kindOf Type
t = if Kind
k Kind -> Kind -> Bool
forall a. Eq a => a -> a -> Bool
== Kind
KNum
then MGU -> Result MGU
forall (m :: * -> *) a. Monad m => a -> m a
return (Subst
emptySubst, [TVar -> Type
TVar TVar
v Type -> Type -> Type
=#= Type
t])
else UnificationError -> Result MGU
uniError (UnificationError -> Result MGU) -> UnificationError -> Result MGU
forall a b. (a -> b) -> a -> b
$ TVar -> Type -> UnificationError
UniNonPoly TVar
v Type
t
| Bool
otherwise = UnificationError -> Result MGU
uniError (UnificationError -> Result MGU) -> UnificationError -> Result MGU
forall a b. (a -> b) -> a -> b
$ Kind -> Kind -> UnificationError
UniKindMismatch Kind
k (Type -> Kind
forall t. HasKind t => t -> Kind
kindOf Type
t)
where k :: Kind
k = TVar -> Kind
forall t. HasKind t => t -> Kind
kindOf TVar
v
bindVar x :: TVar
x@(TVFree _ _ xscope :: Set TParam
xscope _) (TVar y :: TVar
y@(TVFree _ _ yscope :: Set TParam
yscope _))
| Set TParam
xscope Set TParam -> Set TParam -> Bool
forall a. Ord a => Set a -> Set a -> Bool
`Set.isProperSubsetOf` Set TParam
yscope = MGU -> Result MGU
forall (m :: * -> *) a. Monad m => a -> m a
return (TVar -> Type -> Subst
singleSubst TVar
y (TVar -> Type
TVar TVar
x), [])
bindVar x :: TVar
x@(TVFree _ k :: Kind
k inScope :: Set TParam
inScope _d :: TVarInfo
_d) t :: Type
t
| Bool -> Bool
not (Kind
k Kind -> Kind -> Bool
forall a. Eq a => a -> a -> Bool
== Type -> Kind
forall t. HasKind t => t -> Kind
kindOf Type
t) = UnificationError -> Result MGU
uniError (UnificationError -> Result MGU) -> UnificationError -> Result MGU
forall a b. (a -> b) -> a -> b
$ Kind -> Kind -> UnificationError
UniKindMismatch Kind
k (Type -> Kind
forall t. HasKind t => t -> Kind
kindOf Type
t)
| Bool
recTy Bool -> Bool -> Bool
&& Kind
k Kind -> Kind -> Bool
forall a. Eq a => a -> a -> Bool
== Kind
KType = UnificationError -> Result MGU
uniError (UnificationError -> Result MGU) -> UnificationError -> Result MGU
forall a b. (a -> b) -> a -> b
$ TVar -> Type -> UnificationError
UniRecursive TVar
x Type
t
| Bool -> Bool
not (Set TParam -> Bool
forall a. Set a -> Bool
Set.null Set TParam
escaped) = UnificationError -> Result MGU
uniError (UnificationError -> Result MGU) -> UnificationError -> Result MGU
forall a b. (a -> b) -> a -> b
$ TVar -> [TParam] -> UnificationError
UniNonPolyDepends TVar
x ([TParam] -> UnificationError) -> [TParam] -> UnificationError
forall a b. (a -> b) -> a -> b
$ Set TParam -> [TParam]
forall a. Set a -> [a]
Set.toList Set TParam
escaped
| Bool
recTy = MGU -> Result MGU
forall (m :: * -> *) a. Monad m => a -> m a
return (Subst
emptySubst, [TVar -> Type
TVar TVar
x Type -> Type -> Type
=#= Type
t])
| Bool
otherwise = MGU -> Result MGU
forall (m :: * -> *) a. Monad m => a -> m a
return (TVar -> Type -> Subst
singleSubst TVar
x Type
t, [])
where
escaped :: Set TParam
escaped = Type -> Set TParam
forall t. FVS t => t -> Set TParam
freeParams Type
t Set TParam -> Set TParam -> Set TParam
forall a. Ord a => Set a -> Set a -> Set a
`Set.difference` Set TParam
inScope
recTy :: Bool
recTy = TVar
x TVar -> Set TVar -> Bool
forall a. Ord a => a -> Set a -> Bool
`Set.member` Type -> Set TVar
forall t. FVS t => t -> Set TVar
fvs Type
t
freeParams :: FVS t => t -> Set.Set TParam
freeParams :: t -> Set TParam
freeParams x :: t
x = [Set TParam] -> Set TParam
forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
Set.unions ((TVar -> Set TParam) -> [TVar] -> [Set TParam]
forall a b. (a -> b) -> [a] -> [b]
map TVar -> Set TParam
params (Set TVar -> [TVar]
forall a. Set a -> [a]
Set.toList (t -> Set TVar
forall t. FVS t => t -> Set TVar
fvs t
x)))
where
params :: TVar -> Set TParam
params (TVFree _ _ tps :: Set TParam
tps _) = Set TParam
tps
params (TVBound tp :: TParam
tp) = TParam -> Set TParam
forall a. a -> Set a
Set.singleton TParam
tp