#include "JZsdkLib.h"
#include "version_choose.h"

#ifdef SPEEX_STATUS_ON

#include <stdio.h>
#include <string.h>
#include "speex/speex_echo.h"
#include "speex/speex_preprocess.h"

// ========== 可配置参数 ==========
#define TAIL_MS         300         // 回声尾长(毫秒)
#define FRAME_SAMPLES   640         // 每帧样本数(640个short)
#define SAMPLE_RATE     16000       // 采样率
#define PLAYBACK_DELAY_FRAMES  19   // 播放延迟(帧数),软件时间0.75秒约19帧
// =================================

typedef struct JZ_SpeexInfo {
    int SampleRate;
    int FrameSize;
    int TailLen;            // 回声尾长(样本数)
    int PlaybackDelaySamples; // 播放延迟(样本数)

    SpeexEchoState* EchoState;
    SpeexPreprocessState* PreprocessState;

    // 环形缓冲区,存储历史输出帧(即播放过的数据)
    short* HistoryBuffer;
    int    HistorySize;
    int    WritePos;
    int    TotalWritten;

    SpeexPreprocessState* DenoiseOnlyState;
    int DenoiseOnlyFlag;

    int Flag;
} JZ_SpeexInfo;

static JZ_SpeexInfo g_SpeexInfo = { 0 };

// 初始化历史缓冲区
static int InitHistoryBuffer(int size_samples) 
{
    if (g_SpeexInfo.HistoryBuffer) 
    {
        free(g_SpeexInfo.HistoryBuffer);
        g_SpeexInfo.HistoryBuffer = NULL;
    }

    g_SpeexInfo.HistoryBuffer = (short*)malloc(size_samples * sizeof(short));
    if (!g_SpeexInfo.HistoryBuffer) return -1;
    memset(g_SpeexInfo.HistoryBuffer, 0, size_samples * sizeof(short));
    g_SpeexInfo.HistorySize = size_samples;
    g_SpeexInfo.WritePos = 0;
    g_SpeexInfo.TotalWritten = 0;
    return 0;
}

// 写入一帧到历史缓冲区(播放过的帧)
static void WriteHistoryFrame(short* frame) 
{
    int fs = g_SpeexInfo.FrameSize;
    int hist_size = g_SpeexInfo.HistorySize;
    int write_pos = g_SpeexInfo.WritePos;

    if (write_pos + fs <= hist_size) {
        memcpy(g_SpeexInfo.HistoryBuffer + write_pos, frame, fs * sizeof(short));
    }
    else {
        int first_part = hist_size - write_pos;
        memcpy(g_SpeexInfo.HistoryBuffer + write_pos, frame, first_part * sizeof(short));
        memcpy(g_SpeexInfo.HistoryBuffer, frame + first_part, (fs - first_part) * sizeof(short));
    }
    g_SpeexInfo.WritePos = (write_pos + fs) % hist_size;
    g_SpeexInfo.TotalWritten += fs;
}

// 从历史缓冲区读取参考帧(对齐到当前麦克风时间)
static int ReadRefFrame(short* out_ref) 
{
    int fs = g_SpeexInfo.FrameSize;
    int hist_size = g_SpeexInfo.HistorySize;
    int write_pos = g_SpeexInfo.WritePos;
    int delay_samples = g_SpeexInfo.PlaybackDelaySamples;

    int read_pos = write_pos - delay_samples - fs;
    if (read_pos < 0) read_pos += hist_size;

    if (g_SpeexInfo.TotalWritten < delay_samples + fs) {
        return -1;
    }

    if (read_pos + fs <= hist_size) {
        memcpy(out_ref, g_SpeexInfo.HistoryBuffer + read_pos, fs * sizeof(short));
    }
    else {
        int first_part = hist_size - read_pos;
        memcpy(out_ref, g_SpeexInfo.HistoryBuffer + read_pos, first_part * sizeof(short));
        memcpy(out_ref + first_part, g_SpeexInfo.HistoryBuffer, (fs - first_part) * sizeof(short));
    }
    return 0;
}

