/*
 * Andrea Di Biagio
 * Politecnico di Milano, 2007
 * 
 * axe_reg_alloc.c
 * Formal Languages & Compilers Machine, 2007/2008
 * 
 */

#include "axe_reg_alloc.h"
#include "reg_alloc_constants.h"
#include "axe_debug.h"
#include "axe_errors.h"

extern int errorcode;

static int compareIntervalIDs(void *varA, void *varB);
static int compareStartPoints(void *varA, void *varB);
static int compareEndPoints(void *varA, void *varB);
static t_list * updateListOfIntervals(t_list *result
            , t_cflow_Node *current_node, int counter);
static t_list * allocFreeRegisters(int regNum);
static t_list * addFreeRegister
      (t_list *registers, int regID, int position);
static int assignRegister(t_reg_allocator *RA);
static t_list * expireOldIntervals(t_reg_allocator *RA
            , t_list *active_intervals, t_live_interval *interval);
static t_list * getLiveIntervals(t_cflow_Graph *graph);
static int insertListOfIntervals(t_reg_allocator *RA, t_list *intervals);
static int insertLiveInterval(t_reg_allocator *RA, t_live_interval *interval);
static void finalizeLiveInterval (t_live_interval *interval);
static t_live_interval * allocLiveInterval(int varID, int startPoint, int endPoint);
static t_list * spillAtInterval(t_reg_allocator *RA
      , t_list *active_intervals, t_live_interval *interval);


t_list * spillAtInterval(t_reg_allocator *RA
      , t_list *active_intervals, t_live_interval *interval)
{
   t_list *last_element;
   t_live_interval *last_interval;
   
   /* get the last element of the list of active intervals */
   last_element = getLastElement(active_intervals);
   last_interval = (t_live_interval *) LDATA(last_element);

   if (last_interval->endPoint > interval->endPoint)
   {
      RA->bindings[interval->varID]
            = RA->bindings[last_interval->varID];
      RA->bindings[last_interval->varID] = RA_SPILL_REQUIRED;

      active_intervals = removeElement(active_intervals, last_interval);
      
      active_intervals = addSorted(active_intervals
                  , interval, compareEndPoints);
   }
   else
      RA->bindings[interval->varID] = RA_SPILL_REQUIRED;

   return active_intervals;
}

int compareStartPoints(void *varA, void *varB)
{
   t_live_interval *liA;
   t_live_interval *liB;
   
   if (varA == NULL)
      return 0;

   if (varB == NULL)
      return 0;

   liA = (t_live_interval *) varA;
   liB = (t_live_interval *) varB;

   return (liA->startPoint - liB->startPoint);
}

int compareEndPoints(void *varA, void *varB)
{
   t_live_interval *liA;
   t_live_interval *liB;
   
   if (varA == NULL)
     return 0;

   if (varB == NULL)
      return 0;

   liA = (t_live_interval *) varA;
   liB = (t_live_interval *) varB;

   return (liA->endPoint - liB->endPoint);
}

int compareIntervalIDs(void *varA, void *varB)
{
   t_live_interval *liA;
   t_live_interval *liB;

   if (varA == NULL)
      return 0;

   if (varB == NULL)
      return 0;

   liA = (t_live_interval *) varA;
   liB = (t_live_interval *) varB;

   return (liA->varID == liB->varID);
}

int insertLiveInterval(t_reg_allocator *RA, t_live_interval *interval)
{
   /* test the preconditions */
   if (RA == NULL)
      return RA_INVALID_ALLOCATOR;

   if (interval == NULL)
      return RA_INVALID_INTERVAL;

   /* test if an interval for the requested variable is already inserted */
   if (CustomfindElement(RA->live_intervals
               , interval, compareIntervalIDs) != NULL)
   {
      return RA_INTERVAL_ALREADY_INSERTED;
   }

   /* add the given interval to the list of intervals */
   RA->live_intervals = addSorted(RA->live_intervals
            , interval, compareStartPoints);
   
   return RA_OK;
}

