package edu.ucla.ccb.graphshifts.image;
import java.awt.Point;
import java.awt.image.BufferedImage;
import java.awt.image.ColorModel;
import java.awt.image.ComponentSampleModel;
import java.awt.image.DataBuffer;
import java.awt.image.DataBufferInt;
import java.awt.image.IndexColorModel;
import java.awt.image.Raster;
import java.awt.image.SampleModel;
import java.awt.image.SinglePixelPackedSampleModel;
import java.awt.image.WritableRaster;
import java.util.Hashtable;

import edu.ucla.ccb.graphshifts.data.Point3D;
import edu.ucla.ccb.graphshifts.Segmentation;


// i dont like this extending segmentation.  segmentation really should be an interface
/**
 * An integer image.
 * This integer image is also often use to store colormap'd information.  Color is handled
 *  in only one way because Java2D doesn't let indexed color models exist for ints
 *  1 -- the colormap is directly applied to the pixels of the image and the image is converted
 *       to a packed ARGB image.  Use "applyColormap" functions for this.
 */
public class ScalarImageI3   implements Image3PixelAccess, Cloneable, Segmentation
{
	protected int width,height,depth;
	protected int wh;
	/** the minimum value found in the data_ buffer */
	private int minVal_=Integer.MAX_VALUE;  
	/** the maximum value found in the data_ buffer */
	private int maxVal_=Integer.MIN_VALUE;
	
	protected DataBufferInt buffer;
	protected ComponentSampleModel sampleModelXY;
	protected ComponentSampleModel sampleModelXZ;
	protected ComponentSampleModel sampleModelZY;
	protected SampleModel sampleModelsXY[];
	protected SampleModel sampleModelsXZ[];
	protected SampleModel sampleModelsZY[];
	protected WritableRaster rastersXY[];
	protected WritableRaster rastersXZ[];
	protected WritableRaster rastersZY[];

	protected BufferedImage imageXY;
	protected BufferedImage imageXZ;
	protected BufferedImage imageZY;
	// we'll use these when the image is a scalar image
	protected WritableRaster imageXYRaster;
	protected WritableRaster imageXZRaster;
	protected WritableRaster imageZYRaster;	
	// these are used after the image becomes a packed rgb image...
	protected DataBufferInt imageXYbuffer;
	protected DataBufferInt imageXZbuffer;
	protected DataBufferInt imageZYbuffer;
	
		
	protected IndexColorModel colorModel=null;
	
	protected int lastXYSlice=-1;
	protected int lastXZSlice=-1;
	protected int lastZYSlice=-1;
	
	
	private boolean imageIsStaticBool = false;
	private boolean imageIsARGBPacked = false;
	private int sliceOrientation = Image3Protocol.SLICE_XY;
	protected boolean hasColormap = false;
	protected Colormap theColormap = null;
	
	protected boolean scaleBySlice=true;
	
	
	public ScalarImageI3(int w, int h, int d, int[] data)
	{
		width = w;
		height = h;
		depth = d;
		wh = width*height;

		buffer = new DataBufferInt(data,w*h*d);
		
		initializeBuffers(w, h, d);
	}

	public ScalarImageI3(int w, int h, int d)
	{
		width = w;
		height = h;
		depth = d;
		wh = width*height;

		//	create the databuffer itself with only one "bank" of data
		buffer = new DataBufferInt(w*h*d);
		
		initializeBuffers(w, h, d);
	}

