// This may look like C code, but it is really -*- C++ -*-

/*
 *  Copyright (c) University of Aizu
 */

#include <stdarg.h>
#include "Cmatrix.h"

#ifdef __GNUG__
#pragma implementation
#endif

// constructors and destructors
Cmatrix::Cmatrix(int rows, int cols, const complex& value)
{
  m = new Cvector*[r = rows];
  for (int i = 0; i < r; i ++)
    m[i] = new Cvector((c = cols), value);
}

Cmatrix::Cmatrix(const Cmatrix& source)
{
  c = source.c;
  r = source.r;
  m = new Cvector*[r];
  for (int i = 0; i < r; i++)
    m[i] = new Cvector(*(source.m[i]));
}

Cmatrix::~Cmatrix()
{
  for (int i = 0; i < r; i++)
    delete m[i];
  delete [] m;
}


// elementary operations
Cvector& Cmatrix::operator [](int i) const
{
  return *(m[i]);
}

int      Cmatrix::rows() const
{
  return r;
}

int      Cmatrix::cols() const
{
  return c;
}

void Cmatrix::set(double initial ...) 
{
  va_list ap;
  double value;

  va_start(ap, initial);
  m[0][0] = complex(initial);
  for (int i = 0; i < r; i++)
    for (int j = 0; j < c; j++) 
      if (i || j) {
	value = va_arg(ap, double); 
	m[i][j] = complex(value);
      }
  va_end(ap);
}

void Cmatrix::set(int initial ...) 
{
  va_list ap;
  int value;


  va_start(ap, initial);
  m[0][0] = complex(double(initial));
  for (int i = 0; i < r; i++)
    for (int j = 0; j < c; j++) 
      if (i || j) {
	value = va_arg(ap, int); 
	(*(m[i]))[j] = complex(double(value));
      }
  va_end(ap);
}

void Cmatrix::set(float initial ...) 
{
  va_list ap;
  float value;

  va_start(ap, initial);
  m[0][0] = complex(double(initial));
  for (int i = 0; i < r; i++)
    for (int j = 0; j < c; j++) 
      if (i || j) {
	value = va_arg(ap, double); 
	(*(m[i]))[j] = complex(double(value));
      }

  va_end(ap);
}

// Assignment
Cmatrix& Cmatrix::operator = (const Cmatrix& y)
{
  for (int i = 0; i < r; i++)
    *(m[i]) = complex(0);
  return *this;
}

Cmatrix& Cmatrix::operator = (const Cvector& y)
{
  for (int i = 0; i < (r > c ? c : r); i++) {
    *(m[i]) = complex(0);
    (*(m[i]))[i] = y[i];
  }
  return *this;
}

Cmatrix& Cmatrix::operator = (const complex& y)
{
  for (int i = 0; i < (r > c ? c : r); i++) {
    *(m[i]) = complex(0);
    (*(m[i]))[i] = y;
  }
  return *this;
}

// Basic Math
Cmatrix& Cmatrix::operator += (const Cmatrix& y)
{
  for (int i = 0; i < (r > c ? c : r); i++)
    *(m[i]) += *(y.m[i]);
  return *this;
}

Cmatrix& Cmatrix::operator += (const Cvector& y)
{
  for (int i = 0; i < (r > c ? c : r); i++)
    (*(m[i]))[i] += y[i];
  return *this;
}

Cmatrix& Cmatrix::operator += (const complex& y)
{
  for (int i = 0; i < (r > c ? c : r); i++)
    (*(m[i]))[i] += y;
  return *this;
}

Cmatrix& Cmatrix::operator -= (const Cmatrix& y)
{
  for (int i = 0; i < (r > c ? c : r); i++)
    *(m[i]) -= y[i];
  return *this;
}

Cmatrix& Cmatrix::operator -= (const Cvector& y)
{
  for (int i = 0; i < (r > c ? c : r); i++)
    (*(m[i]))[i] -= y[i];
  return *this;
}

Cmatrix& Cmatrix::operator -= (const complex& y)
{
  for (int i = 0; i < (r > c ? c : r); i++)
    (*(m[i]))[i] -= y;
  return *this;
}

Cmatrix& Cmatrix::operator *= (const complex& y)
{
  for (int i = 0; i < (r > c ? c : r); i++)
    *(m[i]) *= y;
  return *this;
}

Cmatrix& Cmatrix::operator /= (const complex& y)
{
  for (int i = 0; i < (r > c ? c : r); i++)
    *(m[i]) /= y;
  return *this;
}

Cmatrix& Cmatrix::operator -  ()
{
  for (int i = 0; i < (r > c ? c : r); i++)
    *(m[i]) *= -1;
  return *this;
}

Cmatrix& Cmatrix::operator +  ()
{
  return *this;
}

