#!<launcher_dir>\python.exe

# priditherpng
# Error Diffusing image dithering.
# Now with serpentine scanning.

# test with:
# priforgepng grl | priditherpng | kitty icat

# See http://www.efg2.com/Lab/Library/ImageProcessing/DHALF.TXT
# archived at http://web.archive.org/web/20160727202727/http://www.efg2.com/Lab/Library/ImageProcessing/DHALF.TXT

# https://docs.python.org/3.5/library/bisect.html
from bisect import bisect_left


import png


def dither(
    out,
    input,
    bitdepth=1,
    targetgamma=None,
):
    """Dither the input PNG `inp` into an image with a smaller bit depth
    and write the result image onto `out`.  `bitdepth` specifies the bit
    depth of the new image.

    The source image gamma is used to convert to
    linear light space before dithering.
    If there is no gamma in the source image, a gamma of 1.0 is
    used (no conversion, assumed to be linear already).

    Use `prichunkpng` to add a `gAMA` chunk if needed.

    The gamma of the output image is, by default, the same as the input
    image.  The `targetgamma` argument can be used to specify a
    different gamma for the output image.  This effectively recodes the
    image to a different gamma, dithering as we go.  The gamma specified
    is the exponent used to encode the output file (and appears in the
    output PNG's ``gAMA`` chunk); it is usually less than 1.

    """

    # Encoding is what happened when the PNG was made (and also what
    # happens when we output the PNG).  Decoding is what we do to the
    # source PNG in order to process it.

    # The dithering algorithm is not completely general; it
    # can only do bit depth reduction, not arbitrary palette changes.
    import operator

    maxval = 2 ** bitdepth - 1
    r = png.Reader(file=input)

    _, _, pixels, info = r.asDirect()
    planes = info["planes"]
    # :todo: make an Exception
    assert planes == 1
    width = info["size"][0]
    sourcemaxval = 2 ** info["bitdepth"] - 1

    gamma = info.get("gamma", 1.0)

    # Calculate an effective gamma for input and output;
    # then build tables using those.

    # `gamma` (whether it was obtained from the input file or an
    # assumed value) is the encoding gamma.
    # We need the decoding gamma, which is the reciprocal.
    decode = 1.0 / gamma

    # `targetdecode` is the assumed gamma that is going to be used
    # to decoding the target PNG.
    # Note that even though we will _encode_ the target PNG we
    # still need the decoding gamma, because
    # the table we use maps from PNG pixel value to linear light level.
    if targetgamma is None:
        targetdecode = decode
    else:
        targetdecode = 1.0 / targetgamma

    incode = build_decode_table(sourcemaxval, decode)

    # For encoding, we still build a decode table, because we
    # use it inverted (searching with bisect).
    outcode = build_decode_table(maxval, targetdecode)

    # The table used for choosing output codes.
    # These values represent the cutoff points between
    # two adjacent output codes.
    # Previous code exposed a cutoff parameter, which
    # made the cutoff points darker or lighter; but
    # that parameter has been removed.
    choosecode = list(zip(outcode[1:], outcode))
    p = 0.5
    choosecode = [x[0] * p + x[1] * (1.0 - p) for x in choosecode]

    rows = repeat_header(pixels)
    dithered_rows = run_dither(incode, choosecode, outcode, width, rows)
    dithered_rows = remove_header(dithered_rows)

    info["bitdepth"] = bitdepth
    info["gamma"] = 1.0 / targetdecode
    w = png.Writer(**info)
    w.write(out, dithered_rows)


def build_decode_table(maxval, gamma):
    """Build a lookup table for decoding;
    table converts from pixel values to linear space.
    """

    assert maxval == int(maxval)
    assert maxval > 0

    f = 1.0 / maxval
    table = [f * v for v in range(maxval + 1)]
    if gamma != 1.0:
        table = [v ** gamma for v in table]
    return table


def run_dither(incode, choosecode, outcode, width, rows):
    """
    Run an serpentine dither.
    Using the incode and choosecode tables.
    """

    # Errors diffused downwards (into next row)
    ed = [0.0] * width
    flipped = False
    for row in rows:
        # Convert to linear...
        row = [incode[v] for v in row]
        # Add errors...
        row = [e + v for e, v in zip(ed, row)]

        if flipped:
            row = row[::-1]
        targetrow = [0] * width

        for i, v in enumerate(row):
            # `it` will be the index of the chosen target colour;
            it = bisect_left(choosecode, v)
            targetrow[i] = it
            t = outcode[it]
            # err is the error that needs distributing.
            err = v - t

            # Sierra "Filter Lite" distributes          * 2
            # as per this diagram.                    1 1
            ef = err * 0.5
            # :todo: consider making rows one wider at each end and
            # removing "if"s
            if i + 1 < width:
                row[i + 1] += ef
            ef *= 0.5
            ed[i] = ef
            if i:
                ed[i - 1] += ef

        if flipped:
            ed = ed[::-1]
            targetrow = targetrow[::-1]
        yield targetrow
        flipped = not flipped


WARMUP_ROWS = 32


def repeat_header(rows):
    """Repeat the first row, to "warm up" the error register."""
    for row in rows:
        yield row
        for _ in range(WARMUP_ROWS):
            yield row
        break
    yield from rows


def remove_header(rows):
    """Remove the same number of rows that repeat_header added."""

    for _ in range(WARMUP_ROWS):
        next(rows)
    yield from rows


def main(argv=None):
    import sys

    # https://docs.python.org/3.5/library/argparse.html
    import argparse

    parser = argparse.ArgumentParser()

    if argv is None:
        argv = sys.argv

    progname, *args = argv

    parser.add_argument("--bitdepth", type=int, default=1, help="bitdepth of output")
    parser.add_argument(
        "input", nargs="?", default="-", type=png.cli_open, metavar="PNG"
    )

    ns = parser.parse_args(args)

    return dither(png.binary_stdout(), **vars(ns))


if __name__ == "__main__":
    main()
