"""Module to handle conversion between in-core image formats."""
#
# Image-conversion helper classes and routines
#
import imgformat
import imgcolormap
import imageop
import imgop

error = 'imgconvert.error'
unsupported_error = 'imgconvert.unsupported_error'

HIGH_QUALITY=1
TRACE=0

def setquality(onoff):
    """Call with zero parameter to disable high-quality conversions"""
    
    global HIGH_QUALITY

    old_q = HIGH_QUALITY
    HIGH_QUALITY = onoff
    return old_q

def settrace(onoff):
    """Call with non-zero parameter to enable conversion trace printing"""
    
    global TRACE

    old_t = TRACE
    TRACE = onoff
    return old_t

_colormaps = {}

#
# getfmtcolormap - Create a colormap for an 8-bit format
def getfmtcolormap(fmt):
        """Return a colormap suitable for the 8-bit format argument"""
	
	if _colormaps.has_key(fmt):
		return _colormaps[fmt]
		
	try:
		size = fmt.descr['size']
	except AttributeError:
		raise error, 'Argument must be an image-format object'
	if size <> 8:
		raise error, 'Only 8-bit formats supported'
	comp = fmt.descr['comp']
	if len(comp) not in (1, 3):
		raise error, 'Only 1- or 3-component formats supported'
		
	map = imgcolormap.new('\0\0\0\0'*256)
	if len(comp) == 1:
		for i in range(1<<comp[0][1]):
			map[i] = (i, i, i)
	else:
		rmax = 1<<comp[0][1]
		gmax = 1<<comp[1][1]
		bmax = 1<<comp[2][1]
		for r in range(rmax):
			rv = r*255/(rmax-1)
			for g in range(gmax):
				gv = g*255/(gmax-1)
				for b in range(bmax):
					bv = b*255/(bmax-1)
					i = (r<<comp[0][0]) | (g<<comp[1][0]) | (b<<comp[2][0])
					map[i] = (rv, gv, bv)
	_colormaps[fmt] = map
	return map	

#
# _reverse - Convert top-to-bottom to bottom-to-top vv.
#
def _reverse(data, reader, srcfmt, dstfmt):
    width = reader.width
    height = len(data) / width
    if height*width <> len(data):
	raise error, 'Incorrect datasize'
    rv = ''
    pos = len(data)-width
    while pos >= 0:
	rv = rv + data[pos:pos+width]
	pos = pos - width
    return rv

def _reverse8(data, reader, srcfmt, dstfmt):
    width = reader.width
    width = (width + 3) & ~3
    height = len(data) / width
    if height*width <> len(data):
	raise error, 'Incorrect datasize'
    rv = ''
    pos = len(data)-width
    while pos >= 0:
	rv = rv + data[pos:pos+width]
	pos = pos - width
    return rv

def _xreverse8(data, reader, srcfmt, dstfmt):
    width = reader.width
    height = len(data) / width
    if height*width <> len(data):
	raise error, 'Incorrect datasize'
    rv = ''
    pos = len(data)-width
    while pos >= 0:
	rv = rv + data[pos:pos+width]
	pos = pos - width
    return rv

#
# _maptorgb - Convert colormap to rgb
#
def _maptorgb(data, reader, srcfmt, dstfmt):
    width = reader.width
    return reader.colormap.map(data, width, srcfmt, dstfmt)

#
# _grey2rgb - Convert greyscale to rgb
#
def _greytorgb(data, reader, srcfmt, dstfmt):
    greymap = getfmtcolormap(srcfmt)
    # We have to trick map() in believing us...
    if srcfmt == imgformat.grey:
	srcfmt = imgformat.colormap
    elif srcfmt == imgformat.xgrey:
	srcfmt = imgformat.xcolormap
    elif srcfmt == imgformat.grey_b2t:
	srcfmt = imgformat.grey_b2t
    return greymap.map(data, reader.width, srcfmt, dstfmt)

#
# _rgb2grey - Convert rgb to greyscale
#
def _rgbtoxgrey(data, reader, srcfmt, dstfmt):
    data = imageop.rgb2grey(data, reader.width, reader.height)
    return data

