#!/bin/python3
'''
    SPDX-FileCopyrightText: 2024 Agata Cacko <cacko.azh@gmail.com>

    This file is part of Fast Sketch Cleanup Plugin

    SPDX-License-Identifier: GPL-3.0-or-later
'''

import numpy as np
from PyQt5.QtGui import QImage
from PyQt5.QtCore import QBuffer, QByteArray


from PIL import Image
import sys


def byteArrayToString(data, maxBytes = 20):
    response = ""
    for i in range(maxBytes):
        #b = int.from_bytes(data[i], "little")
        k: int = data[i]
        response += f"{k}" + " "
    return response

def numpyArrayShortened(data: np.ndarray, maxBytes = 20):
    response = ""
    oneDim = data.reshape((-1))
    for i in range(min(maxBytes, len(oneDim))):
        response += f"{oneDim[i]}" + " "
    return response


def convertImageToNumpy(data, width, height, strides, invert):

    print(f"data really is: {data} {width} {height}")
    #numpyarray = np.ndarray((width, height, 1), buffer=data, dtype=np.uint8)
    numpyarray = np.frombuffer(data, dtype=np.uint8)
    print(f"the numpyarray 1d is = {numpyarray}")
    numpyarray = numpyarray.reshape((height, width, 1))
    print(numpyarray.shape)
    print(f"sum of first numpy input: {np.sum(numpyarray)}")
    print(f"numpy array = {numpyArrayShortened(numpyarray, 20)}")
    print(f"bytes per line/strides were: {strides}")

    #arr = QByteArray.fromRawData(projection.bits());
    #print(arr)
    #print(projection.bits())
    #print(projection.size().width())
    #print(projection.size().height())

    #newArr = arr.data()

    numpyarray = numpyarray.astype(dtype = np.float32)
    data = numpyarray
    #data = torch.from_numpy(numpyarray)
    #data = torch.unsqueeze(data, 0)
    #data = torch.unsqueeze(data, 0)
    
    
    print(f"shape of data beforehand: {data.shape}")
    data = np.expand_dims(data, 0)
    data = np.expand_dims(data, 0)
    
    print(f"shape of data beforehand squeeze: {data.shape}")
    data = np.squeeze(data, 4)
    data = data/255.0
    

    print(data.shape)
    
    print(f"###### INVERTING: {invert}", file=sys.stderr)
    if invert:
        data = 1.0 - data

    print(f"sum of last numpy input: {np.sum(data)}")
    

    print(f"data just before input to the network = {numpyArrayShortened(data, 30)}")

    return data




def saveAsImage(data, filename):
    im = Image.fromarray(data)
    im.save(filename)

