# BBD's Krita Script Starter Feb 2018
'''
    SPDX-FileCopyrightText: 2024 Agata Cacko <cacko.azh@gmail.com>

    This file is part of Fast Sketch Cleanup Plugin

    SPDX-License-Identifier: GPL-3.0-or-later
'''


from krita import Extension

from PyQt5.QtWidgets import QDialog, QHBoxLayout, QVBoxLayout, QPushButton, QToolButton, QCheckBox, QFileDialog, QLabel, QTextEdit, QLineEdit, QRadioButton, QMessageBox, QGroupBox, QFrame, QComboBox
from PyQt5.QtGui import QImage
from PyQt5.QtCore import QBuffer, QByteArray, pyqtSignal

import threading
#from threading import Thread, Lock

from os.path import isfile
import sys

import openvino as ov
from PIL import Image
import numpy as np
import math
import sys

from os.path import isdir, isfile, join
from os import mkdir, listdir

import time


print (sys.path)
import sys
import os

sys.path.append(os.path.dirname(__file__))

from converter import *
from interference import *


EXTENSION_ID = 'pykrita_fast_sketch_cleanup'
MENU_ENTRY = 'Fast Sketch Cleanup'


class FastSketchCleanup(Extension):

    executeButton : QPushButton = None
    invert : bool = False
    modelFilename : str = ""
    dialog : QDialog = None
    
    settingGroup = "fast_sketch_cleanup_plugin"
    modelSettingName = "last_model"
    doubleRunSettingName = "double_run"
    invertSettingName = "invert"
    deviceSettingName = "device"

    margin = 64


    executionMutex = threading.Lock()
    executionsAllowedCounter = 1

    basePath = ""

    
    
    _parent : None
    
    

    def __init__(self, parent):
        # Always initialise the superclass.
        # This is necessary to create the underlying C++ object
        super().__init__(parent)
        self._parent = parent
        self.executionMutex = threading.Lock()

    def setup(self):
        pass

    def createActions(self, window):
        action = window.createAction(EXTENSION_ID, MENU_ENTRY, "tools/scripts")
        # parameter 1 = the name that Krita uses to identify the action
        # parameter 2 = the text to be added to the menu entry for this script
        # parameter 3 = location of menu entry
        action.triggered.connect(self.action_triggered)

    def createDialog(self):
        dialog = QDialog()
        self.dialog = dialog
        
        # textbox for model file with button to choose the file
        # radiobuttons to whether a new layer or save into a file
        # right now, let's assume a layer

        mainLayout = QVBoxLayout()


        print(f"### CWD = {os.getcwd()}")
        print(f"### file = {os.path.dirname(os.path.realpath(__file__))}")
        
        self.basePath = os.path.dirname(os.path.realpath(__file__))
        

        self.executeButton = QPushButton("Run")
        
        self.executeButton.clicked.connect(self.updateGui)
        self.executeButton.clicked.connect(self.updateGuiAndRun)
        
        self.invert = Krita.readSetting(self.settingGroup, self.invertSettingName, "false") == "true"
        self.invertCheckbox = QCheckBox("Invert input and output")
        self.invertCheckbox.stateChanged.connect(self.invertCheckboxChanged)
        self.invertCheckbox.setChecked(self.invert)
        
        self.doubleRun = Krita.readSetting(self.settingGroup, self.doubleRunSettingName, "true") == "true"
        self.doubleRunCheckbox = QCheckBox("Run twice through the model")
        self.doubleRunCheckbox.setChecked(self.doubleRun)
        self.doubleRunCheckbox.stateChanged.connect(self.doubleRunCheckboxChanged)
        
        
        
        


        modelFileLabel = QLabel("Choose model file:")
        mainLayout.addWidget(modelFileLabel)

        #fileButtonLayout = QHBoxLayout()
        
        #self.modelFileChooserButton = QToolButton()
        #self.modelFileChooserButton.clicked.connect(self.chooseModelFileButtonClicked)

        #self.modelFileFilenameTextbox = QLineEdit()
        #self.modelFilename = Krita.readSetting(self.settingGroup, self.modelSettingName, "")
        
        #self.modelFileFilenameTextbox.setMinimumWidth(500)
        #self.modelFileFilenameTextbox.setText(self.modelFilename)
        #self.modelFileFilenameTextbox.setReadOnly(True)


        #fileButtonLayout.addWidget(self.modelFileFilenameTextbox)
        #fileButtonLayout.addWidget(self.modelFileChooserButton)

        #mainLayout.addLayout(fileButtonLayout)
        
        self.modelCombobox = QComboBox()

        print(f"### base path = {self.basePath}")
        for f in sorted(listdir(self.basePath)):
            print(f"### file to consider: {f}, or {join(self.basePath, f)} ({isfile(join(self.basePath, f))})")
            if f.endswith(".xml") and isfile(join(self.basePath, f)):
                self.modelCombobox.addItem(f)
        
        lastUsedModel = Krita.readSetting(self.settingGroup, self.modelSettingName, "")
        lastUsedModelBase = os.path.basename(lastUsedModel)
        if (isfile(join(self.basePath, lastUsedModel))):
            self.modelCombobox.setCurrentText(lastUsedModel)

        self.modelFilename = join(self.basePath, self.modelCombobox.currentText())


        self.modelCombobox.currentTextChanged.connect(self.modelComboboxTextChanged)
        mainLayout.addWidget(self.modelCombobox)
        
        #deviceLabel = QLabel("Device to use:")
        self.device = Krita.readSetting(self.settingGroup, self.deviceSettingName, "CPU")
        
        #mainLayout.addWidget(deviceLabel)


        self.deviceGroupBox = QGroupBox("Device to use:")

        self.cpuRadioButton = QRadioButton("CPU")
        self.npuRadioButton = QRadioButton("NPU")
        self.gpuRadioButton = QRadioButton("GPU")
        
        
        if self.device == "CPU":
            self.cpuRadioButton.setChecked(True)
        elif self.device == "NPU":
            self.npuRadioButton.setChecked(True)
        else:
            self.gpuRadioButton.setChecked(True)
        
        self.cpuRadioButton.toggled.connect(self.deviceRadioButtonChanged)
        self.npuRadioButton.toggled.connect(self.deviceRadioButtonChanged)
        self.gpuRadioButton.toggled.connect(self.deviceRadioButtonChanged)

        self.cpuRadioButton.setEnabled(False)
        self.npuRadioButton.setEnabled(False)
        self.gpuRadioButton.setEnabled(False)

        for availableDevice in ov.runtime.Core().get_available_devices():
            if availableDevice == "CPU":
                self.cpuRadioButton.setEnabled(True)
            elif availableDevice == "GPU":
                self.gpuRadioButton.setEnabled(True)
            elif availableDevice == "NPU":
                self.npuRadioButton.setEnabled(True)
            

        
        
        
        radioButtonLayout = QVBoxLayout()

        
        
        radioButtonLayout.addWidget(self.cpuRadioButton)
        radioButtonLayout.addWidget(self.gpuRadioButton)
        radioButtonLayout.addWidget(self.npuRadioButton)
        
        self.deviceGroupBox.setLayout(radioButtonLayout)
        mainLayout.addWidget(self.deviceGroupBox)
        #mainLayout.addLayout(radioButtonLayout)
        

        self.advancedOptionsGroupBox = QGroupBox("Advanced options:")
        advancedOptionsLayout = QHBoxLayout()
        
        
        advancedOptionsLayout.addWidget(self.invertCheckbox)
        advancedOptionsLayout.addWidget(self.doubleRunCheckbox)

        self.advancedOptionsGroupBox.setLayout(advancedOptionsLayout)
        mainLayout.addWidget(self.advancedOptionsGroupBox)
        
        

        line = QFrame()
        line.setFrameShape(QFrame.HLine)
        line.setFrameShadow(QFrame.Plain)
        line.setFrameShadow(QFrame.Plain)
        
        mainLayout.addWidget(line)


        infoLabel = QLabel("(The dialog will close after converting the image)")
        mainLayout.addWidget(infoLabel)


        mainLayout.addWidget(self.executeButton)

        dialog.setLayout(mainLayout)


        self.readConfigForModel(self.modelFilename)

        dialog.exec()


    def action_triggered(self):
        
        self.createDialog()
        with self.executionMutex:
            self.executionsAllowedCounter = 1
        #pass  # your active code goes here.

    def invertCheckboxChanged(self, value):
        self.invert = self.invertCheckbox.isChecked()
        Krita.writeSetting(self.settingGroup, self.invertSettingName, "true" if self.invert else "false")
        
    def doubleRunCheckboxChanged(self, value):
        self.doubleRun = value
        Krita.writeSetting(self.settingGroup, self.doubleRunSettingName, "true" if self.doubleRun else "false")

    def chooseModelFileButtonClicked(self, value):
        fileDialog = QFileDialog(caption="Choose model file...", filter="*.xml")
        self.modelFilename = fileDialog.getOpenFileName(caption="Choose model file...", filter="*.xml")[0]
        self.modelFileFilenameTextbox.setText(self.modelFilename)
        Krita.writeSetting(self.settingGroup, self.modelSettingName, self.modelFilename)

        self.readConfigForModel(self.modelFilename)
        

    def modelComboboxTextChanged(self, text):
        self.modelFilename = join(self.basePath, text)
        Krita.writeSetting(self.settingGroup, self.modelSettingName, self.modelFilename)
        self.readConfigForModel(self.modelFilename)
        print(f"### Current model file: {self.modelFilename}")


        
    def deviceRadioButtonChanged(self, value):
        if self.cpuRadioButton.isChecked():
            self.device = "CPU"
        elif self.gpuRadioButton.isChecked():
            self.device = "GPU"
        else:
            self.device = "NPU"
        Krita.writeSetting(self.settingGroup, self.deviceSettingName, self.device)
        

    def _ensureDivisableByMargin(self, size, margin):
        
        howMany : int = int(size/margin)
        rest : int = int(size%margin)
        print(f"#### Sizes: size: {size} howmany: {howMany} rest: {rest}")
        if rest > 0:
            howMany += 1
        print(f"#### End result: {int(howMany*margin)}")
        return int(howMany*margin)
        #return int((math.ceil(size - 1)/8)*8)


    def convertKritaImageToNumpy(self, margin):
        application = Krita.instance()
        currentDoc = application.activeDocument()

        # get the current selected layer, called a 'node'
        # currentLayer = currentDoc.activeNode()
        # print(currentLayer.name())
        currentDoc.refreshProjection()

        width = self._ensureDivisableByMargin(currentDoc.width(), margin)
        height = self._ensureDivisableByMargin(currentDoc.height(), margin)
        

        projection = currentDoc.projection(0, 0, width, height) # that's QImage
        #ret = projection.save("/home/tymon/Dokumenty/krita_experiment_projection.png", "PNG")
        #print(ret)
        #print(projection.isNull())
        #print("should be saved")
        #print(dir(currentDoc))

        print(currentDoc.name())


        print(f"projection.size().width(): {projection.size().width()}")
        print(f"projection.size().height(): {projection.size().height()}")
        print(f"width = {width}, height = {height}")
        bitsy1 = projection.bits()
        bitsy1.setsize(height*width)
        print(f"first ten bytes in projection: {byteArrayToString(bitsy1, 10)}")

        projection = projection.convertToFormat(QImage.Format_Grayscale8)



        #buffer = QBuffer()
        #buffer.open(QBuffer.ReadWrite)

        #projection.save(buffer, "PNG")


        #pillowImage = Image.open(io.BytesIO(buffer.data()))
        #pillowImage.show()

        #(w, h) = image.size
        #print(f"image size before = {image.size}")
        #image = image.crop((0, 0, 8*(int(w/8)), 8*(int(h/8)))) # ensures %8 == 0
        #image = image.convert("L")
        #print(f"image size = {image.size}")
        #data = transforms.ToTensor()(image)




        #qimage = pixmap.convertToFormat(QImage.Format_RGB888)


        bitsy = projection.bits()
        bitsy.setsize(height*width)
        print("First 10 bytes in the source: ")


        return (convertImageToNumpy(bitsy, width, height, [projection.bytesPerLine(), 1, 1], self.invert), width, height)

    def convertOutputToKritaLayer(self, output, width, height, invertOutput):

        bytesArray = convertOutputToLayerData(output, invertOutput)

        application = Krita.instance()
        currentDoc = application.activeDocument()

        root = currentDoc.rootNode()

        result = currentDoc.createNode(f"result [sumAll] avg: [average], out of {width*height} pxs", "paintLayer")
        root.addChildNode(result, None)



        bytes = QByteArray()
        bytes.append(chr(0))
        bytes.append(chr(200))
        bytes.append(chr(0))
        bytes.append(chr(255))

        def tobytes2(n):
            return n.to_bytes(1, "big")
        def tobytearr(n):
            return QByteArray(1, tobytes2(n))

        #bytes = QByteArray(5*5*4, tobytes(180))

        
        result.setColorSpace("RGBA", "U8", "sRGB-elle-V2-srgbtrc.icc")

        #result.setPixelData(bytes, 0, 0, 5, 5)
        #result.setPixelData(bytes, 0, 0, 1, 1)


        print(f"pixels data is: " + f"{bytesArray}"[:30])
        print(f"size of the pixel data = {bytesArray.count}")
        print(f"reference pixel data:" + f"{bytes}")
        print(f"the size of the area is: {width} {height}")
        
        result.setPixelData(bytesArray, 0, 0, width, height)
        
        
        currentDoc.setActiveNode(result)

        currentDoc.refreshProjection()
        
    signalToRun = pyqtSignal()
    
    def updateGui(self):
        b : QPushButton = self.executeButton
        b.setText("RUNNING...")
        b.setDisabled(True)
        
    
    def updateGuiAndRun(self):
        b : QPushButton = self.executeButton
        b.setText("RUNNING...")
        b.setDisabled(True)
        
        
        
        self.signalToRun.connect(self.run)
        
        self.signalToRun.emit()
        
    def readConfigForModel(self, modelFile):
        configFile = modelFile.replace(".xml", ".yaml")
        self.readConfig(configFile)

    def readConfig(self, configFile):
        if isfile(configFile):
            with open(configFile, "r") as file:
                for line in file.readlines():
                    if line.startswith("invert: "):
                        line = line.replace("invert: ", "")
                        
                        invert = True if ("true" in line) else False
                        print(f"line = [{line}], so invert = {invert}")
                        self.invertCheckbox.setChecked(invert)
            
    

    def run(self):
        #print(f"[{threading.current_thread().ident}] ---- RUN function start ----", file=sys.stderr)
        canRun = self.executionMutex.acquire(False)
        print(f"[{threading.current_thread().ident}] Trying to acquire the lock, result = {canRun}", file=sys.stderr)
        if not canRun:
            self.dialog.close()
            return
        try:
            #print(f"[{threading.current_thread().ident}] executions number = {self.executionsAllowedCounter}", file=sys.stderr)
            if self.executionsAllowedCounter < 1:
                #print(f"[{threading.current_thread().ident}] executions number = {self.executionsAllowedCounter} => returning early, releasing the mutex", file=sys.stderr)
                
                return
            self.executionsAllowedCounter -= 1
            #print(f"[{threading.current_thread().ident}] => executions number = {self.executionsAllowedCounter}", file=sys.stderr)
            invert = self.invert
            
            (data, width, height) = self.convertKritaImageToNumpy(self.margin)


            print(f"is file? {isfile(self.modelFilename)}")


            model = None
            compiled_model = None

            ie = ov.Core()


            #print(torch.load(modelFile))
            try:
                #model = torch.load(modelFile, map_location=torch.device(device))
                print(f"model file = {self.modelFilename}")
                model = ie.read_model(model=self.modelFilename)
                if model is None:
                    QMessageBox.critical(self.dialog, "The model file cannot be read", f"The model file cannot be read: {self.modelFilename}.")
                    #print(f"[{threading.current_thread().ident}] Model file cannot be read, returning early, releasing the mutex. | executions number = {self.executionsAllowedCounter}", file=sys.stderr)
                    
                    return
                compiled_model = ie.compile_model(model=model, device_name=self.device)
            except Exception as e:
                print(e)
                print("there was an exception")
            finally:
                pass
                #print("should've worked")
                #print(f"model input shape = {compiled_model.input().get_shape()}", file=sys.stderr)

            if (compiled_model is None):
                print("The model cannot be compiled.", file=sys.stderr)
                #print(f"[{threading.current_thread().ident}] Model file cannot be compiled, returning early, releasing the mutex. | executions number = {self.executionsAllowedCounter}", file=sys.stderr)
                QMessageBox.critical(self.dialog, "The model cannot be compiled", f"The model cannot be compiled to specified device: {self.device}.")
                
                return

            




            
            compiledModelInputShape = compiled_model.input().get_shape()
            print(f"model input shape = {compiledModelInputShape}")

            #if compiledModelInputShape[2] != data.shape[2] or compiledModelInputShape[3] != data.shape[3]:
            #    QMessageBox.critical(self.dialog, "The input shape doesn't match the expected input shape", f"The input shape: {data.shape} doesn't match the expected input shape: {compiledModelInputShape}.")
            #    print(f"[{threading.current_thread().ident}] The input shape doesn't match, returning early, releasing the mutex. | executions number = {self.executionsAllowedCounter}", file=sys.stderr)
            #    
            #    return
            print(ov.__version__)
            print(type(model))
            
            
            partSize = compiledModelInputShape[2]
            assert partSize == compiledModelInputShape[3], "Part size must be equal in both dimensions"
            outputData = interference(compiled_model, data, partSize, margin=self.margin)
            print(f"data was: {data.shape}, output is: {outputData.shape}, sums: {np.sum(data)}, {np.sum(outputData)}")


            self.convertOutputToKritaLayer(outputData, width, height, self.invert)

            print("~~~~~~~~~~~~ THE END ~~~~~~~~~~~~~~~~")
        finally:
            self.dialog.close()
            self.executionMutex.release()
