/*
 * Part of upcompiler. Copyright (c) 2012, Urs Fässler, Licensed under the GNU Genera Public License, v3
 * @author: urs@bitzgi.ch
 */

package cfg.linker;


import java.io.PrintStream;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;

import cfg.IrTraverser;
import cfg.basicblock.BasicBlock;
import cfg.basicblock.BbEdge;
import cfg.expression.Expression;
import cfg.expression.VariableRef;
import cfg.expression.VariableRefLinked;
import cfg.expression.VariableRefUnlinked;
import cfg.statement.PhiStmt;
import cfg.variable.SsaVariable;
import cfg.variable.VariableName;

public class DefUseKillVisitor extends IrTraverser<Void, BasicBlock> {
  private HashMap<BasicBlock, HashSet<VariableName>> use    = new HashMap<BasicBlock, HashSet<VariableName>>();
  private HashMap<BasicBlock, HashSet<VariableName>> def    = new HashMap<BasicBlock, HashSet<VariableName>>();
  private HashMap<BasicBlock, HashSet<VariableName>> kill   = new HashMap<BasicBlock, HashSet<VariableName>>();
  private HashMap<VariableName, HashSet<BasicBlock>> blocks = new HashMap<VariableName, HashSet<BasicBlock>>();

  public HashMap<BasicBlock, HashSet<VariableName>> getUse() {
    return use;
  }

  public HashMap<BasicBlock, HashSet<VariableName>> getDef() {
    return def;
  }

  public HashMap<BasicBlock, HashSet<VariableName>> getKill() {
    return kill;
  }

  public HashMap<VariableName, HashSet<BasicBlock>> getBlocks() {
    return blocks;
  }

  @Override
  protected Void visitBasicBlock(BasicBlock obj, BasicBlock param) {
    assert (param == null);

    use.put(obj, new HashSet<VariableName>());
    def.put(obj, new HashSet<VariableName>());
    kill.put(obj, new HashSet<VariableName>());

    super.visitBasicBlock(obj, obj);

    visitFollowingPhi(obj);

//    print(obj, System.out);

    return null;
  }

  /* handle phi functions as they belong to the previous basic blocks, what they actually do*/
  private void visitFollowingPhi(BasicBlock bb) {
    for (BbEdge edge : bb.getOutlist()) {
      Collection<PhiStmt> phis = edge.getDst().getPhis();
      for (PhiStmt phi : phis) {
        Expression expr = phi.getOption().get(bb.getId());
        assert (expr != null);
        visit(expr, bb);
        visit(phi.getVarname(), bb);
      }
    }
  }

  @Override
  protected Void visitPhiStmt(PhiStmt obj, BasicBlock param) {
    return null;
  }

  @Override
  protected Void visitSsaVariable(SsaVariable obj, BasicBlock param) {
    def.get(param).add(obj.getName());
    kill.get(param).add(obj.getName());
    if (!blocks.containsKey(obj.getName())) {
      blocks.put(obj.getName(), new HashSet<BasicBlock>());
    }
    blocks.get(obj.getName()).add(param);
    return super.visitSsaVariable(obj, param);
  }

  @Override
  protected Void visitVariableRefUnlinked(VariableRefUnlinked obj, BasicBlock param) {
    visitVarRef(obj, param);
    return super.visitVariableRefUnlinked(obj, param);
  }

  @Override
  protected Void visitVariableRefLinked(VariableRefLinked obj, BasicBlock param) {
    visitVarRef(obj, param);
    return super.visitVariableRefLinked(obj, param);
  }

  private void visitVarRef(VariableRef ref, BasicBlock bb) {
    if (!def.get(bb).contains(ref.getName())) {
      use.get(bb).add(ref.getName());
      if (!blocks.containsKey(ref.getName())) {
        blocks.put(ref.getName(), new HashSet<BasicBlock>());
      }
    }
  }

  @SuppressWarnings("unused")
  private void print(BasicBlock obj, PrintStream st) {
    st.println(obj);
    st.print("def: ");
    printVarList(def.get(obj), st);
    st.print("use: ");
    printVarList(use.get(obj), st);
    st.print("kill: ");
    printVarList(kill.get(obj), st);
  }

  private void printVarList(Collection<VariableName> vars, PrintStream st) {
    for (VariableName var : vars) {
      st.print(var);
      st.print(", ");
    }
    st.println();
  }

}