def convertOutputToLayerData(output, invert):
    

    sumAll = np.sum(output)
    average = np.average(output)

    outputData = output
    
    print(f"###### INVERTING: {invert}", file=sys.stderr)
    if invert:
        outputData = 1 - outputData
    

    print(f"raw data = {numpyArrayShortened(output, 30)}")

    outputData = outputData.squeeze(0)
    outputData = outputData*255.0
    outputData = np.clip(outputData, 0, 255)


    # --- everything should be black, and the grey value should be the 1-transparency
    width = output.shape[2]
    height = output.shape[3]

    print(f"The shape currently is: {outputData.shape}")
    outputToImage = outputData.squeeze(0)
    #outputToImage = outputToImage
    print(f"The shape currently is: {outputToImage.shape}")
    #outputToImage = outputToImage.reshape((width, height, 1))
    outputRGB = np.zeros((width, height, 3), dtype=np.uint8)
    outputRGB[:, :, 0] = outputToImage
    outputRGB[:, :, 1] = outputToImage
    outputRGB[:, :, 2] = outputToImage

    outputRGBA = np.zeros((width, height, 4), dtype=np.uint8)
    outputRGBA[:, :, (0, 1, 2)] = outputRGB
    outputRGBA[:, :, 3] = np.ones((width, height))*255
    

    print(f"output to image final: {numpyArrayShortened(outputRGB, 30)}")
    


    print(f"The shape currently is: {outputRGB.shape}")
    
    
    #saveAsImage(outputRGB, "/home/tymon/Dokumenty/SmartLineartPluginOutput.png")
    print("Image saved!")

    #from torchvision.utils import save_image
    #save_image(outputData, outputFile)

    
    print(f"output shape = {output.shape}, width = {width}, height = {height}")


    bytesArray = QByteArray(4*width*height, chr(0))
    outputDataOneDim = outputData.reshape((-1))

    def tobytes(n):
        return n.to_bytes(1, "big")
    def tobytearr(n):
        return QByteArray(1, tobytes(n))
    

    print(f"first 10 bytes: {numpyArrayShortened(outputDataOneDim, 10)}")

    print(f"outputData before big stuff: {numpyArrayShortened(outputData, 10)}")
    print(f"outputDataOneDim: {numpyArrayShortened(outputDataOneDim, 10)}")
    

    for i in range(int(width * height)):

        if (outputDataOneDim[i] < 1.0):
            pass
            #print(f"i = {i}, value = {outputDataOneDim[i]}")

        #bytesArray[i*4:(i*4 + 3)] = bytes([0, 0, 0, 255*(1 - outputData[i])])
        #bytesArray[i*4] = 0
        #bytesArray[i*4 + 1] = 0
        #bytesArray[i*4 + 2] = 0
        #bytesArray[i*4 + 3] = 255*(1 - outputData[i])

        #bytesArray.replace(i*4, 4, QByteArray(bytes([0, 0, 0, 255*(1 - outputData[i])])))
        bytesArray.replace(i*4, 1, tobytearr(0))
        bytesArray.replace(i*4 + 1, 1, tobytearr(0))
        bytesArray.replace(i*4 + 2, 1, tobytearr(0))
        bytesArray.replace(i*4 + 3, 1, tobytearr(255))
        
        #print(outputDataOneDim.shape)
        #print(255*(1 - outputDataOneDim[i]))
        #print(255*(1 - float(outputDataOneDim[i])))
        
        alpha = int((float(outputDataOneDim[i])))
        #if invert:    
        #   alpha = int((255 - float(outputDataOneDim[i])))
        #print(alpha)

        #print("replacing alpha with: {alpha}")
        
        #alpha = 255



        bytesArray.replace(i*4 + 3, 1, tobytearr(alpha))



        #bytesArray[i*4] = 0
        #bytesArray[i*4 + 1] = 0
        #bytesArray[i*4 + 2] = 0
        #bytesArray[i*4 + 3] = tobytes(alpha)
        
        
        #bytesArray[i*4 + 1] = 0
        #bytesArray[i*4 + 2] = 0
        #bytesArray[i*4 + 3] = 255*(1 - outputData[i])

    #print(f"bytes array = {bytesArray}")
    #return bytesArray

    im = Image.fromarray(outputRGB)
    im = im.convert("RGBA")
    #im.save("/home/tymon/Dokumenty/SmartLineartPluginOutput2.png")
    bytesArray2 = im.tobytes("raw", "RGBA")
    outputRGBSize = outputRGB.shape[0]*outputRGB.shape[1]*outputRGB.shape[2]
    print(f"{byteArrayToString(bytesArray2, min(outputRGBSize, 20))}")

    outputDataOneDim = outputDataOneDim.astype(np.uint8)
    outputRGB = outputRGB.astype(np.uint8)
    
    
    print(f"outputToImage2 in numbers (10) = \n" + (f"{numpyArrayShortened(outputRGB, 10)}"))
    
    print(f"outputToImage2 in bytes (20 letters) =  = \n" + (f"{byteArrayToString(outputRGB.tobytes(), min(outputRGBSize, 20))}"))
    print(f"outputRGBA to bytes: {byteArrayToString(outputRGBA.tobytes(), min(outputDataOneDim.shape[0], 20))}")
    return outputRGBA.tobytes()


    
    
if __name__ == "__main__":
    test1 = b'\x00\x00\x00\xff\x00\x00\x00\xff\x00\x00\x00\xff\x00\x00\x00\xff\x00\x00\x00\xff\x00\x00\x00\xff' # 6 x 4 bytes
    test1array = np.ndarray((24), buffer=test1, strides=[1], dtype=np.uint8)
    test1array2 = np.ndarray((6, 4), buffer=test1, dtype=np.uint8)
    test1array2 = np.ndarray((6, 4, 1), buffer=test1, dtype=np.uint8)
    
    

    print(f"test1 array = {test1array}")
    print(f"test1 array 2 = {test1array2}")
    



    test = np.array([[[10], [20], [30]], [[50], [100], [150]]], dtype=np.uint8)

    testbytes = test.tobytes()
    print(f"test bytes = {byteArrayToString(testbytes, 6)}")
    testarray = np.ndarray((6), dtype=np.uint8, buffer=testbytes)
    testarray2 = np.frombuffer(testbytes, dtype=np.uint8)
    
    print(f"test array = {testarray}, {testarray2}")


    print(f"test = {test}, TEST SHAPE = {test.shape}")
    result1 = convertImageToNumpy(test, 3, 2, [1, 1, 1])
    result2 = convertOutputToLayerData(result1, 0)
    print(result2)
