//
//  SileroVADNode.hpp
//  SwitchboardSileroVAD
//
//  Created by Tayyab Javed on 2024-08-12.
//

#pragma once

#include "ModelInstance.h"
#include "OnnxMLSinkNode.hpp"
#include <switchboard_v2/Resampler.hpp>
#include <switchboard_v2/StringCallbackParameter.hpp>

#include <functional>

using namespace switchboard::extensions::onnx;

namespace switchboard::extensions::silerovad {

class VadIterator;

class SileroVADNode : public OnnxMLSinkNode {
public:
    SB_WASM_EXPORT(SileroVADNode)

    /**
     * @brief Creates a SileroVADNode instance.
     *
     * @param config Configuration map for the node.
     */
    SB_WASM SileroVADNode(const std::map<std::string, std::any>& config);

    SB_WASM SileroVADNode();

    SB_WASM ~SileroVADNode() override;

    SB_WASM void setModelFilePath(std::string modelPath);
    SB_WASM std::string getModelFilePath() const;

    SB_WASM uint getSupportedSampleRate() const;

    SB_WASM std::function<void(const int& previousStartSample, const int& currentStartSample)> onVADStartEvent;
    SB_WASM std::function<void(const int& currentStartSample, const int& currentEndSample)> onVADEndEvent;

    SB_WASM void reset();

#pragma mark Overridden methods

    SB_WASM void predict(AudioBuffer<float>& audioBuffer) override;
    SB_WASM bool setBusFormat(AudioBusFormat& busFormat) override;
    SB_WASM bool consume(AudioBus& bus) override;
    Result<std::any> callAction(const std::string &actionName, const std::map<std::string, std::any> &params) override;

private:
    static bool isFrameSizeSupported(int frameSize);
    static int getHopSizeForFrameSize(int frameSize);
    void createParameters();

    uint sampleRate; // 16kHz
    int vadHopSize; // 32, 64, 96 against frame_size of 512, 1024, 1536 @ 16kHz
    int frameSize; // 512, 1024, 1536 against vadHopSize of 32, 64, 96 @ 16kHz
    float threshold; // the confidence threshold to decide whether a chunk of audio is speech or not speech
    int minSilenceDurationMs; // minimum silence duration in milliseconds to consider a segment as speech
    int speechPadMs; // padding in milliseconds added before and after detected speech segments

    std::unique_ptr<VadIterator> vadIterator;
    std::vector<float> vadInputBuffer;
};
}
