/*
 * Part of upcompiler. Copyright (c) 2012, Urs Fässler, Licensed under the GNU Genera Public License, v3
 * @author: urs@bitzgi.ch
 */

package knowledge;

import java.util.Collection;
import java.util.HashMap;
import java.util.Map;

import ast.Ast;
import ast.PrgFunc;
import ast.expression.ArithmeticExpr;
import ast.expression.BooleanConstant;
import ast.expression.CallExpr;
import ast.expression.CompareExpr;
import ast.expression.Expression;
import ast.expression.FunctionRef;
import ast.expression.FunctionRefLinked;
import ast.expression.FunctionRefUnlinked;
import ast.expression.IfExpr;
import ast.expression.IntConstant;
import ast.expression.NullConstExpr;
import ast.expression.NullExpr;
import ast.expression.StringConstant;
import ast.expression.UnaryExpression;
import ast.expression.VariablePtrDeref;
import ast.expression.VariablePtrOf;
import ast.statement.AssignmentStmt;
import ast.statement.CallStmt;
import ast.statement.CaseStmt;
import ast.statement.DoWhile;
import ast.statement.IfStmt;
import ast.statement.RetStmt;
import ast.statement.Statement;
import ast.statement.VarDef;
import ast.statement.WhileStmt;
import ast.traverser.AstExpressionTraverser;
import ast.traverser.AstTraverser;
import ast.traverser.NopAstStatementTraverser;
import ast.type.Type;
import ast.variable.ArrayAccess;
import ast.variable.Variable;
import ast.variable.VariableRefLinked;
import ast.variable.VariableRefUnlinked;

public class KnowTypes extends KnowledgeEntry {
  KnowledgeBase  base;
  Map<Ast, Type> types = new HashMap<Ast, Type>();

  @Override
  public void init(KnowledgeBase base) {
    this.base = base;
  }

  public int process() {
    int sum = 0;
    AstTypeFinder finder;
    do {
      finder = new AstTypeFinder(types);
      finder.visit(base.getPrg(), null);
      sum += finder.getChanges();
    } while (finder.getChanges() > 0);
    return sum;
  }

  public Type getTypeOf(Ast obj) {
    if (!types.containsKey(obj)) {
      process();
    }
    if (!types.containsKey(obj)) {
      throw new RuntimeException("No type found for: " + obj + "(" + obj.hashCode() + ")");
    }
    return types.get(obj);
  }
}

class AstTypeFinder extends AstTraverser<Type> {

  public AstTypeFinder(Map<Ast, Type> types) {
    super(new StatementTypeFinder(types));
  }

  public int getChanges() {
    return getStmt().getChanges();
  }

  public StatementTypeFinder getStmt() {
    return (StatementTypeFinder) super.getStmt();
  }

  @Override
  protected Ast visitPrgFunc(PrgFunc obj, Type param) {
    getStmt().setType(obj, obj.getReturnType());
    getStmt().visit(obj.getBody(), null);
    return obj;
  }

}

class StatementTypeFinder extends NopAstStatementTraverser<Type> {

  public StatementTypeFinder(Map<Ast, Type> types) {
    super(new ExpressionTypeFinder(types));
  }

  public int getChanges() {
    return ((ExpressionTypeFinder) getExpr()).getChanges();
  }

  public void setType(Ast obj, Type type) {
    ((ExpressionTypeFinder) getExpr()).setType(obj, type);
  }

  public Type getNarrowType(Collection<? extends Ast> values) {
    return ((ExpressionTypeFinder) getExpr()).getNarrowType(values);
  }

  public Type getType(Ast obj) {
    return ((ExpressionTypeFinder) getExpr()).getType(obj);
  }

  @Override
  protected Statement visitVarDef(VarDef obj, Type param) {
    assert (param == null);
    Type type = getType(obj.getVariable());
    type = Type.getNarrowType(type, obj.getVariable().getType());
    getExpr().visit(obj.getVariable(), type);
    return obj;
  }

  @Override
  protected Statement visitCallStmt(CallStmt obj, Type param) {
    assert (param == null);
    getExpr().visit(obj.getCall(), Type.Generic);
    return obj;
  }

  @Override
  protected Statement visitIfStmt(IfStmt obj, Type param) {
    assert (param == null);
    getExpr().visit(obj.getCondition(), Type.Boolean);
    if (obj.getThenBranch() != null) {
      visit(obj.getThenBranch(), null);
    }
    if (obj.getElseBranch() != null) {
      visit(obj.getElseBranch(), null);
    }
    return obj;
  }

