import Jama.Matrix;
public class LeastSquares2DGaussianFit
{
	// This class fits the 2d data to a two-dimensional gaussian
	// using a linearized least-squares method.
	//
	// The method is valid for small values of r (~.7).
	// For larger values of the correlation coefficient   
	// use an exact nonlinear fitter.
	//
	// The npoints data z are measured at the (x,y) positions.
	//
	// The parameters of the fit are:
	//
	// amplitude 
	// x0 
	// y0
	// sx 
	// sy 
	// rxy
	//


	public static double[] fit(double[] x, double[] y, double[] z, int npoints)
	{
		double[] args = new double[6];
		// set up the measurement vector
		Matrix b = new Matrix(npoints,1);
		//
		// fill the vector b with values z_i*ln(z_i)
		//
		for(int i = 0; i<npoints; ++i)
		{

			b.set(i,0, z[i]*Math.log(z[i])); 

		}
		//
		// create and fill the nx6 matrix of coefficients
		//
		double[][] B = new double[npoints][6];
		for(int i = 0; i<npoints; ++i)
		{
			if (z[i]>0.)
			{

				B[i][0] = z[i];
				B[i][1] = z[i]*x[i];
				B[i][2] = z[i]*y[i];
				B[i][3] = z[i]*x[i]*x[i];
				B[i][4] = z[i]*y[i]*y[i];
				B[i][5] = z[i]*x[i]*y[i];
			}

		}


		Matrix A = new Matrix(B); 

		//	  System.out.println("b= \n"+b);
		Matrix c = A.solve(b); 
		//	  System.out.println(c);
		double sx2 = -1/(2.*c.get(3,0));
		double sy2 = -1/(2.*c.get(4,0));
		double sx = Math.sqrt(sx2);
		double sy = Math.sqrt(sy2);
		// not quite true
		double r = -c.get(5,0)*sx*sy;
		double rnew = 0.;

		int nloop = 0;
		for (; ;)
		{
			// Not quite true, so iterate here...
			sx2 = -1/(2*(1-r*r)*c.get(3,0));
			sy2 = -1/(2*(1-r*r)*c.get(4,0));
			sx = Math.sqrt(sx2);
			sy = Math.sqrt(sy2);
			rnew = c.get(5,0)*sx*sy*(1-r*r);
			double delta = r-rnew;
			if(Math.abs(delta) < 1e-8) break;
			r = rnew;
			nloop++;
			if(nloop>100) break;
		}


		r = rnew;
		double x0 = sx*sx*(c.get(1,0) + c.get(2,0)*r*sy/sx );
		double y0 = sy*sy*(c.get(2,0) + c.get(1,0)*r*sx/sy);

		double arg = x0*x0/(sx*sx) + y0*y0/(sy*sy)-2*r*(x0*y0)/(sx*sy);
		arg*=1/(2*(1-r*r));
		arg = c.get(0,0)+arg; 
		double amp = Math.exp( arg );
		//		System.out.println("x0= "+x0 );
		//		System.out.println("y0= "+y0 );
		//		System.out.println("sx= "+sx );
		//		System.out.println("sy= "+sy );
		//		System.out.println("amp= "+amp);
		//		System.out.println("r= "+r);


		Matrix Residual = A.times(c).minus(b); 
		double rnorm = Residual.normInf(); 
		//		  System.out.println("rnorm= "+rnorm);
		args[0] = amp;
		args[1] = x0;
		args[2] = y0;
		args[3] = sx;
		args[4] = sy;
		args[5] = r;

		return args;
	}
}

