/* 
   COPYRIGHT U.S. GOVERNMENT 
   
   This software is distributed without charge and comes with no
   warranty.

   Please feel free to send questions, comments, and problem reports
   to prism@super.org.  */

/* PURPOSE
   =======
   Node program for generalized broadcast-multiply-roll algorithm 

   This function computes the product C = (r_alpha)op(A) * op(B) +
   (r_beta)C, where A = [a(i,j)] is an i_g_m x i_g_k matrix, B =
   [b(i,j)] is an i_g_k x i_g_n matrix, and C = [c(i,j)] is an i_g_m x
   i_g_n matrix on a 2-dimensional cartesian mesh.  op(A) can either
   be A or its transpose, A'. op(B) can either be B or its transpose,
   B'.  The operation A'B' is not supported yet.
       
   The mesh is passed as an MPI communicator.  The matrix-matrix
   multiplication will be treated as if the submesh is really i_v_dim
   x i_v_dim, where i_v_dim must be a multiple of lcm(# rows of mesh,
   # cols of mesh).  Each matrix is assumed to be composed of
   i_v_dim**2 blocks (A(i,j), B(i,j), C(ij)). For blocked column or
   panelled torus wrap across the virtual blocks, the routine i_ldim.c
   can be used for i_lin_dim.

   The code supports the 2 cases:  i_msh_rows >= i_msh_cols and 
                                   i_msh_rows < i_msh_cols.

   Within each node, the layout of op(A), op(B), and C must be
   conformal.  The exact correspondence between the individual rows
   and columns of the blocks and the original matrices does not need
   to be known for the purposes of performing the multiplication, but
   the **number** of rows and columns is needed.
   */

/* CONDITIONAL COMPILATION */
#if PRISM_TIME
/* causes global variables with timers to be used */
#endif

/* INCLUDE FILES */
#include <stdlib.h>
#include <math.h>
#include "stdeig.h"
#include "mm.h"

/* GLOBAL VARIABLES */
#if PRISM_TIME
/* variables just for timing */
extern double
  prism_d_b0_time, prism_d_b1_time, prism_d_b_time, 
  prism_d_r0_time, prism_d_r1_time, prism_d_r_time,
  prism_d_m0_time, prism_d_m1_time, prism_d_m_time;
#endif