t_list * allocFreeRegisters(int regNum)
{
   int count;
   t_list *result;

   /* initialize the local variables */
   count = 1;
   result = NULL;
   
   while(count <= regNum)
   {
      /* add a new register to the list of free registers */
      result = addFreeRegister(result, count, -1);

      /* update the `count' variable */
      count ++;
   }

   /* return the list of free registers */
   return result;
}

t_list * addFreeRegister(t_list *registers, int regID, int position)
{
   int *element;

   element = (int *) _AXE_ALLOC_FUNCTION(sizeof(int));
   if (element == NULL)
      notifyError(AXE_OUT_OF_MEMORY);

   /* initialize element */
   (* element) = regID;
   
   /* update the list of registers */
   registers = addElement(registers, element, position);

   /* return the list of free registers */
   return registers;
}

int assignRegister(t_reg_allocator *RA)
{
   int regID;

   if (RA->freeRegisters == NULL)
      return RA_SPILL_REQUIRED;

   regID = (* ((int *) LDATA(RA->freeRegisters)));
   _AXE_FREE_FUNCTION(LDATA(RA->freeRegisters));
   RA->freeRegisters = removeFirst(RA->freeRegisters);

   return regID;
}

t_reg_allocator * initializeRegAlloc(t_cflow_Graph *graph)
{
   t_reg_allocator *result;
   t_list *intervals;
   t_list *current_elem;

   /* test if regNum is a valid number of registers */
   if (graph == NULL)
      return NULL;

   /* allocate memory for a new instance of `t_reg_allocator' */
   result = _AXE_ALLOC_FUNCTION(sizeof(t_reg_allocator));
   if (result == NULL)
      notifyError(AXE_OUT_OF_MEMORY);
   
   /* initialize the register allocator informations */
   result->regNum = NUM_REGISTERS - 3;
   result->varNum = getLength(graph->cflow_variables);
   
   current_elem = graph->cflow_variables;
   result->bindings = NULL;
   
   if (result->varNum > 0)
   {
      int counter;

      /*alloc memory for the array of bindings */
      result->bindings = (int *) malloc(sizeof(int) * (result->varNum + 1) );

      /* test if an error occurred */
      if (result->bindings == NULL)
         notifyError(AXE_OUT_OF_MEMORY);
      
      /* initialize the array of bindings */
      for(counter = 0; counter < (result->varNum + 1); counter++)
         result->bindings[counter] = RA_REGISTER_INVALID;
   }

   /* compute the list of live intervals */
   result->live_intervals = NULL;
   intervals = getLiveIntervals(graph);
   if (intervals != NULL)
   {
      if (insertListOfIntervals(result, intervals) != RA_OK)
      {
         finalizeRegAlloc(result);
         notifyError(AXE_REG_ALLOC_ERROR);
      }
   }

   freeList(intervals);
   
   /* create a list of freeRegisters */
   result->freeRegisters = allocFreeRegisters(result->regNum);
   
   /* return the new register allocator */
   return result;
}

void finalizeRegAlloc(t_reg_allocator *RA)
{
   if (RA == NULL)
      return;

   if (RA->live_intervals != NULL)
   {
      t_list *current_element;
      t_live_interval *current_interval;

      /* finalize the memory blocks associated with all
       * the live intervals */
      current_element = RA->live_intervals;
      while (current_element != NULL)
      {
         /* fetch the current interval */
         current_interval = (t_live_interval *) LDATA(current_element);
         if (current_interval != NULL)
         {
            /* finalize the memory block associated with
             * the current interval */
            finalizeLiveInterval(current_interval);
         }

         /* fetch the next interval in the list */
         current_element = LNEXT(current_element);
      }

      /* deallocate the list of intervals */
      freeList(RA->live_intervals);
   }

   if (RA->bindings != NULL)
      _AXE_FREE_FUNCTION(RA->bindings);
   if (RA->freeRegisters != NULL)
   {
      t_list *current_element;

      current_element = RA->freeRegisters;
      while (current_element != NULL)
      {
         _AXE_FREE_FUNCTION(LDATA(current_element));
         current_element = LNEXT(current_element);
      }

      freeList(RA->freeRegisters);
   }

   _AXE_FREE_FUNCTION(RA);
}

