/*
 * Part of upcompiler. Copyright (c) 2012, Urs Fässler, Licensed under the GNU Genera Public License, v3
 * @author: urs@bitzgi.ch
 */

package structuring.matcher;


import java.util.ArrayList;
import java.util.LinkedList;
import java.util.Map;

import structuring.Translator;
import util.Combinations;
import util.StmtUtil;
import ast.expression.Expression;
import ast.expression.UnaryExpression;
import ast.expression.UnaryOp;
import ast.statement.BlockStmt;
import ast.statement.Statement;
import ast.statement.WhileStmt;
import cfg.basicblock.BasicBlock;
import cfg.basicblock.BbEdge;
import cfg.statement.JumpStmt;

/*
 *  0-+
 *  |\|
 *  | 1
 *  |
 *  2
 */

public class WhileMatcher implements Pattern {
  private ArrayList<ArrayList<BasicBlock>> vertices = new ArrayList<ArrayList<BasicBlock>>();
  private ArrayList<ArrayList<BbEdge>>     edges    = new ArrayList<ArrayList<BbEdge>>();


  public ArrayList<ArrayList<BbEdge>> edges() {
    return edges;
  }


  public int internalEdges() {
    return 2;
  }


  public ArrayList<ArrayList<BasicBlock>> vertices() {
    return vertices;
  }


  public int internalVertices() {
    return 2;
  }

  public boolean match(BasicBlock pred) {
    ArrayList<BbEdge> outlist = new ArrayList<BbEdge>(pred.getOutlist());
    if (outlist.size() != 2) {
      return false;
    }
    Combinations combinations = new Combinations(2, outlist.size());

    while (combinations.hasNext()) {
      ArrayList<Integer> comb = combinations.next();

      BbEdge edgeForw = outlist.get(comb.get(0));
      BasicBlock body = edgeForw.getDst();
      BasicBlock exit = outlist.get(comb.get(1)).getDst();

      if (pred == body) {
        continue;
      }

      if (body.getInlist().size() != 1) {
        continue;
      }

      if (body.getOutlist().size() != 1) { // TODO remove if additional edges are allowed
        continue;
      }

      ArrayList<BbEdge> bodyout = new ArrayList<BbEdge>(body.getOutlist());
      Combinations seccombo = new Combinations(1, bodyout.size());
      while (seccombo.hasNext()) {
        ArrayList<Integer> seccomb = seccombo.next();
        BbEdge edgeBack = bodyout.get(seccomb.get(0));
        BasicBlock pred2 = edgeBack.getDst();

        if (pred == pred2) {
          addToVertices(pred, body, exit);
          addToEdges(edgeForw, edgeBack);
        }
      }
    }
    return !vertices.isEmpty();
  }

  private void addToVertices(BasicBlock pred, BasicBlock body, BasicBlock exit) {
    ArrayList<BasicBlock> list = new ArrayList<BasicBlock>(3);
    list.add(pred);
    list.add(body);
    list.add(exit);
    vertices.add(list);
  }

  private void addToEdges(BbEdge loopForw, BbEdge loopBack) {
    ArrayList<BbEdge> list = new ArrayList<BbEdge>(2);
    list.add(loopForw);
    list.add(loopBack);
    edges.add(list);
  }


  public JumpStmt getNewJump(ArrayList<BasicBlock> config) {
    return new JumpStmt(config.get(2).getId());
  }


  public Statement getStmtCode(ArrayList<BasicBlock> config, Map<BasicBlock, Statement> mapping) {
    assert (vertices.contains(config));
    LinkedList<cfg.statement.Statement> stmt = config.get(0).getCode();
    assert (StmtUtil.getNrOfStmts(stmt) == 1);
    Expression expr = Translator.translate(StmtUtil.toJump(stmt.getLast()).getExpression());
    if (!isInverted(config)) {
      expr = new UnaryExpression(expr, UnaryOp.Not);
    }

    assert (config.get(1).getCode().getLast() instanceof JumpStmt);
    BlockStmt body = new BlockStmt();
    body.addCode( mapping.get(config.get(1)) );
    body.addCode(mapping.get(config.get(0)));
    
    BlockStmt block = new BlockStmt();
    block.addCode(mapping.get(config.get(0)));
    block.addCode(new WhileStmt(expr, body));
    return block;
  }

  private boolean isInverted(ArrayList<BasicBlock> config) {
    return StmtUtil.getEdgeValue(config.get(0), config.get(1)) == 1;
  }

}