  @Override
  protected Statement visitWhileStmt(WhileStmt obj, Type param) {
    assert (param == null);
    getExpr().visit(obj.getCondition(), Type.Boolean);
    visit(obj.getBody(), null);
    return obj;
  }

  @Override
  protected Statement visitDoWhileStmt(DoWhile obj, Type param) {
    assert (param == null);
    visit(obj.getBody(), null);
    getExpr().visit(obj.getCondition(), Type.Boolean);
    return obj;
  }

  @Override
  protected Statement visitCaseStmt(CaseStmt obj, Type param) {
    assert (param == null);
    getExpr().visit(obj.getCondition(), Type.Integer);
    for (int key : obj.getOption().keySet()) {
      Statement value = obj.getOption().get(key);
      value = visit(value, param);
      obj.getOption().put(key, value);
    }
    if (obj.getOther() != null) {
      visit(obj.getOther(), null);
    }
    return obj;
  }

  @Override
  protected Statement visitRetStmt(RetStmt obj, Type param) {
    Type type = getType(obj.getFunction());
    assert (obj.getRetval() != null);
    if (obj.getRetval() instanceof NullExpr) {
      type = Type.Void;
    } else {
      getExpr().visit(obj.getRetval(), Type.Scalar);
      type = Type.getNarrowType(type, getType(obj.getRetval()));
    }
    setType(obj.getFunction(), type);
    return obj;
  }

  @Override
  protected Statement visitAssignmentStmt(AssignmentStmt obj, Type param) {
    assert (param == null);
    Type type = getType(obj);
    super.visitAssignmentStmt(obj, type);
    type = Type.getNarrowType(getType(obj.getDestination()), getType(obj.getSource()));
    setType(obj, type);
    return obj;
  }

}

class ExpressionTypeFinder extends AstExpressionTraverser<Type> {
  private Map<Ast, Type> type;
  private int            change = 0;

  public ExpressionTypeFinder(Map<Ast, Type> types) {
    super();
    this.type = types;
  }

  public int getChanges() {
    return change;
  }

  public Type setType(Ast elem, Type type) {
    assert (type != null);
    Type varType = this.type.get(elem);
    if (varType == null) {
      this.type.put(elem, type);
      change++;
      return type;
    } else if (varType == type) {
      return varType;
    } else if (varType.isSupertypeOf(type)) {
      this.type.put(elem, type);
      change++;
      return type;
    } else {
      assert (type.isSupertypeOf(varType));
      return varType;
    }
  }

  public Type getNarrowType(Collection<? extends Ast> values) {
    Type type = Type.Generic;
    for (Ast elem : values) {
      type = Type.getNarrowType(getType(elem), type);
    }
    return type;
  }

  public Type getType(Ast elem) {
    if (!type.containsKey(elem)) {
      type.put(elem, Type.Generic);
    }
    return type.get(elem);
  }

  @Override
  protected Expression visitNullConstExpr(NullConstExpr obj, Type param) {
    assert (param == Type.Pointer);
    setType(obj, Type.Pointer);
    return obj;
  }

  @Override
  protected FunctionRef visit(FunctionRef obj, Type param) {
    throw new RuntimeException("Not yet implemented");
  }

  @Override
  protected FunctionRef visitFunctionRefUnlinked(FunctionRefUnlinked obj, Type param) {
    throw new RuntimeException("Not yet implemented");
  }

  @Override
  protected FunctionRef visitFunctionRefLinked(FunctionRefLinked obj, Type param) {
    throw new RuntimeException("Not yet implemented");
  }

  @Override
  public Variable visit(Variable obj, Type param) {
    setType(obj, param);
    return obj;
  }

  @Override
  protected Expression visitArrayAccess(ArrayAccess obj, Type param) {
    visit(obj.getIndex(), Type.Integer);
    setType(obj, param);
    return obj;
  }

  @Override
  protected Expression visitVariablePtrOf(VariablePtrOf obj, Type param) {
    assert (param.isSupertypeOf(Type.Pointer));
    setType(obj, Type.Pointer);
    visit(obj.getVar(), Type.Generic);
    return obj;
  }

  @Override
  protected Expression visitVariablePtrDeref(VariablePtrDeref obj, Type param) {
    visit(obj.getExpr(), Type.Pointer);
    setType(obj, param);
    return obj;
  }

