package multinomad.tools.kmeans;

import java.util.Iterator;

import multinomad.config.Configuration;
import multinomad.config.LoadProperties;
import multinomad.config.Logger;
import multinomad.individuals.DoubleIndividual;
import multinomad.individuals.Population;
import multinomad.tools.CommonState;
import multinomad.tools.kmeansML.Centroids;

public class Kmeans implements Iterable<DoubleIndividual>{
	
	private Population centroids;
	private Population rawdata;
	private int k;
	private int iter = 0;
	
	public Kmeans(Population poprawdata, int numClusters) {
		centroids = new Population(false);
		if(poprawdata.size() < numClusters)
			k = poprawdata.size();
		else
			k = numClusters;
		rawdata = poprawdata;
		initialize();
	}
	
	public Kmeans(Population poprawdata, Population centroids) {
		this.centroids = centroids;
		rawdata = poprawdata;
	}
	
	private void initialize() {
		if(Configuration.kmeansinit.equals(Configuration.KMEANSRANDOM)) {
			randomInit();
		}else if (Configuration.kmeansinit.equals(Configuration.KMEANSPLUS)) {
			kmeansPlusInit();
		}
	}
	
	private void randomInit() {
		for (int i=0;i<k;i++) {
			
			boolean inserted = false;
			while(!inserted) {
				DoubleIndividual ind = rawdata.getP().get(CommonState.r.nextInt(rawdata.size()));
				if(!centroids.contains(ind)) {
					inserted = true;
					Centroid centroid = new Centroid(ind);
					centroids.add(centroid);
				}
			}
		}
	}
	
	//Implement the initialization of kmeans++
	private void kmeansPlusInit() {

		//First centroid is chosen randomly
		DoubleIndividual ind = rawdata.getP().get(CommonState.r.nextInt(rawdata.size()));
		Centroid centroid = new Centroid(ind);
		centroids.add(centroid);
		
		//For the rest of k-1 centroids
		for (int i=1; i<k; i++) {

			// ---- Constructing the distances vector
			double[] distances = new double[rawdata.size()];
			int itemIndex = 0;
			for(DoubleIndividual item : rawdata) {
				//If the item is in centroids we forget it
				//if (centroids.contains(item)) continue;
				
				//Compute distances from item
				double mindistance = Double.MAX_VALUE;
				for(int j=0;j<i;j++) {//So far we have i centroids
					double distance = centroids.getP().get(j).distance(item);
					if (distance < mindistance) {
						mindistance = distance;
					}
				}
				distances[itemIndex] = mindistance;
				itemIndex++;
			}
			
			// ---- Constructing the cumulative vector
			double[] cumulative = new double[rawdata.size()];
			double sum = 0;
			for(int j=0;j<distances.length;j++) {
				sum += distances[j];
				cumulative[j] = sum;
			}

			
			// ---- Pick one of the data items, using the distances
			// .... this is kinda roulette wheel selection
			double roulette = CommonState.r.nextDouble()*sum;
			for(int j=0;j<cumulative.length;j++) {
				if (cumulative[j] >= roulette) {
					ind = rawdata.getP().get(j);
					centroid = new Centroid(ind);
					centroids.add(centroid);
					break;
				}
			}
		}
	}
	
	/**
	 * 
	 * @return A populatin of centroids after convergence of kmeans
	 */
	public Population cluster() {
		
		boolean changed = true; // change in at least one cluster assignement
		boolean success = true; // all means computed? (no zero-count cluster)
		int maxCount = rawdata.size()*10; // sanity check
		int ct = 0;
		
		
		while(changed == true && success == true && ct < maxCount) { // tecnically this is Lloyd's algorithm
			changed = updateClustering(); // (re)assign tuples to clusters
			success = updateMeansLloyd(); // Compute new cluster means if possible
			ct++; // k-means typically converges very quickly
		}
			
		return centroids;
	}
	
	/**
	 * 
	 * @return A populatin of centroids after convergence of kmeans
	 */
	public Population clusterRecalculate() {
		
		
		boolean changed = true; // change in at least one cluster assignement
		boolean success = true; // all means computed? (no zero-count cluster)
		int maxCount = rawdata.size()*10; // sanity check
		int ct = 0;
		
//		for(DoubleIndividual c : centroids) {
//			Centroid cent = (Centroid)c;			
//			System.err.println("Init: "+cent.asString());
//			for(DoubleIndividual ind: cent) {
//				Logger.append("Init"+cent.asString(),""+ind.asString());
//			}
//		}
		
		while(changed == true && success == true && ct < maxCount) { // tecnically this is Lloyd's algorithm
//			int i=0;
//			for(DoubleIndividual c : centroids) {
//				Centroid cent = (Centroid)c;			
//				
//				Logger.append("F"+Configuration.indexf+"iter"+iter+"initcentroid"+i+"_"+CommonState.seed,""+cent.asString());
//				
//				for(DoubleIndividual ind: cent) {
//					Logger.append("F"+Configuration.indexf+"iter"+iter+"initcluster"+i+"_"+CommonState.seed,""+ind.asString());
//				}
//				i++;
//			}
			
			success = updateMeansLloyd(); // Compute new cluster means if possible
			changed = updateClustering(); // (re)assign tuples to clusters
			ct++; // k-means typically converges very quickly
			
//			i=0;
//			for(DoubleIndividual c : centroids) {
//				Centroid cent = (Centroid)c;			
//				Logger.append("F"+Configuration.indexf+"iter"+iter+"finalcentroid"+i+"_"+CommonState.seed,""+cent.asString());
//				
//				for(DoubleIndividual ind: cent) {
//					Logger.append("F"+Configuration.indexf+"iter"+iter+"finalcluster"+i+"_"+CommonState.seed,""+ind.asString());
//				}
//				i++;
//			}
//			iter++;
		}
		
//		for(DoubleIndividual c : centroids) {
//			Centroid cent = (Centroid)c;			
//			System.err.println("End: "+cent.asString());
//			for(DoubleIndividual ind: cent) {
//				Logger.append("End"+cent.asString(),""+ind.asString());
//			}
//		}
//
//		System.err.println("\n----------\n");

			
		return centroids;
	}
	