#
# _rgb2rgb8 - Convert rgb to xrgb8, simplistic method
#
#def _rgbtoxrgb8(data, reader, srcfmt, dstfmt):
#    return imageop.rgb2rgb8(data, reader.width, reader.height)

    
#
# _shuffle - Convert various RGB formats to each other
#
def _shuffle(data, reader, srcfmt, dstfmt):
    return imgop.shuffle(data, reader.width, reader.height, srcfmt, dstfmt)

#
# _dither - Convert 8-bit grey to 1-bit grey
def _dither(data, reader, srcfmt, dstfmt):
    return imgop.dither(data, reader.width, reader.height, srcfmt, dstfmt)
	
#
# _zapbits - Remove bits from an RGB color (for colormap clustering)
#
class Struct: pass

def _zapbits(data, reader, fmt):
    # Create a new format, initially identical to the old one
    newfmt = Struct()
    newfmt.name = 'Scaled-down RGB format for map-clustering'
    newfmt.descr = {}
    for k in fmt.descr.keys():
	newfmt.descr[k] = fmt.descr[k]
    # Now remove one bit from each color component
    components = newfmt.descr['comp']
    newcomp = ()
    for pos, len in components:
	if len <= 1:
	    raise error, 'Map-clustering failed'
	newcomp = newcomp + ((pos+1, len-1),)
    if TRACE:
	print '    Scaling RGB values to', newcomp
    newfmt.descr['comp'] = newcomp
    data = imgop.shuffle(data, reader.width, reader.height, fmt, newfmt)
    return data, newfmt

def _scalergbdata(data, reader, srcfmt):
    try:
	map = imgcolormap.fromimage(data, reader.width, reader.height, srcfmt)
	return data, map
    except imgcolormap.error, arg:
	if arg[:15] != 'Too many colors':
	    raise imgcolormap.error, arg
    # Do color-clustering
    newfmt = srcfmt
    while 1:
	data, newfmt = _zapbits(data, reader, newfmt)
	data = imgop.shuffle(data, reader.width, reader.height,
			     newfmt, srcfmt)
	try:
	    map = imgcolormap.fromimage(data, reader.width, reader.height,
					srcfmt)
	    return data, map
	except imgcolormap.error:
	    pass

#
# _maprgb - Convert RGB to xcolormap format
#
def _maprgb(data, reader, srcfmt, dstfmt):
    data, map = _scalergbdata(data, reader, srcfmt)
    reader.colormap = map
    data, newfmt = map.dither(data, reader.width, reader.height,
			      srcfmt, HIGH_QUALITY)
    if newfmt != imgformat.xcolormap:
	raise error, 'Internal error: map.dither returned format '+`newfmt`
    return data
    
def _rgbtoxrgb8(data, reader, srcfmt, dstfmt):
    map = getfmtcolormap(dstfmt)
    data, newfmt = map.dither(data, reader.width, reader.height,
			      srcfmt, HIGH_QUALITY)
    if newfmt != imgformat.xcolormap:
	raise error, 'Internal error: map.dither returned format '+`newfmt`
    return data
    
#
# _removestride - Remove the stride from an 8-bit image
#
def _removestride(data, reader, srcfmt, dstfmt):
    rv = ''
    dstwidth = reader.width
    srcwidth = ((dstwidth+3) & ~3)
    if srcwidth == dstwidth:
	return data
    for i in range(0, len(data), srcwidth):
	rv = rv + data[i:i+dstwidth]
    return rv

#
# _addstride - Add stride to an 8-bit image
#
def _addstride(data, reader, srcfmt, dstfmt):
    rv = ''
    srcwidth = reader.width
    dstwidth = ((srcwidth+3) & ~3)
    if srcwidth == dstwidth:
	return data
    extra = '\0' * (dstwidth-srcwidth)
    for i in range(0, len(data), srcwidth):
	rv = rv + data[i:i+srcwidth] + extra
    return rv

