/*
 * Part of upcompiler. Copyright (c) 2012, Urs Fässler, Licensed under the GNU Genera Public License, v3
 * @author: urs@bitzgi.ch
 */

package structuring.matcher;


import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;

import structuring.Translator;
import util.StmtUtil;
import ast.expression.Expression;
import ast.statement.BlockStmt;
import ast.statement.CaseStmt;
import ast.statement.NullStmt;
import ast.statement.Statement;
import cfg.basicblock.BasicBlock;
import cfg.basicblock.BbEdge;
import cfg.statement.JumpStmt;

/*
 *    0
 *  /   \
 * 1 ... n-1
 *  \   /
 *    n
 */

public class CaseMatcher implements Pattern {
  private ArrayList<BasicBlock> vertices  = new ArrayList<BasicBlock>();
  private ArrayList<BbEdge>     outedges  = new ArrayList<BbEdge>();
  private ArrayList<BbEdge>     joinedges = new ArrayList<BbEdge>();


  public ArrayList<ArrayList<BbEdge>> edges() {
    ArrayList<ArrayList<BbEdge>> ret = new ArrayList<ArrayList<BbEdge>>();
    ArrayList<BbEdge> config = new ArrayList<BbEdge>();
    config.addAll(outedges);
    config.addAll(joinedges);
    ret.add(config);
    return ret;
  }


  public int internalEdges() {
    return outedges.size() + joinedges.size() - 1;
  }


  public ArrayList<ArrayList<BasicBlock>> vertices() {
    ArrayList<ArrayList<BasicBlock>> ret = new ArrayList<ArrayList<BasicBlock>>();
    ret.add(vertices);
    return ret;
  }


  public int internalVertices() {
    return vertices.size() - 1;
  }

  public boolean match(BasicBlock pred) {
    ArrayList<BbEdge> outlist = new ArrayList<BbEdge>(pred.getOutlist());
    if (outlist.size() <= 2) {
      return false;
    }

    ArrayList<BasicBlock> midlist = new ArrayList<BasicBlock>(outlist.size());
    ArrayList<BbEdge> joinlist = new ArrayList<BbEdge>(outlist.size());

    BasicBlock join = null;
    for (BbEdge mid : outlist) {
      BasicBlock nodeX = mid.getDst();
      if (nodeX.getOutlist().size() >= 2) { // a middle-node has multiple exits
        return false;
      }
      midlist.add(nodeX);
      BbEdge joinedge = nodeX.getOutlist().iterator().next();
      joinlist.add(joinedge);
      if (join == null) {
        join = joinedge.getDst();
      } else {
        if (join != joinedge.getDst()) {
          return false;
        }
      }
    }

    addToVertices(pred, join, midlist);
    outedges = outlist;
    joinedges = joinlist;

    return true;
  }

  private void addToVertices(BasicBlock pred, BasicBlock join, ArrayList<BasicBlock> midlist) {
    vertices = new ArrayList<BasicBlock>(2 + midlist.size());
    vertices.add(pred);
    vertices.addAll(midlist);
    vertices.add(join);
  }


  public JumpStmt getNewJump(ArrayList<BasicBlock> config) {
    return new JumpStmt(config.get(config.size()-1).getId());
  }


  public Statement getStmtCode(ArrayList<BasicBlock> config, Map<BasicBlock, Statement> mapping) {
    assert (vertices == config);

    Expression expr = Translator.translate(StmtUtil.toJump(vertices.get(0).getCode().getLast()).getExpression());
    BlockStmt block = new BlockStmt();
    
    HashMap<Integer, Statement> option = new HashMap<Integer, Statement>();
    
    for( int i = 1; i < vertices.size()-1; i++ ){
      option.put(i-1, mapping.get(vertices.get(i)));
    }
    
    CaseStmt ifs = new CaseStmt(expr, option, new NullStmt());
    block.addCode(mapping.get(vertices.get(0)));
    block.addCode(ifs);
    return block;
  }

}