void prism_v_bimmer(int i_v_dim, int i_rblk_nw, int i_cblk_nw, int i_g_m,
		    int i_g_n, int i_g_k, int i_m_panelwidth, int i_n_panelwidth,
		    int i_k_panelwidth, int i_m_offset, int i_n_offset,
		    int i_k_offset, int i_panel_spc, E_BOOLEAN e_transa,
		    E_BOOLEAN e_transb, double r_alpha, P_R_MATRIX m_r_a,
		    int i_lda, P_R_MATRIX m_r_b, int i_ldb, double r_beta,
		    P_R_MATRIX m_r_c, int i_ldc, P_R_MATRIX m_r_bcbuf,
		    int i_bcbuf_sz, P_R_MATRIX m_r_rollbuf, int i_rollbuf_sz,
		    MPI_Comm comm_row, MPI_Comm comm_col, 
		    int (*i_lin_dim)(int i_v_dim, int i_g_dim, int i_blk_nw,
				     int i_first_blk, int i_last_blk,
				     int i_panelwidth, int i_offset,
				     int i_panel_spc))
   
     /* PARAMETERS
	==========
	i_v_dim:  virtual dimension used for blocking of matrices 
	i_rblk_nw:  row index of block in nw corner of submesh 
	i_cblk_nw:  col index of block in nw corner of submesh 
	i_g_m:  number of rows of op(A) and of C
	i_g_n:  number of cols of op(B) and of C 
	i_g_k:  number of cols of op(A) and rows of op(B)
	i_m_panelwidth:  panel width for torus wrap of rows of op(A) and of C
	i_n_panelwidth:  panel width for torus wrap of op(B) and of C 
	i_k_panelwidth:  panel width for torus wrap of cols of op(A) and
	                 rows of op(B)
	i_m_offset:  row offset for size of first logical block of op(A) and of C
                     (has the opposite effect on the last logical block)
	i_n_offset:  col offset for size of first logical block of op(B) and of C
                     (has the opposite effect on the last logical block)
	i_k_offset:  offset for size of first logical col block of op(A) and
	             row block of op(B) (has the opposite effect on the last
		     logical block)
	i_panel_spc:  number of virtual nodes to step to next contiguous panel
	e_transa:  = false, if op(A) = A, = true if op(A) = A'
	e_transb:  = false, if op(B) = B, = true not implemented at this time
	r_alpha:  constant to multiply op(A) by in matrix-matrix multiplication
	r_beta:  constant to multiply C by in matrix-matrix multiplication
	m_r_a:  pointer to pointer to subblocks of matrix A; buffer for A MUST
        	be a contiguous block in memory
	i_lda:  leading dimension of "arrays" containing blocks of A
 	m_r_b:  pointer to pointer to subblocks of matrix B; buffer for B MUST
	        be a contiguous block in memory 
	i_ldb:  leading dimension of "arrays" containing blocks of B; must be
        	large enough to accomodate block of B with max number of rows
	m_r_c:  pointer to pointer to subblocks of matrix C; buffer for C MUST
        	be a contiguous block in memory 
	i_ldc:  leading dimension of "array" containing blocks of C; must be
        	large enough to accomodate block of C with max number of rows 
	m_r_bcbuf:  pointer to pointer to buffered copy of subblocks
                    of matrix A; if = NULL, then routine will allocate buffer;
		    buffer must have same leading dimension as A; buffer MUST be
		    a contiguous block in memory; must be large enough to
		    accomodate block with most cols of A
        i_bcbuf_sz:  size of buffer for storing bcast matrix 
	m_r_rollbuf:  pointer to pointer to buffered copy of subblocks of
	              rolled matrix; if = NULL, then routine will allocate buffer;
	              buffer must have same leading dimension as rolled matrix
		      and at least as many cols; buffer MUST be a contiguous
		      block in memory
        i_rollbuf_sz:  size of buffer for storing rolled matrix 
	comm_row:  communicator for row containing calling node
	comm_col:  communicator for col containing calling node
	i_lin_dim:   function which computes number of cols (rows) in blocks
                     i_first_block through i_last_blk.  i_first_blk and
		     i_last_blk are calculated within bimmer routine
		     */
{

  /* --------------- */
  /* Local variables */
  /* --------------- */
  char
    c_transa,		   /* needed for call to dgemm_ */
    c_transb		   /* needed for call to dgemm_ */
      ;
  int
    i_msh_rows,		   /* # rows in mesh */
    i_msh_cols,		   /* # cols in mesh */
    i_node_row,		   /* rank of node in column */
    i_node_col,		   /* rank of node in row */
    i_blk_rows,		   /* number of rows of blocks per node */
    i_blk_cols		   /* number of cols of blocks per node */
      ;

  /* ------------------ */
  /* External functions */
  /* ------------------ */

  /* global initializations */
  prism_v_init_var();

  /* --------------- */
  /* Initializations */
  /* --------------- */
  /* check for quick return */

  /* mesh dimensions */
  MPI_Comm_size(comm_row, &i_msh_cols);
  MPI_Comm_size(comm_col, &i_msh_rows);

  /* node row and col coordinates within mesh */
  MPI_Comm_rank(comm_col, &i_node_row);
  if (i_node_row == MPI_UNDEFINED) 
    prism_v_generror("prism_v_bimmer: node not in col communicator given",
		     brief);
  MPI_Comm_rank(comm_row, &i_node_col);
  if (i_node_col == MPI_UNDEFINED) 
    prism_v_generror("prism_v_bimmer: node not in row communicator given",
		     brief);

  /* just do multiplication if only one node involved */
  if (i_msh_rows*i_msh_cols == 1) {
    c_transa = (e_transa == true) ? 't' : 'n';  
    c_transb = (e_transb == true) ? 't' : 'n';				  
    dgemm_(&c_transa, &c_transb, (fint *) &i_g_m, (fint *) &i_g_n,
	   (fint *) &i_g_k, &r_alpha, m_r_a[0], (fint *) &i_lda, m_r_b[0],
	   (fint *) &i_ldb, &r_beta, m_r_c[0], (fint *) &i_ldc);
    return;
  }

  /* block configuration on each node */
  i_blk_rows = i_v_dim/i_msh_rows;
  i_blk_cols = i_v_dim/i_msh_cols;
  if ((i_blk_rows*i_msh_rows != i_v_dim) || (i_blk_cols*i_msh_cols != i_v_dim))
    prism_v_generror("prism_v_bimmer: illegal virtual dimension", brief);

  if (e_transa == false) {
    if (e_transb == false) {
      if (i_msh_rows >= i_msh_cols) {
	/* broadcast A, roll B */
	prism_v_bca_rb(i_v_dim, i_msh_rows, i_msh_cols, i_node_row, i_node_col,
		       i_rblk_nw, i_cblk_nw, i_blk_rows, i_blk_cols, i_g_m,
		       i_g_n, i_g_k, i_m_panelwidth, i_n_panelwidth,
		       i_k_panelwidth, i_m_offset, i_n_offset, i_k_offset,
		       i_panel_spc, r_alpha, m_r_a, i_lda, m_r_b, i_ldb,
		       r_beta, m_r_c, i_ldc, m_r_bcbuf, i_bcbuf_sz, m_r_rollbuf,
		       i_rollbuf_sz, comm_row, comm_col, prism_i_ldim);
      }
      else
	/* roll A, broadcast B */
	prism_v_ra_bcb(i_v_dim, i_msh_rows, i_msh_cols, i_node_row, i_node_col,
		       i_rblk_nw, i_cblk_nw, i_blk_rows, i_blk_cols, i_g_m,
		       i_g_n, i_g_k, i_m_panelwidth, i_n_panelwidth,
		       i_k_panelwidth, i_m_offset, i_n_offset, i_k_offset,
		       i_panel_spc, r_alpha, m_r_a, i_lda, m_r_b, i_ldb,
		       r_beta, m_r_c, i_ldc, m_r_bcbuf, i_bcbuf_sz, m_r_rollbuf,
		       i_rollbuf_sz, comm_row, comm_col, prism_i_ldim);
    }
    else {
      if (i_msh_rows >= i_msh_cols)
	/* roll B, broadcast C */
	prism_v_rb_bcc(i_v_dim, i_msh_rows, i_msh_cols, i_node_row, i_node_col,
		       i_rblk_nw, i_cblk_nw, i_blk_rows, i_blk_cols, i_g_m,
		       i_g_n, i_g_k, i_m_panelwidth, i_n_panelwidth,
		       i_k_panelwidth, i_m_offset, i_n_offset, i_k_offset,
		       i_panel_spc, r_alpha, m_r_a, i_lda, m_r_b, i_ldb,
		       r_beta, m_r_c, i_ldc, m_r_bcbuf, i_bcbuf_sz, m_r_rollbuf,
		       i_rollbuf_sz, comm_row, comm_col, prism_i_ldim);
      else
	/* broadcast B, roll C */
	prism_v_bcb_rc(i_v_dim, i_msh_rows, i_msh_cols, i_node_row, i_node_col,
		       i_rblk_nw, i_cblk_nw, i_blk_rows, i_blk_cols, i_g_m,
		       i_g_n, i_g_k, i_m_panelwidth, i_n_panelwidth,
		       i_k_panelwidth, i_m_offset, i_n_offset, i_k_offset,
		       i_panel_spc, r_alpha, m_r_a, i_lda, m_r_b, i_ldb,
		       r_beta, m_r_c, i_ldc, m_r_bcbuf, i_bcbuf_sz, m_r_rollbuf,
		       i_rollbuf_sz, comm_row, comm_col, prism_i_ldim);
    }
  }
  else {
    if (e_transb == false) {
      if (i_msh_rows >= i_msh_cols)
	/* broadcast A, roll C */
	prism_v_bca_rc(i_v_dim, i_msh_rows, i_msh_cols, i_node_row, i_node_col,
		       i_rblk_nw, i_cblk_nw, i_blk_rows, i_blk_cols, i_g_m,
		       i_g_n, i_g_k, i_m_panelwidth, i_n_panelwidth,
		       i_k_panelwidth, i_m_offset, i_n_offset, i_k_offset,
		       i_panel_spc, r_alpha, m_r_a, i_lda, m_r_b, i_ldb,
		       r_beta, m_r_c, i_ldc, m_r_bcbuf, i_bcbuf_sz, m_r_rollbuf,
		       i_rollbuf_sz, comm_row, comm_col, prism_i_ldim); 
      else
	/* roll A, broadcast C */
	prism_v_ra_bcc(i_v_dim, i_msh_rows, i_msh_cols, i_node_row, i_node_col,
		       i_rblk_nw, i_cblk_nw, i_blk_rows, i_blk_cols, i_g_m,
		       i_g_n, i_g_k, i_m_panelwidth, i_n_panelwidth,
		       i_k_panelwidth, i_m_offset, i_n_offset, i_k_offset,
		       i_panel_spc, r_alpha, m_r_a, i_lda, m_r_b, i_ldb,
		       r_beta, m_r_c, i_ldc, m_r_bcbuf, i_bcbuf_sz, m_r_rollbuf,
		       i_rollbuf_sz, comm_row, comm_col, prism_i_ldim);
    }
    else {
      /* don't allow A'B' */
      prism_v_generror("prism_v_bimmer: A'B' not implemented", brief);
    }
  }

  return;
}