#
# 'lossiness' is a scalar value. Use 0 if nothing changes in the pixels,
# 1 if no information is lost but bits are (grey->rgb), 2 if information
# is lost (rgb->grey), 3 if the converter also produces a colormap.
#
_converters = [ \
    (imgformat.grey,	imgformat.grey_b2t,	_reverse8,	0),
    (imgformat.grey_b2t,imgformat.grey,		_reverse8,	0),
    (imgformat.xgrey,	imgformat.xgrey_b2t,	_xreverse8,	0),
    (imgformat.xgrey_b2t,imgformat.xgrey,	_xreverse8,	0),
    (imgformat.colormap,imgformat.colormap_b2t,	_reverse8,	0),
    (imgformat.colormap_b2t,imgformat.colormap,	_reverse8,	0),
    (imgformat.rgb,	imgformat.rgb_b2t,	_reverse,	0),
    (imgformat.rgb_b2t,	imgformat.rgb,		_reverse,	0),
    (imgformat.rgb8,	imgformat.rgb8_b2t,	_reverse8,	0),
    (imgformat.rgb8_b2t,imgformat.rgb8,		_reverse8,	0),
    (imgformat.xrgb8,	imgformat.xrgb8_b2t,	_xreverse8,	0),
    (imgformat.xrgb8_b2t,imgformat.xrgb8,	_xreverse8,	0),

    (imgformat.grey,	imgformat.rgb,		_greytorgb,	1),
    (imgformat.xgrey,	imgformat.rgb,		_greytorgb,	1),
    (imgformat.grey_b2t,imgformat.rgb_b2t,	_greytorgb,	1),

    (imgformat.colormap,imgformat.rgb,		_maptorgb,	1),
    (imgformat.xcolormap,imgformat.rgb,		_maptorgb,	1),
    (imgformat.colormap_b2t,imgformat.rgb_b2t,	_maptorgb,	1),

    (imgformat.rgb,	imgformat.xgrey,	_rgbtoxgrey,	2),
#    (imgformat.rgb,	imgformat.xrgb8,	_rgbtoxrgb8,	2),

    (imgformat.macrgb,	imgformat.rgb,		_shuffle,	0),
    (imgformat.macrgb16,imgformat.rgb,		_shuffle,	1),
    (imgformat.rgb8,	imgformat.rgb,		_shuffle,	1),
    (imgformat.xrgb8,	imgformat.rgb,		_shuffle,	1),
    (imgformat.rgb,	imgformat.macrgb,	_shuffle,	0),
    (imgformat.macrgb16,imgformat.macrgb,	_shuffle,	1),
    (imgformat.rgb8,	imgformat.macrgb,	_shuffle,	1),
    (imgformat.xrgb8,	imgformat.macrgb,	_shuffle,	1),
    (imgformat.rgb,	imgformat.macrgb16,	_shuffle,	2),
    (imgformat.macrgb,	imgformat.macrgb16,	_shuffle,	2),
    (imgformat.rgb8,	imgformat.macrgb16,	_shuffle,	1),
    (imgformat.xrgb8,	imgformat.macrgb16,	_shuffle,	1),

    (imgformat.rgb,	imgformat.xrgb8,	_rgbtoxrgb8,	2),

    (imgformat.xrgb8,	imgformat.rgb8,		_shuffle,	0),
    (imgformat.rgb8,	imgformat.xrgb8,	_shuffle,	0),

    (imgformat.pbmbitmap, imgformat.grey,	_shuffle,	1),
    (imgformat.grey,	imgformat.pbmbitmap,	_dither,	2),

    (imgformat.grey,	imgformat.xgrey,	_removestride,	0),
    (imgformat.grey_b2t,imgformat.xgrey_b2t,	_removestride,	0),
    (imgformat.rgb8_b2t,imgformat.xrgb8_b2t,	_removestride,	0),
    (imgformat.colormap,imgformat.xcolormap,	_removestride,	0),
    (imgformat.xgrey,	imgformat.grey,		_addstride,	0),
    (imgformat.xgrey_b2t,imgformat.grey_b2t,	_addstride,	0),
    (imgformat.xrgb8_b2t,imgformat.rgb8_b2t,	_addstride,	0),
    (imgformat.xcolormap,imgformat.colormap,	_addstride,	0),

    (imgformat.rgb,	imgformat.xcolormap,	_maprgb,	3),
]

