/*
 * Part of upcompiler. Copyright (c) 2012, Urs Fässler, Licensed under the GNU Genera Public License, v3
 * @author: urs@bitzgi.ch
 */

package cfg.function.library.stdio;


import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;

import knowledge.KnowOwner;
import knowledge.KnowledgeBase;
import cfg.Application;
import cfg.Assignable;
import cfg.IntConstant;
import cfg.IrReplaceExprTraverser;
import cfg.IrType;
import cfg.StringConstant;
import cfg.basicblock.BasicBlock;
import cfg.expression.CallExprLinked;
import cfg.expression.Expression;
import cfg.expression.VariableRefLinked;
import cfg.function.Function;
import cfg.function.LibFunction;
import cfg.function.PrgFunction;
import cfg.function.system.SystemFunctions;
import cfg.statement.AssignmentStmt;
import cfg.statement.RetStmt;
import cfg.variable.Array;
import cfg.variable.ArrayAccess;
import cfg.variable.VariableName;
import elfreader.ElfReader;

public class PrintfReplacer extends IrReplaceExprTraverser<Void> {
  private Map<Long, PrgFunction> printf;
  private ElfReader              elf;
  private KnowOwner              ko;
  private PrgFunction            acfunc = null;

  public static void process(Application app) {
    Map<Long, PrgFunction> printf = new HashMap<Long, PrgFunction>();

    PrintfReplacer replacer = new PrintfReplacer(app.getKb(), printf, app.getElfReader());
    replacer.visitCollection(app.getFunctions(), null);

    for (PrgFunction func : printf.values()) {
      app.getFunctions().add(func);
    }
  }

  public PrintfReplacer(KnowledgeBase kb, Map<Long, PrgFunction> printf, ElfReader elfReader) {
    super();
    this.printf = printf;
    this.elf = elfReader;
    ko = (KnowOwner) kb.getEntry(KnowOwner.class);
  }

  @Override
  protected void visitPrgFunction(PrgFunction obj, Void param) {
    assert (acfunc == null);
    acfunc = obj;
    super.visitPrgFunction(obj, param);
    assert (acfunc == obj);
    acfunc = null;
  }

  @Override
  protected Expression visitCallExprLinked(CallExprLinked obj, Void param) {
    if (obj.getFunc().getIrType() != IrType.FuncLibrary) {
      return obj;
    }
    LibFunction func = (LibFunction) obj.getFunc();
    if ("printf".equals(func.getName())) {
      long addr = StdioStringParser.getFormatStringAddr(obj, ko);

      PrgFunction pfunc = printf.get(addr);

      if (pfunc == null) {
        pfunc = createPrintfFunc(addr);
        printf.put(pfunc.getAddr(), pfunc);
      }

      ArrayList<Expression> actualparam = new ArrayList<Expression>(pfunc.getParamCount());

      for (int i = 0; i < pfunc.getParamCount(); i++) {
        Array array = acfunc.getArray(i + 1);
        assert (array != null);
        ArrayAccess ref = new ArrayAccess(array, new IntConstant(0));
        actualparam.add(ref);
      }

      CallExprLinked call = new CallExprLinked(pfunc);
      call.setParam(actualparam);
      return call;
    } else if ("puts".equals(func.getName())) {
      long addr = StdioStringParser.getFormatStringAddr(obj, ko);

      PrgFunction pfunc = printf.get(addr);

      if (pfunc == null) {
        pfunc = createPutsFunc(addr);
        printf.put(pfunc.getAddr(), pfunc);
      }

      CallExprLinked call = new CallExprLinked(pfunc);
      return call;
    } else if ("putchar".equals(func.getName())) {
      long addr = StdioStringParser.getFormatStringAddr(obj, ko);

      if( addr == 0xa ){
        return new CallExprLinked(SystemFunctions.writeNl);
      } else {
        Function callee;
        List<Expression> actualparam = new ArrayList<Expression>();
        Expression arg = new StringConstant( String.valueOf((char) addr));
        actualparam.add(arg);
        callee = SystemFunctions.writeStr;

        CallExprLinked call = new CallExprLinked(callee);
        call.setParam(actualparam);
        return call;
      }
    } else {
      return obj;
    }
  }

  private PrgFunction createPutsFunc(long addr) {
    BasicBlock bb = new BasicBlock(addr);
    bb.addCode(new RetStmt(0, null, new HashMap<VariableName, Expression>()));

    elf.seek(addr);
    String format = elf.readString();

    ArrayList<BasicBlock> bbs = new ArrayList<BasicBlock>();
    bbs.add(bb);
    PrgFunction func = new PrgFunction(addr, bbs);
    func.setParamCount(0);
    func.setPreserves(new HashSet<VariableName>());

    {
      Function callee;
      List<Expression> param = new ArrayList<Expression>();
      Expression arg = new StringConstant(format);
      param.add(arg);
      callee = SystemFunctions.writeStr;
      CallExprLinked call = new CallExprLinked(callee);
      call.setParam(param);
      AssignmentStmt ass = new AssignmentStmt(bb.getCode().size(), new ArrayList<Assignable>(), call);
      bb.getCode().add(bb.getCode().size() - 1, ass);
    }
    {
      Function callee;
      List<Expression> param = new ArrayList<Expression>();
      callee = SystemFunctions.writeNl;
      CallExprLinked call = new CallExprLinked(callee);
      call.setParam(param);
      AssignmentStmt ass = new AssignmentStmt(bb.getCode().size(), new ArrayList<Assignable>(), call);
      bb.getCode().add(bb.getCode().size() - 1, ass);
    }

    return func;
  }

  private PrgFunction createPrintfFunc(long addr) {
    BasicBlock bb = new BasicBlock(addr);
    bb.addCode(new RetStmt(0, null, new HashMap<VariableName, Expression>()));

    elf.seek(addr);
    String format = elf.readString();

    List<FormatToken> parsed = StdioStringParser.parseFormat(format);

    ArrayList<BasicBlock> bbs = new ArrayList<BasicBlock>();
    bbs.add(bb);
    PrgFunction func = new PrgFunction(addr, bbs);
    func.setParamCount(StdioStringParser.getArgCount(parsed));
    func.setPreserves(new HashSet<VariableName>());

    int argnr = 0;

    for (FormatToken fmt : parsed) {
      Function callee;
      List<Expression> param = new ArrayList<Expression>();
      switch (fmt.getType()) {
        case Integer: {
          Expression arg = new VariableRefLinked(func.getInternParam().get(argnr));
          argnr++;
          param.add(arg);
          callee = SystemFunctions.writeInt;
          break;
        }
        case String: {
          Expression arg = new StringConstant(((StringFormat) fmt).getContent());
          param.add(arg);
          callee = SystemFunctions.writeStr;
          break;
        }
        case Newline: {
          callee = SystemFunctions.writeNl;
          break;
        }
        default: {
          throw new RuntimeException("Unhandled format token type: " + fmt.getType());
        }
      }
      CallExprLinked call = new CallExprLinked(callee);
      call.setParam(param);
      AssignmentStmt ass = new AssignmentStmt(bb.getCode().size(), new ArrayList<Assignable>(), call);
      bb.getCode().add(bb.getCode().size() - 1, ass);
    }

    return func;
  }

}
