/**
 * 
 */
package edu.ucla.ccb.graphshifts.image;


import java.awt.image.BufferedImage;
import java.awt.image.ComponentSampleModel;
import java.awt.image.DataBuffer;
import java.awt.image.DataBufferByte;
import java.awt.image.Raster;
import java.awt.image.SampleModel;
import java.awt.image.WritableRaster;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;

import edu.ucla.ccb.graphshifts.data.Point3D;




/**
 * A hopefully efficient version of the ScalarImage for 3D byte images.
 * Here, we use a full set of BufferedImages and directly modify their contents.
 * 
 * @author jcorso
 *
 */
public class ScalarImageB3 implements Image3Protocol, Image3PixelAccess
{
	public static ScalarImageB3 readFromFilePath(String path)
	{
		ScalarImageF3 image = ImageUtilities.readViaImageIO(path);
		ScalarImageB3 label = new ScalarImageB3(image.getWidth(),image.getHeight(),image.getDepth());
		for (int z=0;z<image.getDepth();z++)
			for (int y=0;y<image.getHeight();y++)
				for (int x=0;x<image.getWidth();x++)
					label.setPixelAt(x,y,z,(int)(image.getPixelValueAt(x,y,z)));
		return label;
	}
	protected int width;
	protected int height;
	protected int depth;
	protected int wh;  
	/* store as ints -- NOTE that there is a shift of -128 to deal with the
	 *                        
	 */
	protected int minVal=0,maxVal=255;
	
	/* the pixel values in this image have been scaled by the following
	 * linear function  
	 * currentValue = (originalValue - scaledMinVal) * 255 / (scaledMaxVal - scaledMinVal)
	 * Initial settings of 0,255 give a null scaling as expected.
	 */
	protected float scaledMinVal=0.0f,scaledMaxVal=255.0f;
	protected DataBufferByte buffer;
	protected ComponentSampleModel sampleModelXY;
	protected ComponentSampleModel sampleModelXZ;
	protected ComponentSampleModel sampleModelZY;
	protected SampleModel sampleModelsXY[];
	protected SampleModel sampleModelsXZ[];
	protected SampleModel sampleModelsZY[];
	protected WritableRaster rastersXY[];
	protected WritableRaster rastersXZ[];

	protected WritableRaster rastersZY[];
	protected BufferedImage imageXY;
	protected BufferedImage imageXZ;
	protected BufferedImage imageZY;
	// store the rasters for each image too because we will need to update them on the fly
	protected WritableRaster imageXYRaster;
	protected WritableRaster imageXZRaster;
	protected WritableRaster imageZYRaster;
	
	protected ImageLabelSet labels=null;
	protected int lastXYSlice=-1;
	protected int lastXZSlice=-1;
	protected int lastZYSlice=-1;
	
	private int sliceOrientation = Image3Protocol.SLICE_XY;
	protected boolean scaleBySlice=true;

	
	/**
	 * The labels member is not created during construction.
	 * @param w
	 * @param h
	 * @param d
	 */
	public ScalarImageB3(int w, int h, int d)
	{
		width = w;
		height = h;
		depth = d;
		wh = width*height;
	
		// create the databuffer itself with only one "bank" of data
		buffer = new DataBufferByte(w*h*d);
		
		initializeBuffers(w, h, d);
	}
	
	public ScalarImageB3(int w, int h, int d, byte[] buffer)
	{
		width = w;
		height = h;
		depth = d;
		wh = width*height;
	
		// create the databuffer itself with only one "bank" of data
		this.buffer = new DataBufferByte(buffer,w*h*d);
		
		initializeBuffers(w, h, d);
	}

	public void computeMinMax()
	{
		minVal = Integer.MAX_VALUE;
		maxVal = Integer.MIN_VALUE;

		for(int z=0;z<depth;z++)
			for(int y=0;y<height;y++)
				for(int x=0;x<width;x++)
				{
					int val = getPixelInt(x,y,z);
					if (val < minVal)
						minVal = val;
					if (val > maxVal)
						maxVal = val;
				}
	}
	
	/**
	 * @param x
	 * @param y
	 * @param z
	 */
	public void debugCheckSample(int x, int y, int z)
	{
		System.out.println("checking the pixel in raster XY "+ rastersXY[z].getSample(x,y,0));
		System.out.println("checking the pixel in raster XZ "+ rastersXZ[y].getSample(x,z,0));
		System.out.println("checking the pixel in raster YZ "+ rastersZY[x].getSample(z,y,0));
		
		// third component in these calls is the "band" which depends on the sampleModel in use
		System.out.println("checking pixel in csmXY + buffer " + sampleModelXY.getSample(x,y,z,buffer));
		System.out.println("checking pixel in csmXZ + buffer " + sampleModelXZ.getSample(x,z,y,buffer));
		System.out.println("checking pixel in csmYZ + buffer " + sampleModelZY.getSample(z,y,x,buffer));
	}
	
	
	