t_live_interval * allocLiveInterval
               (int varID, int startPoint, int endPoint)
{
   t_live_interval *result;

   /* create a new instance of `t_live_interval' */
   result = _AXE_ALLOC_FUNCTION(sizeof(t_live_interval));
   if (result == NULL)
      notifyError(AXE_OUT_OF_MEMORY);

   /* initialize the new instance */
   result->varID = varID;
   result->startPoint = startPoint;
   result->endPoint = endPoint;

   /* return the new `t_live_interval' */
   return result;
}

void finalizeLiveInterval (t_live_interval *interval)
{
   if (interval == NULL)
      return;

   /* finalize the current interval */
   _AXE_FREE_FUNCTION(interval);
}

int insertListOfIntervals(t_reg_allocator *RA, t_list *intervals)
{
   t_list *current_element;
   t_live_interval *interval;
   int ra_errorcode;
   
   /* preconditions */
   if (RA == NULL)
      return RA_INVALID_ALLOCATOR;
   if (intervals == NULL)
      return RA_OK;

   /* get the head of the list of intervals */
   current_element = intervals;
   while (current_element != NULL)
   {
      interval = (t_live_interval *) LDATA(current_element);
      
      if (interval == NULL)
         return RA_INVALID_INTERVAL;

      /* insert a new live interval */
      ra_errorcode = insertLiveInterval(RA, interval);

      /* test if an error occurred */
      if (ra_errorcode != RA_OK)
         return ra_errorcode;
      
      /* fetch the next live interval */
      current_element = LNEXT(current_element);
   }

   return RA_OK;
}

t_list * getLiveIntervals(t_cflow_Graph *graph)
{
   t_list *current_bb_element;
   t_list *current_nd_element;
   t_basic_block *current_block;
   t_cflow_Node *current_node;
   t_list *result;
   int counter;

   /* preconditions */
   if (graph == NULL)
      return NULL;

   if (graph->blocks == NULL)
      return NULL;

   /* initialize the local variable `result' */
   result = NULL;

   /* intialize the instruction counter */
   counter = 0;
   
   /* fetch the first basic block */
   current_bb_element = graph->blocks;
   while (current_bb_element != NULL)
   {
      current_block = (t_basic_block *) LDATA(current_bb_element);

      /* fetch the first node of the basic block */
      current_nd_element = current_block->nodes;
      while(current_nd_element != NULL)
      {
         current_node = (t_cflow_Node *) LDATA(current_nd_element);

         /* update the live intervals with the liveness informations */
         result = updateListOfIntervals(result, current_node, counter);
         
         /* fetch the next node in the basic block */
         counter++;
         current_nd_element = LNEXT(current_nd_element);
      }

      /* fetch the next element in the list of basic blocks */
      current_bb_element = LNEXT(current_bb_element);
   }

   return result;
}

