/*
 * Decompiled with CFR 0.152.
 */
package org.jetbrains.java.decompiler.modules.decompiler;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import org.jetbrains.java.decompiler.main.DecompilerContext;
import org.jetbrains.java.decompiler.modules.decompiler.exps.AssignmentExprent;
import org.jetbrains.java.decompiler.modules.decompiler.exps.ConstExprent;
import org.jetbrains.java.decompiler.modules.decompiler.exps.Exprent;
import org.jetbrains.java.decompiler.modules.decompiler.exps.FunctionExprent;
import org.jetbrains.java.decompiler.modules.decompiler.exps.InvocationExprent;
import org.jetbrains.java.decompiler.modules.decompiler.exps.VarExprent;
import org.jetbrains.java.decompiler.modules.decompiler.stats.RootStatement;
import org.jetbrains.java.decompiler.modules.decompiler.stats.Statement;
import org.jetbrains.java.decompiler.struct.StructClass;
import org.jetbrains.java.decompiler.struct.StructMethod;
import org.jetbrains.java.decompiler.struct.gen.CodeType;
import org.jetbrains.java.decompiler.struct.gen.TypeFamily;
import org.jetbrains.java.decompiler.struct.gen.VarType;
import org.jetbrains.java.decompiler.struct.gen.generics.GenericMethodDescriptor;
import org.jetbrains.java.decompiler.util.Pair;

public class IntersectionCastProcessor {
    public static boolean makeIntersectionCasts(RootStatement root) {
        return IntersectionCastProcessor.makeIntersectionCastsRec(root, root);
    }

    private static boolean makeIntersectionCastsRec(Statement stat, RootStatement root) {
        boolean result = false;
        if (stat.getExprents() != null) {
            for (Exprent e : stat.getExprents()) {
                result |= IntersectionCastProcessor.makeIntersectionCasts(e, root);
            }
        } else {
            for (Object o : stat.getSequentialObjects()) {
                if (o instanceof Statement) {
                    Statement s = (Statement)o;
                    result |= IntersectionCastProcessor.makeIntersectionCastsRec(s, root);
                    continue;
                }
                if (!(o instanceof Exprent)) continue;
                Exprent e = (Exprent)o;
                result |= IntersectionCastProcessor.makeIntersectionCasts(e, root);
            }
        }
        return result;
    }

    private static boolean makeIntersectionCasts(Exprent exp, RootStatement root) {
        AssignmentExprent assignment;
        InvocationExprent inv;
        if (exp instanceof InvocationExprent ? IntersectionCastProcessor.handleInvocation(inv = (InvocationExprent)exp, root) : exp instanceof AssignmentExprent && IntersectionCastProcessor.handleAssignment(assignment = (AssignmentExprent)exp, root)) {
            return true;
        }
        boolean result = false;
        for (Exprent sub : exp.getAllExprents()) {
            result |= IntersectionCastProcessor.makeIntersectionCasts(sub, root);
        }
        return result;
    }

    private static boolean handleInvocation(InvocationExprent exp, RootStatement root) {
        List<Exprent> lstParameters = exp.getLstParameters();
        boolean result = false;
        for (int i = 0; i < lstParameters.size(); ++i) {
            FunctionExprent cast;
            Exprent parameter = lstParameters.get(i);
            if (!(parameter instanceof FunctionExprent) || !IntersectionCastProcessor.isValidCast(cast = (FunctionExprent)parameter)) continue;
            Pair<List<Exprent>, Exprent> casts = IntersectionCastProcessor.getCasts(cast);
            List types = (List)casts.a;
            Exprent inner = (Exprent)casts.b;
            List<VarType> bounds = IntersectionCastProcessor.getBounds(exp, i).stream().filter(bound -> !types.stream().anyMatch(constant -> DecompilerContext.getStructContext().instanceOf(constant.getExprType().value, bound.value))).toList();
            if (!bounds.isEmpty() && bounds.stream().allMatch(bound -> DecompilerContext.getStructContext().instanceOf(inner.getExprType().value, bound.value))) {
                types.add(new ConstExprent(inner.getExprType(), null, null));
            }
            result |= IntersectionCastProcessor.replaceCasts(cast, types, inner);
        }
        return result;
    }

    private static boolean handleAssignment(AssignmentExprent exp, RootStatement root) {
        Exprent exprent = exp.getLeft();
        if (exprent instanceof VarExprent) {
            FunctionExprent cast;
            VarExprent varExp = (VarExprent)exprent;
            Exprent assigned = exp.getRight();
            if (assigned instanceof FunctionExprent && IntersectionCastProcessor.isValidCast(cast = (FunctionExprent)assigned)) {
                Pair<List<Exprent>, Exprent> casts = IntersectionCastProcessor.getCasts(cast);
                List types = (List)casts.a;
                Exprent inner = (Exprent)casts.b;
                List<VariablePosition> references = IntersectionCastProcessor.findReferences(varExp, root);
                Set<Object> bounds = new HashSet<VarType>();
                for (VariablePosition position : references) {
                    bounds.addAll(switch (position.position) {
                        default -> throw new IncompatibleClassChangeError();
                        case VariablePositionEnum.METHOD_PARAMETER -> IntersectionCastProcessor.getBounds((InvocationExprent)position.exp, position.index);
                        case VariablePositionEnum.CASTED -> {
                            FunctionExprent func = (FunctionExprent)position.exp;
                            if (func.getLstOperands().size() == 2) {
                                yield List.of(func.getLstOperands().get(1).getExprType());
                            }
                            yield List.of();
                        }
                    });
                }
                if (!(bounds = bounds.stream().filter(bound -> !types.stream().anyMatch(constant -> DecompilerContext.getStructContext().instanceOf(constant.getExprType().value, bound.value))).collect(Collectors.toSet())).isEmpty() && bounds.stream().anyMatch(bound -> DecompilerContext.getStructContext().instanceOf(inner.getExprType().value, bound.value))) {
                    types.add(new ConstExprent(inner.getExprType(), null, null));
                }
                if (IntersectionCastProcessor.replaceCasts(cast, types, inner)) {
                    varExp.setIntersectionType(true);
                    return true;
                }
            }
        }
        return false;
    }

