/*
 * Part of upcompiler. Copyright (c) 2012, Urs Fässler, Licensed under the GNU Genera Public License, v3
 * @author: urs@bitzgi.ch
 */

package ast.dataflow;

import java.util.Set;

import ast.Ast;
import ast.PrgFunc;
import ast.expression.ArithmeticExpr;
import ast.expression.CompareExpr;
import ast.expression.Expression;
import ast.expression.IntConstant;
import ast.expression.IntegerOp;
import ast.expression.UnaryExpression;
import ast.expression.UnaryOp;
import ast.statement.AssignmentStmt;
import ast.statement.NullStmt;
import ast.statement.Statement;
import ast.statement.VarDef;
import ast.traverser.AstExpressionTraverser;
import ast.traverser.AstStatementTraverser;
import ast.traverser.AstTraverser;
import ast.traverser.NopAstStatementTraverser;
import ast.traverser.VariableAccessCollector;
import ast.variable.Variable;
import ast.variable.VariableRefLinked;

public class ExpressionReduction extends AstTraverser<Void> {

  public ExpressionReduction() {
    super(new StmtReduction());
  }

  static public void normalize(Ast ast) {
    ExpressionReduction asttraverser = new ExpressionReduction();
    asttraverser.visit(ast, null);
  }

  @Override
  public StmtReduction getStmt() {
    return (StmtReduction) super.getStmt();
  }

  @Override
  protected Ast visitPrgFunc(PrgFunc obj, Void param) {
    VariableAccessCollector coll = new VariableAccessCollector();
    AstTraverser<Void> atra = new AstTraverser<Void>(new NopAstStatementTraverser<Void>(coll));
    atra.visit(obj, null);
    getStmt().setUsedVariables(coll.getReferenced());
    return super.visitPrgFunc(obj, null);
  }

}

class StmtReduction extends AstStatementTraverser<Void> {
  private ExprReduction exptrav = new ExprReduction();
  private Set<Variable> used;

  @Override
  public Expression visit(Expression expr, Void param) {
    return exptrav.visit(expr, param);
  }

  public void setUsedVariables(Set<Variable> referenced) {
    used = referenced;
  }

  @Override
  public Variable visit(Variable expr, Void param) {
    return exptrav.visit(expr, param);
  }

  @Override
  protected Statement visitVarDef(VarDef obj, Void param) {
    if (used.contains(obj.getVariable())) {
      return super.visitVarDef(obj, null);
    } else {
      return new NullStmt();
    }
  }

  @Override
  protected Statement visitAssignmentStmt(AssignmentStmt obj, Void param) {
    // remove copies to itself
    if ((obj.getSource() instanceof VariableRefLinked) && (obj.getDestination() instanceof VariableRefLinked)) {
      if (((VariableRefLinked) obj.getSource()).getReference() == ((VariableRefLinked) obj.getDestination())
          .getReference()) {
        return new NullStmt();
      }
    }
    return super.visitAssignmentStmt(obj, null);
  }
}

class ExprReduction extends AstExpressionTraverser<Void> {

  private static Expression evaluate(IntegerOp op, IntConstant left, IntConstant right) {
    switch (op) {
      case Add:
        return new IntConstant(left.getValue() + right.getValue());
      case Sub:
        return new IntConstant(left.getValue() - right.getValue());
      case Mul:
        return new IntConstant(left.getValue() * right.getValue());
      case Div:
        return new IntConstant(left.getValue() / right.getValue());
      case ShiftLeft:
        return new IntConstant(left.getValue() << right.getValue());
      case ShiftRight:
        return new IntConstant(left.getValue() >> right.getValue());
      case And:
        return new IntConstant(left.getValue() & right.getValue());
      case Or:
        return new IntConstant(left.getValue() | right.getValue());
      case Xor:
        return new IntConstant(left.getValue() ^ right.getValue());
      default:
        throw new RuntimeException("Unhandled operation: " + op);
    }
  }

  @Override
  protected Expression visitArithmeticExpr(ArithmeticExpr obj, Void param) {
    obj = (ArithmeticExpr) super.visitArithmeticExpr(obj, param);

    IntConstant left = null;
    IntConstant right = null;

    if (obj.getLeft() instanceof IntConstant) {
      left = (IntConstant) obj.getLeft();
    }
    if (obj.getRight() instanceof IntConstant) {
      right = (IntConstant) obj.getRight();
    }
    if ((left != null) && (right != null)) {
      return evaluate(obj.getOp(), left, right);
    }
    if ((left == null) && (right == null)) {
      return obj;
    }

    IntConstant cons;
    Expression var;

    if (left != null) {
      cons = left;
      var = obj.getRight();
    } else {
      cons = right;
      var = obj.getLeft();
    }

    switch (obj.getOp()) {
      case Add:
      case Xor: {
        switch ((int) cons.getValue()) {
          case 0:
            return var;
          default:
            return obj;
        }
      }
      case And: {
        switch ((int) cons.getValue()) {
          case 0:
            return new IntConstant(0);
          default:
            return obj;
        }
      }
      case Div: {
        switch ((int) cons.getValue()) {
          case 1:
            return var;
          default:
            return obj;
        }
      }
      case Mul: {
        switch ((int) cons.getValue()) {
          case 0:
            return new IntConstant(0);
          case 1:
            return var;
          default:
            return obj;
        }
      }
      case Or:
      case Sub:
      case ShiftLeft:
      case ShiftRight: {
        return obj;
      }
      default: {
        throw new RuntimeException("Unhandled integer expression: " + obj.getOp());
      }
    }
  }

  @Override
  protected Expression visitUnaryExpression(UnaryExpression obj, Void param) {
    obj.setExpr(visit(obj.getExpr(), null));
    if ((obj.getOp() == UnaryOp.Not) && (obj.getExpr() instanceof CompareExpr)) {
      CompareExpr comp = (CompareExpr) obj.getExpr();
      comp.setOperand(comp.getOperand().getInverse());
      return comp;
    } else {
      return obj;
    }
  }

}