  @Override
  protected Expression visitCallExpr(CallExpr obj, Type param) {
    if (!(obj.getFunction() instanceof FunctionRefLinked)) {
      return obj;
    }

    FunctionRefLinked func = (FunctionRefLinked) obj.getFunction();

    assert (obj.getParam().size() == func.getFunc().getParam().size());
    for (int i = 0; i < obj.getParam().size(); i++) {
      Type type = Type.getNarrowType(getType(obj.getParam().get(i)), getType(func.getFunc().getParam().get(i)));
      type = Type.getNarrowType(type, func.getFunc().getParam().get(i).getType());
      visit(obj.getParam().get(i), type);
      setType(func.getFunc().getParam().get(i), getType(obj.getParam().get(i)));
    }

    Type narrow = Type.getNarrowType(getType(obj), func.getFunc().getReturnType());
    narrow = Type.getNarrowType(narrow, param);
    narrow = Type.getNarrowType(narrow, getType(func.getFunc()));
    setType(obj, narrow);
    setType(func.getFunc(), narrow);
    return obj;
  }

  @Override
  protected Expression visitUnaryExpression(UnaryExpression obj, Type param) {
    switch (obj.getOp()) {
      case Neg: {
        param = Type.getNarrowType(param, Type.Number);
        param = Type.getNarrowType(param, getType(obj));
        setType(obj, param);
        visit(obj.getExpr(), param);
        return obj;
      }
      case Not: {
        setType(obj, Type.Boolean);
        visit(obj.getExpr(), Type.Boolean);
        return obj;
      }
    }
    throw new RuntimeException("Not yet implemented");
  }

  @Override
  protected Expression visitVariableRefUnlinked(VariableRefUnlinked obj, Type param) {
    throw new RuntimeException("Not yet implemented");
  }

  @Override
  protected Expression visitVariableRefLinked(VariableRefLinked obj, Type param) {
    Type narrow = Type.getNarrowType(getType(obj), getType(obj.getReference()));
    narrow = Type.getNarrowType(narrow, param);
    setType(obj, setType(obj.getReference(), param));
    setType(obj.getReference(), setType(obj.getReference(), param));
    return obj;
  }

  @Override
  protected Expression visitBooleanConstant(BooleanConstant obj, Type param) {
    assert (param.isSupertypeOf(Type.Boolean));
    setType(obj, Type.Boolean);
    return obj;
  }

  @Override
  protected Expression visitConstant(IntConstant obj, Type param) {
    switch (param) {
      case Generic:
      case Scalar:
      case Number:
        if (obj.getValue() < 0) {
          setType(obj, Type.Integer);
        } else {
          setType(obj, Type.Number);
        }
        break;
      case Integer:
        setType(obj, Type.Integer);
        break;
      case Pointer:
        if (obj.getValue() < 0) {
          setType(obj, Type.Integer);
        } else {
          setType(obj, Type.Pointer);
        }
        break;
      default:
        throw new RuntimeException("Unhandled case: " + param);
    }
    return obj;
  }

  @Override
  protected Expression visitStringConstant(StringConstant obj, Type param) {
    setType(obj, Type.String);
    return obj;
  }

  @Override
  protected Expression visitIfExpr(IfExpr obj, Type param) {
    visit(obj.getCondition(), Type.Boolean);

    Type narrow = Type.getNarrowType(getType(obj.getLeft()), getType(obj.getRight()));
    narrow = Type.getNarrowType(param, narrow);
    narrow = Type.getNarrowType(Type.Number, narrow);
    visit(obj.getLeft(), narrow);
    visit(obj.getRight(), narrow);
    Type lt = getType(obj.getLeft());
    Type rt = getType(obj.getRight());
    Type ret = Type.getCommonType(lt, rt);
    assert (param.isSupertypeOf(ret));
    setType(obj, ret);

    return obj;
  }

  @Override
  protected Expression visitCompareExpr(CompareExpr obj, Type param) {
    assert (param.isSupertypeOf(Type.Boolean));
    Type type;
    switch (obj.getOperand()) {
      case EQUAL:
      case NOT_EQUAL:
        type = Type.getNarrowType(getType(obj.getLeft()), getType(obj.getRight()));
        break;
      case GREATER:
      case GREATER_EQUAL:
      case LOWER:
      case LOWER_EQUAL:
        type = Type.getNarrowType(getType(obj.getLeft()), getType(obj.getRight()));
        type = Type.getNarrowType(type, Type.Number);
        break;
      default:
        throw new RuntimeException("Unhandled compare op: " + obj.getOperand());
    }
    super.visitCompareExpr(obj, type);
    setType(obj, Type.Boolean);
    return obj;
  }

