huffman压缩算法
创建:2023-10-27 16:47
更新:2023-10-27 17:06
#pragma once
#include <memory.h>

class huffman {
    template<int n>
    struct bits {    // 位存储器,不使用std::bitset,实际存储顺序有问题,多次位运算性能也存在问题
        static const int size = n / 8;
        unsigned char v[size] = {0};
        unsigned char get(int i) const {    // i 从前往后
            int a = i / 8;
            int b = 7 - i % 8;
            return (v[a] & (1 << b)) >> b;
        }
        void or_tail(unsigned char p) {
            v[size - 1] |= p;
        }
        void or_merge(const bits& p, int from, int target, int len) {
            for (int i = 0; i < len; i++) {
                int a = (i + target) / 8;
                int b = (i + target) % 8;
                v[a] |= p.get(i + from) << (7 - b);
            }
        }
    };

    typedef bits<256> code_type;

    struct element {
        unsigned int count;
        unsigned short i;    // 索引,也是char的ascci码

        unsigned short left;
        unsigned short right;
        code_type code;    // 编码存放,这里256位,树的最大层级是255层

        // 按照 count 降序
        static int compare1(element* a, element* b) {
            if (b->count > a->count)
                return 1;
            if (b->count < a->count)
                return -1;
            return 0;
        }
        // 按照index升序
        static int compare2(element* a, element* b) {
            if (a->i > b->i)
                return 1;
            if (a->i < b->i)
                return -1;
            return 0;
        }
    };

    static void quick_sort(element* eles, int start, int end, int (*compare)(element* a, element* b)) {
        if (start < end) {
            int i = start, j = end;
            element it = eles[i];    // 比较值,且空出来一个位置
            while (i < j) {
                // 从end->start查找一个比it大的,放到空出来的位置i,且产生新的空位j
                while (i < j && compare(&eles[j], &it) >= 0) {
                    j--;
                }
                if (i < j) {
                    eles[i] = eles[j];
                    i++;
                }
                // 从start->end查找一个比it小的,放到空出来的位置j,且产生新的空位i
                while (i < j && compare(&eles[i], &it) < 0) {
                    i++;
                }
                if (i < j) {
                    eles[j] = eles[i];
                    j--;
                }
            }
            eles[i] = it;
            quick_sort(eles, start, i - 1, compare);
            quick_sort(eles, i + 1, end, compare);
        }
    }

    // 0 左,1 右
    static void encoding(element* root, element* index) {
        if (root->left == 0 && root->right == 0) {
            return;
        }
        index[root->left].count = root->count + 1;
        index[root->right].count = root->count + 1;

        const unsigned int buf_size = sizeof(root->code) * 8;
        index[root->left].code = {0};
        index[root->left].code.or_merge(root->code, buf_size - root->count, buf_size - root->count - 1, root->count);
        index[root->right].code = index[root->left].code;
        index[root->right].code.or_tail(1);

        encoding(&index[root->left], index);
        encoding(&index[root->right], index);
    }

    static element* tree(element* index) {
        // 排序统计
        quick_sort(index, 0, 255, element::compare1);
        int size = 0;
        for (; size < 256; size++) {
            if (index[size].count == 0) {
                break;
            }
        }

        // 构建树,每次减少一个节点,直到只有一个节点。最后一个就是根节点
        for (int i = size; i > 1; i--) {
            element* it = &index[size];
            int m = size - 1;
            int n = size - 2;
            it->left = index[m].i;
            it->right = index[n].i;
            it->count = index[m].count + index[n].count;
            index[m].count = 0xffffffff;    // 排到最前边,可视为删除了
            index[n].count = 0xffffffff;
            size++;
            quick_sort(index, 0, size - 1, element::compare1);
        }
        // 获得的位置
        unsigned short root = index[size - 1].i;
        // 已经是乱序, 恢复需要的索引顺序
        quick_sort(index, 0, size - 1, element::compare2);
        for (int i = size - 1; i >= 0; i--) {
            index[i].count = 0;
            index[index[i].i] = index[i];
        }

        // 获取编码
        encoding(&index[root], index);
        return &index[root];
    }

public:
    static const int MAX_HEAD_SIZE = 8 + (4 + 1) * 256;

    static unsigned int compress(unsigned char* input_data, unsigned int input_len, unsigned char* output_data) {
        if (!input_data || input_len == 0 || !output_data) {
            return 0;
        }

        element index[512] = {{0}};

        // 初始化索引,因为排序会打乱,所以记录位置以供恢复
        for (unsigned int i = 0; i < 512; i++) {
            index[i].i = i;
        }

        // 统计
        for (unsigned int i = 0; i < input_len; i++) {
            index[input_data[i]].count += 1;
        }

        // 记录头
        unsigned int out_offset = 0;
        out_offset += 4;    // 预占头大小
        memcpy(output_data + out_offset, &input_len, 4);
        out_offset += 4;
        for (unsigned int i = 0; i < 256; i++) {
            unsigned char c = i;
            if (index[i].count > 0) {
                memcpy(output_data + out_offset, &index[i].count, 4);
                out_offset += 4;
                memcpy(output_data + out_offset, &c, 1);
                out_offset += 1;
            }
        }
        memcpy(output_data, &out_offset, 4);

        // 只有一个字符,直接不需要压缩
        if (out_offset == 13) {
            return out_offset;
        }

        // 构建树
        tree(index);

        // flush
        code_type buf = {0};
        unsigned int buf_len = 0;
        const unsigned int buf_size = sizeof(buf) * 8;
        for (unsigned int i = 0; i < input_len; i++) {
            unsigned char c = input_data[i];
            element* it = &index[c];

            unsigned int left = buf_size - buf_len;
            if (left <= it->count) {
                buf.or_merge(it->code, buf_size - it->count, buf_len, left);

                // 写入
                memcpy(output_data + out_offset, &buf, sizeof(buf));
                out_offset += sizeof(buf);
                buf = {0};
                buf_len = 0;

                buf.or_merge(it->code, buf_size - it->count + left, buf_len, it->count - left);
                buf_len += it->count - left;
            } else {
                buf.or_merge(it->code, buf_size - it->count, buf_len, it->count);
                buf_len += it->count;
            }
        }

        if (buf_len > 0) {
            // 写入
            memcpy(output_data + out_offset, &buf, buf_len / 8 + 1);
            out_offset += sizeof(buf);
        }
        return out_offset;
    }

