/* 
   COPYRIGHT U.S. GOVERNMENT 
   
   This software is distributed without charge and comes with no
   warranty.

   Please feel free to send questions, comments, and problem reports
   to prism@super.org.  */

/* PURPOSE
   =======
   Node program for generalized broadcast-multiply-roll algorithm. 

   This function computes the product C = (r_alpha)A * B' + (r_beta)C,
   where A = [a(i,j)] is an i_g_m x i_g_k matrix, B = [b(i,j)] is an
   i_g_k x i_g_n matrix, and C = [c(i,j)] is an i_g_m x i_g_n matrix
   on a 2-dimensional cartesian mesh.
       
   NOTE: One additional array is required for B', which is formed
   explicitly and then passed on to DGEMM.  This is required in order 
   to have matrix B in contiguous memory for the broadcast operation.  

   The mesh is passed as two MPI communicators, specifying the rows
   and columns of the mesh. The matrix-matrix multiplication will be
   treated as if the mesh is really i_v_dim x i_v_dim, where i_v_dim
   must be a multiple of lcm(# rows of mesh, # cols of mesh).  Each
   matrix is assumed to be composed of i_v_dim**2 blocks (A(i,j),
   B(i,j), C(ij)). For blocked column or panelled torus wrap across
   the virtual blocks, the routine i_ldim.c can be used for i_lin_dim.

   This code supports the case i_msh_rows < i_msh_cols, with rows of B
   being broadcast and the columns of C being rolled.
   
   Within each node, the layout of A, B, and C must be conformal.  The
   exact correspondence between the individual rows and columns of the
   blocks and the original matrices does not need to be known for the
   purposes of performing the multiplication, but the **number** of
   rows and columns is needed.  */

/* CONDITIONAL COMPILATION */
#if PRISM_TIME
/* Causes global variables with timers to be used */
#endif

/* INCLUDE FILES */
#include <stdlib.h>
#include <math.h>
#include "stdeig.h"
#include "mm.h"