  @Override
  protected Expression visitArithmeticExpr(ArithmeticExpr obj, Type param) {
    Type common = Type.getCommonType(getType(obj.getLeft()), getType(obj.getRight()));
    switch (obj.getOp()) {
      case Or:
      case Xor:
      case And: {
        common = Type.getNarrowType(Type.Scalar, common);
        common = Type.getNarrowType(param, common); // may not work for pointer
        visit(obj.getLeft(), common);
        visit(obj.getRight(), common);
        Type lt = getType(obj.getLeft());
        Type rt = getType(obj.getRight());
        Type ret = Type.getCommonType(lt, rt);
        assert (param.isSupertypeOf(ret));
        setType(obj, ret);
        return obj;
      }
      case Sub: {
        common = Type.getNarrowType(Type.Number, common);
        common = Type.getNarrowType(Type.Number, param);
        visit(obj.getLeft(), common);
        visit(obj.getRight(), common);
        Type lt = getType(obj.getLeft());
        Type rt = getType(obj.getRight());
        Type ret = Type.getCommonType(lt, rt);
        switch (ret) {
          case Integer:
            break;
          case Pointer:
            ret = Type.Integer;
            break;
          case Number:
            if ((Type.Pointer == lt) || (Type.Pointer == rt)) {
              ret = Type.Pointer;
            }
            break;
          default:
            throw new RuntimeException("Unhandled case: " + ret);
        }
        assert (param.isSupertypeOf(ret));
        setType(obj, ret);
        return obj;
      }
      case Add: {
        if (param == Type.Pointer) {
          common = Type.Number;
        } else {
          common = Type.getNarrowType(Type.Number, param);
        }
        // common = Type.getNarrowType(Type.Number, param);

        visit(obj.getLeft(), common);
        visit(obj.getRight(), common);
        Type lt = getType(obj.getLeft());
        Type rt = getType(obj.getRight());
        Type ret = Type.getCommonType(lt, rt);
        assert (Type.Number.isSupertypeOf(ret));
        if (ret == Type.Number) {
          if (param == Type.Pointer) {
            if (lt == Type.Integer) {
              setType(obj.getRight(), Type.Pointer);
              rt = Type.Pointer;
              ret = Type.Pointer;
            } else if (rt == Type.Integer) {
              setType(obj.getLeft(), Type.Pointer);
              lt = Type.Pointer;
              ret = Type.Pointer;
            } else if (lt == Type.Pointer) {
              setType(obj.getRight(), Type.Integer);
              rt = Type.Integer;
              ret = Type.Pointer;
            } else if (rt == Type.Pointer) {
              setType(obj.getLeft(), Type.Integer);
              lt = Type.Integer;
              ret = Type.Pointer;
            } else {
              if (obj.getLeft() instanceof IntConstant) {
                lt = Type.Integer;
                ret = Type.Pointer;
              } else if (obj.getRight() instanceof IntConstant) {
                rt = Type.Integer;
                ret = Type.Pointer;
              } else {
                throw new RuntimeException("Unhandled case");
              }
            }
            setType(obj.getLeft(), lt);
            setType(obj.getRight(), rt);
          } else if ((Type.Number == lt) && (Type.Number == rt)) {
            ret = Type.Number;
          } else if ((Type.Pointer == lt) || (Type.Pointer == rt)) {
            ret = Type.Pointer;
          } else if ((Type.Integer == lt) || (Type.Integer == rt)) {
            ret = Type.Number;
          } else {
            throw new RuntimeException("Unhandled case");
          }
        }
        assert (param.isSupertypeOf(ret));
        setType(obj, ret);
        return obj;
      }
      case ShiftLeft:
      case ShiftRight:
      case Div: {
        common = Type.getNarrowType(Type.Number, common);
        visit(obj.getLeft(), Type.getNarrowType(common, param));
        visit(obj.getRight(), Type.Integer);
        Type lt = getType(obj.getLeft());
        Type rt = getType(obj.getRight());
        Type ret = Type.getCommonType(lt, rt);

        if (ret == Type.Number) {
          assert (Type.Integer == rt);
          ret = lt;
        }

        assert (param.isSupertypeOf(ret));
        setType(obj, ret);
        return obj;
      }
      case Mul: {
        common = Type.getNarrowType(Type.Integer, common);
        visit(obj.getLeft(), common);
        visit(obj.getRight(), common);
        Type lt = getType(obj.getLeft());
        Type rt = getType(obj.getRight());
        Type ret = Type.getCommonType(lt, rt);
        assert (Type.Integer.isSupertypeOf(ret));
        // assert ((ret == Type.Generic) || Type.Integer.isSupertypeOf(ret));
        setType(obj, ret);
        return obj;
      }
      default: {
        throw new RuntimeException("Not yet implemented: " + obj.getOp());
      }
    }
  }

}
