/*
 * Decompiled with CFR 0.152.
 */
package org.apache.kyuubi.shaded.thrift.transport;

import java.lang.ref.WeakReference;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.WeakHashMap;
import javax.security.auth.callback.CallbackHandler;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslException;
import javax.security.sasl.SaslServer;
import org.apache.kyuubi.shaded.thrift.transport.TSaslTransport;
import org.apache.kyuubi.shaded.thrift.transport.TTransport;
import org.apache.kyuubi.shaded.thrift.transport.TTransportException;
import org.apache.kyuubi.shaded.thrift.transport.TTransportFactory;
import org.apache.kyuubi.shaded.thrift.transport.sasl.NegotiationStatus;
import org.apache.kyuubi.shaded.thrift.transport.sasl.TSaslServerDefinition;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TSaslServerTransport
extends TSaslTransport {
    private static final Logger LOGGER = LoggerFactory.getLogger(TSaslServerTransport.class);
    private SaslServerFactory saslServerFactory = new JdkSaslServerFactory();
    private Map<String, TSaslServerDefinition> serverDefinitionMap = new HashMap<String, TSaslServerDefinition>();

    public TSaslServerTransport(TTransport transport) throws TTransportException {
        super(transport);
    }

    public TSaslServerTransport(String mechanism, String protocol, String serverName, Map<String, String> props, CallbackHandler cbh, TTransport transport) throws TTransportException {
        super(transport);
        this.addServerDefinition(mechanism, protocol, serverName, props, cbh);
    }

    private TSaslServerTransport(Map<String, TSaslServerDefinition> serverDefinitionMap, TTransport transport) throws TTransportException {
        super(transport);
        this.serverDefinitionMap.putAll(serverDefinitionMap);
    }

    public void addServerDefinition(String mechanism, String protocol, String serverName, Map<String, String> props, CallbackHandler cbh) {
        this.serverDefinitionMap.put(mechanism, new TSaslServerDefinition(mechanism, protocol, serverName, props, cbh));
    }

    public void setSaslServerFactory(SaslServerFactory f) {
        this.saslServerFactory = f;
    }

    @Override
    protected TSaslTransport.SaslRole getRole() {
        return TSaslTransport.SaslRole.SERVER;
    }

    @Override
    protected void handleSaslStartMessage() throws TTransportException, SaslException {
        TSaslTransport.SaslResponse message = this.receiveSaslMessage();
        LOGGER.debug("Received start message with status {}", (Object)message.status);
        if (message.status != NegotiationStatus.START) {
            throw this.sendAndThrowMessage(NegotiationStatus.ERROR, "Expecting START status, received " + (Object)((Object)message.status));
        }
        String mechanismName = new String(message.payload, StandardCharsets.UTF_8);
        TSaslServerDefinition serverDefinition = this.serverDefinitionMap.get(mechanismName);
        LOGGER.debug("Received mechanism name '{}'", (Object)mechanismName);
        if (serverDefinition == null) {
            throw this.sendAndThrowMessage(NegotiationStatus.BAD, "Unsupported mechanism type " + mechanismName);
        }
        SaslServer saslServer = this.saslServerFactory.create(serverDefinition);
        this.setSaslServer(saslServer);
    }

    public static class Factory
    extends TTransportFactory {
        private static Map<TTransport, WeakReference<TSaslServerTransport>> transportMap = Collections.synchronizedMap(new WeakHashMap());
        private Map<String, TSaslServerDefinition> serverDefinitionMap = new HashMap<String, TSaslServerDefinition>();
        private SaslServerFactory saslServerFactory = new JdkSaslServerFactory();

        public Factory() {
        }

        public Factory(String mechanism, String protocol, String serverName, Map<String, String> props, CallbackHandler cbh) {
            this.addServerDefinition(mechanism, protocol, serverName, props, cbh);
        }

        public void addServerDefinition(String mechanism, String protocol, String serverName, Map<String, String> props, CallbackHandler cbh) {
            this.serverDefinitionMap.put(mechanism, new TSaslServerDefinition(mechanism, protocol, serverName, props, cbh));
        }

        public void setSaslServerFactory(SaslServerFactory f) {
            this.saslServerFactory = f;
        }

        @Override
        public TTransport getTransport(TTransport base) throws TTransportException {
            WeakReference<TSaslServerTransport> ret = transportMap.get(base);
            if (ret == null || ret.get() == null) {
                LOGGER.debug("transport map does not contain key {}", (Object)base);
                TSaslServerTransport t = new TSaslServerTransport(this.serverDefinitionMap, base);
                t.setSaslServerFactory(this.saslServerFactory);
                ret = new WeakReference<TSaslServerTransport>(t);
                try {
                    ((TSaslServerTransport)ret.get()).open();
                }
                catch (TTransportException e) {
                    LOGGER.debug("failed to open server transport", (Throwable)e);
                    throw new RuntimeException(e);
                }
                transportMap.put(base, ret);
            } else {
                LOGGER.debug("transport map does contain key {}", (Object)base);
            }
            return (TTransport)ret.get();
        }
    }

    public static class JdkSaslServerFactory
    implements SaslServerFactory {
        @Override
        public SaslServer create(TSaslServerDefinition d) throws SaslException {
            return Sasl.createSaslServer(d.mechanism, d.protocol, d.serverName, d.props, d.cbh);
        }
    }

    public static interface SaslServerFactory {
        public SaslServer create(TSaslServerDefinition var1) throws SaslException;
    }
}

