/*

 LCPlus is a super-toy lambda-calculus-based language.

 Author: Matthew Might
 Site:   http://matt.might.net/

 */

package languages.lcplus ;

import languages.sexp._ ;


object LCPlusSyntax {
  import SExpSyntax._ ;


  // Custom deconstructor pattern matchers for S-Expressions:
  object SLambda {
    // unapply() is the inverse of apply()
    def unapply(sx : SExp) : Option[(List[S], SExp)] = sx match {
      case L(S("lambda") :: L(sxvars) :: List(sxbody)) =>
        Some(sxvars.map(_.asInstanceOf[S]), sxbody)
      case _ => None
    }
  }

  object SIfZero {
    def unapply(sx : SExp) : Option[(SExp,SExp,SExp)] = sx match {
      case L(S("if-zero") :: cond :: cons :: alt :: List()) => 
        Some(cond,cons,alt)
      case _ => None
    }
  }

  object SLets {
    def unapply(sx : SExp) : Option[(List[S], List[SExp], SExp)] = sx match {
      case L(S("let*") :: L(clauses) :: body :: List()) => { 
        val namesXvalues : List[(S,SExp)] = clauses map ({case L(List(name,value)) => (name.asInstanceOf[S],value)})
        val (names,values) = List.unzip(namesXvalues)
        Some(names,values,body)
      }
      case _ => None
    }
  }



  /* Terms. */
  private object Term {
    private var maxId = 0

    def nextId() : Int = { maxId = maxId + 1 ; maxId }
  }
  
  abstract class Term {
    val tid = Term.nextId() 

    var sx : SExp = null
    def from (sx : SExp) : Term = {
      this.sx = sx
      this
    }

    def toSExp : SExp
    override def toString = toSExp.toString

    override def hashCode() = tid
    override def equals(a : Any) = a.asInstanceOf[Term].tid == tid
  }

  /* Denoters. */
  case class Var(val name : String) extends Term with Ordered[Var] {
    override val toString = name
    override def compare (y : Var) = name compare (y.name)
    override def equals (a : Any) = a.asInstanceOf[Var].name equals name
    override def hashCode() = name.hashCode()
    override def from (sx : SExp) : Var = super.from(sx).asInstanceOf[Var]
    override def toSExp : SExp = S(name)
  }


  /* Expressions */
  abstract class Exp extends Term with Ordered[Exp] {
    override def from (sx : SExp) : Exp = super.from(sx).asInstanceOf[Exp]
    def subst (v : Var, e : Exp) : Exp ; // Closed-term substitution.
    def compare (t2 : Exp) = tid compare t2.tid
  }

  // Literals:
  case class IntLit(z : BigInt) extends Exp {
    def subst (v : Var, e : Exp) : Exp = this
    def toSExp = Z(z)
  }

  case class PrimOp(s : String) extends Exp {
    def subst (v : Var, e : Exp) : Exp = this
    def toSExp = S(s)
  }


  // Core:
  case class Lambda (params : List[Var], body : Exp) extends Exp {
    def subst (v : Var, e : Exp) : Exp =
      if (params contains v) 
        this
      else
        Lambda(params,body subst (v,e))

    def toSExp = L(List(S("lambda"), L(params map (v => S(v.name))), body.toSExp))
  }
  case class Ref (v : Var) extends Exp {
    def subst (v : Var, e : Exp) = if (v equals this.v) { e } else { this }
    def toSExp = S(v.name)
  }
  case class App (f : Exp, args : List[Exp]) extends Exp {
    def subst (v : Var, e : Exp) : Exp = App(f subst (v,e), args map (a => a subst (v,e)))
    def toSExp = L(f.toSExp :: (args map (_.toSExp)))
  }
  case class IfZero (cond : Exp, ifTrue : Exp, ifFalse : Exp) extends Exp {
    def subst (v : Var, e : Exp) : Exp = IfZero(cond subst (v,e), ifTrue subst (v,e), ifFalse subst (v,e))
    def toSExp = L(S("if-zero") :: cond.toSExp :: ifTrue.toSExp :: ifFalse.toSExp :: List())
  }
  

  // Sugar:
  object Let1 {
    def apply(name : Var, exp : Exp, body : Exp) : Exp =
      Let(List(name), List(exp), body)
  }

  object Let {
    def apply(names : List[Var], exps : List[Exp], body : Exp) : Exp = 
      App(Lambda(names,body), exps)
  }

  object Lets {
    def apply(names : List[Var], exps : List[Exp], body : Exp) : Exp = 
      (names,exps) match {
        case (name :: names, exp :: exps) => Let1(name,exp,Lets(names,exps,body))
        case (List(),List()) => body
      }
  }

