/*
 * Part of upcompiler. Copyright (c) 2012, Urs Fässler, Licensed under the GNU Genera Public License, v3
 * @author: urs@bitzgi.ch
 */

package reduction;


import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.Set;

import knowledge.KnowEqual;
import knowledge.KnowledgeBase;
import cfg.basicblock.BasicBlock;
import cfg.basicblock.BbEdge;
import cfg.function.FunctionGraph;
import cfg.function.PrgFunction;
import cfg.statement.Statement;

public class CommonStmtReducer {

  public static void process(Collection<PrgFunction> functions, KnowledgeBase base) {
    for (PrgFunction func : functions) {
      process(func.getGraph(), (KnowEqual) base.getEntry(KnowEqual.class));
    }
  }

  private static void process(FunctionGraph func, KnowEqual ke) {
    ArrayList<BasicBlock> bbs = new ArrayList<BasicBlock>(func.vertexSet());
    for (BasicBlock bb : bbs) {
      Set<BbEdge> inedges = bb.getInlist();
      ArrayList<BasicBlock> silblings = new ArrayList<BasicBlock>(inedges.size());
      for (BbEdge edge : inedges) {
        BasicBlock src = edge.getSrc();
        if (src.getOutlist().size() == 1) { // only consider basic blocks without branches
          silblings.add(src);
        }
      }
      if (silblings.size() >= 2) { // otherwise, there is nothing to merge :)
        mergeCommonStmt(silblings, func, ke);
      }
    }
  }

  private static void mergeCommonStmt(ArrayList<BasicBlock> silblings, FunctionGraph func, KnowEqual ke) {
    ArrayList<BasicBlock> merged = new ArrayList<BasicBlock>(silblings.size());
    merged.add(silblings.get(0));
    for (int i = 1; i < silblings.size(); i++) {
      if (merge(merged, silblings.get(i), func, ke)) {
        func.removeVertex(silblings.get(i));
      } else {
        merged.add(silblings.get(i));
      }
    }
  }

  // returns true if the basic block is fully merged into the list, i.e. if the list does not grow
  private static boolean merge(ArrayList<BasicBlock> merged, BasicBlock bb, FunctionGraph func, KnowEqual ke) {
    for (int i = 0; i < merged.size(); i++) {
      BasicBlock lbb = merged.get(i);
      BasicBlock rbb;
      if (bb.getCode().size() > lbb.getCode().size()) {
        rbb = lbb;
        lbb = bb;
      } else {
        rbb = bb;
      }
      // lbb contains now the bigger code block, it is only possible to merge rbb into lbb
      BasicBlock nbb = merge(lbb, rbb, ke);
      if (nbb != null) {
        func.addVertex(nbb);
      }
      merged.set(i, lbb);
    }
    return false; // FIXME do it better
  }

  private static BasicBlock merge(BasicBlock lbb, BasicBlock rbb, KnowEqual ke) {
    int rofs = lbb.getCode().size() - rbb.getCode().size();
    assert (rofs >= 0);
    Iterator<Statement> litr, ritr;
    litr = lbb.getCode().descendingIterator();
    ritr = rbb.getCode().descendingIterator();

    // do not check jump
    ritr.next();
    litr.next();
    int common = 1;

    while (litr.hasNext() && ritr.hasNext()) {
      if (ke.equal(litr.next(), ritr.next())) {
        common++;
      } else {
        break;
      }
    }

    if (common <= 1) {
      return null;
    }

//    System.out.println(lbb);
//    System.out.println(rbb);
//    System.out.println("Common size: " + common);

    if (common == rbb.getCode().size()) {
      int lpos = rofs;
      assert (lpos > 0);
//      System.out.println("whole merge: " + lpos);
      // BasicBlock sec = BasicBlock.splitAtStmt( lbb, lpos );
      // return sec;
      return null;
      // TODO whole merge
    } else {
      // TODO partial merge
      System.out.println("partial merge");
      return null;
    }
  }

}