	private void initializeBuffers(int w, int h, int d)
	{
		// now we have to create the component sample models for each of the three
		//  plane directions
		// each slice is considered a band by the component sample model
		int bandOffsets[] = new int[d];
		for (int i=0;i<d;i++)
			bandOffsets[i] = i*wh;
		sampleModelXY = new ComponentSampleModel(DataBuffer.TYPE_INT,w,h,1,w,bandOffsets);
		bandOffsets = new int[h];
		for (int i=0;i<h;i++)
			bandOffsets[i] = i*w;
		sampleModelXZ = new ComponentSampleModel(DataBuffer.TYPE_INT,w,d,1,wh,bandOffsets);
		bandOffsets = new int[w];
		for (int i=0;i<w;i++)
			bandOffsets[i] = i;
		sampleModelZY = new ComponentSampleModel(DataBuffer.TYPE_INT,d,h,wh,w,bandOffsets);
		// note that the bandOffsets array must be realloc'd each plane bc it's length is used to 
		//  determine the number of bands for that plane....
		
		// now, create a ComponentSampleModel for each slice in each plane...needed for raster creation
		int band[] = new int[1];
		sampleModelsXY = new ComponentSampleModel[d];
		for (int i=0;i<d;i++)
		{
			band[0] = i;
			sampleModelsXY[i] = sampleModelXY.createSubsetSampleModel(band);
		}
		sampleModelsXZ = new ComponentSampleModel[h];
		for (int i=0;i<h;i++)
		{
			band[0] = i;
			sampleModelsXZ[i] = sampleModelXZ.createSubsetSampleModel(band);
		}	
		sampleModelsZY = new ComponentSampleModel[w];
		for (int i=0;i<w;i++)
		{
			band[0] = i;
			sampleModelsZY[i] = sampleModelZY.createSubsetSampleModel(band);
		}	
		
		// now, we need to create a WritableRaster for each of the planes and each slice
		rastersXY = new WritableRaster[d];
		for (int i=0;i<d;i++)
			rastersXY[i] = Raster.createWritableRaster(sampleModelsXY[i],buffer,null);
		rastersXZ = new WritableRaster[h];
		for (int i=0;i<h;i++)
			rastersXZ[i] = Raster.createWritableRaster(sampleModelsXZ[i],buffer,null);
		rastersZY = new WritableRaster[w];
		for (int i=0;i<w;i++)
			rastersZY[i] = Raster.createWritableRaster(sampleModelsZY[i],buffer,null);
	
		// a single BufferedImage object will be used and returned during getSlice operations
		//  to remove the need for allocating memory all of the time. 
		imageXY = new BufferedImage(w,h,BufferedImage.TYPE_BYTE_GRAY);
		imageXZ = new BufferedImage(w,d,BufferedImage.TYPE_BYTE_GRAY);
		imageZY = new BufferedImage(d,h,BufferedImage.TYPE_BYTE_GRAY);
		imageXYRaster = imageXY.getRaster();
		imageXZRaster = imageXZ.getRaster();
		imageZYRaster = imageZY.getRaster();
	}
	
	
	/** this method is called when the image goes from being a simple little grayscale image
	 * to a packed ARGB image
	 */
	private void initializePackedARGBBuffers()
	{
		// a single BufferedImage object will be used and returned during getSlice operations
		//  to remove the need for allocating memory all of the time. 
		// However, I need to create some other wrapper objects as well so that I can
		//  very easily change individual pixel elements of the bufferedimage without
		//  going through changing each color component individually.
		int bandLUT[] = new int[] {0x00FF0000,0x0000FF00,0x000000FF,0xFF000000};
		ColorModel cm = ColorModel.getRGBdefault();
		Hashtable ht = new Hashtable();  //no props
		SinglePixelPackedSampleModel sppsm;
		WritableRaster r;
		
		imageXYbuffer = new DataBufferInt(width*height);
		sppsm = new SinglePixelPackedSampleModel(DataBuffer.TYPE_INT, width,height,bandLUT);
		r = Raster.createWritableRaster(sppsm,imageXYbuffer,new Point(0,0));
		imageXY = new BufferedImage(cm,r,false,ht);
		
		imageXZbuffer = new DataBufferInt(width*depth);
		sppsm = new SinglePixelPackedSampleModel(DataBuffer.TYPE_INT, width,depth,bandLUT);
		r = Raster.createWritableRaster(sppsm,imageXZbuffer,new Point(0,0));
		imageXZ = new BufferedImage(cm,r,false,ht);
		
		imageZYbuffer = new DataBufferInt(depth*height);
		sppsm = new SinglePixelPackedSampleModel(DataBuffer.TYPE_INT, depth,height, bandLUT);
		r = Raster.createWritableRaster(sppsm,imageZYbuffer,new Point(0,0));
		imageZY = new BufferedImage(cm,r,false,ht);
		
		imageXYRaster = null;
		imageXZRaster = null;
		imageZYRaster = null;
	}
	
	
	

