/*
   bit_matrix_fns.c

   Functions for bit matrices.
*/

/* $Id: bit_matrix_fns.c,v 1.2 1997/05/06 22:01:39 thc Exp $
 * $Log: bit_matrix_fns.c,v $
 * Revision 1.2  1997/05/06 22:01:39  thc
 * Added copyright notice.
 *
 * Revision 1.1  1997/04/18 18:23:59  thc
 * Initial revision
 *
 */

/*
 * Copyright (C) 1997, Thomas H. Cormen, thc@cs.dartmouth.edu
 *
 * This software may be freely copied, modified, and redistributed,
 * provided that this copyright notice is preserved on all copies.
 *
 * There is no warranty or other guarantee of fitness for this
 * software, and it is provided solely "as is".  Bug reports or fixes
 * may be sent to the author, who may or may not act on them as he
 * desires.
 *
 * Rights are granted to use this software in any non-commercial
 * enterprise.  For commercial rights to this software, please contact
 * the author.
 */

#include <stdlib.h>
#include <stdio.h>
#include <strings.h>
#include "bit_matrix_types.h"
#include "bit_matrix_fns.h"

static void swap_cols(bit_matrix A, int i, int j);

/**********************************************************************/

bit_matrix identity_matrix(bit_matrix A, int n)
/*
   Make A be an m x n identity matrix.
   We don't care what m is.
*/
{
  int j;

  for (j = 0; j < n; j++)
    A[j] = ((matrix_column) 1) << j; /* Want row j, column j to be 1. */

  return A;
}

/**********************************************************************/

int is_identity_matrix(bit_matrix A, int n)
/* Return whether n x n matrix A is an identity matrix. */
{
  int j;

  for (j = 0; j < n; j++, A++)
    if (*A != (((matrix_column) 1) << j))
      return 0;

  return 1;
}

/**********************************************************************/

bit_matrix bit_matrix_multiply(bit_matrix C, bit_matrix A, bit_matrix B,
			       int n)
/*
   Multiply two bit matrices to form the product C = A x B.
   n is the number of columns of matrices C and B.
   Allows C and B to point to the same matrix.  (C and A must be distinct.)
*/
{
  matrix_column *A_col, B_col, C_col;
  int i, j;

  for (i = 0; i < n; i++)
    {
      /* Compute the ith column of C. */
      for (A_col = A, C_col = 0, B_col = B[i], j = 0;
	   B_col != 0;
	   B_col >>= 1, A_col++, j++)
	  if (B_col & 1)	/* if B[i][j] = 1 */
	    C_col ^= *A_col;	/* then C[*][i] = C[*][i] ^ A[*][j] */

      C[i] = C_col;
    }

  return C;
}

/**********************************************************************/

matrix_column bit_matrix_vector_multiply(bit_matrix A, matrix_column x, int n)
/* Return the matrix-vector product Ax, where n is the size of x
   and the number of columns of A. */
{
  matrix_column result = 0;
  int j;

  for (j = 0; j < n && x != 0; j++, A++, x >>= 1)
    if ((x & 1) != 0)
      result ^= *A;

  return result;
}

/**********************************************************************/

int invert_bit_matrix(bit_matrix A_inv, bit_matrix A, int n)
/*
   Invert a square bit matrix A, which is n x n.
   Return 1 if the matrix is invertible, 0 if singular.
*/
{
  int i, j;
  bit_matrix A_copy = dup_bit_matrix(A, n); /* copy of A, since we destroy */

  identity_matrix(A_inv, n);	/* start with A_inv = I */
  
  /* We attempt to reduce the matrix to the identity matrix using
     elementary column operations, performing the same operations in
     the same order on an identity matrix, and finish with the
     identity matrix transformed into the inverse of the original. */

  for (i = 0; i < n; i++)
    {
      /* Find a column j >= i s.t. A_copy[i][j] is 1.  Then swap columns i
	 and j in A and A_inv.  If no such j, then A is noninvertible. */
      for (j = i; j < n && ((A_copy[j] >> i) & 1) == 0; j++)
	;
      if (j >= n)
	return 0;		/* noninvertible */
      else if (j > i)
	{
	  /* Swap columns i and j in A_copy and A_inv. */
	  swap_cols(A_copy, i, j);
	  swap_cols(A_inv, i, j);
	}

      /* Now clear row i by adding column i to any column containing a
	 1 in row i in A_copy.  Do the same operations on A_inv. */
      for (j = 0; j < n; j++)
	if (((A_copy[j] >> i) & 1) && (j != i))
	  {
	    A_copy[j] ^= A_copy[i];
	    A_inv[j] ^= A_inv[i];
	  }
    }

  free_bit_matrix(A_copy);		/* Free the copy */

  return 1;			/* Success! */
}

/**********************************************************************/

bit_matrix allocate_bit_matrix(int n)
/*
   Allocate an n x n bit matrix and return it.
*/
{
  bit_matrix A;

  if (n <= 0)
    n = 1;

  if ((A = calloc(n, sizeof(matrix_column))) == NULL)
    {
      fprintf(stderr, "allocate_bit_matrix: out of space, exiting\n");
      exit(1);
    }

  return A;
}

/**********************************************************************/

bit_matrix copy_bit_matrix(bit_matrix target, bit_matrix source, int n)
/*
   Make a copy of an n x n bit matrix and return the copy.
*/
{
  bcopy((char *)source, (char *)target, n * sizeof(matrix_column));

  return target;
}

/**********************************************************************/