T_JZsdkReturnCode Speex_Deinit()
{
    if (g_SpeexInfo.Flag == JZ_FLAGCODE_ON)
    {
        if (g_SpeexInfo.EchoState)
        {
            speex_echo_state_destroy(g_SpeexInfo.EchoState);
            g_SpeexInfo.EchoState = NULL;
        }

        if (g_SpeexInfo.PreprocessState)
        {
            speex_preprocess_state_destroy(g_SpeexInfo.PreprocessState);
            g_SpeexInfo.PreprocessState = NULL;
        }

        if (g_SpeexInfo.HistoryBuffer)
        {
            free(g_SpeexInfo.HistoryBuffer);
            g_SpeexInfo.HistoryBuffer = NULL;
        }

        memset(&g_SpeexInfo, 0, sizeof(g_SpeexInfo));
        g_SpeexInfo.Flag = JZ_FLAGCODE_OFF;
    }

    if (g_SpeexInfo.DenoiseOnlyState)
    {
        speex_preprocess_state_destroy(g_SpeexInfo.DenoiseOnlyState);
        g_SpeexInfo.DenoiseOnlyState = NULL;
    }

    JZSDK_LOG_DEBUG("Speex_Deinit success\n");
    return JZ_ERROR_SYSTEM_MODULE_CODE_SUCCESS;
}

T_JZsdkReturnCode Speex_Init(int sample_rate)
{
    if (g_SpeexInfo.Flag == JZ_FLAGCODE_ON)
    {
        Speex_Deinit();
    }

    int frame_samples = FRAME_SAMPLES;
    int playback_delay_frames = PLAYBACK_DELAY_FRAMES;

    g_SpeexInfo.SampleRate = sample_rate;
    g_SpeexInfo.FrameSize = frame_samples;
    g_SpeexInfo.TailLen = sample_rate * TAIL_MS / 1000;
    g_SpeexInfo.PlaybackDelaySamples = playback_delay_frames * frame_samples;

    int hist_size = g_SpeexInfo.TailLen + g_SpeexInfo.PlaybackDelaySamples + frame_samples * 2;

    g_SpeexInfo.EchoState = speex_echo_state_init(frame_samples, g_SpeexInfo.TailLen);
    if (!g_SpeexInfo.EchoState) 
    {
        JZSDK_LOG_DEBUG("Speex_Init: speex_echo_state_init failed\n");
        return JZ_ERROR_SYSTEM_MODULE_CODE_FAILURE;
    }
    speex_echo_ctl(g_SpeexInfo.EchoState, SPEEX_ECHO_SET_SAMPLING_RATE, &sample_rate);

    g_SpeexInfo.PreprocessState = speex_preprocess_state_init(frame_samples, sample_rate);
    if (!g_SpeexInfo.PreprocessState)
    {
        speex_echo_state_destroy(g_SpeexInfo.EchoState);
        JZSDK_LOG_DEBUG("Speex_Init: speex_preprocess_state_init failed\n");
        return JZ_ERROR_SYSTEM_MODULE_CODE_FAILURE;
    }

    speex_preprocess_ctl(g_SpeexInfo.PreprocessState, SPEEX_PREPROCESS_SET_ECHO_STATE, g_SpeexInfo.EchoState);

    if (InitHistoryBuffer(hist_size) != 0) 
    {
        speex_echo_state_destroy(g_SpeexInfo.EchoState);
        speex_preprocess_state_destroy(g_SpeexInfo.PreprocessState);
        JZSDK_LOG_DEBUG("Speex_Init: history buffer allocation failed\n");
        return JZ_ERROR_SYSTEM_MODULE_CODE_FAILURE;
    }

    g_SpeexInfo.Flag = JZ_FLAGCODE_ON;
    JZSDK_LOG_DEBUG("Speex_Init success: sr=%d, fs=%d, tail=%d, delay=%d samples, hist=%d\n",
        sample_rate, frame_samples, g_SpeexInfo.TailLen,
        g_SpeexInfo.PlaybackDelaySamples, hist_size);

    // 独立降噪初始化
    g_SpeexInfo.DenoiseOnlyState = speex_preprocess_state_init(FRAME_SAMPLES, sample_rate);
    if (!g_SpeexInfo.DenoiseOnlyState) 
    {
        return JZ_ERROR_SYSTEM_MODULE_CODE_FAILURE;
    }

    //语音活动检测
    int vad = 0;
    speex_preprocess_ctl(g_SpeexInfo.DenoiseOnlyState, SPEEX_PREPROCESS_SET_VAD, &vad);

    //自动增益
    int agc = 0;
    speex_preprocess_ctl(g_SpeexInfo.DenoiseOnlyState, SPEEX_PREPROCESS_SET_AGC, &agc);

    //启用降噪
    int denoise = 1;
    speex_preprocess_ctl(g_SpeexInfo.DenoiseOnlyState, SPEEX_PREPROCESS_SET_DENOISE, &denoise);

    int dereverb = 0;  // 关闭去混响
    speex_preprocess_ctl(g_SpeexInfo.DenoiseOnlyState, SPEEX_PREPROCESS_SET_DEREVERB, &dereverb);

    /*
        设置噪声抑制最大衰减量  越低越激进,可以为负数   -40基本没有原噪音了,但是偶尔会有一点打印机一样的噪音
        -15 没什么效果
        -80 也是没有原噪音,但是引入的噪音没改善
        -30 没什么效果
    */
    int noise_suppress = -40;
    speex_preprocess_ctl(g_SpeexInfo.DenoiseOnlyState, SPEEX_PREPROCESS_SET_NOISE_SUPPRESS, &noise_suppress);

    return JZ_ERROR_SYSTEM_MODULE_CODE_SUCCESS;
}