	/** 
	 * Take the image in memory and convert it to a packed INT_ARGB
	 * using the colormap.
	 */
	public void applyColormap(Colormap colormap)
	{
		imageIsARGBPacked = true;
		initializePackedARGBBuffers();
		theColormap = colormap;
		hasColormap = true;
		// take the alpha from the first entry in the colormap
		int gray = ((colormap.getColorAsIntARGB(0)>>>24)<<24) | 0x00CCCCCC;

		for(int z=0;z<depth;z++)
			for(int y=0;y<height;y++)
				for(int x=0;x<width;x++)
				{
					int idx = getPixelValueAt(x,y,z);
					if (idx < 0)
						setPixelAt(x,y,z,gray);
					else
						setPixelAt(x,y,z,colormap.getColorAsIntARGB(idx));
				}
	}

	/** 
	 * Take the image in memory and convert it to a packed INT_ARGB
	 * using the colormap.
	 */
	public void applyColormapARGB(int colormap[])
	{ 
		applyColormap(Colormap.createFromMap(colormap));
	}

	public Object clone()
	{
		ScalarImageI3 I3 = new ScalarImageI3(width,height,depth);
		I3.minVal_ = minVal_;
		I3.maxVal_ = maxVal_;
		I3.imageIsARGBPacked = imageIsARGBPacked;
		if (imageIsARGBPacked)
			I3.initializePackedARGBBuffers();
		I3.imageIsStaticBool = imageIsStaticBool;
		I3.sliceOrientation = sliceOrientation;

		for(int z=0;z<depth;z++)
			for(int y=0;y<height;y++)
				for(int x=0;x<width;x++)
					I3.setPixelAt(x,y,z,getPixelValueAt(x,y,z));

		return I3;
	}

	/** Force a computation of the min and max values in the buffer */
	public void computeMinMax()
	{
		minVal_ = Integer.MAX_VALUE;
		maxVal_ = Integer.MIN_VALUE;

		for(int z=0;z<depth;z++)
			for(int y=0;y<height;y++)
				for(int x=0;x<width;x++)
				{
					int val = getPixelInt(x,y,z);
					if (val < minVal_)
						minVal_ = val;
					if (val > maxVal_)
						maxVal_ = val;
				}

	}
	
	
	// no bounds checking
	public void copyFrom(ScalarImageI3 I)
	{
		System.arraycopy(I.buffer.getData(),0, this.buffer.getData(), 0, this.buffer.getSize());
	}

	/**
	 * @param x
	 * @param y
	 * @param z
	 */
	public void debugCheckSample(int x, int y, int z)
	{
		System.out.println("checking the pixel in raster XY "+ rastersXY[z].getSample(x,y,0));
		System.out.println("checking the pixel in raster XZ "+ rastersXZ[y].getSample(x,z,0));
		System.out.println("checking the pixel in raster YZ "+ rastersZY[x].getSample(z,y,0));
		
		// third component in these calls is the "band" which depends on the sampleModel in use
		System.out.println("checking pixel in csmXY + buffer " + sampleModelXY.getSample(x,y,z,buffer));
		System.out.println("checking pixel in csmXZ + buffer " + sampleModelXZ.getSample(x,z,y,buffer));
		System.out.println("checking pixel in csmYZ + buffer " + sampleModelZY.getSample(z,y,x,buffer));
	}
	
	final public short getBitsPerPixel()
	{
		return 32;
	}
	
	
	public DataBufferInt getBuffer()
	{
		return buffer;
	}

	public Colormap getColormap()
	{
		return theColormap;
	}

	public int getDepth()
	{
		return depth;
	}

	public int getHeight()
	{
		return height;
	}
	
	public float getMaxValue()
	{
		if (this.maxVal_ == Integer.MIN_VALUE)
			computeMinMax();
		return maxVal_;
	}

	public float getMinValue()
	{
		if (this.minVal_ == Integer.MAX_VALUE)
			computeMinMax();
		return minVal_;
	}

