/*
 * Part of upcompiler. Copyright (c) 2012, Urs Fässler, Licensed under the GNU Genera Public License, v3
 * @author: urs@bitzgi.ch
 */

package ast.traverser;

import knowledge.KnowTypes;
import ast.expression.ArithmeticExpr;
import ast.expression.Expression;
import ast.expression.IntConstant;
import ast.expression.IntegerOp;
import ast.expression.VariablePtrOf;
import ast.type.Type;
import ast.variable.ArrayAccess;
import ast.variable.VariableRef;

/*
 * var5 + -0x8
 * var5
 * -2
 * 
 * (var9 * 0x4) + (var10 + -0x4)
 * var10
 * var9-1
 * 
 * (var21 * 0x4) + &var6[0x0]
 * var6
 * var21
 * 
 */

public class PointerExprMatcher {
  private Expression  displacement;
  private VariableRef pointer;
  private KnowTypes   types;

  public PointerExprMatcher(Expression expr, KnowTypes types) {
    this.types = types;
    parse(expr);
  }

  public Expression getDisplacement() {
    return displacement;
  }

  public VariableRef getPointer() {
    return pointer;
  }

  private void parse(Expression expr) {
    if (expr instanceof ArithmeticExpr) {
      parseArithmeticExpr((ArithmeticExpr) expr);
    } else if (expr instanceof VariableRef) {
      pointer = (VariableRef) expr;
      assert (types.getTypeOf(pointer) == Type.Pointer);
      displacement = new IntConstant(0);
    } else if (expr instanceof VariablePtrOf) {
      expr = ((VariablePtrOf) expr).getVar();
      if (!(expr instanceof ArrayAccess)) {
        throw new RuntimeException("Unknown pointer of: " + expr);
      }
      pointer = ((ArrayAccess) expr).getBase();
      displacement = ((ArrayAccess) expr).getIndex();
    } else {
      throw new RuntimeException("Unknown pointer access: " + expr);
    }
  }

  private void parseArithmeticExpr(ArithmeticExpr expr) {
    if ((expr.getRight() instanceof ArithmeticExpr) && (expr.getLeft() instanceof ArithmeticExpr)) {
      assert (expr.getOp() == IntegerOp.Add);
      ArithmeticExpr rarh = (ArithmeticExpr) expr.getRight();
      ArithmeticExpr larh = (ArithmeticExpr) expr.getLeft();

      Expression index;
      if (rarh.getOp() == IntegerOp.Mul) {
        assert (larh.getOp() != IntegerOp.Mul);
        parse(larh);
        index = parseIndexVar(rarh);
      } else {
        assert (larh.getOp() == IntegerOp.Mul);
        parse(rarh);
        index = parseIndexVar(larh);
      }

      assert (displacement != null);
      assert (types != null);

      displacement = new ArithmeticExpr(displacement, index, IntegerOp.Add);
    } else if (expr.getRight() instanceof ArithmeticExpr) {
      assert (expr.getOp() == IntegerOp.Add);

      parse(expr.getLeft());
      Expression index = parseIndexVar((ArithmeticExpr) expr.getRight());

      assert (displacement != null);
      assert (types != null);

      displacement = new ArithmeticExpr(displacement, index, IntegerOp.Add);
    } else if (expr.getLeft() instanceof ArithmeticExpr) {
      assert (expr.getOp() == IntegerOp.Add);

      parse(expr.getRight());
      Expression index = parseIndexVar((ArithmeticExpr) expr.getLeft());

      assert (displacement != null);
      assert (types != null);

      displacement = new ArithmeticExpr(displacement, index, IntegerOp.Add);
    } else if (expr.getLeft() instanceof VariableRef) {
      assert (expr.getOp() == IntegerOp.Add);
      if (expr.getLeft() instanceof IntConstant) {
        pointer = (VariableRef) expr.getRight();
        displacement = expr.getLeft();
      } else if (expr.getRight() instanceof IntConstant) {
        pointer = (VariableRef) expr.getLeft();
        displacement = expr.getRight();
      } else if (types.getTypeOf(expr.getLeft()) == Type.Pointer) {
        pointer = (VariableRef) expr.getLeft();
        displacement = expr.getRight();
      } else {
        pointer = (VariableRef) expr.getRight();
        displacement = expr.getLeft();
      }
      assert (types.getTypeOf(pointer) == Type.Pointer);

      if (displacement instanceof IntConstant) {
        assert (expr.getRight() instanceof IntConstant);
        long displ = ((IntConstant) displacement).getValue();
        assert (displ % 4 == 0);
        displacement = new IntConstant(displ / 4);
      } else {
        displacement = new ArithmeticExpr(displacement, new IntConstant(4), IntegerOp.Div);

      }
    } else {
      throw new RuntimeException("Not yet implemented");
    }
  }

  private Expression parseIndexVar(ArithmeticExpr expr) {
    if ((expr.getOp() == IntegerOp.Mul) && (expr.getRight() instanceof IntConstant)
        && (((IntConstant) expr.getRight()).getValue() == 4)) {
      return expr.getLeft();
    } else {
      return new ArithmeticExpr(expr, new IntConstant(4), IntegerOp.Div);
    }
  }

}