t_list * updateListOfIntervals(t_list *result
         , t_cflow_Node *current_node, int counter)
{
   t_list *current_element;
   t_cflow_var *current_var;
   t_list *element_found;
   t_live_interval *interval_found;
   t_live_interval pattern;
   
   if (current_node == NULL)
      return result;

   current_element = current_node->in;
   while (current_element != NULL)
   {
      current_var = (t_cflow_var *) LDATA(current_element);

      /* initialize the pattern for the custom search into the
       * list of live intervals */
      pattern.varID = current_var->ID;

      if (current_var->ID == RA_EXCLUDED_VARIABLE)
      {
         /* fetch the next element in the list of live variables */
         current_element = LNEXT(current_element);
         continue;
      }

      /* search for the current live interval */
      element_found = CustomfindElement
            (result, &pattern, compareIntervalIDs);
      if (element_found != NULL)
      {
         interval_found = (t_live_interval *) LDATA(element_found);

         /* update the interval informations */
         if (interval_found->startPoint > counter)
            interval_found->startPoint = counter;
         if (interval_found->endPoint < counter)
            interval_found->endPoint = counter;
      }
      else
      {
         /* we have to add a new live interval */
         interval_found = allocLiveInterval(current_var->ID, counter, counter);
         if (interval_found == NULL)
            notifyError(AXE_OUT_OF_MEMORY);

         result = addElement(result, interval_found, -1);
      }
      
      /* fetch the next element in the list of live variables */
      current_element = LNEXT(current_element);
   }
   
   current_element = current_node->out;
   while (current_element != NULL)
   {
      current_var = (t_cflow_var *) LDATA(current_element);

      /* initialize the pattern for the custom search into the
       * list of live intervals */
      pattern.varID = current_var->ID;

      if (current_var->ID == RA_EXCLUDED_VARIABLE)
      {
         /* fetch the next element in the list of live variables */
         current_element = LNEXT(current_element);
         continue;
      }
      
      element_found = CustomfindElement(result, &pattern, compareIntervalIDs);
      if (element_found != NULL)
      {
         interval_found = (t_live_interval *) LDATA(element_found);

         /* update the interval informations */
         if (interval_found->startPoint > counter)
            interval_found->startPoint = counter;
         if (interval_found->endPoint < counter)
            interval_found->endPoint = counter;
      }
      else
      {
         /* we have to add a new live interval */
         interval_found = allocLiveInterval(current_var->ID, counter, counter);
         if (interval_found == NULL)
            notifyError(AXE_OUT_OF_MEMORY);
            
         result = addElement(result, interval_found, -1);
      }
      
      /* fetch the next element in the list of live variables */
      current_element = LNEXT(current_element);
   }

   return result;
}

t_list * expireOldIntervals(t_reg_allocator *RA, t_list *active_intervals
               , t_live_interval *interval)
{
   t_list *current_element;
   t_list *next_element;
   t_live_interval *current_interval;
   
   if (active_intervals == NULL)
      return NULL;
   if (RA == NULL)
      return NULL;
   if (interval == NULL)
      return active_intervals;


   /* get the first interval of the sequence */
   current_element = active_intervals;
   while(current_element != NULL)
   {
      current_interval = (t_live_interval *) LDATA(current_element);

      if (current_interval->endPoint >= interval->startPoint)
         return active_intervals;

      next_element = LNEXT(current_element);
      active_intervals = removeElement(active_intervals, current_interval);
      
      RA->freeRegisters = addFreeRegister
            (RA->freeRegisters, RA->bindings[current_interval->varID], 0);
      
      current_element = next_element;
   }

   return active_intervals;
}

int execute_linear_scan(t_reg_allocator *RA)
{
   t_list *current_element;
   t_live_interval *current_interval;
   t_list *active_intervals;
   
   /* test the preconditions */
   if (RA == NULL)
      return RA_INVALID_ALLOCATOR;
   if (RA->live_intervals == NULL)
      return RA_OK;

   /* initialize the list of active intervals */
   active_intervals = NULL;
   
   current_element = RA->live_intervals;
   while(current_element != NULL)
   {
      current_interval = (t_live_interval *) LDATA(current_element);

      /* expire old intervals */
      active_intervals = expireOldIntervals
               (RA, active_intervals, current_interval);

      if (getLength(active_intervals) == RA->regNum)
      {
         /* perform a spill */
         active_intervals = spillAtInterval
               (RA, active_intervals, current_interval);
      }
      else
      {
         RA->bindings[current_interval->varID] = assignRegister(RA);

         /* add the current interval to the list of active intervals */
         active_intervals = addSorted(active_intervals
            , current_interval, compareStartPoints);
      }
      
      /* fetch the next live interval */
      current_element = LNEXT(current_element);
   }

   /* free the list of active intervals */
   freeList(active_intervals);
   
   return RA_OK;
}
