module Curry.Syntax.Utils
( hasLanguageExtension, knownExtensions
, isTypeSig, infixOp, isTypeDecl, isValueDecl, isInfixDecl
, isFunctionDecl, isExternalDecl, patchModuleId
, flatLhs, mkInt, fieldLabel, fieldTerm, field2Tuple, opName
, addSrcRefs
, constrId, nconstrId
, recordLabels, nrecordLabels
) where
import Control.Monad.State
import Data.Generics
import Curry.Base.Ident
import Curry.Base.Position
import Curry.Files.Filenames (takeBaseName)
import Curry.Syntax.Extension
import Curry.Syntax.Type
hasLanguageExtension :: Module -> KnownExtension -> Bool
hasLanguageExtension mdl ext = ext `elem` knownExtensions mdl
knownExtensions :: Module -> [KnownExtension]
knownExtensions (Module ps _ _ _ _) =
[ e | LanguagePragma _ exts <- ps, KnownExtension _ e <- exts]
patchModuleId :: FilePath -> Module -> Module
patchModuleId fn m@(Module ps mid es is ds)
| mid == mainMIdent = Module ps (mkMIdent [takeBaseName fn]) es is ds
| otherwise = m
isInfixDecl :: Decl -> Bool
isInfixDecl (InfixDecl _ _ _ _) = True
isInfixDecl _ = False
isTypeDecl :: Decl -> Bool
isTypeDecl (DataDecl _ _ _ _) = True
isTypeDecl (NewtypeDecl _ _ _ _) = True
isTypeDecl (TypeDecl _ _ _ _) = True
isTypeDecl _ = False
isTypeSig :: Decl -> Bool
isTypeSig (TypeSig _ _ _) = True
isTypeSig (ForeignDecl _ _ _ _ _) = True
isTypeSig _ = False
isValueDecl :: Decl -> Bool
isValueDecl (FunctionDecl _ _ _) = True
isValueDecl (ForeignDecl _ _ _ _ _) = True
isValueDecl (ExternalDecl _ _) = True
isValueDecl (PatternDecl _ _ _) = True
isValueDecl (FreeDecl _ _) = True
isValueDecl _ = False
isFunctionDecl :: Decl -> Bool
isFunctionDecl (FunctionDecl _ _ _) = True
isFunctionDecl _ = False
isExternalDecl :: Decl -> Bool
isExternalDecl (ForeignDecl _ _ _ _ _) = True
isExternalDecl (ExternalDecl _ _) = True
isExternalDecl _ = False
infixOp :: InfixOp -> Expression
infixOp (InfixOp op) = Variable op
infixOp (InfixConstr op) = Constructor op
flatLhs :: Lhs -> (Ident, [Pattern])
flatLhs lhs = flat lhs []
where flat (FunLhs f ts) ts' = (f, ts ++ ts')
flat (OpLhs t1 op t2) ts' = (op, t1 : t2 : ts')
flat (ApLhs lhs' ts) ts' = flat lhs' (ts ++ ts')
mkInt :: Integer -> Literal
mkInt i = mk (\r -> Int (addPositionIdent (AST r) anonId) i)
fieldLabel :: Field a -> QualIdent
fieldLabel (Field _ l _) = l
fieldTerm :: Field a -> a
fieldTerm (Field _ _ t) = t
field2Tuple :: Field a -> (QualIdent, a)
field2Tuple (Field _ l t) = (l, t)
opName :: InfixOp -> QualIdent
opName (InfixOp op) = op
opName (InfixConstr c) = c
type M a = a -> State Int a
addSrcRefs :: Module -> Module
addSrcRefs x = evalState (addRefs x) 0
where
addRefs :: Data a' => M a'
addRefs = down `extM` addRefPos
`extM` addRefSrc
`extM` addRefIdent
`extM` addRefListPat
`extM` addRefListExp
where
down :: Data a' => M a'
down = gmapM addRefs
nextRef :: State Int SrcRef
nextRef = do
i <- get
put $! i+1
return $ srcRef i
addRefSrc :: M SrcRef
addRefSrc _ = nextRef
addRefPos :: M [SrcRef]
addRefPos _ = (:[]) `liftM` nextRef
addRefIdent :: M Ident
addRefIdent ident = flip addRefId ident `liftM` nextRef
addRefListPat :: M Pattern
addRefListPat (ListPattern _ ts) = uncurry ListPattern `liftM` addRefList ts
addRefListPat ct = down ct
addRefListExp :: M Expression
addRefListExp (List _ ts) = uncurry List `liftM` addRefList ts
addRefListExp ct = down ct
addRefList :: Data a' => [a'] -> State Int ([SrcRef],[a'])
addRefList ts = do
i <- nextRef
let add t = do t' <- addRefs t; j <- nextRef; return (j, t')
ists <- sequence (map add ts)
let (is,ts') = unzip ists
return (i:is,ts')
constrId :: ConstrDecl -> Ident
constrId (ConstrDecl _ _ c _) = c
constrId (ConOpDecl _ _ _ op _) = op
constrId (RecordDecl _ _ c _) = c
nconstrId :: NewConstrDecl -> Ident
nconstrId (NewConstrDecl _ _ c _) = c
nconstrId (NewRecordDecl _ _ c _) = c
recordLabels :: ConstrDecl -> [Ident]
recordLabels (ConstrDecl _ _ _ _) = []
recordLabels (ConOpDecl _ _ _ _ _) = []
recordLabels (RecordDecl _ _ _ fs) = [l | FieldDecl _ ls _ <- fs, l <- ls]
nrecordLabels :: NewConstrDecl -> [Ident]
nrecordLabels (NewConstrDecl _ _ _ _ ) = []
nrecordLabels (NewRecordDecl _ _ _ (l, _)) = [l]