{-# LANGUAGE
    TemplateHaskell
  , UnicodeSyntax
  #-}
-- |Doesn't anyone know why these instances can't be derived using
-- GeneralizedNewtypeDeriving? I think its limitation isn't reasonable
-- at all...
module Data.Collections.Newtype.TH
    ( derive
    )
    where
import Control.Applicative hiding (empty)
import Control.Arrow
import Control.Monad.Unicode
import Data.Collections
import Data.Collections.BaseInstances ()
import Data.Data
import Data.Generics.Aliases
import Data.Generics.Schemes
import Data.Maybe
import Language.Haskell.TH.Lib
import Language.Haskell.TH.Ppr
import Language.Haskell.TH.Syntax
import Prelude hiding ( concat, concatMap, exp, filter
                      , foldl, foldr, foldl1, foldr1
                      , lookup, null
                      )
import Prelude.Unicode

type Deriver = Q Cxt  Q Type  Q Exp  Q Exp  Q Dec

-- |Automatic newtype instance deriver for type classes defined by the
-- collections-api package.
--
-- @
--   {-\# LANGUAGE TemplateHaskell \#-}
--   module Foo (T) where
--   import "Data.Collections"
--   import "Data.Collections.BaseInstances" ()
--   import qualified Data.Collections.Newtype.TH as C
--   import qualified "Data.Map" as M
--
--   newtype T = T (M.Map 'Int' 'Bool')
--
--   C.derive [d| instance 'Unfoldable' T ('Int', 'Bool')
--                instance 'Foldable'   T ('Int', 'Bool')
--                instance 'Indexed'    T  'Int'  'Bool'
--                ...
--              |]
-- @
--
-- This function can derive the following instances:
--
--   * 'Unfoldable'
--
--   * 'Foldable'
--
--   * 'Collection'
--
--   * 'Indexed'
--
--   * 'Map'
--
--   * 'Set'
--
--   * 'SortingCollection'
--
derive  Q [Dec]  Q [Dec]
derive = (concat <$>)  (mapM go =)
    where
      go  Dec  Q [Dec]
      go (InstanceD c ty _) = deriveInstance c ty
      go _ = fail "derive: usage: derive [d| instance A; instance B; ... |]"

deriveInstance  Cxt  Type  Q [Dec]
deriveInstance c ty
    = do (wrapperTy, deriver)  inspectInstance ty
         (wrap     , wrapD  )  genWrap   wrapperTy
         (unwrap   , unwrapD)  genUnwrap wrapperTy
         instanceDecl          deriver (return c     )
                                        (return ty    )
                                        (return wrap  )
                                        (return unwrap)
         return $ [ d | d  wrapD  , wrap   `isUsedIn` instanceDecl ]
                 [ d | d  unwrapD, unwrap `isUsedIn` instanceDecl ]
                 [ instanceDecl ]

isUsedIn  (Eq α, Typeable α, Data β)  α  β  Bool
isUsedIn α = (> 0)  gcount (mkQ False ( α))

inspectInstance  Type  Q (Type, Deriver)
inspectInstance (AppT (AppT (ConT classTy) wrapperTy) _)
    | classTy  ''Unfoldable
        = return (wrapperTy, deriveUnfoldable)
    | classTy  ''Foldable
        = return (wrapperTy, deriveFoldable)
    | classTy  ''Collection
        = return (wrapperTy, deriveCollection)
    | classTy  ''Set
        = return (wrapperTy, deriveSet)
    | classTy  ''SortingCollection
        = return (wrapperTy, deriveSortingCollection)
inspectInstance (AppT (AppT (AppT (ConT classTy) wrapperTy) _) _)
    | classTy  ''Indexed
        = return (wrapperTy, deriveIndexed)
    | classTy  ''Map
        = return (wrapperTy, deriveMap)
inspectInstance ty
    = fail $ "deriveInstance: unsupported type: "  pprint ty

genWrap  Type  Q (Exp, [Dec])
genWrap wrapperTy
    = do name       newName "wrap"
         (con, ty)  wrapperConTy wrapperTy
         decls      sequence
                     [ sigD name [t| $(return ty)  $(return wrapperTy) |]
                     , pragInlD name (inlineSpecNoPhase True True)
                     , funD name [clause [] (normalB (conE con)) []]
                     ]
         return (VarE name, decls)

genUnwrap  Type  Q (Exp, [Dec])
genUnwrap wrapperTy
    = do name       newName "unwrap"
         i          newName "i"
         (con, ty)  wrapperConTy wrapperTy
         decls      sequence
                     [ sigD name [t| $(return wrapperTy)  $(return ty) |]
                     , pragInlD name (inlineSpecNoPhase True True)
                     , funD name [clause [conP con [varP i]] (normalB (varE i)) []]
                     ]
         return (VarE name, decls)

wrapperConTy  Type  Q (Name, Type)
wrapperConTy = (conTy =)  tyInfo
    where
      tyInfo  Type  Q Info
      tyInfo (ConT name) = reify name
      tyInfo (AppT ty _) = tyInfo ty
      tyInfo (SigT ty _) = tyInfo ty
      tyInfo ty
          = fail $ "wrapperConTy: unsupported type: "  pprint ty

      conTy  Info  Q (Name, Type)
      conTy (TyConI (NewtypeD [] _ [] (NormalC con [(NotStrict, ty)]) []))
          = return (con, ty)
      conTy info
          = fail $ "wrapperConTy: unsupported type: "  pprint info

methodNames  Name  Q [Name]
methodNames = (names =)  reify
    where
      names  Info  Q [Name]
      names (ClassI (ClassD _ _ _ _ decls) _)
              = return  catMaybes $ map name decls
      names c = fail $ "methodNames: not a class: "  pprint c

      name  Dec  Maybe Name
      name (SigD n _) = Just n
      name _          = Nothing

pointfreeMethod  (Name  Q Exp)  Name  [Q Dec]
pointfreeMethod f name
    = [ funD name [clause [] (normalB (f name)) []]
      -- THINKME: Inserting PragmaD in an InstanceD causes an error
      -- least GHC 7.0.3. Why?
      -- , pragInlD name (inlineSpecNoPhase True False)
      ]

deriveUnfoldable  Q Cxt  Q Type  Q Exp  Q Exp  Q Dec
deriveUnfoldable c ty wrap unwrap
    = do names  methodNames ''Unfoldable
         instanceD c ty $ concatMap (pointfreeMethod exp) names
    where
      exp  Name  Q Exp
      exp name
          | name  'insert
              = [| ($wrap )  ( $unwrap)  insert |]
          | name  'empty
              = [| $wrap empty |]
          | name  'singleton
              = [| $wrap  singleton |]
          | name  'insertMany
              = [| ($wrap )  ( $unwrap)  insertMany |]
          | name  'insertManySorted
              = [| ($wrap )  ( $unwrap)  insertManySorted |]
          | otherwise
              = fail $ "deriveUnfoldable: unknown method: "  pprint name

deriveFoldable  Q Cxt  Q Type  Q Exp  Q Exp  Q Dec
deriveFoldable c ty _ unwrap
    = do names  methodNames ''Foldable
         instanceD c ty $ concatMap (pointfreeMethod exp) names
    where
      exp  Name  Q Exp
      exp name
          | name  'fold
              = [| fold  $unwrap |]
          | name  'foldMap
              = [| ( $unwrap)  foldMap |]
          | name  'foldr
              = [| flip flip $unwrap  (() )  foldr |]
          | name  'foldl
              = [| flip flip $unwrap  (() )  foldl |]
          | name  'foldr1
              = [| ( $unwrap)  foldr1 |]
          | name  'foldl1
              = [| ( $unwrap)  foldl1 |]
          | name  'null
              = [| null  $unwrap |]
          | name  'size
              = [| size  $unwrap |]
          | name  'isSingleton
              = [| isSingleton  $unwrap |]
          | otherwise
              = fail $ "deriveFoldable: unknown method: "  pprint name

deriveCollection  Q Cxt  Q Type  Q Exp  Q Exp  Q Dec
deriveCollection c ty wrap unwrap
    = do names  methodNames ''Collection
         instanceD c ty $ concatMap (pointfreeMethod exp) names
    where
      exp  Name  Q Exp
      exp name
          | name  'filter
              = [| ($wrap )  ( $unwrap)  filter |]
          | otherwise
              = fail $ "deriveCollection: unknown method: "  pprint name

deriveIndexed  Q Cxt  Q Type  Q Exp  Q Exp  Q Dec
deriveIndexed c ty wrap unwrap
    = do names  methodNames ''Indexed
         instanceD c ty $ concatMap (pointfreeMethod exp) names
    where
      exp  Name  Q Exp
      exp name
          | name  'index
              = [| ( $unwrap)  index |]
          | name  'adjust
              = [| (($wrap ) )  flip flip $unwrap  (() )  adjust |]
          | name  'inDomain
              = [| ( $unwrap)  inDomain |]
          | name  '(//)
              = [| ($wrap )  (//)  $unwrap |]
          | name  'accum
              = [| (($wrap ) )  ( $unwrap)  accum |]
          | otherwise
              = fail $ "deriveIndexed: unknown method: "  pprint name

deriveMap  Q Cxt  Q Type  Q Exp  Q Exp  Q Dec
deriveMap c ty wrap unwrap
    = do names  methodNames ''Map
         instanceD c ty $ concatMap (pointfreeMethod exp) names
    where
      exp  Name  Q Exp
      exp name
          | name  'delete
              = [| ($wrap )  ( $unwrap)  delete |]
          | name  'member
              = [| ( $unwrap)  member |]
          | name  'union
              = [| ($wrap )  ( $unwrap)  union  $unwrap |]
          | name  'intersection
              = [| ($wrap )  ( $unwrap)  intersection  $unwrap |]
          | name  'difference
              = [| ($wrap )  ( $unwrap)  difference  $unwrap |]
          | name  'isSubset
              = [| ( $unwrap)  isSubset  $unwrap |]
          | name  'isProperSubset
              = [| ( $unwrap)  isProperSubset  $unwrap |]
          | name  'lookup
              = [| ( $unwrap)  lookup |]
          | name  'alter
              = [| (($wrap ) )  flip flip $unwrap  (() )  alter |]
          | name  'insertWith
              = [| ((($wrap ) ) )  flip flip $unwrap  ((flip  (() )) )  insertWith |]
          | name  'fromFoldableWith
              = [| ($wrap )  fromFoldableWith |]
          | name  'foldGroups
              = [| (($wrap ) )  foldGroups |]
          | name  'mapWithKey
              = [| ($wrap )  ( $unwrap)  mapWithKey |]
          | name  'unionWith
              = [| (($wrap ) )  flip flip $unwrap  (() )  ( $unwrap)  unionWith |]
          | name  'intersectionWith
              = [| (($wrap ) )  flip flip $unwrap  (() )  ( $unwrap)  intersectionWith |]
          | name  'differenceWith
              = [| (($wrap ) )  flip flip $unwrap  (() )  ( $unwrap)  differenceWith |]
          | name  'isSubmapBy
              = [| flip flip $unwrap  (() )  ( $unwrap)  isSubmapBy |]
          | name  'isProperSubmapBy
              = [| flip flip $unwrap  (() )  ( $unwrap)  isProperSubmapBy |]
          | otherwise
              = fail $ "deriveMap: unknown method: "  pprint name

deriveSet  Q Cxt  Q Type  Q Exp  Q Exp  Q Dec
deriveSet c ty _ _
    = do names  methodNames ''Set
         instanceD c ty $ concatMap (pointfreeMethod exp) names
    where
      exp  Name  Q Exp
      exp name
          | name  'haddock_candy
              = [| haddock_candy |]
          | otherwise
              = fail $ "deriveSet: unknown method: "  pprint name

deriveSortingCollection  Q Cxt  Q Type  Q Exp  Q Exp  Q Dec
deriveSortingCollection c ty wrap unwrap
    = do names  methodNames ''SortingCollection
         instanceD c ty $ concatMap (pointfreeMethod exp) names
    where
      exp  Name  Q Exp
      exp name
          | name  'minView
              = [| (second $wrap <$>)  minView  $unwrap |]
          | otherwise
              = fail $ "deriveSortingCollection: unknown method: "  pprint name