/*
 * Part of upcompiler. Copyright (c) 2012, Urs Fässler, Licensed under the GNU Genera Public License, v3
 * @author: urs@bitzgi.ch
 */

package cfg.function.argument;


import java.util.Collection;
import java.util.HashSet;
import java.util.Set;

import cfg.IntConstant;
import cfg.IrItr;
import cfg.IrReplaceExprTraverser;
import cfg.expression.Expression;
import cfg.expression.IntegerExpr;
import cfg.expression.IntegerOp;
import cfg.expression.VariableRefUnlinked;
import cfg.function.Function;
import cfg.function.PrgFunction;
import cfg.matcher.VarIntopConstMatcher;
import cfg.variable.SsaVariable;
import disassembler.diStorm3.Registers;

public class EspEbpReplacer extends IrReplaceExprTraverser<Void> {
  private PrgFunction func         = null;
  private Set<Long>  locPtrOffset = new HashSet<Long>();

  static public Set<Long> process(PrgFunction obj) {
    EspEbpReplacer replacer = new EspEbpReplacer();
    replacer.visit(obj, null);
    return replacer.getLocPtrOffset();
  }

  static public void process(Collection<? extends Function> obj) {
    EspEbpReplacer replacer = new EspEbpReplacer();
    replacer.visitCollection(obj, null);
  }

  public Set<Long> getLocPtrOffset() {
    return locPtrOffset;
  }

  @Override
  protected void visitPrgFunction(PrgFunction obj, Void param) {
    assert (func == null);
    func = obj;

    SsaVariable argvar = new SsaVariable(FuncVariables.argPtr, 0);
    SsaVariable locvar = new SsaVariable(FuncVariables.locPtr, 0);
    func.getEntry().insertPhi(argvar);
    func.getEntry().insertPhi(locvar);

    super.visitPrgFunction(obj, param);

    assert (func == obj);
    func = null;
  }

  @Override
  protected Expression visitIntegerExpr(IntegerExpr obj, Void param) {
    VarIntopConstMatcher matcher = new VarIntopConstMatcher();
    matcher.parse(new IrItr(obj));
    if (matcher.hasError() || !(matcher.getVar() instanceof Registers)) {
      return super.visitIntegerExpr(obj, param);
    } else {
      switch ((Registers) matcher.getVar()) {
        case EBP: {
          if (func.useEbpAsStackAddressing()) {
            return doEbp(matcher.getOp(), matcher.getValue());
          } else {
            return super.visitIntegerExpr(obj, param);
          }
        }
        case ESP: {
          return doEsp(matcher.getOp(), matcher.getValue());
        }
        default: {
          return super.visitIntegerExpr(obj, param);
        }
      }
    }
  }

  private Expression doEsp(IntegerOp op, long value) {
    if (op != IntegerOp.Add) {
      throw new RuntimeException("Unexpected operation: " + op);
    }
    if (value < 0) {
      throw new RuntimeException("Unexpected value: " + value);
    }
    int argOffsetFromEsp = func.getBackupSize() + func.getStacksize() + 4;
    if (value >= argOffsetFromEsp) {
      value = value - argOffsetFromEsp;
      return new IntegerExpr(new VariableRefUnlinked(FuncVariables.argPtr), new IntConstant(value), IntegerOp.Add);
    } else {
      if (value >= func.getStacksize()) {
        throw new RuntimeException("Accessing restricted area: " + value);
      }
      locPtrOffset.add(value);
      return new IntegerExpr(new VariableRefUnlinked(FuncVariables.locPtr), new IntConstant(value), IntegerOp.Add);
    }
  }

  private Expression doEbp(IntegerOp op, long value) {
    if ((op != IntegerOp.Add) && (op != IntegerOp.Sub)) {
      throw new RuntimeException("Unexpected operation: " + op);
    }
    if (op == IntegerOp.Sub) {
      value = -value;
    }
    if (value >= 8) {
      value = value - 8;
      assert (value >= 0);
      return new IntegerExpr(new VariableRefUnlinked(FuncVariables.argPtr), new IntConstant(value), IntegerOp.Add);
    } else {
      value = value + func.getBackupSize() + func.getStacksize() - 4;
      if (value < 0) {
        throw new RuntimeException("Accessing nomansland: " + value);
      }
      if (value >= func.getStacksize()) {
        throw new RuntimeException("Accessing restricted area: " + value);
      }
      locPtrOffset.add(value);
      return new IntegerExpr(new VariableRefUnlinked(FuncVariables.locPtr), new IntConstant(value), IntegerOp.Add);
    }
  }

}