	final public short getBitsPerPixel()
	{
		return 8;
	}

	public DataBufferByte getBuffer()
	{
		return buffer;
	}
	
	public Colormap getColormap() {
		return null;
	}
	
	
	/* (non-Javadoc)
	 * @see jcorso.image.Image3Protocol#getDepth()
	 */
	public int getDepth()
	{
		return depth;
	}
	
	public int getHeight()
	{
		return height;
	}
	

	public float getMaxValue()
	{
		return (float)maxVal;
	}
	
	public float getMinValue()
	{
		return (float)minVal;
	}

	/* (non-Javadoc)
	 * @see jcorso.image.Image3Protocol#getNumberOfSlices()
	 */
	public int getNumberOfSlices()
	{
		if (sliceOrientation == Image3Protocol.SLICE_XY)
			return depth;
		else if (sliceOrientation == Image3Protocol.SLICE_ZY)
			return width;
		else 
			return height;
	}
	
    final public Object getPixel(int x, int y, int z) {
		return buffer.getElem(z*wh+y*width+x);
	}

	final public byte getPixelByte(int x, int y, int z) {
		return (byte)buffer.getElem(z*wh+y*width+x);
	}

	/** All pixel accesses that cast, must treat the byte as an unsigned byte */
	final public float getPixelFloat(int x, int y, int z) {
		return (float)(0xFF & buffer.getElem(z*wh+y*width+x));
	}

	final public int getPixelInt(int x, int y, int z) {
		return 0xFF & buffer.getElem(z*wh+y*width+x);
	}
	final public short getPixelShort(int x, int y, int z) {
		return (short)(0xFF & buffer.getElem(z*wh+y*width+x));
	}
	
	
	

	public int getPixelValueAt(int x, int y, int z)
	{
		return 0xFF & buffer.getElem(z*wh+y*width+x);
	}

	public int getPixelValueAt(Point3D P)
	{
		return getPixelValueAt((int)P.getX(),(int)P.getY(),(int)P.getZ());
	}
	
	
	/* (non-Javadoc)
	 * @see jcorso.image.Image3Protocol#getSlice(int)
	 */
	public BufferedImage getSlice(int i)
	{
		if (sliceOrientation == Image3Protocol.SLICE_XY)
		{
			// copy the raster data from slice i in XY plane to the buffered image
			lastXYSlice = i;
			if (!scaleBySlice)
			{
				imageXY.setData(rastersXY[i]);
			}
			else
			{
				int smi=256,sma=-1; 
				for (int y=0;y<height;y++)
					for (int x=0;x<width;x++)
					{
						int f = getPixelInt(x, y, i);
						sma = (f > sma) ? f : sma;
						smi = ((f < smi) && (f > 0)) ? f : smi; // throw away non-data zero pixels
					}
				
				float scale = 255.0f/(sma-smi);
				for (int y=0;y<height;y++)
					for (int x=0;x<width;x++)
					{
						byte b;
						b = (byte)Math.floor((getPixelFloat(x,y,i)-smi)*scale);
						imageXYRaster.setSample(x,y,0,b);
					}

			}
			return imageXY;
		}
		else if (sliceOrientation == Image3Protocol.SLICE_XZ)
		{
			lastXZSlice = i;
			imageXZ.setData(rastersXZ[i]);
			return imageXZ;
		}
		else
		{
			lastZYSlice = i;
			imageZY.setData(rastersZY[i]);
			return imageZY;
		}
	}
	
	/* (non-Javadoc)
	 * @see jcorso.image.Image3Protocol#getSliceOrientation()
	 */
	public int getSliceOrientation()
	{
		return sliceOrientation;
	}
	/* (non-Javadoc)
	 * @see jcorso.image.Image3Protocol#getWidth()
	 */
	public int getWidth()
	{
		return width;
	}



	public boolean hasColormap() {
		return false;
	}


	
	
