/*
 * Part of upcompiler. Copyright (c) 2012, Urs Fässler, Licensed under the GNU Genera Public License, v3
 * @author: urs@bitzgi.ch
 */

package cfg.function.library.stdio;


import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;

import knowledge.KnowOwner;
import knowledge.KnowledgeBase;
import reduction.ExprCopy;
import util.Pair;
import cfg.Application;
import cfg.Assignable;
import cfg.IntConstant;
import cfg.IrTraverser;
import cfg.IrType;
import cfg.basicblock.BasicBlock;
import cfg.expression.CallExprLinked;
import cfg.expression.Expression;
import cfg.function.LibFunction;
import cfg.function.system.SystemFunctions;
import cfg.statement.AssignmentStmt;
import cfg.statement.Statement;
import cfg.variable.VariablePtrDeref;
import elfreader.ElfReader;

public class ScanfReplacer extends IrTraverser<Void, LinkedList<Pair<Integer, Statement>>> {
  private ElfReader elf;
  private KnowOwner ko;

  public static void process(Application app) {
    ScanfReplacer replacer = new ScanfReplacer(app.getKb(), app.getElfReader());
    replacer.visit(app.getFunctions(), null);
  }

  public ScanfReplacer(KnowledgeBase kb, ElfReader elfReader) {
    super();
    this.elf = elfReader;
    ko = (KnowOwner) kb.getEntry(KnowOwner.class);
  }

  @Override
  protected Void visitBasicBlock(BasicBlock obj, LinkedList<Pair<Integer, Statement>> param) {
    assert (param == null);
    param = new LinkedList<Pair<Integer, Statement>>();
    super.visitBasicBlock(obj, param);

    Iterator<Pair<Integer, Statement>> itr = param.descendingIterator();

    while (itr.hasNext()) {
      Pair<Integer, Statement> stmt = itr.next();
      obj.getCode().add(stmt.getFirst(), stmt.getSecond());
    }

    return null;
  }

  @Override
  protected Void visitCallExprLinked(CallExprLinked obj, LinkedList<Pair<Integer, Statement>> param) {
    if (obj.getFunc().getIrType() != IrType.FuncLibrary) {
      return null;
    }
    LibFunction func = (LibFunction) obj.getFunc();
    if (!"__isoc99_scanf".equals(func.getName())) {
      return null;
    }

    long addr = StdioStringParser.getFormatStringAddr(obj, ko);

    int scancount = createPrintfFunc(addr);

    AssignmentStmt stmt = (AssignmentStmt) ko.getExprOwner(obj);
    BasicBlock bb = ko.getStmtOwner(stmt);

    int pos = bb.getCode().indexOf(stmt);

    for (int i = 0; i < scancount; i++) {
      Expression expr = StdioStringParser.getArrayWriteExpr(bb.getCode(), pos - 1, i+1);
      // System.out.println(expr);

      expr = ExprCopy.copy(expr);
      VariablePtrDeref dstvar = new VariablePtrDeref(expr);
      List<Assignable> dst = new LinkedList<Assignable>();
      dst.add(dstvar);
      AssignmentStmt ass = new AssignmentStmt(0, dst, new CallExprLinked(SystemFunctions.readInt));
      param.add(new Pair<Integer, Statement>(pos + 1, ass));
    }

    stmt.setSource(new IntConstant(scancount)); // FIXME this is not true if a error occurs during scanf

    return null;
  }

  private int createPrintfFunc(long addr) {
    elf.seek(addr);
    String format = elf.readString();

    List<FormatToken> parsed = StdioStringParser.parseFormat(format);

    int argnr = 0;

    for (FormatToken fmt : parsed) {
      switch (fmt.getType()) {
        case Integer: {
          argnr++;
          break;
        }
        case String: {
          throw new RuntimeException("Unhandled format token type: " + fmt.getType());
        }
        case Newline: {
          throw new RuntimeException("Unhandled format token type: " + fmt.getType());
        }
        default: {
          throw new RuntimeException("Unhandled format token type: " + fmt.getType());
        }
      }
    }

    return argnr;
  }

}