void prism_v_bcb_rc(int i_v_dim, int i_msh_rows, int i_msh_cols, int i_node_row,
		    int i_node_col, int i_rblk_nw, int i_cblk_nw, int i_blk_rows,
		    int i_blk_cols, int i_g_m, int i_g_n, int i_g_k,
		    int i_m_panelwidth, int i_n_panelwidth, int i_k_panelwidth,
		    int i_m_offset, int i_n_offset, int i_k_offset,
		    int i_panel_spc, double r_alpha, P_R_MATRIX m_r_a, int i_lda,
		    P_R_MATRIX m_r_b, int i_ldb, double r_beta, P_R_MATRIX m_r_c,
		    int i_ldc, P_R_MATRIX m_r_bcbuf, int i_bcbuf_sz,
		    P_R_MATRIX m_r_rollbuf, int i_rollbuf_sz, MPI_Comm comm_row,
		    MPI_Comm comm_col, 
		    int (*i_lin_dim)(int i_v_dim, int i_g_dim, int i_blk_nw,
				     int i_first_blk, int i_last_blk,
				     int i_panelwidth, int i_offset,
				     int i_panel_spc))
   
     /* PARAMETERS
	==========
	i_v_dim:  virtual dimension used for blocking of matrices 
	i_msh_rows:  # rows in mesh
	i_msh_cols:  # cols in mesh
	i_node_col:  node rank in col 
	i_node_row:  node rank in row
	i_rblk_nw:  row index of block in nw corner of submesh 
	i_cblk_nw:  col index of block in nw corner of submesh 
	i_blk_rows:  # rows of blocks for each matrix
	i_blk_cols:  # cols of blocks for each matrix
	i_g_m:  number of rows of A and of C
	i_g_n:  number of cols of B and of C 
	i_g_k:  number of cols of A and rows of B
	i_m_panelwidth:  panel width for torus wrap of rows of A and of C
	i_n_panelwidth:  panel width for torus wrap of cols of B and of C 
	i_k_panelwidth:  panel width for torus wrap of cols of A and rows of B
	i_m_offset:  row offset for size of first logical block of A and of C
                     (has the opposite effect on the last logical block)
	i_n_offset:  col offset for size of first logical block of B and of C
                     (has the opposite effect on the last logical block)
	i_k_offset:  offset for size of first logical col block of A and row 
	             block of B (has the opposite effect on the last logical
		     block)
	i_panel_spc:  number of virtual nodes to step to next contiguous panel
	r_alpha:  constant to multiply A by in matrix-matrix multiplication
	r_beta:  constant to multiply C by in matrix-matrix multiplication
	m_r_a:  pointer to pointer to subblocks of matrix A; buffer for A MUST
        	be a contiguous block in memory
	i_lda:  leading dimension of "arrays" containing blocks of A
 	m_r_b:  pointer to pointer to subblocks of matrix B; buffer for B MUST
	        be a contiguous block in memory 
	i_ldb:  leading dimension of "arrays" containing blocks of B; must be
        	large enough to accomodate block of B with max number of rows
	m_r_c:  pointer to pointer to subblocks of matrix C; buffer for C MUST
        	be a contiguous block in memory 
	i_ldc:  leading dimension of "array" containing blocks of C; must be
        	large enough to accomodate block of C with max number of rows 
        m_r_bcbuf:  pointer to pointer to buffered copy of subblocks
                    of matrix B; if = NULL, then routine will allocate buffer;
                    buffer must have same leading dimension as B; if mesh has
                    more than 1 col, buffer MUST have row dimension i_ldb and
                    col dimension of at least (max # cols of C within col of mesh)
        i_bcbuf_sz:  size of buffer for storing bcast matrix; must be at least
                     i_ldb*(max # cols of C within col of mesh)
        m_r_rollbuf:  pointer to pointer to buffered copy of subblocks of
                      rolled matrix; if = NULL, then routine will allocate 
                      buffer; buffer provided MUST have row dimension i_ldc and
                      col dimension at least equal to # rows of A in node
        i_rollbuf_sz:  size of buffer for storing rolled matrix; must be at least
                       i_ldc*(# rows of B in node)
	comm_row:  communicator for row containing calling node
	comm_col:  communicator for col containing calling node
	i_lin_dim:   function which computes number of cols (rows) in blocks
                     i_first_block through i_last_blk.  i_first_blk and
		     i_last_blk are calculated within BiMMeR routine
	*/
{
  /* --------------- */
  /* Local variables */
  /* --------------- */
  int
    i_node_bc,		   /* y-coordinate of node broadcasting */
    i_bc_1st_blk,	   /* col index of first matrix block of B to be bcast */
    i_bc_last_blk,	   /* col index of last matrix block of B to be bcast */
			   /* before multiply can occur */
    i_cols_bc_sent,        /* # cols already bc in current stage */
    i_bc_last_blk_sent,	   /* col index of last matrix block of B to be bcast */
    i_bc_1st_indx,	   /* index of first col of B to be bcast */
    i_bc_1st_indxbuf,	   /* index of first col of buffered B where data is sent */
    i_1st_rblk_roll,	   /* first block row index of C in node */
    i_1st_cblk_roll,	   /* first block col index of C's in node */
    i_cols_bc,		   /* number of cols of B in block being broadcast */
    i_l_k,		   /* number of rows of B in block being multiplied */
    i_l_m,		   /* number of rows of C in node */
    i_l_n,		   /* number of cols of C in node */ 
    i_ldbc,		   /* ld of matrix being broadcast */
    i_ldroll,		   /* ld of matrix being rolled */
    i_cols_rollbuf,	   /* number of cols in roll buffer */
    i_cols_bcbuf,	   /* number of cols in bc buffer */
    i_world_id             /* rank in MPI_COMM_WORLD */
      ;
  int
    i_roll_send,	   /* rank of sending node in roll */
    i_roll_recv,	   /* rank of receiving node in roll */
    i_index,		   /* rank of node in bcast process group */
    i_rolls,		   /* counter for number of rolls performed */ 
    i,
    j
      ;
  double
    *p_bcbuf,		   /* pointer used to store location of user bc buffer */
    **p_a,		   /* current copy of A */
    **p_b,		   /* current copy of B */
    **p_c,		   /* current copy of C */
    **p_roll,		   /* roll buffer */
    **p_bc,		   /* bc buffer */
    **p_temp,		   /* temporary pointer */
    **m_r_bt		   /* transposed matrix */
      ;
  int
    i_alloc_bcbuf,	   /* set to 1 if buffer for B needs to be allocated */
    i_alloc_rollbuf	   /* set to 1 if buffer for C needs to be allocated */
      ;
  char
    c_uplo,		   /* needed for call to dlacpy_ */
    c_transa = 'n',	   /* needed for call to dgemm_ */
    c_transb = 'n',	   /* needed for call to dgemm_ */
    ctemp[80]              /* hold filename for printing matrices */
      ;
  MPI_Comm
    comm_bc_direction,	   /* communicator for bcast */
    comm_roll_direction	   /* communicator for roll */
      ;
  FILE *debug_file;

#if PRISM_MPI_COLL
  MPI_Status
    status		   /* status for sendrecv */
      ;
#endif

  /* ------------------ */
  /* External functions */
  /* ------------------ */
  /* --------------- */
  /* Initializations */
  /* --------------- */
  /* Roll within row, bc within col */
  comm_bc_direction = comm_col;
  comm_roll_direction = comm_row;

/* Define PRT_M if you want to see matrices */
#ifdef PRT_M
  MPI_Comm_rank(MPI_COMM_WORLD, &i_world_id);
  sprintf(ctemp,"mm%d.d", i_world_id);
  debug_file = fopen(ctemp,"w");
#endif

#if PRISM_MPI_COLL
  i_roll_send = (i_node_col - 1 + i_msh_cols) % i_msh_cols;
  i_roll_recv = (i_node_col + 1 + i_msh_cols) % i_msh_cols;
#endif

  /* Info about blocks of A, B, C */
  i_1st_cblk_roll = (i_rblk_nw + i_node_row*i_blk_rows)%i_v_dim;
  i_1st_rblk_roll = 
     i_bc_1st_blk = (i_cblk_nw + i_node_col*i_blk_cols)%i_v_dim;

  /* Number of rows of A & C is fixed */
  i_l_m = (*i_lin_dim)(i_v_dim, i_g_m, i_cblk_nw, i_1st_cblk_roll,
		       (i_1st_cblk_roll+i_blk_rows-1)%i_v_dim, i_m_panelwidth,
		       i_m_offset, i_panel_spc);

  /* Number of cols of B in node is fixed */
  i_l_k = (*i_lin_dim)(i_v_dim, i_g_k, i_rblk_nw, i_1st_rblk_roll,
		       (i_1st_rblk_roll+i_blk_cols-1)%i_v_dim, i_k_panelwidth,
		       i_k_offset, i_panel_spc);

#ifdef PRT_M
  i_l_n = (*i_lin_dim)(i_v_dim, i_g_n, i_rblk_nw, i_1st_rblk_roll,
		       (i_1st_rblk_roll+i_blk_cols-1)%i_v_dim, i_n_panelwidth,
		       i_n_offset, i_panel_spc);
  fprintf (debug_file, "Initial matrix A\n");
  prism_v_prt_m(m_r_a, i_lda, i_l_k, debug_file);
  fprintf (debug_file, "Initial matrix B\n");
  prism_v_prt_m(m_r_b, i_ldb, i_l_k, debug_file);
  fprintf (debug_file, "Initial matrix C\n");
  prism_v_prt_m(m_r_c, i_ldc, i_l_n, debug_file);
#endif

  /* Info needed to set up communication buffers */
  i_ldbc = i_l_k;  
  i_cols_rollbuf = prism_i_ldwk(i_v_dim, i_g_n, i_cblk_nw, i_n_panelwidth,
                              i_n_offset, i_panel_spc, i_v_dim/i_msh_cols,
                              i_msh_cols);
  i_cols_bcbuf = prism_i_ldwk(i_v_dim, i_g_k, i_cblk_nw, i_k_panelwidth,
                              i_k_offset, i_panel_spc, i_v_dim/i_msh_cols,
                              i_msh_cols);
  i_ldroll = i_ldc;

  /* Allocate work space, if needed */
  i_alloc_bcbuf = ((m_r_bcbuf == NULL) ||
		   ((i_msh_rows > 1) && (i_bcbuf_sz < i_ldbc*i_cols_bcbuf)));
  i_alloc_rollbuf = ((m_r_rollbuf == NULL) ||
		     (i_rollbuf_sz < i_ldroll*i_cols_rollbuf));

  if (i_alloc_bcbuf) {
    if (i_msh_rows > 1) 
      m_r_bcbuf = prism_m_d_alloc_matrix(i_ldbc, i_cols_bcbuf);
    else {
      m_r_bcbuf = (double **) malloc(sizeof(double *));
      if (m_r_bcbuf == NULL) 
	prism_v_generror("prism_v_ra_bcb: couldn't allocate work space", brief);
    }
  }
  if (i_alloc_rollbuf) 
    m_r_rollbuf = prism_m_d_alloc_matrix(i_ldroll, i_cols_rollbuf);

  /* Form B^t so that bcast data will be in contiguous memory locations by col */

  if (i_l_k*i_ldb == 0)
    m_r_bt = (double **) malloc(sizeof(double *));
  else
    m_r_bt = prism_m_d_alloc_matrix(i_l_k, i_ldb);

  for (i = 0; i < i_ldb; i++) 
    for (j = 0; j < i_l_k; j++) 
      m_r_bt[i][j] = m_r_b[j][i];

  /* Initialize pointers */
  p_b = m_r_bcbuf;
  p_bc = m_r_bt;

  /* Calculate cols in rolled matrix C */
  i_l_n = (*i_lin_dim)(i_v_dim, i_g_n, i_rblk_nw, i_1st_rblk_roll,
		       (i_1st_rblk_roll+i_blk_cols-1)%i_v_dim, i_n_panelwidth,
		       i_n_offset, i_panel_spc);

  /* Multiply C by beta */
  prism_v_scl_mtrx(i_ldc, i_l_n, m_r_c, r_beta);

#ifdef PRT_M
  fprintf (debug_file, "Scaled matrix C:  r_beta = %e\n", r_beta);
  prism_v_prt_m(m_r_c, i_ldc, i_l_n, debug_file);
  fprintf (debug_file, "Transposed matrix B\n");
  prism_v_prt_m(m_r_bt, i_l_k, i_ldb, debug_file);
#endif

  /* Fix pointers to point to buffer containing blocks of rolled matrix. */
  /* Multiplication is always performed with blocks starting at          */
  /* bcbuf[0] and rollbuf[0].                                            */

  if (i_msh_cols%2 != 0) { /* do copy if needed to make rolled matrix */
			   /* end up in place */
    c_uplo = 'a'; 
    p_roll = m_r_c;
    prism_v_dlacpy_(&c_uplo, (fint *) &i_ldroll, (fint *) &i_cols_rollbuf,
		    p_roll[0], (fint *) &i_ldroll, m_r_rollbuf[0],
		    (fint *) &i_ldroll);

    /* Switch pointers */
    p_c = m_r_rollbuf;
    p_a = m_r_a;
  }
  else {
    p_a = m_r_a;
    p_c = m_r_c;
    p_roll = m_r_rollbuf;
  }

  /* --- */
  /* BMR */
  /* --- */

  /* Need to know who you are in your list to know if you will need to
     do a copy before the broadcast */
  MPI_Comm_rank(comm_bc_direction, &i_index);

  i_rolls = 0;
  do {/* Stages of BMR until blocks being rolled return to nodes where */
      /* they started */

    /* --------- */
    /* Broadcast */
    /* --------- */

    i_cols_bc_sent = 0;
    /* Last block needed to be bcast */
    i_bc_last_blk = (i_bc_1st_blk + i_blk_cols - 1)%i_v_dim;
    do {/* Until all blocks needed are broadcast */
      i_node_bc = ((i_bc_1st_blk - i_rblk_nw + i_v_dim)%i_v_dim)/i_blk_rows;
      /* If last block needed is in bc node send rest of blocks needed
	 else send remaining blocks on bc node */
      i_bc_last_blk_sent
	= (((i_bc_last_blk - i_rblk_nw + i_v_dim)%i_v_dim/i_blk_rows)
	   != i_node_bc) ?
	     (i_rblk_nw + i_blk_rows*(i_node_bc + 1) - 1)%i_v_dim
	       : i_bc_last_blk;

      i_cols_bc = (*i_lin_dim)(i_v_dim, i_g_n, i_rblk_nw, i_bc_1st_blk,
			       i_bc_last_blk_sent, i_n_panelwidth, i_n_offset,
			       i_panel_spc);

      /* Compute starting column indices for send and receive buffers */

      /* If bcast data does not start with first column of B, starting */
      /* index for send buffer = column index of first col of B to be bcast */
      i_bc_1st_indx = (i_1st_cblk_roll == i_bc_1st_blk) ?
	0 : (*i_lin_dim)(i_v_dim, i_g_n, i_cblk_nw, i_1st_cblk_roll,
			 (i_bc_1st_blk-1+i_v_dim)%i_v_dim, i_n_panelwidth,
			 i_n_offset, i_panel_spc);

      /* If first round of bcast, start receive at beginning of receive buffer */
      /* otherwise, concatenate onto end of data received in first round */
      i_bc_1st_indxbuf = (i_1st_rblk_roll == i_bc_1st_blk) ?
	0 : i_cols_bc_sent;

      /* Perform broadcast, if necessary */
#if PRISM_TIME
      prism_d_b0_time = MPI_Wtime();
      prism_d_b1_time = MPI_Wtime();
#endif
      if (i_msh_rows > 1 && i_cols_bc > 0) { 
	/* If you are the root node of the bcast then you need to copy */
	/* data on your own node before the broadcast */
	if (i_index == i_node_bc)
	  prism_v_copy(p_bc[i_bc_1st_indx], 
		       m_r_bcbuf[0]+i_bc_1st_indxbuf*i_l_k,
		       i_ldbc*i_cols_bc, MPI_DOUBLE);

	PRISM_BCAST(m_r_bcbuf[0]+i_bc_1st_indxbuf*i_l_k, i_ldbc*i_cols_bc,
		    MPI_DOUBLE, i_node_bc, comm_bc_direction);
      }
      else if (i_msh_rows == 1) { 
	p_bcbuf = m_r_bcbuf[0];
	m_r_bcbuf[0] = p_bc[i_bc_1st_indx];
      }
#if PRISM_TIME
      prism_d_b_time += MPI_Wtime() - prism_d_b1_time -
	(prism_d_b1_time - prism_d_b0_time);
#endif

      i_bc_1st_blk = (i_bc_last_blk_sent + 1)%i_v_dim;
      i_cols_bc_sent = i_cols_bc;
    } while ((i_bc_last_blk_sent - i_1st_rblk_roll + 1 + i_v_dim)%i_v_dim
	     < i_blk_cols);

    /* -------- */
    /* Multiply */
    /* -------- */

    /* Calculate cols in rolled matrix C */
    i_l_n = (*i_lin_dim)(i_v_dim, i_g_n, i_rblk_nw, i_1st_rblk_roll,
			 (i_1st_rblk_roll+i_blk_cols-1)%i_v_dim, i_n_panelwidth,
			 i_n_offset, i_panel_spc);

#if PRISM_TIME
    prism_d_m0_time = MPI_Wtime();
    prism_d_m1_time = MPI_Wtime();
#endif

    if (i_l_k != 0)   /* Leading dim of B must be nonzero */
      dgemm_(&c_transa, &c_transb, (fint *) &i_l_m, (fint *) &i_l_n,
	     (fint *) &i_l_k, &r_alpha, p_a[0], (fint *) &i_lda, p_b[0],
	     (fint *) &i_l_k, &r_one, p_c[0], (fint *) &i_ldc);

#if PRISM_TIME
    prism_d_m_time += MPI_Wtime() - prism_d_m1_time -
      (prism_d_m1_time - prism_d_m0_time);
#endif

#ifdef PRT_M
    fprintf (debug_file, "Matrix C after DGEMM: i_rolls = %d\n", i_rolls);
    prism_v_prt_m(m_r_c, i_ldc, i_l_n, debug_file);
#endif
    
    /* ---- */
    /* Roll */  
    /* ---- */

    i_1st_rblk_roll = (i_1st_rblk_roll + i_blk_cols)%i_v_dim;
    p_temp = p_c;

#ifdef PRT_M
    fprintf (debug_file, "Temp/Rollbuffer (C) before Sendrec\n");
    prism_v_prt_m(p_temp, i_ldroll, i_cols_rollbuf, debug_file);
#endif

#if PRISM_TIME
    prism_d_r0_time = MPI_Wtime();
    prism_d_r1_time = MPI_Wtime();
#endif

#if PRISM_MPI_COLL
    MPI_Sendrecv(p_temp[0], i_cols_rollbuf*i_ldroll, MPI_DOUBLE, i_roll_send, 17,
		 p_roll[0], i_cols_rollbuf*i_ldroll, MPI_DOUBLE, i_roll_recv, 17,
		 comm_roll_direction, &status);
#else
    prism_v_skew(p_temp[0], p_roll[0], i_cols_rollbuf*i_ldroll,
		 i_cols_rollbuf*i_ldroll, MPI_DOUBLE, (int)-1,
		 comm_roll_direction);
#endif
#if PRISM_TIME
    prism_d_r_time += MPI_Wtime() - prism_d_r1_time -
      (prism_d_r1_time - prism_d_r0_time);
#endif

#ifdef PRT_M
    fprintf (debug_file, "Rollbuffer (C) after Sendrec\n");
    prism_v_prt_m(p_roll, i_ldroll, i_cols_rollbuf, debug_file);
#endif

    i_rolls ++;

    /* Switch pointers for rolled matrix and buffer */
    p_c = p_roll;
    p_roll = p_temp;

  } while (i_rolls < i_msh_cols);

  /* Free up buffers for A and B if allocated by BiMMeR routine */
  if (i_l_k*i_ldb == 0)
    free(m_r_bt);
  else
    prism_v_free_matrix(m_r_bt);
  if (i_alloc_bcbuf) {
    if (i_msh_rows > 1)
      prism_v_free_matrix(m_r_bcbuf);
    else 
      free(m_r_bcbuf);
  }
  /* Restore user's pointer to work space */
  else if (i_msh_rows == 1)
    m_r_bcbuf[0] = p_bcbuf;
  if (i_alloc_rollbuf)
    prism_v_free_matrix(m_r_rollbuf);
  
  return;
}