#
# The converts we have built, indexed by sourceformat.
# Each entry is another dictionary (indexed by dstformat).
# The entries of these dictionaries are lists of [lossiness, len, [funcs]]
#
_generated = {}

#
# Add a converter from 'srcfmt' to 'dstfmt' to the list, possibly
# replacing an existing converter
#
def addconverter(srcfmt, dstfmt, func, lossy):
        """Tell imgconvert about a new converter.
	Args: source_format, dest_format, function, lossy.
	lossy is 0 (not lossy), 1 (wastes bits), 2 (loses bits) or
	3 (converts to colormap format).

	function is called as function(data, reader, srcfmt, dstfmt)
	"""
	
	for i in range(len(_converters)):
		isrcfmt, idstfmt, irtn, ilossy = _converters[i]
		if (srcfmt, dstfmt) == (isrcfmt, idstfmt):
			_converters[i] = (srcfmt, dstfmt, func, lossy)
			return
	_converters.append((srcfmt,dstfmt,func,lossy))

#
# Returns a list of conversion functions that will convert
# srcfmt to dstfmt if applied in that order.
#
def getconverter(srcfmt, dstfmt):
        """Return a converter from srcfmt to dstfmt.
	A converter is a list [lossy, length, list-of-tuples],
	where each tuple is (srcfmt, dstfmt, func, lossy).
	Calling each of the functions in order will convert your image.
	"""
	
        global _generated
	#
	# If formats are the same return the dummy converter
	#
	if srcfmt == dstfmt: return []
	#
	# Otherwise, if we have a converter, return that one
	#
	for this in _converters:
	        isrcfmt, idstfmt, irtn, ilossy = this
		if (srcfmt, dstfmt) == (isrcfmt, idstfmt):
			return [ilossy, 1, [this]]
	#
	# Finally, we try to create a converter
	#
	if not _generated.has_key(srcfmt):
	        # Not there yet. Try to create it.
		_generated[srcfmt] = _enumerate_converters(srcfmt)
		
	if not _generated[srcfmt].has_key(dstfmt):
		raise unsupported_error, (srcfmt, dstfmt)

	cf = _generated[srcfmt][dstfmt]
	return cf

def _enumerate_converters(srcfmt):
	cvs = {}
	formats = [srcfmt]
	steps = 0
	while 1:
		workdone = 0
		for this in _converters:
		        isrcfmt, idstfmt, irtn, ilossy = this
		        #
			# First see if the source format is of any use.
			#
			if isrcfmt == srcfmt:
			        #
				# This converter directly understands our
				# source format. Remember it.
				#
				template = [ilossy, 1, [this]]
			elif cvs.has_key(isrcfmt):
			        #
				# We have a path to this format, so
				# this converter can help us further.
				#
				template = cvs[isrcfmt][:]
				template[0] = max(template[0], ilossy)
				template[1] = template[1] + 1
				template[2] = template[2] + [this]
			else:
				continue
			#
			# Next, check whether we want this converter
			# (if it is the first one for this dstfmt, or
			# if it is better than what we have)
			#
			if not cvs.has_key(idstfmt):
				cvs[idstfmt] = template
				workdone = 1
			else:
				previous = cvs[idstfmt]
				if template < previous:
					cvs[idstfmt] = template
					workdone = 1
		if not workdone:
			break
		#
		# Finally, a check for loops.
		#
		steps = steps + 1
		if steps > len(_converters):
			print '------------------loop in emunerate_converters--------'
			print 'CONVERTERS:'
			print _converters
			print 'RESULTS:'
			print cvs
			raise error, 'Internal error - loop'
	return cvs

