module Data.Collections.Newtype.TH
( derive
)
where
import Control.Applicative hiding (empty)
import Control.Arrow
import Control.Monad.Unicode
import Data.Collections
import Data.Collections.BaseInstances ()
import Data.Data
import Data.Generics.Aliases
import Data.Generics.Schemes
import Data.Maybe
import Language.Haskell.TH.Lib
import Language.Haskell.TH.Ppr
import Language.Haskell.TH.Syntax
import Prelude hiding ( concat, concatMap, exp, filter
, foldl, foldr, foldl1, foldr1
, lookup, null
)
import Prelude.Unicode
type Deriver = Q Cxt → Q Type → Q Exp → Q Exp → Q Dec
derive ∷ Q [Dec] → Q [Dec]
derive = (concat <$>) ∘ (mapM go =≪)
where
go ∷ Dec → Q [Dec]
go (InstanceD c ty _) = deriveInstance c ty
go _ = fail "derive: usage: derive [d| instance A; instance B; ... |]"
deriveInstance ∷ Cxt → Type → Q [Dec]
deriveInstance c ty
= do (wrapperTy, deriver) ← inspectInstance ty
(wrap , wrapD ) ← genWrap wrapperTy
(unwrap , unwrapD) ← genUnwrap wrapperTy
instanceDecl ← deriver (return c )
(return ty )
(return wrap )
(return unwrap)
return $ [ d | d ← wrapD , wrap `isUsedIn` instanceDecl ]
⧺ [ d | d ← unwrapD, unwrap `isUsedIn` instanceDecl ]
⧺ [ instanceDecl ]
isUsedIn ∷ (Eq α, Typeable α, Data β) ⇒ α → β → Bool
isUsedIn α = (> 0) ∘ gcount (mkQ False (≡ α))
inspectInstance ∷ Type → Q (Type, Deriver)
inspectInstance (AppT (AppT (ConT classTy) wrapperTy) _)
| classTy ≡ ''Unfoldable
= return (wrapperTy, deriveUnfoldable)
| classTy ≡ ''Foldable
= return (wrapperTy, deriveFoldable)
| classTy ≡ ''Collection
= return (wrapperTy, deriveCollection)
| classTy ≡ ''Set
= return (wrapperTy, deriveSet)
| classTy ≡ ''SortingCollection
= return (wrapperTy, deriveSortingCollection)
inspectInstance (AppT (AppT (AppT (ConT classTy) wrapperTy) _) _)
| classTy ≡ ''Indexed
= return (wrapperTy, deriveIndexed)
| classTy ≡ ''Map
= return (wrapperTy, deriveMap)
inspectInstance ty
= fail $ "deriveInstance: unsupported type: " ⧺ pprint ty
genWrap ∷ Type → Q (Exp, [Dec])
genWrap wrapperTy
= do name ← newName "wrap"
(con, ty) ← wrapperConTy wrapperTy
decls ← sequence
[ sigD name [t| $(return ty) → $(return wrapperTy) |]
, pragInlD name (inlineSpecNoPhase True True)
, funD name [clause [] (normalB (conE con)) []]
]
return (VarE name, decls)
genUnwrap ∷ Type → Q (Exp, [Dec])
genUnwrap wrapperTy
= do name ← newName "unwrap"
i ← newName "i"
(con, ty) ← wrapperConTy wrapperTy
decls ← sequence
[ sigD name [t| $(return wrapperTy) → $(return ty) |]
, pragInlD name (inlineSpecNoPhase True True)
, funD name [clause [conP con [varP i]] (normalB (varE i)) []]
]
return (VarE name, decls)
wrapperConTy ∷ Type → Q (Name, Type)
wrapperConTy = (conTy =≪) ∘ tyInfo
where
tyInfo ∷ Type → Q Info
tyInfo (ConT name) = reify name
tyInfo (AppT ty _) = tyInfo ty
tyInfo (SigT ty _) = tyInfo ty
tyInfo ty
= fail $ "wrapperConTy: unsupported type: " ⧺ pprint ty
conTy ∷ Info → Q (Name, Type)
conTy (TyConI (NewtypeD [] _ [] (NormalC con [(NotStrict, ty)]) []))
= return (con, ty)
conTy info
= fail $ "wrapperConTy: unsupported type: " ⧺ pprint info
methodNames ∷ Name → Q [Name]
methodNames = (names =≪) ∘ reify
where
names ∷ Info → Q [Name]
names (ClassI (ClassD _ _ _ _ decls) _)
= return ∘ catMaybes $ map name decls
names c = fail $ "methodNames: not a class: " ⧺ pprint c
name ∷ Dec → Maybe Name
name (SigD n _) = Just n
name _ = Nothing
pointfreeMethod ∷ (Name → Q Exp) → Name → [Q Dec]
pointfreeMethod f name
= [ funD name [clause [] (normalB (f name)) []]
]
deriveUnfoldable ∷ Q Cxt → Q Type → Q Exp → Q Exp → Q Dec
deriveUnfoldable c ty wrap unwrap
= do names ← methodNames ''Unfoldable
instanceD c ty $ concatMap (pointfreeMethod exp) names
where
exp ∷ Name → Q Exp
exp name
| name ≡ 'insert
= [| ($wrap ∘) ∘ (∘ $unwrap) ∘ insert |]
| name ≡ 'empty
= [| $wrap empty |]
| name ≡ 'singleton
= [| $wrap ∘ singleton |]
| name ≡ 'insertMany
= [| ($wrap ∘) ∘ (∘ $unwrap) ∘ insertMany |]
| name ≡ 'insertManySorted
= [| ($wrap ∘) ∘ (∘ $unwrap) ∘ insertManySorted |]
| otherwise
= fail $ "deriveUnfoldable: unknown method: " ⧺ pprint name
deriveFoldable ∷ Q Cxt → Q Type → Q Exp → Q Exp → Q Dec
deriveFoldable c ty _ unwrap
= do names ← methodNames ''Foldable
instanceD c ty $ concatMap (pointfreeMethod exp) names
where
exp ∷ Name → Q Exp
exp name
| name ≡ 'fold
= [| fold ∘ $unwrap |]
| name ≡ 'foldMap
= [| (∘ $unwrap) ∘ foldMap |]
| name ≡ 'foldr
= [| flip flip $unwrap ∘ ((∘) ∘) ∘ foldr |]
| name ≡ 'foldl
= [| flip flip $unwrap ∘ ((∘) ∘) ∘ foldl |]
| name ≡ 'foldr1
= [| (∘ $unwrap) ∘ foldr1 |]
| name ≡ 'foldl1
= [| (∘ $unwrap) ∘ foldl1 |]
| name ≡ 'null
= [| null ∘ $unwrap |]
| name ≡ 'size
= [| size ∘ $unwrap |]
| name ≡ 'isSingleton
= [| isSingleton ∘ $unwrap |]
| otherwise
= fail $ "deriveFoldable: unknown method: " ⧺ pprint name
deriveCollection ∷ Q Cxt → Q Type → Q Exp → Q Exp → Q Dec
deriveCollection c ty wrap unwrap
= do names ← methodNames ''Collection
instanceD c ty $ concatMap (pointfreeMethod exp) names
where
exp ∷ Name → Q Exp
exp name
| name ≡ 'filter
= [| ($wrap ∘) ∘ (∘ $unwrap) ∘ filter |]
| otherwise
= fail $ "deriveCollection: unknown method: " ⧺ pprint name
deriveIndexed ∷ Q Cxt → Q Type → Q Exp → Q Exp → Q Dec
deriveIndexed c ty wrap unwrap
= do names ← methodNames ''Indexed
instanceD c ty $ concatMap (pointfreeMethod exp) names
where
exp ∷ Name → Q Exp
exp name
| name ≡ 'index
= [| (∘ $unwrap) ∘ index |]
| name ≡ 'adjust
= [| (($wrap ∘) ∘) ∘ flip flip $unwrap ∘ ((∘) ∘) ∘ adjust |]
| name ≡ 'inDomain
= [| (∘ $unwrap) ∘ inDomain |]
| name ≡ '(//)
= [| ($wrap ∘) ∘ (//) ∘ $unwrap |]
| name ≡ 'accum
= [| (($wrap ∘) ∘) ∘ (∘ $unwrap) ∘ accum |]
| otherwise
= fail $ "deriveIndexed: unknown method: " ⧺ pprint name
deriveMap ∷ Q Cxt → Q Type → Q Exp → Q Exp → Q Dec
deriveMap c ty wrap unwrap
= do names ← methodNames ''Map
instanceD c ty $ concatMap (pointfreeMethod exp) names
where
exp ∷ Name → Q Exp
exp name
| name ≡ 'delete
= [| ($wrap ∘) ∘ (∘ $unwrap) ∘ delete |]
| name ≡ 'member
= [| (∘ $unwrap) ∘ member |]
| name ≡ 'union
= [| ($wrap ∘) ∘ (∘ $unwrap) ∘ union ∘ $unwrap |]
| name ≡ 'intersection
= [| ($wrap ∘) ∘ (∘ $unwrap) ∘ intersection ∘ $unwrap |]
| name ≡ 'difference
= [| ($wrap ∘) ∘ (∘ $unwrap) ∘ difference ∘ $unwrap |]
| name ≡ 'isSubset
= [| (∘ $unwrap) ∘ isSubset ∘ $unwrap |]
| name ≡ 'isProperSubset
= [| (∘ $unwrap) ∘ isProperSubset ∘ $unwrap |]
| name ≡ 'lookup
= [| (∘ $unwrap) ∘ lookup |]
| name ≡ 'alter
= [| (($wrap ∘) ∘) ∘ flip flip $unwrap ∘ ((∘) ∘) ∘ alter |]
| name ≡ 'insertWith
= [| ((($wrap ∘) ∘) ∘) ∘ flip flip $unwrap ∘ ((flip ∘ ((∘) ∘)) ∘) ∘ insertWith |]
| name ≡ 'fromFoldableWith
= [| ($wrap ∘) ∘ fromFoldableWith |]
| name ≡ 'foldGroups
= [| (($wrap ∘) ∘) ∘ foldGroups |]
| name ≡ 'mapWithKey
= [| ($wrap ∘) ∘ (∘ $unwrap) ∘ mapWithKey |]
| name ≡ 'unionWith
= [| (($wrap ∘) ∘) ∘ flip flip $unwrap ∘ ((∘) ∘) ∘ (∘ $unwrap) ∘ unionWith |]
| name ≡ 'intersectionWith
= [| (($wrap ∘) ∘) ∘ flip flip $unwrap ∘ ((∘) ∘) ∘ (∘ $unwrap) ∘ intersectionWith |]
| name ≡ 'differenceWith
= [| (($wrap ∘) ∘) ∘ flip flip $unwrap ∘ ((∘) ∘) ∘ (∘ $unwrap) ∘ differenceWith |]
| name ≡ 'isSubmapBy
= [| flip flip $unwrap ∘ ((∘) ∘) ∘ (∘ $unwrap) ∘ isSubmapBy |]
| name ≡ 'isProperSubmapBy
= [| flip flip $unwrap ∘ ((∘) ∘) ∘ (∘ $unwrap) ∘ isProperSubmapBy |]
| otherwise
= fail $ "deriveMap: unknown method: " ⧺ pprint name
deriveSet ∷ Q Cxt → Q Type → Q Exp → Q Exp → Q Dec
deriveSet c ty _ _
= do names ← methodNames ''Set
instanceD c ty $ concatMap (pointfreeMethod exp) names
where
exp ∷ Name → Q Exp
exp name
| name ≡ 'haddock_candy
= [| haddock_candy |]
| otherwise
= fail $ "deriveSet: unknown method: " ⧺ pprint name
deriveSortingCollection ∷ Q Cxt → Q Type → Q Exp → Q Exp → Q Dec
deriveSortingCollection c ty wrap unwrap
= do names ← methodNames ''SortingCollection
instanceD c ty $ concatMap (pointfreeMethod exp) names
where
exp ∷ Name → Q Exp
exp name
| name ≡ 'minView
= [| (second $wrap <$>) ∘ minView ∘ $unwrap |]
| otherwise
= fail $ "deriveSortingCollection: unknown method: " ⧺ pprint name