    static unsigned int decompress(unsigned char* input_data, unsigned int input_len, unsigned char* output_data) {
        if (!input_data || input_len == 0 || !output_data) {
            return 0;
        }

        element index[512] = {{0}};
        unsigned int out_size = 0;

        // 初始化索引,因为排序会打乱,所以记录位置
        for (unsigned int i = 0; i < 512; i++) {
            index[i].i = i;
        }

        // 获得统计
        unsigned int head_size = 0;
        unsigned int data_size = 0;
        unsigned int offset = 0;
        memcpy(&head_size, input_data + offset, 4);
        offset += 4;
        memcpy(&data_size, input_data + offset, 4);
        offset += 4;

        unsigned int _count = 0;
        unsigned char _c = 0;
        while (offset < head_size) {
            memcpy(&_count, input_data + offset, 4);
            offset += 4;
            memcpy(&_c, input_data + offset, 1);
            offset += 1;
            index[_c].count = _count;
        }

        // 特殊处理一个单字符
        if (offset == 13) {
            for (unsigned int i = 0; i < _count; i++) {
                *(output_data + out_size) = _c;
                out_size += 1;
            }
            return out_size;
        }

        // 构建树
        element* root = tree(index);

        // 解析数据
        unsigned char buf = 0;
        element* current = root;
        const int bits_count = sizeof(buf) * 8;
        while (offset < input_len && out_size < data_size) {
            memcpy(&buf, input_data + offset, sizeof(buf));
            offset += sizeof(buf);
            for (int i = 0; i < bits_count; i++) {
                unsigned char tmp = (buf >> (bits_count - 1 - i)) & 1;
                if (tmp == 0) {
                    current = &index[current->left];
                } else {
                    current = &index[current->right];
                }
                if (current->left == 0 && current->right == 0) {
                    *(output_data + out_size) = current->i;
                    out_size += 1;
                    current = root;
                }
                if (out_size >= data_size) {
                    break;
                }
            }
        }

        return out_size;
    }
};

测试例子:

#include "huffman.h"
#include <stdio.h>
#include <string>
#include <iostream>
#include <chrono>

int main(int argc, char const *argv[]) {
    if (argc > 2) {
        return 0;
    }

    auto start = std::chrono::high_resolution_clock::now();
    FILE *file = fopen(argv[1], "rb");
    if (!file) {
        printf("not find file: %s\n", argv[1]);
        return 0;
    }
    fseek(file, 0, SEEK_END);
    unsigned int len = ftell(file);
    fseek(file, 0, SEEK_SET);
    unsigned char *text = new unsigned char[len];
    memset(text, 0, len);
    fread(text, 1, len, file);
    fclose(file);
    printf("input len : %u\n", len);

    auto end = std::chrono::high_resolution_clock::now();
    std::cout << "read: " << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << "ms" << std::endl;
    start = end;

    unsigned char *output = new unsigned char[len + huffman::MAX_HEAD_SIZE];
    memset(output, 0, len + huffman::MAX_HEAD_SIZE);

    unsigned int olen = huffman::compress(text, len, output);
    printf("output len: %u\n", olen);

    end = std::chrono::high_resolution_clock::now();
    int times = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();
    std::cout << "compress: " << times << "ms" << std::endl;
    std::cout << "    " << 1.0f * len / times * 1000 / 1024 << "kb/s" << std::endl;
    start = end;

    FILE *ofile = fopen((std::string(argv[1]) + ".hzip").c_str(), "wb+");
    fwrite(output, 1, olen, ofile);
    fclose(ofile);

    end = std::chrono::high_resolution_clock::now();
    std::cout << "write: " << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << "ms" << std::endl;
    start = end;

    memset(text, 0, len);
    len = huffman::decompress(output, olen, text);
    FILE *rfile = fopen((std::string(argv[1]) + ".huzip").c_str(), "wb+");
    fwrite(text, 1, len, rfile);
    fclose(rfile);

    end = std::chrono::high_resolution_clock::now();
    times = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();
    std::cout << "decompress: " << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << "ms" << std::endl;
    std::cout << "    " << 1.0f * len / times * 1000 / 1024 << "kb/s" << std::endl;
    start = end;
    return 0;
}

输出示例:

input len : 139246592
read: 48ms
output len: 136604744
compress: 8726ms
    15583.7kb/s
write: 676ms
decompress: 10629ms
    12793.6kb/s