def stackreader(dstfmt, reader):
    """Create a reader-like object that reads image file data and
    converts it to the requested format.
    Args: format, original_reader
    """
    
    if dstfmt in reader.format_choices:
	reader.format = dstfmt
	return reader
    # Nope, not supported directly. Find all possible converters
    list = []
    for srcfmt in reader.format_choices:
	try:
	    rv = getconverter(srcfmt, dstfmt)
	except unsupported_error:
	    continue
	if rv:
	    [lossy, len, funclist] = rv
	    list.append(lossy, len, srcfmt, funclist)
    if not list:
	raise unsupported_error, (reader.format_choices, dstfmt)
    # Now, sort and use the best
    list.sort()
    lossy, len, srcfmt, funclist = list[0]
    if lossy == 3:
	return _MapReaderStack(reader, dstfmt, srcfmt, funclist)
    else:
	return _ConverterStack(reader, dstfmt, srcfmt, funclist)

def stackwriter(srcfmt, writer):
    """Create a writer-like object that writes an image file from source
    data in the specified format.
    Args: source_format, destination_writer
    """
    
    if srcfmt in writer.format_choices:
	writer.format = srcfmt
	return writer
    # Nope, not supported directly. Find all possible converters
    list = []
    for dstfmt in writer.format_choices:
	try:
	    [lossy, len, funclist] = getconverter(srcfmt, dstfmt)
	except unsupported_error:
	    continue
	list.append(lossy, len, dstfmt, funclist)
    if not list:
	raise unsupported_error, (srcfmt, writer.format_choices)
    # Now, sort and use the best
    list.sort()
    lossy, len, dstfmt, funclist = list[0]
    return _ConverterStack(writer, srcfmt, dstfmt, funclist)

#
# The placeholder class
class _ConverterStack:
    def __init__(self, base, ourfmt, basefmt, funclist):
	self._base = base
	self._funclist = funclist
	self._copyattrtoself()
	self.format_choices = (ourfmt,)
	self.format = ourfmt
	self._basefmt = basefmt

    def _copyattrtoself(self):
	srcdict = self._base.__dict__
	dstdict = self.__dict__
	for k in srcdict.keys():
	    if k[0] <> '_':
		dstdict[k] = srcdict[k]

    def _copyattrfromself(self):
	srcdict = self.__dict__
	dstdict = self._base.__dict__
	for k in srcdict.keys():
	    if k[0] <> '_':
		dstdict[k] = srcdict[k]

    def read(self):
	self._copyattrfromself()
	self._base.format = self._basefmt
	data = self._base.read()
	if TRACE:
	    print 'Converting', self._basefmt.name, 'to', self.format.name,
	    print 'in', len(self._funclist), 'steps:'
	for f in self._funclist:
	    if TRACE:
		print '  ',f[0].name, 'to',f[1].name
	    data = apply(f[2], (data, self, f[0], f[1]))
	return data

    def write(self, data):
	if TRACE:
	    print 'Converting', self.format.name, 'to', self._basefmt.name,
	    print 'in', len(self._funclist), 'steps:'
	for f in self._funclist:
	    if TRACE:
		print '  ',f[0].name, 'to',f[1].name
	    data = apply(f[2], (data, self, f[0], f[1]))
	self._copyattrfromself()
	self._base.format = self._basefmt
	self._base.write(data)

#
# A MapReaderStack is used if one of the converters also returns a
# colormap. In this case we have to read upon init to set the colormap
# attribute.
#
class _MapReaderStack(_ConverterStack):
    def __init__(self, base, ourfmt, basefmt, funclist):
	_ConverterStack.__init__(self, base, ourfmt, basefmt, funclist)
	self._base.format = self._basefmt
	data = self._base.read()
	if TRACE:
	    print 'Converting', self._basefmt.name, 'to', self.format.name,
	    print 'in', len(self._funclist), 'steps:'
	for f in self._funclist:
	    if TRACE:
		print '  ',f[0].name, 'to',f[1].name
	    data = apply(f[2], (data, self, f[0], f[1]))
	self._data = data

    def read(self):
	return self._data