	private boolean updateClustering() {
		
		boolean changed = false;
		
		
		for(DoubleIndividual ind : rawdata) {
			double distance = Double.MAX_VALUE;
			int i = 0;
			int index = -1;
			int indexContains = -1;
			for(DoubleIndividual cent : centroids) {
				double distanceToCent =  cent.distance(ind);
				if (distanceToCent < distance) {
					index = i;
					distance = distanceToCent;
				}
				
				if(((Centroid)cent).isInCluster(ind)) {
					indexContains = i;
				}
				i++;
			}
			
			if (index != indexContains) { // That means that the individual has changed from cluster;
				if (indexContains != -1) {
					((Centroid)centroids.getP().get(indexContains)).removeFromCluster(ind);
				}
				((Centroid)centroids.getP().get(index)).addToCluster(ind);
				changed = true;
			}
		}
		
		return changed;
	}
	
	private boolean updateMeans() {
		if(Configuration.kmeansconvergence.equals(Configuration.KMEANSLLOYD)) {
			return updateMeansLloyd();
		}else if (Configuration.kmeansinit.equals(Configuration.KMEANSBIASED)) {
			return updateMeansBiased();
		}else
			return false;
	}
	
	private boolean updateMeansLloyd() {
		
		for(DoubleIndividual c : centroids) {
			Centroid cent = (Centroid)c;
			
			//System.err.print("Moving: "+cent.asString());
			if (cent.clusterSize() == 0) {
				//System.err.println("Error: Centroid wo cluster: "+cent.asString());
				//return false;
			}
			
			double[] newMean = initVectorToZero(cent.getChr().getLength());
			for(DoubleIndividual itemInCluster : cent) {
				newMean = sumV1toV2(newMean, itemInCluster.getChr().asdouble());
			}
			newMean = divideVectorByValue(newMean, cent.clusterSize());
			cent.updateMean(newMean);
			//System.err.println(" -> "+cent.asString());
		}
		
		centroids.cleanNaN();
		
		return true;
	}
	
	private double[] initVectorToZero(int length) {
		double[] vector = new double[length];
		for(int i=0;i<vector.length;i++)
			vector[i] = 0;
		return vector;
	}
	
	private double[] sumV1toV2(double[] v1,double[] v2) {
		double[] vector = initVectorToZero(v1.length);
		for(int i=0;i<v1.length;i++)
			vector[i] = v1[i] + v2[i];
		return vector;
	}
	
	private double[] divideVectorByValue(double[] v1,double value) {
		double[] vector = initVectorToZero(v1.length);
		for(int i=0;i<v1.length;i++)
			vector[i] = v1[i] / value;
		return vector;
	}
	
	private double[] multiplyVectorByValue(double[] v1,double value) {
		double[] vector = initVectorToZero(v1.length);
		for(int i=0;i<v1.length;i++)
			vector[i] = v1[i] * value;
		return vector;
	}
	
	private boolean updateMeansBiased() {
		for(DoubleIndividual c : centroids) {
			Centroid cent = (Centroid)c;
			if (cent.clusterSize() == 0) {
				return false;
			}
			
			double minfitness = 1000000;
			double maxfitness = -1000000;
			for(DoubleIndividual itemInCluster : cent) {
				double fitness = itemInCluster.getFitness();
				if(fitness > maxfitness) {
					maxfitness = fitness;
				}
				if(fitness < minfitness) {
					minfitness = fitness;
				}
			}
			
			double[] newMean = initVectorToZero(cent.getChr().getLength());
			double sumFitness = 0;
			for(DoubleIndividual itemInCluster : cent) {
				double normFitness = normalizedFitness(itemInCluster.getFitness(), minfitness, maxfitness);
				sumFitness += normFitness;
				double [] biasedVector = multiplyVectorByValue(itemInCluster.getChr().asdouble(), normFitness);
				newMean = sumV1toV2(newMean, biasedVector);
			}
			newMean = divideVectorByValue(newMean, sumFitness);
			cent.updateMean(newMean);
		}
		
		return true;
	}
	
	private double normalizedFitness (double fitness, double min, double max) {
		return (fitness-min)/(max-min);
	}
	

	
	public Iterator<DoubleIndividual> iterator() {
		return centroids.iterator();
	}
	
	public static void main(String[] args) {
		LoadProperties lp = new LoadProperties(args,null);
		Configuration.setConfiguration(lp);
		
		double[] a= {65,220};
		double[] b= {67,220};
		double[] c= {68,230};
		double[] d= {70,220};
		double[] e= {66,210};	
		double[] f= {59,110};
		double[] g= {61,120};
		double[] h= {62,130};
		double[] i= {61,115};
		
		Population pop = new Population(false);
		
		pop.add(new DoubleIndividual(a));
		pop.add(new DoubleIndividual(b));
		pop.add(new DoubleIndividual(c));
		pop.add(new DoubleIndividual(d));
		pop.add(new DoubleIndividual(e));
		pop.add(new DoubleIndividual(f));
		pop.add(new DoubleIndividual(g));
		pop.add(new DoubleIndividual(h));
		pop.add(new DoubleIndividual(i));
		
		Kmeans kmeans = new Kmeans(pop, 2);
		
		Population centroids = kmeans.cluster();
		
	}

}
