/*
 * Decompiled with CFR 0.152.
 */
package org.apache.lucene.ars_nouveau.sandbox.search;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.TreeMap;
import org.apache.lucene.ars_nouveau.index.FieldInfo;
import org.apache.lucene.ars_nouveau.index.FieldInfos;
import org.apache.lucene.ars_nouveau.index.IndexReader;
import org.apache.lucene.ars_nouveau.index.LeafReaderContext;
import org.apache.lucene.ars_nouveau.index.PostingsEnum;
import org.apache.lucene.ars_nouveau.index.Term;
import org.apache.lucene.ars_nouveau.index.TermState;
import org.apache.lucene.ars_nouveau.index.TermStates;
import org.apache.lucene.ars_nouveau.index.TermsEnum;
import org.apache.lucene.ars_nouveau.sandbox.search.MultiNormsLeafSimScorer;
import org.apache.lucene.ars_nouveau.search.BooleanClause;
import org.apache.lucene.ars_nouveau.search.BooleanQuery;
import org.apache.lucene.ars_nouveau.search.CollectionStatistics;
import org.apache.lucene.ars_nouveau.search.DisiPriorityQueue;
import org.apache.lucene.ars_nouveau.search.DisiWrapper;
import org.apache.lucene.ars_nouveau.search.DisjunctionDISIApproximation;
import org.apache.lucene.ars_nouveau.search.DocIdSetIterator;
import org.apache.lucene.ars_nouveau.search.Explanation;
import org.apache.lucene.ars_nouveau.search.IndexSearcher;
import org.apache.lucene.ars_nouveau.search.Matches;
import org.apache.lucene.ars_nouveau.search.Query;
import org.apache.lucene.ars_nouveau.search.QueryVisitor;
import org.apache.lucene.ars_nouveau.search.ScoreMode;
import org.apache.lucene.ars_nouveau.search.Scorer;
import org.apache.lucene.ars_nouveau.search.ScorerSupplier;
import org.apache.lucene.ars_nouveau.search.TermQuery;
import org.apache.lucene.ars_nouveau.search.TermScorer;
import org.apache.lucene.ars_nouveau.search.TermStatistics;
import org.apache.lucene.ars_nouveau.search.Weight;
import org.apache.lucene.ars_nouveau.search.similarities.Similarity;
import org.apache.lucene.ars_nouveau.util.Accountable;
import org.apache.lucene.ars_nouveau.util.BytesRef;
import org.apache.lucene.ars_nouveau.util.IOSupplier;
import org.apache.lucene.ars_nouveau.util.RamUsageEstimator;