  // For CPS conversion:
  case class CPS(exp : Exp) extends Exp {
    def subst (v : Var, e : Exp) : Exp = this
    def toSExp = L(List(S("cps"), exp.toSExp))
  }



  // Global syntax-related objects:
  object SymbolTable {
    private var current = 0
    private def next() = {
      current += 1
      current
    }

    // Assume $ not used in symbols:
    def fresh() = "$" + next()

    def fresh(prefix : String) = {
      val s = prefix + "$" + next()
      val v = Var(s)
      val r = Ref(v)
      (s,v,r)
    }
  }
  

  // parseVar: Converts a S-Expression symbol into a Var.
  def parseVar (sx : SExp) : Var = (sx match {
    case S(name) => Var(name)
  }).from(sx)



  // parseExp: Converts an S-Expression into an Exp.
  def parseExp (sx : SExp) : Exp = (sx match {

    // primitives:
    case S("succ") => PrimOp("succ")
    case S("pred") => PrimOp("pred")
    case S("halt") => PrimOp("halt")
    
    case Z(z) => IntLit(z)

    case S(v) => Ref(Var(v))

    // conditionals:
    case SIfZero(cond,ifTrue,ifFalse) => 
      IfZero(parseExp(cond),parseExp(ifTrue),parseExp(ifFalse))
    
    // lambda terms:
    case SLambda(names, body) =>
      Lambda(names map parseVar, parseExp(body))

    // sequential let:
    case SLets(names,values,body) => {
      def parse (names : List[S],values : List[SExp],body : SExp) : Exp = 
        if (names.isEmpty)
          parseExp(body)
        else
          Let1(parseVar(names.head), parseExp(values.head), parse (names.tail,values.tail,body))
      parse (names,values,body)
    }

    // CPS literals:
    case L(List(S("cps"),sexp)) => CPS(parseExp(sexp))
    
    // function application:
    case L(f :: args) => App(parseExp(f), args map parseExp)

    // error:
    case L(List()) => throw new Exception("Empty list!")

  }).from(sx)
}




object Interpreter {
  import LCPlusSyntax._

  // isFinal: Final terms cannot be further reduced.
  def isFinal (exp : Exp) : Boolean = exp match {
    case (Ref(_) | Lambda(_,_) | IntLit(_) | PrimOp(_) | CPS(PrimOp(_))) => true
    case _ => false
  }

  // beta: Perform function/primitive application.
  def beta(exp : Exp) : Exp = exp match {
    case App(PrimOp("succ"),List(IntLit(z))) => IntLit(z + 1)
    case App(PrimOp("pred"),List(IntLit(z))) => IntLit(z - 1)

    case App(Lambda(v::vars,body),a::args) => 
      App(Lambda(vars,body subst (v,a)),args)
      
    case App(Lambda(List(),body),List()) => 
      body
    
    case _ => throw new Exception("Not a redex: " + exp)
  }

  // beta_cbv: Reduce argument and function expressions until all are
  // final; then beta-reduce.
  def beta_cbv(exp : Exp) : Exp = exp match {

    // Halt on halt:
    case App(PrimOp("halt"),List(arg)) if isFinal(arg) => arg
    
    // Branch on if-zero:
    case IfZero(cond,ifTrue,ifFalse) if isFinal(cond) => cond match {
      case IntLit(z) if z == 0 => ifTrue
      case _ => ifFalse
    }

    // Reduce the condition if it's not final:
    case IfZero(cond,ifTrue,ifFalse) => IfZero(beta_cbv(cond),ifTrue,ifFalse)

    // If any argument is non-final, reduce them all once:
    case App(f,args) if args exists (!isFinal(_)) => App(f, args map beta_cbv)

    // Beta-reduce a beta-redex:
    case App(Lambda(vars,body),args) => beta(exp)
    
    // Beta-reduce a primitive application:
    case App(PrimOp(_),args) => beta(exp)

    // Reduce the function expression:
    case App(f,args) => App(beta_cbv(f), args)
    
    // Don't reduce final terms:
    case e if isFinal(e) => e
  }

  // run_cbv: Run call-by-value beta-reduction until the term is final.
  def run_cbv(init : Exp) : Exp = {
    var e = init
    while (!isFinal(e)) {
      if (LCPlus.printIntermediateReductions) {
        println(e) 
        println()
      }
      e = beta_cbv(e)
    }
    e
  }
}




object LCPlus {
  import LCPlusSyntax._

  var printIntermediateReductions = false

  def main(args : Array[String]) {
    if (args contains "--print-intermediate")
      printIntermediateReductions = true

    val stdin : String = (scala.io.Source.fromInputStream(System.in)) mkString ""
    val sxs = SParser.parse(stdin)
    val sexp = sxs.head
    val exp = LCPlusSyntax.parseExp(sexp)

    println(Interpreter.run_cbv(exp))
  }
}
