diff --git a/CMakeLists.txt b/CMakeLists.txt index cf3ef4d8442d78900f454840a8a5742b7720c05c..c73627b67245c388ce03926d11a65f1cf190b586 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -226,6 +226,8 @@ if(KLEIDIAI_BUILD_TESTS) if(MSVC) add_library(kleidiai_test_framework test/common/data_type.cpp + test/common/bfloat16.cpp + test/common/bfloat16_asm.S test/common/float16.cpp test/common/float16.S test/common/cpu_info.cpp @@ -242,6 +244,7 @@ if(KLEIDIAI_BUILD_TESTS) test/common/rect.cpp test/common/round.cpp test/common/bfloat16.cpp + test/common/bfloat16_asm.S test/common/float16.cpp test/common/float16.S test/common/cpu_info.cpp @@ -268,6 +271,7 @@ if(KLEIDIAI_BUILD_TESTS) ) if(MSVC) + set_source_files_properties(test/common/bfloat16_asm.S PROPERTIES LANGUAGE ASM_MARMASM) set_source_files_properties(test/common/float16.S PROPERTIES LANGUAGE ASM_MARMASM) endif() @@ -279,10 +283,12 @@ if(KLEIDIAI_BUILD_TESTS) if(MSVC) add_executable(kleidiai_test + test/tests/bfloat16_test.cpp test/tests/float16_test.cpp ) else() add_executable(kleidiai_test + test/tests/bfloat16_test.cpp test/tests/float16_test.cpp test/tests/matmul_test.cpp test/tests/matmul_clamp_f32_f32_f32p_test.cpp diff --git a/test/BUILD.bazel b/test/BUILD.bazel index 834520f19e91ec072400a546fb89ed9990034273..9019364a1f10c9fdc530aa6269337f963b60b6c6 100644 --- a/test/BUILD.bazel +++ b/test/BUILD.bazel @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com> +# SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> # # SPDX-License-Identifier: Apache-2.0 # @@ -30,6 +30,9 @@ kai_cxx_library( ), # compare.cpp requires fp16 and bf16 support cpu_uarch = kai_cpu_bf16() + kai_cpu_fp16(), + textual_hdrs = [ + "common/assembly.h", + ], ) kai_cxx_library( diff --git a/test/common/assembly.h b/test/common/assembly.h new file mode 100644 index 0000000000000000000000000000000000000000..094970934efb0080d991f6813169aa9140737cb6 --- /dev/null +++ b/test/common/assembly.h @@ -0,0 +1,47 @@ +// +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> +// +// SPDX-License-Identifier: Apache-2.0 +// + +#ifndef KAI_TEST_COMMON_ASSEMBLY_H +#define KAI_TEST_COMMON_ASSEMBLY_H + +// clang-format off + +#ifdef _MSC_VER + +#define KAI_ASM_HEADER AREA |.text|, CODE, READONLY, ALIGN=4 +#define KAI_ASM_LABEL(label) |label| +#define KAI_ASM_TARGET(label, direction) |label| +#define KAI_ASM_FUNCTION(label) |label| +#define KAI_ASM_EXPORT(label) global label +#define KAI_ASM_FOOTER end +#define KAI_ASM_INST(num) dcd num + +#else // _MSC_VER + +#define KAI_ASM_HEADER .text +#define KAI_ASM_LABEL(label) label: +#define KAI_ASM_TARGET(label, direction) label##direction + +#ifdef __APPLE__ +#define KAI_ASM_FUNCTION(label) _##label: +#define KAI_ASM_EXPORT(label) \ + .global _##label; \ + .type _##label, %function +#else // __APPLE__ +#define KAI_ASM_FUNCTION(label) label: +#define KAI_ASM_EXPORT(label) \ + .global label; \ + .type label, %function +#endif // __APPLE__ + +#define KAI_ASM_FOOTER +#define KAI_ASM_INST(num) .inst num + +#endif // _MSC_VER + +// clang-format on + +#endif // KAI_TEST_COMMON_ASSEMBLY_H diff --git a/test/common/bfloat16.cpp b/test/common/bfloat16.cpp index d9581b2f80fe67aae2e57382902b0135fdd2c970..26e9259a7c3283af4ddd43ed206ece50c82628f7 100644 --- a/test/common/bfloat16.cpp +++ b/test/common/bfloat16.cpp @@ -7,9 +7,25 @@ #include "test/common/bfloat16.hpp" #include <iostream> +#include <type_traits> namespace kai::test { +static_assert(sizeof(BFloat16) == 2); + +static_assert(std::is_trivially_destructible_v<BFloat16>); +static_assert(std::is_nothrow_destructible_v<BFloat16>); + +static_assert(std::is_trivially_copy_constructible_v<BFloat16>); +static_assert(std::is_trivially_copy_assignable_v<BFloat16>); +static_assert(std::is_trivially_move_constructible_v<BFloat16>); +static_assert(std::is_trivially_move_assignable_v<BFloat16>); + +static_assert(std::is_nothrow_copy_constructible_v<BFloat16>); +static_assert(std::is_nothrow_copy_assignable_v<BFloat16>); +static_assert(std::is_nothrow_move_constructible_v<BFloat16>); +static_assert(std::is_nothrow_move_assignable_v<BFloat16>); + std::ostream& operator<<(std::ostream& os, BFloat16 value) { return os << static_cast<float>(value); } diff --git a/test/common/bfloat16.hpp b/test/common/bfloat16.hpp index 9291918a638499b94545e6b9ff70c30310c46874..da529e03cd480ce8dfad9e4649b035c900026ae6 100644 --- a/test/common/bfloat16.hpp +++ b/test/common/bfloat16.hpp @@ -12,82 +12,58 @@ #include "test/common/type_traits.hpp" +extern "C" { + +/// Converts single-precision floating-point to half-precision brain floating-point. +/// +/// @params[in] value The single-precision floating-point value. +/// +/// @return The half-precision brain floating-point value reinterpreted as 16-bit unsigned integer. +uint16_t kai_test_bfloat16_from_float(float value); + +} // extern "C" + namespace kai::test { /// Half-precision brain floating-point. -/// -/// This class encapsulates `bfloat16_t` data type provided by `arm_bf16.h`. class BFloat16 { public: /// Constructor. BFloat16() = default; - /// Destructor. - ~BFloat16() = default; - - /// Copy constructor. - BFloat16(const BFloat16&) = default; - - /// Copy assignment. - BFloat16& operator=(const BFloat16&) = default; - - /// Move constructor. - BFloat16(BFloat16&&) = default; - - /// Move assignment. - BFloat16& operator=(BFloat16&&) = default; - /// Creates a new object from the specified numeric value. - BFloat16(float value) : _data(0) { -#ifdef __ARM_FEATURE_BF16 - __asm__ __volatile__("bfcvt %h[output], %s[input]" : [output] "=w"(_data) : [input] "w"(value)); -#else - const uint32_t* value_i32 = reinterpret_cast<const uint32_t*>(&value); - _data = (*value_i32 >> 16); -#endif + explicit BFloat16(float value) : m_data(kai_test_bfloat16_from_float(value)) { } /// Assigns to the specified numeric value which will be converted to `bfloat16_t`. template <typename T, std::enable_if_t<is_arithmetic<T>, bool> = true> BFloat16& operator=(T value) { const auto value_f32 = static_cast<float>(value); -#ifdef __ARM_FEATURE_BF16 - __asm__ __volatile__("bfcvt %h[output], %s[input]" : [output] "=w"(_data) : [input] "w"(value_f32)); -#else - const uint32_t* value_i32 = reinterpret_cast<const uint32_t*>(&value_f32); - _data = (*value_i32 >> 16); -#endif + m_data = kai_test_bfloat16_from_float(value_f32); return *this; } - /// Converts to floating-point. - operator float() const { + /// Converts to single-precision floating-point. + explicit operator float() const { union { float f32; uint32_t u32; } data; - data.u32 = static_cast<uint32_t>(_data) << 16; + data.u32 = static_cast<uint32_t>(m_data) << 16; return data.f32; } +private: /// Equality operator. - bool operator==(BFloat16 rhs) const { - return _data == rhs._data; - } - - /// Unequality operator. - bool operator!=(BFloat16 rhs) const { - return _data != rhs._data; - } - - uint16_t data() const { - return _data; + [[nodiscard]] friend bool operator==(BFloat16 lhs, BFloat16 rhs) { + return lhs.m_data == rhs.m_data; } - void set_data(uint16_t data) { - _data = data; + /// Inequality operator. + [[nodiscard]] friend bool operator!=(BFloat16 lhs, BFloat16 rhs) { + return lhs.m_data != rhs.m_data; } /// Writes the value to the output stream. @@ -98,8 +74,7 @@ public: /// @return The output stream. friend std::ostream& operator<<(std::ostream& os, BFloat16 value); -private: - uint16_t _data; + uint16_t m_data; }; } // namespace kai::test diff --git a/test/common/bfloat16_asm.S b/test/common/bfloat16_asm.S new file mode 100644 index 0000000000000000000000000000000000000000..9f16cda4d7a70fef9b31d3e8c5a55e0bc1ae90b7 --- /dev/null +++ b/test/common/bfloat16_asm.S @@ -0,0 +1,18 @@ +// +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com> +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/common/assembly.h" + + KAI_ASM_HEADER + + KAI_ASM_EXPORT(kai_test_bfloat16_from_float) + +KAI_ASM_FUNCTION(kai_test_bfloat16_from_float) + KAI_ASM_INST(0x1e634000) // bfcvt h0, s0 + fmov w0, h0 + ret + + KAI_ASM_FOOTER diff --git a/test/reference/pack.cpp b/test/reference/pack.cpp index 69b756c60a32610cd24e308b83b395d4445e8a8d..dc44ad1097797ad41a5aaef001210d7ccaa17dc8 100644 --- a/test/reference/pack.cpp +++ b/test/reference/pack.cpp @@ -26,14 +26,14 @@ namespace kai::test { namespace { -uint16_t convert(const uint8_t* src_ptr_elm, DataType src_dtype, DataType dst_dtype) { +BFloat16 convert(const uint8_t* src_ptr_elm, DataType src_dtype, DataType dst_dtype) { KAI_ASSUME((src_dtype == DataType::FP32 || src_dtype == DataType::FP16) && dst_dtype == DataType::BF16); switch (src_dtype) { case DataType::FP32: - return BFloat16(*reinterpret_cast<const float*>(src_ptr_elm)).data(); + return BFloat16(*reinterpret_cast<const float*>(src_ptr_elm)); case DataType::FP16: - return BFloat16(static_cast<float>(*reinterpret_cast<const float16_t*>(src_ptr_elm))).data(); + return BFloat16(static_cast<float>(*reinterpret_cast<const float16_t*>(src_ptr_elm))); default: KAI_ERROR("Unsupported Data Type"); } @@ -77,7 +77,7 @@ std::vector<uint8_t> pack_block( x_element) * src_esize; - const uint16_t src_value = convert(src_ptr_elm, src_dtype, dst_dtype); + const BFloat16 src_value = convert(src_ptr_elm, src_dtype, dst_dtype); memcpy(dst_ptr, &src_value, dst_esize); } } @@ -149,7 +149,7 @@ std::vector<uint8_t> pack_bias_per_row( x_element) * src_esize; - const uint16_t dst_value = convert(src_ptr_elm, src_dtype, dst_dtype); + const BFloat16 dst_value = convert(src_ptr_elm, src_dtype, dst_dtype); memcpy(dst_ptr, &dst_value, dst_esize); } } diff --git a/test/tests/bfloat16_test.cpp b/test/tests/bfloat16_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a8d4bb881b803be8a7435a907bdb15a1b104aafc --- /dev/null +++ b/test/tests/bfloat16_test.cpp @@ -0,0 +1,33 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com> +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/common/bfloat16.hpp" + +#include <gtest/gtest.h> + +#include "test/common/cpu_info.hpp" + +namespace kai::test { + +TEST(BFloat16, SimpleTest) { + if (!cpu_has_bf16()) { + GTEST_SKIP(); + } + + ASSERT_EQ(static_cast<float>(BFloat16()), 0.0F); + ASSERT_EQ(static_cast<float>(BFloat16(1.25F)), 1.25F); + ASSERT_EQ(static_cast<float>(BFloat16(3)), 3.0F); + + ASSERT_FALSE(BFloat16(1.25F) == BFloat16(2.0F)); + ASSERT_TRUE(BFloat16(1.25F) == BFloat16(1.25F)); + ASSERT_FALSE(BFloat16(2.0F) == BFloat16(1.25F)); + + ASSERT_TRUE(BFloat16(1.25F) != BFloat16(2.0F)); + ASSERT_FALSE(BFloat16(1.25F) != BFloat16(1.25F)); + ASSERT_TRUE(BFloat16(2.0F) != BFloat16(1.25F)); +} + +} // namespace kai::test