/* SVD.c
 *
 * Copyright (C) 1994-2003 David Weenink
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or (at
 * your option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 djmw 20010719
 djmw 20020408 GPL + cosmetic changes
 djmw 20020415 +SVD_synthesize
 djmw 20030205 Latest modification
*/

#include "SVD.h"
#include "NUMlapack.h"
#include "NUMmachar.h"
#include "Collection.h"
#include "NUMclapack.h"

#include "TableOfReal.h"
#include "praat.h"

#include "oo_DESTROY.h"
#include "SVD_def.h"
#include "oo_COPY.h"
#include "SVD_def.h"
#include "oo_EQUAL.h"
#include "SVD_def.h"
#include "oo_WRITE_ASCII.h"
#include "SVD_def.h"
#include "oo_WRITE_BINARY.h"
#include "SVD_def.h"
#include "oo_READ_ASCII.h"
#include "SVD_def.h"
#include "oo_READ_BINARY.h"
#include "SVD_def.h"
#include "oo_DESCRIPTION.h"
#include "SVD_def.h"


#define MAX(m,n) ((m) > (n) ? (m) : (n))
#define MIN(m,n) ((m) < (n) ? (m) : (n))

extern machar_Table NUMfpp;
extern int praat_USE_LAPACK;
static void NUMtranspose_d (double **m, long n);

static void classSVD_info (I)
{
	iam (SVD);
	Melder_information ("Number of rows = %ld\n"
		"Number of columns = %ld\n",
		my numberOfRows, my numberOfColumns);
}

class_methods (SVD, Data)
	class_method_local (SVD, destroy)
	class_method_local (SVD, equal)
	class_method_local (SVD, copy)
	class_method_local (SVD, readAscii)
	class_method_local (SVD, readBinary)
	class_method_local (SVD, writeAscii)
	class_method_local (SVD, writeBinary)
	class_method_local (SVD, description)
	class_method_local (SVD, info)
class_methods_end

int SVD_init (I, long numberOfRows, long numberOfColumns)
{
	iam (SVD);
	my numberOfRows = numberOfRows;
	my numberOfColumns = numberOfColumns;
	if (! NUMfpp) NUMmachar ();
	my tolerance = NUMfpp -> eps * MAX (numberOfRows, numberOfColumns);
	if (((my u = NUMdmatrix (1, numberOfRows, 1, numberOfColumns)) == NULL) ||
		((my v = NUMdmatrix (1, numberOfColumns, 1, numberOfColumns)) == NULL) ||
		((my d = NUMdvector (1, numberOfColumns)) == NULL)) return 0;
	return 1;
}

SVD SVD_create (long numberOfRows, long numberOfColumns)
{
	SVD me = new (SVD);
	if (! me) return NULL;
	if (! SVD_init (me, numberOfRows, numberOfColumns)) forget (me);
	return me;
}

SVD SVD_create_d (double **m, long numberOfRows, long numberOfColumns)
{
	SVD me = SVD_create (numberOfRows, numberOfColumns);
	
	if ((me == NULL) || ! SVD_svd_d (me, m)) forget (me);
	return me;
}

SVD SVD_create_f (float **m, long numberOfRows, long numberOfColumns)
{
	SVD me = SVD_create (numberOfRows, numberOfColumns);
	 
	if ((me == NULL) || ! SVD_svd_f (me, m)) forget (me);
	return me;
}

int SVD_svd_d (I, double **m)
{
	iam (SVD);
	NUMdmatrix_copyElements (m, my u, 1, my numberOfRows, 1, 
		my numberOfColumns);
	return SVD_compute (me);
}

int SVD_svd_f (I, float **m)
{
	iam (SVD);
	long i, j;
	for (i = 1; i <= my numberOfRows; i++)
	{
		for (j = 1; j <= my numberOfColumns; j++) my u[i][j] = m[i][j];
	}
	return SVD_compute (me);
}

void SVD_setTolerance (I, double tolerance)
{
	iam (SVD);
	my tolerance = tolerance;
}

double SVD_getTolerance (I)
{
	iam (SVD);
	return my tolerance;
}

static void NUMtranspose_d (double **m, long n)
{
	long i, j;
	for (i = 1; i <= n - 1; i++)
	{
		for (j = i + 1; j <= n; j++)
		{
			double t = m[i][j];
			m[i][j] = m[j][i];
			m[j][i] = t;
		}
	}
}


