/*
 * Part of upcompiler. Copyright (c) 2012, Urs Fässler, Licensed under the GNU Genera Public License, v3
 * @author: urs@bitzgi.ch
 */

package ast;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

import ast.expression.Expression;
import ast.statement.NullStmt;
import ast.statement.Statement;
import ast.statement.VarDef;
import ast.traverser.AstExpressionTraverser;
import ast.traverser.AstStatementTraverser;
import ast.traverser.AstTraverser;
import ast.variable.Variable;
import ast.variable.VariableRefLinked;
import ast.variable.VariableRefUnlinked;

public class VariableMerger extends AstTraverser<Map<Variable, List<Variable>>> {

  public VariableMerger() {
    super(new StmtVariableMerger());
  }

  public static int merge(PrgFunc func, Set<List<Variable>> color) {
    int remove = 0;
    Map<Variable, List<Variable>> keys = new HashMap<Variable, List<Variable>>();
    for (List<Variable> var : color) {
      if (var.size() > 1) {
        remove += var.size() - 1;
        List<Variable> list = createList(var, func.getParam());
        for (Variable v : list) {
          keys.put(v, list);
        }
      }
    }

    if (!keys.isEmpty()) {
      VariableMerger merger = new VariableMerger();
      merger.visit(func, keys);
    } else {
      assert (remove == 0);
    }

    return remove;
  }

  private static List<Variable> createList(List<Variable> var, List<Variable> front) {
    List<Variable> list = new ArrayList<Variable>(var);

    // only one parameter variable should be in list and in front
    List<Variable> frontelem = new ArrayList<Variable>(list);
    frontelem.retainAll(front);
    assert (frontelem.size() <= 1);
    list.removeAll(frontelem);
    Collections.sort(list);
    list.addAll(0, frontelem);

    return list;
  }

}

class StmtVariableMerger extends AstStatementTraverser<Map<Variable, List<Variable>>> {
  private ExprVariableMerger exptrav = new ExprVariableMerger();

  @Override
  public Expression visit(Expression expr, Map<Variable, List<Variable>> param) {
    return exptrav.visit(expr, param);
  }

  @Override
  public Variable visit(Variable expr, Map<Variable, List<Variable>> param) {
    return exptrav.visit(expr, param);
  }

  @Override
  protected Statement visitVarDef(VarDef obj, Map<Variable, List<Variable>> param) {
    Variable repl = exptrav.getReplacement(obj.getVariable(), param);
    if (repl != obj.getVariable()) {
      return new NullStmt();
    }
    return obj;
  }

}

class ExprVariableMerger extends AstExpressionTraverser<Map<Variable, List<Variable>>> {

  Variable getReplacement(Variable var, Map<Variable, List<Variable>> param) {
    if (param.containsKey(var)) {
      return param.get(var).get(0);
    }
    return var;
  }

  @Override
  protected Expression visitVariableRefLinked(VariableRefLinked obj, Map<Variable, List<Variable>> param) {
    return new VariableRefLinked(getReplacement(obj.getReference(), param));
  }

  @Override
  protected Expression visitVariableRefUnlinked(VariableRefUnlinked obj, Map<Variable, List<Variable>> param) {
    throw new RuntimeException("Not yet implemented");
  }

}