public final class CombinedFieldQuery
extends Query
implements Accountable {
    private static final long BASE_RAM_BYTES = RamUsageEstimator.shallowSizeOfInstance(CombinedFieldQuery.class);
    private final TreeMap<String, FieldAndWeight> fieldAndWeights;
    private final BytesRef[] terms;
    private final Term[] fieldTerms;
    private final long ramBytesUsed;

    private CombinedFieldQuery(TreeMap<String, FieldAndWeight> fieldAndWeights, BytesRef[] terms) {
        this.fieldAndWeights = fieldAndWeights;
        this.terms = terms;
        int numFieldTerms = fieldAndWeights.size() * terms.length;
        if (numFieldTerms > IndexSearcher.getMaxClauseCount()) {
            throw new IndexSearcher.TooManyClauses();
        }
        this.fieldTerms = new Term[numFieldTerms];
        Arrays.sort(terms);
        int pos = 0;
        for (String field : fieldAndWeights.keySet()) {
            for (BytesRef term : terms) {
                this.fieldTerms[pos++] = new Term(field, term);
            }
        }
        this.ramBytesUsed = BASE_RAM_BYTES + RamUsageEstimator.sizeOfObject(fieldAndWeights) + RamUsageEstimator.sizeOfObject(this.fieldTerms) + RamUsageEstimator.sizeOfObject(terms);
    }

    public List<Term> getTerms() {
        return Collections.unmodifiableList(Arrays.asList(this.fieldTerms));
    }

    @Override
    public String toString(String field) {
        StringBuilder builder = new StringBuilder("CombinedFieldQuery((");
        int pos = 0;
        for (FieldAndWeight fieldWeight : this.fieldAndWeights.values()) {
            if (pos++ != 0) {
                builder.append(" ");
            }
            builder.append(fieldWeight.field);
            if (fieldWeight.weight == 1.0f) continue;
            builder.append("^");
            builder.append(fieldWeight.weight);
        }
        builder.append(")(");
        pos = 0;
        for (BytesRef term : this.terms) {
            if (pos++ != 0) {
                builder.append(" ");
            }
            builder.append(term.utf8ToString());
        }
        builder.append("))");
        return builder.toString();
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (!this.sameClassAs(o)) {
            return false;
        }
        CombinedFieldQuery that = (CombinedFieldQuery)o;
        return Objects.equals(this.fieldAndWeights, that.fieldAndWeights) && Arrays.equals(this.terms, that.terms);
    }

    @Override
    public int hashCode() {
        int result = this.classHash();
        result = 31 * result + Objects.hash(this.fieldAndWeights);
        result = 31 * result + Arrays.hashCode(this.terms);
        return result;
    }

    @Override
    public long ramBytesUsed() {
        return this.ramBytesUsed;
    }

    @Override
    public Query rewrite(IndexSearcher indexSearcher) throws IOException {
        if (this.terms.length == 0 || this.fieldAndWeights.isEmpty()) {
            return new BooleanQuery.Builder().build();
        }
        return this;
    }

    @Override
    public void visit(QueryVisitor visitor) {
        Term[] selectedTerms = (Term[])Arrays.stream(this.fieldTerms).filter(t -> visitor.acceptField(t.field())).toArray(Term[]::new);
        if (selectedTerms.length > 0) {
            QueryVisitor v = visitor.getSubVisitor(BooleanClause.Occur.SHOULD, this);
            v.consumeTerms(this, selectedTerms);
        }
    }

    private BooleanQuery rewriteToBoolean() {
        BooleanQuery.Builder bq = new BooleanQuery.Builder();
        for (Term term : this.fieldTerms) {
            bq.add(new TermQuery(term), BooleanClause.Occur.SHOULD);
        }
        return bq.build();
    }

    @Override
    public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
        this.validateConsistentNorms(searcher.getIndexReader());
        if (scoreMode.needsScores()) {
            return new CombinedFieldWeight(this, searcher, scoreMode, boost);
        }
        BooleanQuery bq = this.rewriteToBoolean();
        return searcher.rewrite(bq).createWeight(searcher, ScoreMode.COMPLETE_NO_SCORES, boost);
    }

    private void validateConsistentNorms(IndexReader reader) {
        boolean allFieldsHaveNorms = true;
        boolean noFieldsHaveNorms = true;
        for (LeafReaderContext context : reader.leaves()) {
            FieldInfos fieldInfos = context.reader().getFieldInfos();
            for (String field : this.fieldAndWeights.keySet()) {
                FieldInfo fieldInfo = fieldInfos.fieldInfo(field);
                if (fieldInfo == null) continue;
                allFieldsHaveNorms &= fieldInfo.hasNorms();
                noFieldsHaveNorms &= fieldInfo.omitsNorms();
            }
        }
        if (!allFieldsHaveNorms && !noFieldsHaveNorms) {
            throw new IllegalArgumentException(this.getClass().getSimpleName() + " requires norms to be consistent across fields: some fields cannot  have norms enabled, while others have norms disabled");
        }
    }

    record FieldAndWeight(String field, float weight) {
    }

    class CombinedFieldWeight
    extends Weight {
        private final IndexSearcher searcher;
        private final TermStates[] termStates;
        private final Similarity.SimScorer simWeight;

        CombinedFieldWeight(Query query, IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
            super(query);
            assert (scoreMode.needsScores());
            this.searcher = searcher;
            long docFreq = 0L;
            long totalTermFreq = 0L;
            this.termStates = new TermStates[CombinedFieldQuery.this.fieldTerms.length];
            for (int i = 0; i < this.termStates.length; ++i) {
                TermStates ts;
                FieldAndWeight field = CombinedFieldQuery.this.fieldAndWeights.get(CombinedFieldQuery.this.fieldTerms[i].field());
                this.termStates[i] = ts = TermStates.build(searcher, CombinedFieldQuery.this.fieldTerms[i], true);
                if (ts.docFreq() <= 0) continue;
                TermStatistics termStats = searcher.termStatistics(CombinedFieldQuery.this.fieldTerms[i], ts.docFreq(), ts.totalTermFreq());
                docFreq = Math.max(termStats.docFreq(), docFreq);
                totalTermFreq = (long)((double)totalTermFreq + (double)field.weight * (double)termStats.totalTermFreq());
            }
            if (docFreq > 0L) {
                CollectionStatistics pseudoCollectionStats = this.mergeCollectionStatistics(searcher);
                TermStatistics pseudoTermStatistics = new TermStatistics(new BytesRef("pseudo_term"), docFreq, Math.max(1L, totalTermFreq));
                this.simWeight = searcher.getSimilarity().scorer(boost, pseudoCollectionStats, pseudoTermStatistics);
            } else {
                this.simWeight = null;
            }
        }

        private CollectionStatistics mergeCollectionStatistics(IndexSearcher searcher) throws IOException {
            long maxDoc = 0L;
            long docCount = 0L;
            long sumTotalTermFreq = 0L;
            long sumDocFreq = 0L;
            for (FieldAndWeight fieldWeight : CombinedFieldQuery.this.fieldAndWeights.values()) {
                CollectionStatistics collectionStats = searcher.collectionStatistics(fieldWeight.field);
                if (collectionStats == null) continue;
                maxDoc = Math.max(collectionStats.maxDoc(), maxDoc);
                docCount = Math.max(collectionStats.docCount(), docCount);
                sumDocFreq = Math.max(collectionStats.sumDocFreq(), sumDocFreq);
                sumTotalTermFreq = (long)((double)sumTotalTermFreq + (double)fieldWeight.weight * (double)collectionStats.sumTotalTermFreq());
            }
            return new CollectionStatistics("pseudo_field", maxDoc, docCount, sumTotalTermFreq, sumDocFreq);
        }

        @Override
        public Matches matches(LeafReaderContext context, int doc) throws IOException {
            Weight weight = this.searcher.rewrite(CombinedFieldQuery.this.rewriteToBoolean()).createWeight(this.searcher, ScoreMode.COMPLETE, 1.0f);
            return weight.matches(context, doc);
        }

        @Override
        public Explanation explain(LeafReaderContext context, int doc) throws IOException {
            int newDoc;
            Scorer scorer = this.scorer(context);
            if (scorer != null && (newDoc = scorer.iterator().advance(doc)) == doc) {
                assert (scorer instanceof CombinedFieldScorer);
                float freq = ((CombinedFieldScorer)scorer).freq();
                MultiNormsLeafSimScorer docScorer = new MultiNormsLeafSimScorer(this.simWeight, context.reader(), CombinedFieldQuery.this.fieldAndWeights.values(), true);
                Explanation freqExplanation = Explanation.match((Number)Float.valueOf(freq), "termFreq=" + freq, new Explanation[0]);
                Explanation scoreExplanation = docScorer.explain(doc, freqExplanation);
                return Explanation.match(scoreExplanation.getValue(), "weight(" + String.valueOf(this.getQuery()) + " in " + doc + "), result of:", scoreExplanation);
            }
            return Explanation.noMatch("no matching term", new Explanation[0]);
        }

        @Override
        public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
            ArrayList<PostingsEnum> iterators = new ArrayList<PostingsEnum>();
            ArrayList<FieldAndWeight> fields = new ArrayList<FieldAndWeight>();
            for (int i = 0; i < CombinedFieldQuery.this.fieldTerms.length; ++i) {
                TermState state;
                IOSupplier<TermState> supplier = this.termStates[i].get(context);
                TermState termState = state = supplier == null ? null : supplier.get();
                if (state == null) continue;
                TermsEnum termsEnum = context.reader().terms(CombinedFieldQuery.this.fieldTerms[i].field()).iterator();
                termsEnum.seekExact(CombinedFieldQuery.this.fieldTerms[i].bytes(), state);
                PostingsEnum postingsEnum = termsEnum.postings(null, 8);
                iterators.add(postingsEnum);
                fields.add(CombinedFieldQuery.this.fieldAndWeights.get(CombinedFieldQuery.this.fieldTerms[i].field()));
            }
            if (iterators.isEmpty()) {
                return null;
            }
            MultiNormsLeafSimScorer scoringSimScorer = new MultiNormsLeafSimScorer(this.simWeight, context.reader(), CombinedFieldQuery.this.fieldAndWeights.values(), true);
            DisiPriorityQueue queue = new DisiPriorityQueue(iterators.size());
            for (int i = 0; i < iterators.size(); ++i) {
                float weight = ((FieldAndWeight)fields.get((int)i)).weight;
                queue.add(new WeightedDisiWrapper((Scorer)new TermScorer((PostingsEnum)iterators.get(i), this.simWeight, null), weight));
            }
            DisjunctionDISIApproximation iterator = new DisjunctionDISIApproximation(queue);
            CombinedFieldScorer scorer = new CombinedFieldScorer(queue, iterator, scoringSimScorer);
            return new Weight.DefaultScorerSupplier(scorer);
        }

        @Override
        public boolean isCacheable(LeafReaderContext ctx) {
            return false;
        }
    }

    private static class CombinedFieldScorer
    extends Scorer {
        private final DisiPriorityQueue queue;
        private final DocIdSetIterator iterator;
        private final MultiNormsLeafSimScorer simScorer;
        private final float maxScore;

        CombinedFieldScorer(DisiPriorityQueue queue, DocIdSetIterator iterator, MultiNormsLeafSimScorer simScorer) {
            this.queue = queue;
            this.iterator = iterator;
            this.simScorer = simScorer;
            this.maxScore = simScorer.getSimScorer().score(Float.POSITIVE_INFINITY, 1L);
        }

        @Override
        public int docID() {
            return this.iterator.docID();
        }

        float freq() throws IOException {
            DisiWrapper w = this.queue.topList();
            float freq = ((WeightedDisiWrapper)w).freq();
            w = w.next;
            while (w != null) {
                if ((freq += ((WeightedDisiWrapper)w).freq()) < 0.0f) {
                    return 2.1474836E9f;
                }
                w = w.next;
            }
            return freq;
        }

        @Override
        public float score() throws IOException {
            return this.simScorer.score(this.iterator.docID(), this.freq());
        }

        @Override
        public DocIdSetIterator iterator() {
            return this.iterator;
        }

        @Override
        public float getMaxScore(int upTo) throws IOException {
            return this.maxScore;
        }
    }

    private static class WeightedDisiWrapper
    extends DisiWrapper {
        final PostingsEnum postingsEnum;
        final float weight;

        WeightedDisiWrapper(Scorer scorer, float weight) {
            super(scorer, false);
            this.weight = weight;
            this.postingsEnum = (PostingsEnum)scorer.iterator();
        }

        float freq() throws IOException {
            return this.weight * (float)this.postingsEnum.freq();
        }
    }

    public static class Builder {
        private final Map<String, FieldAndWeight> fieldAndWeights = new HashMap<String, FieldAndWeight>();
        private final Set<BytesRef> termsSet = new HashSet<BytesRef>();

        public Builder addField(String field) {
            return this.addField(field, 1.0f);
        }

        public Builder addField(String field, float weight) {
            if (weight < 1.0f) {
                throw new IllegalArgumentException("weight must be greater or equal to 1");
            }
            this.fieldAndWeights.put(field, new FieldAndWeight(field, weight));
            return this;
        }

        public Builder addTerm(BytesRef term) {
            if (this.termsSet.size() >= IndexSearcher.getMaxClauseCount()) {
                throw new IndexSearcher.TooManyClauses();
            }
            this.termsSet.add(term);
            return this;
        }

        public CombinedFieldQuery build() {
            int size = this.fieldAndWeights.size() * this.termsSet.size();
            if (size > IndexSearcher.getMaxClauseCount()) {
                throw new IndexSearcher.TooManyClauses();
            }
            BytesRef[] terms = this.termsSet.toArray(new BytesRef[0]);
            return new CombinedFieldQuery(new TreeMap<String, FieldAndWeight>(this.fieldAndWeights), terms);
        }
    }
}