// 独立降噪处理
T_JZsdkReturnCode Speex_DenoiseOnly_Process(short* mic, short* out)
{
    if (!g_SpeexInfo.DenoiseOnlyState)
    {
        if (out != mic) memcpy(out, mic, FRAME_SAMPLES * sizeof(short));
        return JZ_ERROR_SYSTEM_MODULE_CODE_FAILURE;
    }

    if (out != mic) memcpy(out, mic, FRAME_SAMPLES * sizeof(short));
    speex_preprocess_run(g_SpeexInfo.DenoiseOnlyState, out);
    return JZ_ERROR_SYSTEM_MODULE_CODE_SUCCESS;
}

// 回声消除处理(使用历史输出作为参考信号)
T_JZsdkReturnCode Speex_ProcessMic(short* mic, short* out)
{
    if (g_SpeexInfo.Flag == JZ_FLAGCODE_OFF)
    {
        return JZ_ERROR_SYSTEM_MODULE_CODE_FAILURE;
    }

    short ref_frame[FRAME_SAMPLES];
    if (ReadRefFrame(ref_frame) != 0)
    {
        // 历史数据不足,直接拷贝输出并写入历史缓冲区
        if (out != mic) memcpy(out, mic, g_SpeexInfo.FrameSize * sizeof(short));
        WriteHistoryFrame(out);
        return JZ_ERROR_SYSTEM_MODULE_CODE_SUCCESS;
    }

    // 执行回声消除
    speex_echo_cancellation(g_SpeexInfo.EchoState, mic, ref_frame, out);

    // 后处理降噪
    speex_preprocess_run(g_SpeexInfo.DenoiseOnlyState, out);

    // 将输出帧写入历史缓冲区(供后续帧作为参考)
    //WriteHistoryFrame(out);

    return JZ_ERROR_SYSTEM_MODULE_CODE_SUCCESS;
}

#endif // SPEEX_STATUS_ON