/*******************************************************************************
* Copyright 2023-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#ifndef GPU_INTEL_CONV_JIT_KEY_HPP
#define GPU_INTEL_CONV_JIT_KEY_HPP

#include <iostream>
#include <memory>
#include <string>
#include <vector>

#include "gpu/intel/jit/ir/core.hpp"

namespace dnnl {
namespace impl {
namespace gpu {
namespace intel {
namespace conv {
namespace jit {

using namespace intel::jit;

class config_t;
class key_impl_t;

// Represents a key with hash/equality functionality for convolution problems.
// Mainly used for the lookup table with key -> <optimal convolution parameters>
// mapping.
// When used for lookup tables a conv_key_t object represents a filter which is
// used with matches() API.
// Examples of differences between key and filter:
// 1) Type:
//    - Convolution problem: s8s8s32
//    - Filter:              x8x8*
// 2) Batch size
//    - Convolution problem: mb32(blocked)
//    - Filter:              mb32+(blocked)
class key_t {
public:
    key_t() = default;
    key_t(const config_t &cfg, bool make_filter = false);
    const std::string &desc() const;
    // Makes a filter from the given key.
    key_t to_filter() const;
    // Computes the distance between this key and other key (must be
    // non-filter), a filter with a smaller distance is a better match for the
    // key.
    dim_t distance(const key_t &other) const;
    bool operator==(const key_t &other) const;
    bool matches(const key_t &other) const;
    size_t get_hash() const;
    void stringify(std::ostream &out) const;
    void parse(std::istream &in);
    std::string str(bool csv = false) const;
    static std::vector<std::string> csv_keys();

    IR_DEFINE_DUMP()

private:
    std::shared_ptr<key_impl_t> impl_;
};

struct conv_key_hash_t {
    size_t operator()(const key_t &key) const { return key.get_hash(); }
};

} // namespace jit
} // namespace conv
} // namespace intel
} // namespace gpu
} // namespace impl
} // namespace dnnl

#endif