	// interface methods for Image3Protocol
	public int getNumberOfSlices()
	{
		if (sliceOrientation == Image3Protocol.SLICE_XY)
			return depth;
		else if (sliceOrientation == Image3Protocol.SLICE_ZY)
			return width;
		else 
			return height;
	}

	public Object getPixel(int x, int y, int z) {
		return (int)buffer.getElem(z*wh+y*width+x);
	}

	public byte getPixelByte(int x, int y, int z) {
		return (byte)buffer.getElem(z*wh+y*width+x);
	}
	
	
	public float getPixelFloat(int x, int y, int z) {
		return (float)buffer.getElem(z*wh+y*width+x);
	}
	
	public int getPixelInt(int x, int y, int z) {
		return (int)buffer.getElem(z*wh+y*width+x);
	}
	
	public short getPixelShort(int x, int y, int z) {
		return (short)buffer.getElem(z*wh+y*width+x);
	}
	

	public int getPixelValueAt(int x, int y, int z)
	{
		return buffer.getElem(z*wh+y*width+x);
	}

	public int getPixelValueAt(Point3D P)
	{
		return getPixelValueAt((int)P.getX(),(int)P.getY(),(int)P.getZ());
	}

	public BufferedImage getSlice(int i)
	{
		if (imageIsARGBPacked)
		{
			if (sliceOrientation == Image3Protocol.SLICE_XY)
			{
				// copy the raster data from slice i in XY plane to the buffered image
				lastXYSlice = i;

				/* Since the sample models do not match, I need to manually
				 * copy the data over instead of using the more simple
				 * imageXY.setData(rastersXY[i]);
				 */
				for (int x=0;x<width;x++)
					for (int y=0;y<height;y++)
						imageXYbuffer.setElem(y*width+x, rastersXY[i].getSample(x, y, 0));
				
				return imageXY;
			}
			else if (sliceOrientation == Image3Protocol.SLICE_XZ)
			{
				lastXZSlice = i;
				/*imageXZ.setData(rastersXZ[i]);*/
				for (int x=0;x<width;x++)
					for (int z=0;z<depth;z++)
						imageXZbuffer.setElem(z*width+x, rastersXZ[i].getSample(x, z, 0));
				return imageXZ;
			}
			else
			{
				lastZYSlice = i;
				/* imageZY.setData(rastersZY[i]); */
				for (int z=0;z<depth;z++)
					for (int y=0;y<height;y++)
						imageZYbuffer.setElem(y*depth+z, rastersZY[i].getSample(z, y, 0));
				return imageZY;
			}
		}
		else
		{
			if (sliceOrientation == Image3Protocol.SLICE_XY)
			{
				// copy the raster data from slice i in XY plane to the buffered image
				lastXYSlice = i;

				/* Since the sample models do not match, I need to manually
				 * copy the data over instead of using the more simple
				 * imageXY.setData(rastersXY[i]);
				 */
				if (!scaleBySlice)
				{
					for (int x=0;x<width;x++)
						for (int y=0;y<height;y++)
							imageXYRaster.setSample(x,y,0,(byte)( 
									255 * (rastersXY[i].getSample(x, y, 0) - minVal_) / maxVal_ ));
				}
				else
				{

					int smi=Integer.MAX_VALUE,sma=Integer.MIN_VALUE; 
					for (int y=0;y<height;y++)
						for (int x=0;x<width;x++)
						{
							int f = getPixelInt(x, y, i);
							sma = (f > sma) ? f : sma;
							smi = ((f < smi) && (f > 0)) ? f : smi; // throw away non-data zero pixels
						}
					int denom = sma-smi;
					float scale;
					
					if (denom == 0)
						scale = 0;
					else
						scale = 255.0f/(sma-smi);
					
					for (int y=0;y<height;y++)
						for (int x=0;x<width;x++)
						{
							float f= getPixelFloat(x,y,i);
							if (f == 0.0f)
								imageXYRaster.setSample(x,y,0,0);
							else
							{
								byte b = (byte)Math.floor((getPixelFloat(x,y,i)-smi)*scale);
								imageXYRaster.setSample(x,y,0,b);
							}
						}
				}
				return imageXY;
			}
			else if (sliceOrientation == Image3Protocol.SLICE_XZ)
			{
				lastXZSlice = i;
				/*imageXZ.setData(rastersXZ[i]);*/
				for (int x=0;x<width;x++)
					for (int z=0;z<depth;z++)
						imageXZRaster.setSample(x,z,0,(byte)( 
								  255 * (rastersXZ[i].getSample(x, z, 0) - minVal_) / maxVal_ ));
				return imageXZ;
			}
			else
			{
				lastZYSlice = i;
				/* imageZY.setData(rastersZY[i]); */
				for (int z=0;z<depth;z++)
					for (int y=0;y<height;y++)
						imageZYRaster.setSample(z,y,0,(byte)( 
								  255 * (rastersZY[i].getSample(z, y, 0) - minVal_) / maxVal_ ));
				return imageZY;
			}
		}
	}
	