bit_matrix dup_bit_matrix(bit_matrix A, int n)
/* Allocate a new n x n bit matrix, copy A into it, and return the
   copy.
*/
{
  return copy_bit_matrix(allocate_bit_matrix(n), A, n);
}

/**********************************************************************/

void free_bit_matrix(bit_matrix A)
/*
   Free a bit matrix.
*/
{
  if (A != NULL)
    free(A);
}

/**********************************************************************/

bit_matrix extract_bit_submatrix(bit_matrix target, bit_matrix source,
				 int start_row, int start_col,
				 int rows, int cols)
/*
   Extract a submatrix of source into target (already allocated).
   Return target.
*/
{
  int j;
  matrix_column mask = ~((~(matrix_column) 0) << rows); /* to get bits of each column */

  for (j = 0; j < cols; j++)
    target[j] = (source[start_col + j] >> start_row) & mask;

  return target;
}

/**********************************************************************/

bit_matrix remove_bit_matrix_rows(bit_matrix target, bit_matrix source,
				  int cols, int start_row, int rows)
/*
   Remove a set of contiguous rows from a bit matrix.
   Return the result.
   OK for source and target to be the same matrix.
*/
{
  int j;
  matrix_column bottom_rows_mask = (~(matrix_column) 0) << start_row;
  matrix_column top_rows_mask = ~bottom_rows_mask;

  for (j = 0; j < cols; j++)
    target[j] = (source[j] & top_rows_mask) |
      ((source[j] >> rows) & bottom_rows_mask);

  return target;
}

/**********************************************************************/

int find_bit_matrix_basis(matrix_column *basis, bit_matrix A,
			  int rows, int cols)
/*
   Find a set of basis columns for A, which is rows x columns.
   Bit j of basis is set to 1 iff jth column of A is in the basis.
   Return value is size of basis == rank A.

   Note: Although all column bases for a given matrix have the same
   size, they're not all equally useful.  Bases favoring columns from
   the right are more useful in BMMC factoring algorithms.  So we'll
   actually work from right to left rather than left to right.
*/
{
  int rank = 0;			/* rank of A */
  int i, j, k;			/* row and column indices */
  bit_matrix A_copy = dup_bit_matrix(A, cols); /* copy of A, since we destroy */
  matrix_column this_column;	/* column to add into other columns */

  *basis = 0;			/* start with empty basis */

  for (i = 0; i < rows; i++)
    {
      /* Find a column index j for which A_copy[i][j] == 1.
	 If no such j, keep going. */
      for (j = cols-1; j >= 0 && ((A_copy[j] >> i) & 1) == 0; --j)
	;
      if (j >= 0)
	{
	  /* Found one...add it into every column k s.t. A_copy[i][k] == 1. */
	  this_column = A_copy[j];

	  for (k = cols-1; k >= 0; --k)
	    if ((A_copy[k] >> i) & 1)
	      A_copy[k] ^= this_column;

	  /* And ring up j as a basis column. */
	  *basis |= ((matrix_column) 1) << j;
	  rank++;
	}
    }

  free_bit_matrix(A_copy);

  return rank;
}

/**********************************************************************/

matrix_column find_dependencies(bit_matrix dep, bit_matrix A, int m, int n)
/* Determine, for each column of an m x n matrix A, which columns of a
   basis it depends on.
   Output is an n x n matrix dep, where dep[i][j] == 1 iff column
   j depends at least in part on column i.
   Columns in the basis are all 0 in dep.
   Return value is which columns are in the basis.
*/
{
  int i, j, k;			/* row and column indices */
  bit_matrix A_copy = dup_bit_matrix(A, n);
  matrix_column basis = 0;	/* which columns are in the basis */

  /* No dependencies yet. */
  for (j = 0; j < n; j++)
    dep[j] = 0;

  /* Go through the columns from right to left.  If the current column
     j isn't dependent on the other columns known to be independent at
     the moment, consider it as one of the independent columns.  In so
     doing, eliminate it from every column to its left that could
     depend on it, record these eliminations in dep, and set its own
     entry in dep to be just itself.  */
  for (j = n-1; j >= 0; --j)
    {
      /* Has column j already been eliminated? */
      if (A_copy[j] != 0)
	{
	  /* If not, welcome column j to the set of independent columns.
	     Find the row i with the first i in column j. */
	  basis |= ((matrix_column) 1) << j;

	  for (i = 0; !((A_copy[j] >> i) & 1); i++)
	    ;

	  /* Eliminate row j from every column k to its left that has
             A_copy[i][k] == 1, and record this dependency in dep. */
	  for (k = 0; k < j; k++)
	    {
	      if ((A_copy[k] >> i) & 1)
		{
		  A_copy[k] ^= A_copy[j];
		  dep[k] ^= dep[j] | (((matrix_column) 1) << j);
		}
	    }

	  /* Column j depends only on itself.  We make it 0 in dep. */
	  dep[j] = 0;
	}
    }

  free_bit_matrix(A_copy);

  return basis;
}

/**********************************************************************/

static void swap_cols(bit_matrix A, int i, int j)
/*
   Swap columns i and j of bit matrix A.
*/
{
  matrix_column c = A[i];
  A[i] = A[j];
  A[j] = c;
}

/**********************************************************************/

void print_bit_matrix(bit_matrix A, int m, int n, char *name)
/*
   Print an m x n bit matrix, with its name.
*/
{
  int i, j;
  
  printf("\n %s:\n", name);

  for (i = 0; i < m; i++)
    {
      printf("   ");
      for (j = 0; j < n; j++)
	printf("%d ", (int) ((A[j] >> i) & 1));
      printf("\n");
    }
}