    private static List<VarType> getBounds(InvocationExprent exp, int parameter) {
        int start;
        StructMethod method = exp.getDesc();
        GenericMethodDescriptor gmd = method != null ? method.getSignature() : null;
        int n = start = gmd != null && DecompilerContext.getStructContext().getClass(method.getClassQualifiedName()).hasModifier(16384) && method.getName().equals("<init>") ? 2 : 0;
        if (gmd != null) {
            int typeParameterIndex;
            int index = parameter - start;
            VarType type = gmd.parameterTypes.get(index);
            if (type.type == CodeType.GENVAR && (typeParameterIndex = gmd.typeParameters.indexOf(type.value)) != -1) {
                return gmd.typeParameterBounds.get(typeParameterIndex);
            }
        }
        return List.of();
    }

    private static List<VariablePosition> findReferences(VarExprent varExp, RootStatement root) {
        ArrayList<VariablePosition> list = new ArrayList<VariablePosition>();
        IntersectionCastProcessor.findReferencesRec(varExp, root, root, list);
        return list;
    }

    private static void findReferencesRec(VarExprent varExp, Statement stat, RootStatement root, List<VariablePosition> list) {
        if (stat.getExprents() != null) {
            for (Exprent e : stat.getExprents()) {
                IntersectionCastProcessor.findReferences(varExp, e, root, list);
            }
        } else {
            for (Object o : stat.getSequentialObjects()) {
                if (o instanceof Statement) {
                    Statement s = (Statement)o;
                    IntersectionCastProcessor.findReferencesRec(varExp, s, root, list);
                    continue;
                }
                if (!(o instanceof Exprent)) continue;
                Exprent e = (Exprent)o;
                IntersectionCastProcessor.findReferences(varExp, e, root, list);
            }
        }
    }

    private static void findReferences(VarExprent varExp, Exprent exp, RootStatement root, List<VariablePosition> list) {
        Exprent exprent;
        FunctionExprent func;
        if (exp instanceof InvocationExprent) {
            InvocationExprent inv = (InvocationExprent)exp;
            IntersectionCastProcessor.findReferences(varExp, inv, list);
        } else if (exp instanceof FunctionExprent && (func = (FunctionExprent)exp).getFuncType() == FunctionExprent.FunctionType.CAST && (exprent = func.getLstOperands().get(0)) instanceof VarExprent) {
            VarExprent otherVar = (VarExprent)exprent;
            if (varExp.getVarVersionPair().equals(otherVar.getVarVersionPair())) {
                list.add(new VariablePosition(VariablePositionEnum.CASTED, exp, -1));
            }
        }
        for (Exprent sub : exp.getAllExprents()) {
            IntersectionCastProcessor.findReferences(varExp, sub, root, list);
        }
    }

    private static void findReferences(VarExprent varExp, InvocationExprent inv, List<VariablePosition> list) {
        List<Exprent> lstParameters = inv.getLstParameters();
        for (int i = 0; i < lstParameters.size(); ++i) {
            Exprent parameter = lstParameters.get(i);
            if (!(parameter instanceof VarExprent)) continue;
            VarExprent otherVar = (VarExprent)parameter;
            if (!varExp.getVarVersionPair().equals(otherVar.getVarVersionPair())) continue;
            list.add(new VariablePosition(VariablePositionEnum.METHOD_PARAMETER, inv, i));
        }
    }

    private static Pair<List<Exprent>, Exprent> getCasts(Exprent exp) {
        FunctionExprent cast;
        ArrayList<Exprent> types = new ArrayList<Exprent>();
        Exprent inner = exp;
        while (inner instanceof FunctionExprent && IntersectionCastProcessor.isValidCast(cast = (FunctionExprent)inner)) {
            types.add(cast.getLstOperands().get(1));
            inner = cast.getLstOperands().get(0);
        }
        return Pair.of(types, inner);
    }

    private static boolean isValidCast(FunctionExprent cast) {
        if (cast.getFuncType() == FunctionExprent.FunctionType.CAST && cast.getLstOperands().size() == 2) {
            VarType type = cast.getLstOperands().get(1).getExprType();
            return type.typeFamily == TypeFamily.OBJECT && type.arrayDim == 0;
        }
        return false;
    }

    private static boolean replaceCasts(FunctionExprent cast, List<Exprent> types, Exprent inner) {
        if (types.size() > 1) {
            Exprent nonInterface = null;
            for (Exprent type : types) {
                StructClass clazz = DecompilerContext.getStructContext().getClass(type.getExprType().value);
                if (clazz == null || clazz.hasModifier(512)) continue;
                if (nonInterface == null) {
                    nonInterface = type;
                    continue;
                }
                return false;
            }
            if (nonInterface != null) {
                types.remove(types.indexOf(nonInterface));
                types.add(0, nonInterface);
            }
            cast.getLstOperands().clear();
            cast.getLstOperands().add(inner);
            cast.getLstOperands().addAll(types);
            return true;
        }
        return false;
    }

    private record VariablePosition(VariablePositionEnum position, Exprent exp, int index) {
    }

    private static enum VariablePositionEnum {
        METHOD_PARAMETER,
        CASTED;

    }
}

