/*
 * Decompiled with CFR 0.152.
 */
package org.apache.lucene.internal.vectorization;

import java.io.IOException;
import java.lang.foreign.MemorySegment;
import java.util.Optional;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.vectorization.PanamaVectorUtilSupport;
import org.apache.lucene.store.FilterIndexInput;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.MemorySegmentAccessInput;
import org.apache.lucene.util.hnsw.RandomVectorScorer;

abstract sealed class Lucene99MemorySegmentByteVectorScorer
extends RandomVectorScorer.AbstractRandomVectorScorer {
    final int vectorByteSize;
    final MemorySegmentAccessInput input;
    final byte[] query;
    byte[] scratch;

    public static Optional<Lucene99MemorySegmentByteVectorScorer> create(VectorSimilarityFunction type, IndexInput input, KnnVectorValues values, byte[] queryVector) {
        assert (values instanceof ByteVectorValues);
        if (!((input = FilterIndexInput.unwrapOnlyTest(input)) instanceof MemorySegmentAccessInput)) {
            return Optional.empty();
        }
        MemorySegmentAccessInput msInput = (MemorySegmentAccessInput)((Object)input);
        Lucene99MemorySegmentByteVectorScorer.checkInvariants(values.size(), values.getVectorByteLength(), input);
        return switch (type) {
            default -> throw new MatchException(null, null);
            case VectorSimilarityFunction.COSINE -> Optional.of(new CosineScorer(msInput, values, queryVector));
            case VectorSimilarityFunction.DOT_PRODUCT -> Optional.of(new DotProductScorer(msInput, values, queryVector));
            case VectorSimilarityFunction.EUCLIDEAN -> Optional.of(new EuclideanScorer(msInput, values, queryVector));
            case VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT -> Optional.of(new MaxInnerProductScorer(msInput, values, queryVector));
        };
    }

    Lucene99MemorySegmentByteVectorScorer(MemorySegmentAccessInput input, KnnVectorValues values, byte[] queryVector) {
        super(values);
        this.input = input;
        this.vectorByteSize = values.getVectorByteLength();
        this.query = queryVector;
    }

    final MemorySegment getSegment(int ord) throws IOException {
        this.checkOrdinal(ord);
        long byteOffset = (long)ord * (long)this.vectorByteSize;
        MemorySegment seg = this.input.segmentSliceOrNull(byteOffset, this.vectorByteSize);
        if (seg == null) {
            if (this.scratch == null) {
                this.scratch = new byte[this.vectorByteSize];
            }
            this.input.readBytes(byteOffset, this.scratch, 0, this.vectorByteSize);
            seg = MemorySegment.ofArray(this.scratch);
        }
        return seg;
    }

    static void checkInvariants(int maxOrd, int vectorByteLength, IndexInput input) {
        if (input.length() < (long)vectorByteLength * (long)maxOrd) {
            throw new IllegalArgumentException("input length is less than expected vector data");
        }
    }

    final void checkOrdinal(int ord) {
        if (ord < 0 || ord >= this.maxOrd()) {
            throw new IllegalArgumentException("illegal ordinal: " + ord);
        }
    }

    static final class CosineScorer
    extends Lucene99MemorySegmentByteVectorScorer {
        CosineScorer(MemorySegmentAccessInput input, KnnVectorValues values, byte[] query) {
            super(input, values, query);
        }

        @Override
        public float score(int node) throws IOException {
            this.checkOrdinal(node);
            float raw = PanamaVectorUtilSupport.cosine(this.query, this.getSegment(node));
            return (1.0f + raw) / 2.0f;
        }
    }

    static final class DotProductScorer
    extends Lucene99MemorySegmentByteVectorScorer {
        DotProductScorer(MemorySegmentAccessInput input, KnnVectorValues values, byte[] query) {
            super(input, values, query);
        }

        @Override
        public float score(int node) throws IOException {
            this.checkOrdinal(node);
            float raw = PanamaVectorUtilSupport.dotProduct(this.query, this.getSegment(node));
            return 0.5f + raw / (float)(this.query.length * 32768);
        }
    }

    static final class EuclideanScorer
    extends Lucene99MemorySegmentByteVectorScorer {
        EuclideanScorer(MemorySegmentAccessInput input, KnnVectorValues values, byte[] query) {
            super(input, values, query);
        }

        @Override
        public float score(int node) throws IOException {
            this.checkOrdinal(node);
            float raw = PanamaVectorUtilSupport.squareDistance(this.query, this.getSegment(node));
            return 1.0f / (1.0f + raw);
        }
    }

    static final class MaxInnerProductScorer
    extends Lucene99MemorySegmentByteVectorScorer {
        MaxInnerProductScorer(MemorySegmentAccessInput input, KnnVectorValues values, byte[] query) {
            super(input, values, query);
        }

        @Override
        public float score(int node) throws IOException {
            this.checkOrdinal(node);
            float raw = PanamaVectorUtilSupport.dotProduct(this.query, this.getSegment(node));
            if (raw < 0.0f) {
                return 1.0f / (1.0f + -1.0f * raw);
            }
            return raw + 1.0f;
        }
    }
}