// other functions
istream&  operator >> (istream& s, Cmatrix& x)
{
  for (int i = 0; i < x.r; i++)
    for (int j = 0; j < x.c; j++)
      s >> x[i][j];
  return s;
}

ostream&  operator << (ostream& s, const Cmatrix& x)
{
  for (int i = 0; i < x.r; i++) {
    s << "      | ";
    for (int j = 0; j < x.c; j++)
      s << x[i][j] << "\t";
    s << "|" << endl;
  }
  return s;
}


// adding und subtracting
Cmatrix operator  + (const Cmatrix& x, const Cmatrix& y)
{
  Cmatrix result(x);
  result += y;
  return result;
}

Cmatrix operator  - (const Cmatrix& x, const Cmatrix& y)
{
  Cmatrix result(x);
  result -= y;
  return result;
}

// adding/subtracing  a scalar z to a matrix A is A + zI (I = identity)
Cmatrix operator  + (const complex& x, const Cmatrix& y)
{
  Cmatrix result(y);
  result += x;
  return result;
}

Cmatrix operator  + (const Cmatrix& y, const complex& x)
{
  Cmatrix result(y);
  result += x;
  return result;
}

Cmatrix operator  - (const complex& x, const Cmatrix& y)
{
  Cmatrix result(y);
  result -= x;
  return result;
}

Cmatrix operator  - (const Cmatrix& y, const complex& x)
{
  Cmatrix result(y);
  result -= x;
  return result;
}


// adding/subtracting a vector to a matrix is adding/subtracting the 
// vector to the diagonal of the matrix
Cmatrix operator  + (const Cvector& x, const Cmatrix& y)
{
  Cmatrix result(y);
  result += x;
  return result;
}

Cmatrix operator  + (const Cmatrix& y, const Cvector& x)
{
  Cmatrix result(y.r, y.c);
  result += x;
  return result;
}

Cmatrix operator  - (const Cvector& x, const Cmatrix& y)
{
  Cmatrix result(y);
  result -= x;
  return result;
}

Cmatrix operator  - (const Cmatrix& y, const Cvector& x)
{
  Cmatrix result(y);
  result -= x;
  return result;
}


// multiplication and division by scalar values
Cmatrix operator  * (const complex& z, const Cmatrix& x)
{
  Cmatrix result(x);
  result *= z;
  return result;
}

Cmatrix operator  * (const Cmatrix& x, const complex& z)
{
  return z * x;
}

Cmatrix operator  / (const Cmatrix& x, const complex& z)
{
  Cmatrix result(x);
  result /= z;
  return result;
}


// multiplication by a vector (note the non-commutativity)
Cvector operator  * (const Cvector& v, const Cmatrix& m)
{
  Cvector result(m.c);
  for (int i = 0; i < m.c; i++)
    for (int j = 0; j < m.r; j++)
      result[i] += v[j]*m[j][i];
  return result;
}

Cvector operator  * (const Cmatrix& m, const Cvector& v)
{
  Cvector result(m.r);
  for (int i = 0; i < m.r; i++)
    for (int j = 0; j < m.c; j++)
      result[i] += m[i][j]*v[j];
  return result;
}

// Multiplication of 2 matrices
Cmatrix operator  * (const Cmatrix& x, const Cmatrix& y)
{
  Cmatrix result(x.r, y.c);
  for (int i = 0; i < result.r; i++)
    for (int j = 0; j < result.c; j++)
      for (int k = 0; k < x.c; k++)
	result[i][j] += x[i][k]*y[k][j];
  return result;
}

// comparisons
int     operator == (const Cmatrix& x, const Cmatrix& y)
{
  if (x.c != y.c || x.r != y.r) return 0;
  for (int i = 0; i < x.r; i++) 
    if (x[i] != y[i]) return 0;
  return 1;
}

int     operator != (const Cmatrix& x, const Cmatrix& y)
{
  return !(x == y);
}


// the "infinity" norm
double norm(const Cmatrix& m)
{
  double result = 0;
  double value;

  for (int i = 0; i < m.r; i++)
    if ((value = norm(m[i])) > result)
      result = value;
  return result;
}

complex max(const Cmatrix& m)
{
  complex result;
  complex max_value;
  double abs_value;
  double high_value = 0;

  for (int i = 0; i < m.r; i++)
    if ((abs_value = abs(max_value = max(m[i]))) > high_value) {
      high_value = abs_value;
      result = max_value;
    }

  return result;
}

// Transposition
Cmatrix transpose(const Cmatrix& m)
{
  Cmatrix result(m.c, m.r);
  for (int i = 0; i < m.r; i++)
    for (int j = 0; j < m.c; j++)
      result[i][j] = conj(m[j][i]);
  return result;
}