	public byte[] getSliceClamp(int z)
	{
		byte[] buf = new byte[width*height];

		for (int y=0;y<height;y++)
			for (int x=0;x<width;x++)
				buf[y*width+x] = (byte)getPixelValueAt(x,y,z);

		return buf;
	}
	
	public int getSliceOrientation()
	{
		return sliceOrientation;
	}
	
	public byte[] getSliceToBuffer(int z)
	{
		byte[] buf = new byte[width*height];

		float scale = 255.0f/((float)maxVal_-(float)minVal_);
		for (int y=0;y<height;y++)
			for (int x=0;x<width;x++)
				buf[y*width+x] = (byte)(((float)getPixelValueAt(x,y,z)-(float)minVal_) * scale);

		return buf;
	}

	public byte[] getSliceToBufferSliceNorm(int z)
	{
		byte[] buf = new byte[width*height];

		int minVal = Integer.MAX_VALUE;
		int maxVal = Integer.MIN_VALUE;

		for(int y=0;y<height;y++)
			for(int x=0;x<width;x++)
			{
				int val = getPixelValueAt(x,y,z);
				if (val < minVal)
					minVal = val;
				if (val > maxVal)
					maxVal = val;
			}

		float scale = 255.0f/((float)maxVal-(float)minVal);

		for (int y=0;y<height;y++)
			for (int x=0;x<width;x++)
				buf[y*width+x] = (byte)(((float)getPixelValueAt(x,y,z)-(float)minVal) * scale);

		return buf;
	}

	public int getWidth()
	{
		return width;
	}

	public boolean hasColormap()
	{
		return hasColormap;
	}

	public final boolean isValidVoxel(int x, int y, int z)
	{
		if  (   (x < 0) || (x >= getWidth()) ||
				(y < 0) || (y >= getHeight()) ||
				(z < 0) || (z >= getDepth()) )
			return false;
		return true;		
	}

	public final boolean isValidVoxel(Point3D P)
	{
		return isValidVoxel((int)(P.getX()),(int)(P.getY()),(int)(P.getZ()));
	}

	/**
	 * This function should be called when no further modifications to
	 * the underlying image data are permitted.
	 */
	public void makeImageStatic()
	{
		imageIsStaticBool = true;
		computeMinMax();
	}

	/**
	 *  Use the one-dim vector of integer data to set the values of the
	 *  pixels in the image.
	 *  Assume the array D is same size as the image...
	 */
	public void setData(int D[])
	{
		if (!imageIsStaticBool)
		{
			int d=0;
			for (int z=0;z<depth;z++)
				for (int y=0;y<height;y++)
					for (int x=0;x<width;x++)
						buffer.setElem(z*wh+y*width+x,D[d++]);
		}
	}

	public void setPixelAt(int x, int y, int z, int v)
	{
		if (!imageIsStaticBool)
			buffer.setElem(z*wh+y*width+x,v);
	}

	public void setRange(int minv, int maxv)
	{
		this.maxVal_ = maxv;
		this.minVal_ = minv;
	}

	public void setSliceOrientation(int O)
	{
		sliceOrientation = O;
	}
}
