module Cryptol.TypeCheck.Solve
( simplifyAllConstraints
, proveImplication
, wfType
, wfTypeFunction
, improveByDefaultingWith
, defaultReplExpr
) where
import Cryptol.TypeCheck.PP(pp)
import Cryptol.TypeCheck.AST
import Cryptol.TypeCheck.Monad
import Cryptol.TypeCheck.Subst
(apSubst, singleSubst, isEmptySubst, substToList,
emptySubst,Subst,listSubst, (@@), Subst,
apSubstMaybe, substBinds)
import qualified Cryptol.TypeCheck.SimpleSolver as Simplify
import Cryptol.TypeCheck.Solver.Types
import Cryptol.TypeCheck.Solver.Selector(tryHasGoal)
import Cryptol.TypeCheck.SimpType(tMax)
import Cryptol.TypeCheck.Solver.SMT(proveImp,checkUnsolvable)
import Cryptol.TypeCheck.Solver.Improve(improveProp,improveProps)
import Cryptol.TypeCheck.Solver.Numeric.Interval
import qualified Cryptol.TypeCheck.Solver.Numeric.AST as Num
import qualified Cryptol.TypeCheck.Solver.Numeric.ImportExport as Num
import qualified Cryptol.TypeCheck.Solver.CrySAT as Num
import Cryptol.TypeCheck.Solver.CrySAT
import Cryptol.Utils.PP (text,vcat,(<+>))
import Cryptol.Utils.Panic(panic)
import Cryptol.Utils.Patterns(matchMaybe)
import Control.Monad (guard, mzero)
import Control.Applicative ((<|>))
import Data.Either(partitionEithers)
import Data.Maybe(catMaybes)
import Data.Map ( Map )
import qualified Data.Map as Map
import Data.Set ( Set )
import qualified Data.Set as Set
wfTypeFunction :: TFun -> [Type] -> [Prop]
wfTypeFunction TCSub [a,b] = [ a >== b, pFin b]
wfTypeFunction TCDiv [a,b] = [ b >== tOne, pFin a ]
wfTypeFunction TCMod [a,b] = [ b >== tOne, pFin a ]
wfTypeFunction TCLenFromThen [a,b,w] =
[ pFin a, pFin b, pFin w, a =/= b, w >== tWidth a ]
wfTypeFunction TCLenFromThenTo [a,b,c] = [ pFin a, pFin b, pFin c, a =/= b ]
wfTypeFunction _ _ = []
wfType :: Type -> [Prop]
wfType t =
case t of
TCon c ts ->
let ps = concatMap wfType ts
in case c of
TF f -> wfTypeFunction f ts ++ ps
_ -> ps
TVar _ -> []
TUser _ _ s -> wfType s
TRec fs -> concatMap (wfType . snd) fs
quickSolverIO :: Ctxt -> [Goal] -> IO (Either Goal (Subst,[Goal]))
quickSolverIO _ [] = return (Right (emptySubst, []))
quickSolverIO ctxt gs =
case quickSolver ctxt gs of
Left err ->
do msg (text "Contradiction:" <+> pp (goal err))
return (Left err)
Right (su,gs') ->
do msg (vcat (map (pp . goal) gs' ++ [pp su]))
return (Right (su,gs'))
where
msg _ = return ()
quickSolver :: Ctxt
-> [Goal]
-> Either Goal (Subst,[Goal])
quickSolver ctxt gs0 = go emptySubst [] gs0
where
go su [] [] = Right (su,[])
go su unsolved [] =
case matchMaybe (findImprovement unsolved) of
Nothing -> Right (su,unsolved)
Just (newSu, subs) -> go (newSu @@ su) [] (subs ++ apSubst newSu unsolved)
go su unsolved (g : gs) =
case Simplify.simplifyStep ctxt (goal g) of
Unsolvable _ -> Left g
Unsolved -> go su (g : unsolved) gs
SolvedIf subs ->
let cvt x = g { goal = x }
in go su unsolved (map cvt subs ++ gs)
findImprovement [] = mzero
findImprovement (g : gs) =
do (su,ps) <- improveProp False ctxt (goal g)
return (su, [ g { goal = p } | p <- ps ])
<|> findImprovement gs
simplifyAllConstraints :: InferM ()
simplifyAllConstraints =
do simpHasGoals
gs <- getGoals
case gs of
[] -> return ()
_ ->
case quickSolver Map.empty gs of
Left badG -> recordError (UnsolvedGoals True [badG])
Right (su,gs1) ->
do extendSubst su
addGoals gs1
simpHasGoals :: InferM ()
simpHasGoals = go False [] =<< getHasGoals
where
go _ [] [] = return ()
go True unsolved [] = go False [] unsolved
go False unsolved [] = mapM_ addHasGoal unsolved
go changes unsolved (g : todo) =
do (ch,solved) <- tryHasGoal g
let changes' = ch || changes
unsolved' = if solved then unsolved else g : unsolved
changes' `seq` unsolved `seq` go changes' unsolved' todo
proveImplication :: Name -> [TParam] -> [Prop] -> [Goal] -> InferM Subst
proveImplication lnam as ps gs =
do evars <- varsWithAsmps
solver <- getSolver
(mbErr,su) <- io (proveImplicationIO solver lnam evars as ps gs)
case mbErr of
Right ws -> mapM_ recordWarning ws
Left err -> recordError err
return su
proveImplicationIO :: Num.Solver
-> Name
-> Set TVar
-> [TParam]
-> [Prop]
-> [Goal]
-> IO (Either Error [Warning], Subst)
proveImplicationIO _ _ _ _ [] [] = return (Right [], emptySubst)
proveImplicationIO s f varsInEnv ps asmps0 gs0 =
do let ctxt = assumptionIntervals Map.empty asmps
res <- quickSolverIO ctxt gs
case res of
Left bad -> return (Left (UnsolvedGoals True [bad]), emptySubst)
Right (su,[]) -> return (Right [], su)
Right (su,gs1) ->
do gs2 <- proveImp s asmps gs1
case gs2 of
[] -> return (Right [], su)
gs3 ->
do let free = Set.toList
$ Set.difference (fvs (map goal gs3)) varsInEnv
case improveByDefaultingWithPure free gs3 of
(_,_,newSu,_)
| isEmptySubst newSu -> return (err gs3, su)
(_,newGs,newSu,ws) ->
do let su1 = newSu @@ su
(res1,su2) <- proveImplicationIO s f varsInEnv ps
(apSubst su1 asmps0) newGs
let su3 = su2 @@ su1
case res1 of
Left bad -> return (Left bad, su3)
Right ws1 -> return (Right (ws++ws1),su3)
where
err us = Left $ cleanupError
$ UnsolvedDelayedCt
$ DelayedCt { dctSource = f
, dctForall = ps
, dctAsmps = asmps0
, dctGoals = us
}
(asmps,gs) =
let gs1 = [ g { goal = p } | g <- gs0, p <- pSplitAnd (goal g)
, notElem p asmps0 ]
in case matchMaybe (improveProps True Map.empty asmps0) of
Nothing -> (asmps0,gs1)
Just (newSu,newAsmps) ->
( [ TVar x =#= t | (x,t) <- substToList newSu ]
++ newAsmps
, [ g { goal = apSubst newSu (goal g) } | g <- gs1 ]
)
cleanupError :: Error -> Error
cleanupError err =
case err of
UnsolvedDelayedCt d ->
let noInferVars = Set.null . Set.filter isFreeTV . fvs . goal
without = filter noInferVars (dctGoals d)
in UnsolvedDelayedCt $
if not (null without) then d { dctGoals = without } else d
_ -> err
simpGoals' :: Num.Solver -> Ctxt -> [Goal] -> IO (Either [Goal] [Goal], Subst)
simpGoals' s asmps gs0 = go emptySubst [] (wellFormed gs0 ++ gs0)
where
wellFormed gs = [ g { goal = p } | g <- gs, p <- wfType (goal g) ]
go su old [] = return (Right old, su)
go su old gs =
do res <- solveConstraints s asmps old gs
case res of
Left err -> return (Left err, su)
Right gs2 ->
do let gs3 = gs2 ++ old
mb <- computeImprovements s gs3
case mb of
Left err -> return (Left err, su)
Right impSu ->
let (unchanged,changed) =
partitionEithers (map (applyImp impSu) gs3)
new = wellFormed changed
in go (impSu @@ su) unchanged (new ++ changed)
applyImp su g = case apSubstMaybe su (goal g) of
Nothing -> Left g
Just p -> Right g { goal = p }
assumptionIntervals :: Ctxt -> [Prop] -> Ctxt
assumptionIntervals as ps =
case computePropIntervals as ps of
NoChange -> as
InvalidInterval {} -> as
NewIntervals bs -> Map.union bs as
solveConstraints :: Num.Solver ->
Ctxt ->
[Goal] ->
[Goal] ->
IO (Either [Goal] [Goal])
solveConstraints s asmps otherGs gs0 =
debugBlock s "Solving constraints" $ go ctxt0 [] gs0
where
ctxt0 = assumptionIntervals asmps (map goal otherGs)
go _ unsolved [] =
do let (cs,nums) = partitionEithers (map Num.numericRight unsolved)
nums' <- solveNumerics s otherNumerics nums
return (Right (cs ++ nums'))
go ctxt unsolved (g : gs) =
case Simplify.simplifyStep ctxt (goal g) of
Unsolvable _x -> return (Left [g])
Unsolved -> go ctxt (g : unsolved) gs
SolvedIf subs ->
let cvt x = g { goal = x }
in go ctxt unsolved (map cvt subs ++ gs)
otherNumerics = [ g | Right g <- map Num.numericRight otherGs ]
solveNumerics :: Num.Solver ->
[(Goal,Num.Prop)] ->
[(Goal,Num.Prop)] ->
IO [Goal]
solveNumerics _ _ [] = return []
solveNumerics s consultGs solveGs =
Num.withScope s $
do _ <- Num.assumeProps s (map (goal . fst) consultGs)
Num.simplifyProps s (map Num.knownDefined solveGs)
computeImprovements :: Num.Solver -> [Goal] -> IO (Either [Goal] Subst)
computeImprovements s gs =
debugBlock s "Computing improvements" $
do let nums = [ g | Right g <- map Num.numericRight gs ]
res <- Num.withScope s $
do _ <- Num.assumeProps s (map (goal . fst) nums)
mb <- Num.check s
case mb of
Nothing -> return Nothing
Just (suish,_ps1) ->
do let (su,_ps2) = importSplitImps suish
Right ints <- Num.getIntervals s
return (Just (ints,su))
case res of
Just (_ints, su) -> return (Right su)
Nothing ->
do bad <- Num.minimizeContradictionSimpDef s
(map Num.knownDefined nums)
return (Left bad)
importSplitImps :: Map Num.Name Num.Expr -> (Subst, [Prop])
importSplitImps = mk . partitionEithers . map imp . Map.toList
where
mk (uni,props) = (listSubst (catMaybes uni), props)
imp (x,e) = case (x, Num.importType e) of
(Num.UserName tv, Just ty) ->
case tv of
TVFree {} -> Left (Just (tv,ty))
TVBound {} -> Right (TVar tv =#= ty)
_ -> Left Nothing
improveByDefaultingWith ::
Num.Solver ->
[TVar] ->
[Goal] ->
IO ( [TVar]
, [Goal]
, Maybe Subst
, [Warning]
)
improveByDefaultingWith s as gs =
do bad <- checkUnsolvable s gs
if bad
then return (as, gs, Nothing, [])
else tryImp
where
tryImp =
case improveByDefaultingWithPure as gs of
(xs,gs',su,ws) ->
do (res,su1) <- simpGoals' s Map.empty gs'
case res of
Left err ->
panic "improveByDefaultingWith"
$ [ "Defaulting resulted in unsolvable constraints."
, "Before:"
] ++ [ " " ++ show (pp (goal g)) | g <- gs ] ++
[ "After:"
] ++ [ " " ++ show (pp (goal g)) | g <- gs' ] ++
[ "Contradiction:" ] ++
[ " " ++ show (pp (goal g)) | g <- err ]
Right gs'' ->
do let su2 = su1 @@ su
isDef x = x `Set.member` substBinds su2
return ( filter (not . isDef) xs
, gs''
, Just su2
, ws
)
improveByDefaultingWithPure :: [TVar] -> [Goal] ->
( [TVar]
, [Goal]
, Subst
, [Warning]
)
improveByDefaultingWithPure as ps =
classify (Map.fromList [ (a,([],Set.empty)) | a <- as ]) [] [] ps
where
classify leqs fins others [] =
let
(defs, newOthers) = select [] [] (fvs others) (Map.toList leqs)
su = listSubst defs
warn (x,t) =
case x of
TVFree _ _ _ d -> DefaultingTo d t
TVBound {} -> panic "Crypto.TypeCheck.Infer"
[ "tryDefault attempted to default a quantified variable."
]
names = substBinds su
in ( [ a | a <- as, not (a `Set.member` names) ]
, newOthers ++ others ++ apSubst su fins
, su
, map warn defs
)
classify leqs fins others (prop : more) =
case tNoUser (goal prop) of
TCon (PC PFin) [ _ ] -> classify leqs (prop : fins) others more
TCon (PC PGeq) [ TVar x, t ]
| x `elem` as && x `Set.notMember` freeRHS ->
classify leqs' fins others more
where freeRHS = fvs t
add (xs1,vs1) (xs2,vs2) = (xs1 ++ xs2, Set.union vs1 vs2)
leqs' = Map.insertWith add x ([(t,prop)],freeRHS) leqs
_ -> classify leqs fins (prop : others) more
select yes no _ [] = ([ (x, t) | (x,t) <- yes ] ,no)
select yes no otherFree ((x,(rhsG,vs)) : more) =
select newYes newNo newFree newMore
where
(ts,gs) = unzip rhsG
(newYes,newNo,newFree,newMore)
| x `Set.member` otherFree = noDefaulting
| otherwise =
let deps = [ y | (y,(_,yvs)) <- more, x `Set.member` yvs ]
recs = filter (`Set.member` vs) deps
in if not (null recs) || isBoundTV x
then noDefaulting
else yesDefaulting
where
noDefaulting = ( yes, gs ++ no, vs `Set.union` otherFree, more )
yesDefaulting =
let ty = case ts of
[] -> tNum (0::Int)
_ -> foldr1 tMax ts
su1 = singleSubst x ty
in ( (x,ty) : [ (y,apSubst su1 t) | (y,t) <- yes ]
, no
, otherFree
, [ (y, (apSubst su1 ts1, vs1)) | (y,(ts1,vs1)) <- more ]
)
defaultReplExpr :: Num.Solver -> Expr -> Schema
-> IO (Maybe ([(TParam,Type)], Expr))
defaultReplExpr so e s =
if all (\v -> kindOf v == KNum) (sVars s)
then do let params = map tpVar (sVars s)
mbSubst <- tryGetModel so params (sProps s)
case mbSubst of
Just su ->
do (res,su1) <- simpGoals' so Map.empty (map (makeGoal su) (sProps s))
return $
case res of
Right [] | isEmptySubst su1 ->
do tys <- mapM (bindParam su) params
return (zip (sVars s) tys, appExpr tys)
_ -> Nothing
_ -> return Nothing
else return Nothing
where
makeGoal su p = Goal { goalSource = error "goal source"
, goalRange = error "goal range"
, goal = apSubst su p
}
bindParam su tp =
do let ty = TVar tp
ty' = apSubst su ty
guard (ty /= ty')
return ty'
appExpr tys = foldl (\e1 _ -> EProofApp e1) (foldl ETApp e tys) (sProps s)
tryGetModel ::
Num.Solver ->
[TVar] ->
[Prop] ->
IO (Maybe Subst)
tryGetModel s xs ps =
Num.getModel s (map (pFin . TVar) xs ++ ps)