	private void initializeBuffers(int w, int h, int d)
	{
		// now we have to create the component sample models for each of the three
		//  plane directions
		// each slice is considered a band by the component sample model
		int bandOffsets[] = new int[d];
		for (int i=0;i<d;i++)
			bandOffsets[i] = i*wh;
		sampleModelXY = new ComponentSampleModel(DataBuffer.TYPE_BYTE,w,h,1,w,bandOffsets);
		bandOffsets = new int[h];
		for (int i=0;i<h;i++)
			bandOffsets[i] = i*w;
		sampleModelXZ = new ComponentSampleModel(DataBuffer.TYPE_BYTE,w,d,1,wh,bandOffsets);
		bandOffsets = new int[w];
		for (int i=0;i<w;i++)
			bandOffsets[i] = i;
		sampleModelZY = new ComponentSampleModel(DataBuffer.TYPE_BYTE,d,h,wh,w,bandOffsets);
		// note that the bandOffsets array must be realloc'd each plane bc it's length is used to 
		//  determine the number of bands for that plane....
		
		// now, create a ComponentSampleModel for each slice in each plane...needed for raster creation
		int band[] = new int[1];
		sampleModelsXY = new ComponentSampleModel[d];
		for (int i=0;i<d;i++)
		{
			band[0] = i;
			sampleModelsXY[i] = sampleModelXY.createSubsetSampleModel(band);
		}
		sampleModelsXZ = new ComponentSampleModel[h];
		for (int i=0;i<h;i++)
		{
			band[0] = i;
			sampleModelsXZ[i] = sampleModelXZ.createSubsetSampleModel(band);
		}	
		sampleModelsZY = new ComponentSampleModel[w];
		for (int i=0;i<w;i++)
		{
			band[0] = i;
			sampleModelsZY[i] = sampleModelZY.createSubsetSampleModel(band);
		}	
		
		// now, we need to create a WritableRaster for each of the planes and each slice
		rastersXY = new WritableRaster[d];
		for (int i=0;i<d;i++)
			rastersXY[i] = Raster.createWritableRaster(sampleModelsXY[i],buffer,null);
		rastersXZ = new WritableRaster[h];
		for (int i=0;i<h;i++)
			rastersXZ[i] = Raster.createWritableRaster(sampleModelsXZ[i],buffer,null);
		rastersZY = new WritableRaster[w];
		for (int i=0;i<w;i++)
			rastersZY[i] = Raster.createWritableRaster(sampleModelsZY[i],buffer,null);
	
		// a single BufferedImage object will be used and returned during getSlice operations
		//  to remove the need for allocating memory all of the time. 
		imageXY = new BufferedImage(w,h,BufferedImage.TYPE_BYTE_GRAY);
		imageXZ = new BufferedImage(w,d,BufferedImage.TYPE_BYTE_GRAY);
		imageZY = new BufferedImage(d,h,BufferedImage.TYPE_BYTE_GRAY);
		imageXYRaster = imageXY.getRaster();
		imageXZRaster = imageXZ.getRaster();
		imageZYRaster = imageZY.getRaster();
	}
	
	public final boolean isValidVoxel(int x, int y, int z)
	{
		if  (  (x < 0) || (x >= width) ||
				(y < 0) || (y >= height) ||
				(z < 0) || (z >= depth) )
			return false;
		return true;		
	}
	
	
	
	public final boolean isValidVoxel(Point3D P)
	{
		return isValidVoxel((int)(P.getX()),(int)(P.getY()),(int)(P.getZ()));
	}
	
	public void setMaxValue(int M)
	{
		maxVal = M;
	}

	public void setMinValue(int M)
	{
		minVal = M;
	}

	public void setPixelAt(int x, int y, int z, byte b)
	{
		buffer.setElem(z*wh+y*width+x,b);
		
		// check if we have to update the buffered images
		if (lastXYSlice == z)
		{
			imageXYRaster.setSample(x,y,0,b);
		}
		if (lastXZSlice == y)
		{
			imageXZRaster.setSample(x,z,0,b);
		}
		if (lastZYSlice == x)
		{
			imageZYRaster.setSample(z,y,0,b);
		}
		
	}
	

	public void setPixelAt(int x, int y, int z, int b)
	{
		setPixelAt(x,y,z,(byte)b);
	}

	
	
	public void setScaledIntensity(float a, float b)
	{
		scaledMinVal = a;
		scaledMaxVal = b;
	}


	/* (non-Javadoc)
	 * @see jcorso.image.Image3Protocol#setSliceOrientation(int)
	 */
	public void setSliceOrientation(int O)
	{
		sliceOrientation = O;
	}
	
	
	public void writeToGzipMatrix(String fn)
	{
		try
		{
			ObjectOutputStream s = new ObjectOutputStream(
								   new GZIPOutputStream(
							       new FileOutputStream(fn)
					));
			
			s.writeByte(3);
			s.writeShort(width);
			s.writeShort(height);
			s.writeShort(depth);
			s.writeObject(buffer.getData());
			s.close();
		}
		catch (IOException e)
		{
			throw new AssertionError(e);
		}
	}
	
	public static ScalarImageB3 readGzipMatrix(String fn)
	{
		int w,h,d;
		byte[] buf;
		try
		{
			ObjectInputStream s = new ObjectInputStream(
								  new GZIPInputStream(
							      new FileInputStream(fn)
					));
			
			int dim = s.readByte();
			if (dim != 3)
			{
				throw new AssertionError("Cannot read gzip matrices of dimension other than 3");
			}
			w = s.readShort();
			h = s.readShort();
			d = s.readShort();
			buf = (byte[])s.readObject();
			s.close();
		}
		catch (Exception e)
		{
			throw new AssertionError(e);
		}
		
		return new ScalarImageB3(w,h,d,buf);
	}
	
}
