import com.sf.presto.jdbc.internal.airlift.slice.Slice;
import com.sf.presto.jdbc.internal.airlift.slice.Slices;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.parquet.column.Encoding;
import org.apache.parquet.column.ParquetProperties;
import org.apache.parquet.column.statistics.BinaryStatistics;
import org.apache.parquet.example.data.Group;
import org.apache.parquet.example.data.simple.SimpleGroupFactory;
import org.apache.parquet.hadoop.ParquetReader;
import org.apache.parquet.hadoop.ParquetWriter;
import org.apache.parquet.hadoop.example.GroupReadSupport;
import org.apache.parquet.hadoop.example.GroupWriteSupport;
import org.apache.parquet.hadoop.metadata.BlockMetaData;
import org.apache.parquet.hadoop.metadata.ColumnChunkMetaData;
import org.apache.parquet.hadoop.metadata.ParquetMetadata;
import org.apache.parquet.io.api.Binary;
import org.apache.parquet.schema.MessageType;
import org.junit.Test;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static java.util.Arrays.asList;
import static org.apache.parquet.column.Encoding.DELTA_BYTE_ARRAY;
import static org.apache.parquet.column.Encoding.PLAIN;
import static org.apache.parquet.column.Encoding.PLAIN_DICTIONARY;
import static org.apache.parquet.column.Encoding.RLE_DICTIONARY;
import static org.apache.parquet.column.ParquetProperties.WriterVersion.PARQUET_1_0;
import static org.apache.parquet.column.ParquetProperties.WriterVersion.PARQUET_2_0;
import static org.apache.parquet.format.converter.ParquetMetadataConverter.NO_FILTER;
import static org.apache.parquet.hadoop.ParquetFileReader.readFooter;
import static org.apache.parquet.hadoop.metadata.CompressionCodecName.UNCOMPRESSED;
import static org.apache.parquet.schema.MessageTypeParser.parseMessageType;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

public class TestParquet
{
    List<String> testStrings = new ArrayList<>(Arrays.asList("中文", "2日本語", "3English"));

    @Test
    public void testParquet()
    {
        testParquet(testStrings, false);
        testParquet(testStrings, true);
    }

    private void testParquet(List<String> strings, boolean containEmptyString)
    {
        if (containEmptyString) {
            strings.add("");
        }
        // mock parquet stats generate
        BinaryStatistics binaryStatistics = new BinaryStatistics();
        for (String testString : strings) {
            binaryStatistics.updateStats(Binary.fromString(testString));
        }

        String max = binaryStatistics.genericGetMax().toStringUsingUTF8();
        String min = binaryStatistics.genericGetMin().toStringUsingUTF8();

        // get slice min,max value
        Slice maxSlice = null, minSlice = null;
        for (String testString : strings) {
            Slice slice = Slices.utf8Slice(testString);
            if (maxSlice == null || maxSlice.compareTo(slice) < 0) {
                maxSlice = slice;
            }
            if (minSlice == null || minSlice.compareTo(slice) > 0) {
                minSlice = slice;
            }
        }
        assertEquals("Min-binary and Min-slice mismatch", min, minSlice.toStringUtf8());
        assertEquals("Max-binary and Max-slice mismatch", max, maxSlice.toStringUtf8());
    }

    @Test
    public void testFile()
            throws Exception
    {
        testFile(testStrings, false);
        testFile(testStrings, true);
    }

    public void testFile(List<String> strings, boolean containEmptyString)
            throws Exception
    {
        if (containEmptyString) {
            strings.add("");
        }
        Configuration conf = new Configuration();
        Path root = new Path("target/tests/TestParquetWriter/");
        TestUtils.enforceEmptyDir(conf, root);
        MessageType schema = parseMessageType(
                "message test { "
                        + "required int32 int32_field; "
                        + "required binary binary_field; "
                        + "required fixed_len_byte_array(3) flba_field; "
                        + "} ");
        GroupWriteSupport.setSchema(schema, conf);
        SimpleGroupFactory f = new SimpleGroupFactory(schema);
        Map<String, Encoding> expected = new HashMap<String, Encoding>();
        expected.put("10-" + PARQUET_1_0, PLAIN_DICTIONARY);
        expected.put("1000-" + PARQUET_1_0, PLAIN);
        expected.put("10-" + PARQUET_2_0, RLE_DICTIONARY);
        expected.put("1000-" + PARQUET_2_0, DELTA_BYTE_ARRAY);
        for (int modulo : asList(10, 1000)) {
            for (ParquetProperties.WriterVersion version : ParquetProperties.WriterVersion.values()) {
                Path file = new Path(root, version.name() + "_" + modulo);
                ParquetWriter<Group> writer = new ParquetWriter<Group>(
                        file,
                        new GroupWriteSupport(),
                        UNCOMPRESSED, 1024, 1024, 512, true, false, version, conf);
                for (int i = 0; i < strings.size(); i++) {
                    for (int j = 0; j < 1000; j++) {
                        writer.write(f.newGroup()
                                .append("int32_field", i)
                                .append("binary_field", strings.get(i) + j % modulo)
                                .append("flba_field", "foo")
                        );
                    }
                }

                writer.close();
                ParquetReader<Group> reader = ParquetReader.builder(new GroupReadSupport(), file).withConf(conf).build();
                for (int i = 0; i < strings.size(); i++) {
                    for (int j = 0; j < 1000; j++) {
                        Group group = reader.read();
                        assertEquals(strings.get(i) + j % modulo, group.getBinary("binary_field", 0).toStringUsingUTF8());
                    }
                }
                reader.close();
                ParquetMetadata footer = readFooter(conf, file, NO_FILTER);
                for (BlockMetaData blockMetaData : footer.getBlocks()) {
                    for (ColumnChunkMetaData column : blockMetaData.getColumns()) {
                        if (column.getPath().toDotString().equals("binary_field")) {
                            String key = modulo + "-" + version;
                            Encoding expectedEncoding = expected.get(key);
                            assertTrue(
                                    key + ":" + column.getEncodings() + " should contain " + expectedEncoding,
                                    column.getEncodings().contains(expectedEncoding));
                        }
                    }
                }
            }
        }
    }
}