int SVD_compute (I)
{
	iam (SVD);
	char jobu = 'S', jobvt = 'O';
	long m = my numberOfColumns, n = my numberOfRows;
	long lda = m, ldu = m, ldvt = m, lwork = -1, info;
	double *work = NULL, wt[2];

	if (praat_USE_LAPACK)
	{
		/*
			Compute svd(A) = U D Vt.
			The svd routine from CLAPACK uses (fortran) column major storage.
			To solve the problem above we have to transpose the matrix A.
			However, instead of transposing the data matrix, we can also solve
			the transposed problem svd(At) = V D Ut.
			In the latter case we only have to revert row/column dimensions on 
			input and transpose the V matrix.
			The sv's are sorted.
		*/
		
		(void) NUMlapack_dgesvd (&jobu, &jobvt, &m, &n, &my u[1][1], &lda, &my d[1], &my v[1][1], &ldu,
			NULL, &ldvt, wt, &lwork, &info);
		if (info != 0) return 0;
		lwork = wt[0];
		work = NUMdvector (1, lwork);
		if (work == NULL) return 0;
		(void) NUMlapack_dgesvd(&jobu, &jobvt, &m, &n, &my u[1][1], &lda, &my d[1], &my v[1][1], &ldu,
			NULL, &ldvt, &work[1], &lwork, &info);
		NUMtranspose_d (my v, MIN(m, n));		
		NUMdvector_free (work, 1);
		if (info != 0) return 0;
	}
	else
	{
    	return NUMsvdcmp (my u, my numberOfRows, my numberOfColumns, my d, my v) &&
			SVD_sort (me);
	}
	return 1;
}

int SVD_solve (I, double b[], double x[])
{
	iam (SVD);
	if (! SVD_compute (me)) return 0;
	(void) SVD_zeroSmallSingularValues (me, 0);
	return NUMsvbksb (my u, my d, my v, my numberOfRows, my numberOfColumns, b, x);
}

int SVD_sort (I)
{
	iam (SVD); SVD thee = NULL; 
	long i, j, *index = NULL;
	
	if (((thee = Data_copy (me)) == NULL) ||
		((index = NUMlvector (1, my numberOfColumns)) == NULL)) goto end;
	
	NUMindexx_d (my d, my numberOfColumns, index);
			
	for (j = 1; j <= my numberOfColumns; j++)
	{
		long from = index[my numberOfColumns - j + 1];
		my d[j] = thy d[from];
		for (i = 1; i <= my numberOfRows; i++) my u[i][j] = thy u[i][from];
		for (i = 1; i <= my numberOfColumns; i++) my v[i][j] = thy v[i][from];
	}
end:
	forget (thee);
	NUMlvector_free (index, 1);
	return ! Melder_hasError ();
}

long SVD_zeroSmallSingularValues (I, double tolerance)
{
	iam (SVD);
	long i, numberOfZeroed = 0; 
	double dmax = my d[1];
	
	if (tolerance == 0) tolerance = my tolerance;
	for (i = 2; i <= my numberOfColumns; i++)
	{
		if (my d[i] > dmax) dmax = my d[i];
	}
	for (i = 1; i <= my numberOfColumns; i++)
	{
		if (my d[i] < dmax * tolerance)
		{
			my d[i] = 0; numberOfZeroed++;
		}
	}
	return numberOfZeroed;
}


long SVD_getRank (I)
{
	iam (SVD);
	long i, rank = 0;
	for (i = 1; i <= my numberOfColumns; i++)
	{
		if (my d[i] > 0) rank++;
	}
	return rank;
}

int SVD_synthesize (I, long sv_from, long sv_to, double **m)
{
	iam (SVD);
	long i, j, k;
	
	if (sv_to == 0) sv_to = my numberOfColumns;
	
	if (sv_from >= sv_to || sv_from < 1 || sv_to > my numberOfColumns) return 
		Melder_error ("SVD_aproximate: indices must be in range [1, %d].",
			my numberOfColumns);
	
	for (i = 1; i <= my numberOfRows; i++)
	{
		for (j = 1; j <= my numberOfColumns; j++) m[i][j] = 0;
	}
	
	for (k = sv_from; k <= sv_to; k++)
	{
		for (i = 1; i <= my numberOfRows; i++)
		{
			for (j = 1; j <= my numberOfColumns; j++)
			{
				m[i][j] += my d[k] * my u[i][k] * my v[j][k];	
			}
		}
	}
	return 1;
}

#undef MAX
#undef MIN

/* End of file SVD.c */
