/*

 This object has three CPS transformers: 
  (i)   a naive CPS transformer,
  (ii)  a smarter CPS transformer and
  (iii) a CPS-based CPS transformer.     

 The results of (ii) and (iii) are identical, but there is less
 special case handling in (iii).

 These transformers only work on the pure lambda calculus: 

  e ::= (lambda (v) e)
     |  (e1 e2)
     |  v


 Author: Matthew Might
 Site:   http://matt.might.net/

 */

package languages.lcplus ;
import languages.sexp._ ;


object LambdaCalculusCPSTransformer {
  import LCPlusSyntax._ ;

  // Note: This only accepts the core lambda calculus.
  

  private def isAtom (e : Exp) = 
    e.isInstanceOf[Ref] || e.isInstanceOf[Lambda]


  /* The _a functions convert a direct-style atomic expression (a
     reference or a lambda term) into a CPS atomic expression.

     In CPS, every function is extended with a new argument.  The CPS
     protocol dictates that a function should invoke this argument on
     its return value. */

  /* The _t functions accept a direct-style expression (e) and a CPS
     expression (q).  They need to return a call site that, when
     executed, will end up passing the CPS-converted result of e into q. */
  

  // The naive translator:
  def naive_a (e : Exp) : Exp = e match {
    case Ref(_) => e

    case Lambda(List(v),body) => {
      val (k,vk,rk) = SymbolTable.fresh("k")
      Lambda(List(v,vk), naive_t (body) (rk))
    }
  }

  def naive_t (e : Exp) (q : Exp) : Exp = e match {
    case a if isAtom(a) =>
      App(q, List(naive_a(a)))

    case App(f,List(e)) => {
      val (f_,vf_,rf_) = SymbolTable.fresh("f")
      val (e_,ve_,re_) = SymbolTable.fresh("e")
      naive_t (f) (Lambda(List(vf_), 
                          naive_t (e) (Lambda(List(ve_),
                                              App(rf_, List(re_,q))))))
    }
  }




  // A smarter translator:
  def smarter_a (e : Exp) : Exp = e match {
    case Ref(_) => e

    case Lambda(List(v),body) => {
      val (k,vk,rk) = SymbolTable.fresh("k")
      Lambda(List(v,vk), smarter_t (body) (rk))
    } 
  }


  def smarter_t (e : Exp) (q : Exp) : Exp = e match {
    case a if isAtom(a) =>
      App(q, List(smarter_a(a)))

    // Use four patterns instead of one to catch the special cases:
    case App(f,List(e)) if isAtom(f) && isAtom(e) => 
      App(smarter_a(f), List(smarter_a(e),q))

    case App(f,List(e)) if isAtom(f) && !isAtom(e) => {
      val (e_,ve_,re_) = SymbolTable.fresh("e")
      smarter_t (e) (Lambda(List(ve_),
                            App(smarter_a(f), List(re_,q))))
    }

    case App(f,List(e)) if !isAtom(f) && isAtom(e) => {
      val (f_,vf_,rf_) = SymbolTable.fresh("f")
      smarter_t (f) (Lambda(List(vf_),
                            App(rf_, List(smarter_a(e),q))))
    }

    case App(f,List(e)) => {
      val (f_,vf_,rf_) = SymbolTable.fresh("f")
      val (e_,ve_,re_) = SymbolTable.fresh("e")
      smarter_t (f) (Lambda(List(vf_), 
                            smarter_t (e) (Lambda(List(ve_),
                                                  App(rf_, List(re_,q))))))
    }
  }




  // A continuation-based CPS translator:
  def cont_a (e : Exp) : Exp = e match {
    case Ref(_) => e

    case Lambda(List(v),body) => {
      val (k,vk,rk) = SymbolTable.fresh("k")
      Lambda(List(v,vk), cont_t (body) (rk))
    } 
  }

  def cont_t (e : Exp) (q : Exp) : Exp = e match {
    case a if isAtom(a) => 
      App(q, List(cont_a(e)))

    case App(f,List(e)) => {
      cont_c (f) (f_ =>
        cont_c (e) (e_ =>
          App(f_, List(e_,q))))
    }
  }

  // cont_c needs to create a context in which it has an expression e'
  // containing the (CPS-coverted) result of e.  k(e') should be
  // inserted into that context.
  def cont_c (e : Exp) (k : Exp => Exp) = e match {
    case a if isAtom(a) => k(cont_a(a))

    case App(f,List(e)) => {
      val (rv,vrv,rrv) = SymbolTable.fresh("rv")
      cont_t (e) (Lambda(List(vrv),k(rrv)))
    }
  }



  // A driver:
  def main (args : Array[String]) {
    val stdin : String = (scala.io.Source.fromInputStream(System.in)) mkString ""
    val sxs = SParser.parse(stdin)
    val sexp = sxs.head
    val exp = LCPlusSyntax.parseExp(sexp)

    var cpsexp : Exp = null

    if (args contains "--smarter") {
      cpsexp = smarter_t (exp) (PrimOp("halt"))
    } else if (args contains "--cont") {
      cpsexp = cont_t (exp) (PrimOp("halt"))
    } else {
      cpsexp = naive_t (exp) (PrimOp("halt"))
    }

    println(cpsexp)    
  }
    
}
