module Data.SBV.Provers.Z3(z3) where
import qualified Control.Exception as C
import Data.Char (toLower)
import Data.Function (on)
import Data.List (sortBy, intercalate, isPrefixOf, groupBy)
import System.Environment (getEnv)
import qualified System.Info as S(os)
import Data.SBV.BitVectors.AlgReals
import Data.SBV.BitVectors.Data
import Data.SBV.BitVectors.PrettyNum
import Data.SBV.SMT.SMT
import Data.SBV.SMT.SMTLib
import Data.SBV.Utils.Lib (splitArgs)
optionPrefix :: Char
optionPrefix
| map toLower S.os `elem` ["linux", "darwin"] = '-'
| True = '/'
z3 :: SMTSolver
z3 = SMTSolver {
name = Z3
, executable = "z3"
, options = map (optionPrefix:) ["in", "smt2"]
, engine = \cfg isSat qinps modelMap skolemMap pgm -> do
execName <- getEnv "SBV_Z3" `C.catch` (\(_ :: C.SomeException) -> return (executable (solver cfg)))
execOpts <- (splitArgs `fmap` getEnv "SBV_Z3_OPTIONS") `C.catch` (\(_ :: C.SomeException) -> return (options (solver cfg)))
let cfg' = cfg { solver = (solver cfg) {executable = execName, options = addTimeOut (timeOut cfg) execOpts} }
tweaks = case solverTweaks cfg' of
[] -> ""
ts -> unlines $ "; --- user given solver tweaks ---" : ts ++ ["; --- end of user given tweaks ---"]
dlim = printRealPrec cfg'
ppDecLim = "(set-option :pp.decimal_precision " ++ show dlim ++ ")\n"
script = SMTScript {scriptBody = tweaks ++ ppDecLim ++ pgm, scriptModel = Just (cont (roundingMode cfg) skolemMap)}
if dlim < 1
then error $ "SBV.Z3: printRealPrec value should be at least 1, invalid value received: " ++ show dlim
else standardSolver cfg' script cleanErrs (ProofError cfg') (interpretSolverOutput cfg' (extractMap isSat qinps modelMap))
, xformExitCode = id
, capabilities = SolverCapabilities {
capSolverName = "Z3"
, mbDefaultLogic = Nothing
, supportsMacros = True
, supportsProduceModels = True
, supportsQuantifiers = True
, supportsUninterpretedSorts = True
, supportsUnboundedInts = True
, supportsReals = True
, supportsFloats = True
, supportsDoubles = True
}
}
where cleanErrs = intercalate "\n" . filter (not . junk) . lines
junk = ("WARNING:" `isPrefixOf`)
cont rm skolemMap = intercalate "\n" $ concatMap extract skolemMap
where
extract (Left s) = ["(echo \"((" ++ show s ++ " " ++ mkSkolemZero rm (kindOf s) ++ "))\")"]
extract (Right (s, [])) = let g = "(get-value (" ++ show s ++ "))" in getVal (kindOf s) g
extract (Right (s, ss)) = let g = "(get-value ((" ++ show s ++ concat [' ' : mkSkolemZero rm (kindOf a) | a <- ss] ++ ")))" in getVal (kindOf s) g
getVal KReal g = ["(set-option :pp.decimal false) " ++ g, "(set-option :pp.decimal true) " ++ g]
getVal _ g = [g]
addTimeOut Nothing o = o
addTimeOut (Just i) o
| i < 0 = error $ "Z3: Timeout value must be non-negative, received: " ++ show i
| True = o ++ [optionPrefix : "T:" ++ show i]
extractMap :: Bool -> [(Quantifier, NamedSymVar)] -> [(String, UnintKind)] -> [String] -> SMTModel
extractMap isSat qinps _modelMap solverLines =
SMTModel { modelAssocs = map snd $ squashReals $ sortByNodeId $ concatMap (interpretSolverModelLine inps) solverLines
, modelUninterps = []
, modelArrays = []
}
where sortByNodeId :: [(Int, a)] -> [(Int, a)]
sortByNodeId = sortBy (compare `on` fst)
inps
| isSat = map snd $ if all (== ALL) (map fst qinps)
then qinps
else reverse $ dropWhile ((== ALL) . fst) $ reverse qinps
| True = map snd $ takeWhile ((== ALL) . fst) qinps
squashReals :: [(Int, (String, CW))] -> [(Int, (String, CW))]
squashReals = concatMap squash . groupBy ((==) `on` fst)
where squash [(i, (n, cw1)), (_, (_, cw2))] = [(i, (n, mergeReals n cw1 cw2))]
squash xs = xs
mergeReals :: String -> CW -> CW -> CW
mergeReals n (CW KReal (CWAlgReal a)) (CW KReal (CWAlgReal b)) = CW KReal (CWAlgReal (mergeAlgReals (bad n a b) a b))
mergeReals n a b = bad n a b
bad n a b = error $ "SBV.Z3: Cannot merge reals for variable: " ++ n ++ " received: " ++ show (a, b)