Advertisement
JadonChan

Untitled

May 4th, 2025
21
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.96 KB | None | 0 0
  1. #include <algorithm>
  2. #include <cassert>
  3. #include <chrono>
  4. #include <cmath>
  5. #include <cstdint>
  6. #include <cstdio>
  7. #include <memory>
  8. #include <random>
  9. #include <string_view>
  10. #include <vector>
  11.  
  12. namespace
  13. {
  14.  
  15. void invalidate_cache() noexcept
  16. {
  17. constexpr size_t size = 1ULL * 1024 * 1024 * 1024; // 1GB
  18. auto buffer = std::make_unique<volatile char[]>(size);
  19. for (size_t i = 0; i < size; i += 64)
  20. buffer[i] = 0;
  21. }
  22.  
  23. struct Query
  24. {
  25. uint32_t k;
  26. float lookup;
  27. };
  28.  
  29. // Test state.
  30. struct State
  31. {
  32. public:
  33. std::vector<std::vector<float>> data;
  34. std::vector<Query> queries;
  35. std::vector<int64_t> answers;
  36.  
  37. private:
  38. std::vector<int64_t> std_answers_;
  39. bool need_validate_;
  40. const char *algo_name_;
  41.  
  42. using hrc_clock = std::chrono::high_resolution_clock;
  43. hrc_clock::time_point start_time_;
  44.  
  45. static constexpr uint32_t random_seed = 0;
  46.  
  47. public:
  48. State(uint32_t data_batch_size, uint32_t data_size, uint32_t query_batch_size, bool need_validate)
  49. : need_validate_(need_validate)
  50. {
  51. std::printf("generating test data... ");
  52. assert(query_batch_size >= data_batch_size);
  53. assert(data_size >= 8);
  54.  
  55. // Create arrays in ascending order
  56. // NOTE Please assume they are generated randomly. Another random methods will be used for the final evaluation.
  57. data.resize(data_batch_size);
  58. data.front().resize(data_size);
  59. for (uint32_t i = 0; i < data_size; ++i)
  60. data.front()[i] = float(i);
  61. for (auto &d : data)
  62. d = data.front();
  63.  
  64. // Your answers to the queries will be saved in this vector.
  65. answers.resize(query_batch_size);
  66.  
  67. // Create queries
  68. std::mt19937_64 generator(random_seed); /* NOLINT(cert-*) */
  69. std::uniform_real_distribution<float> distribution(-1.0f, float(data_size));
  70. queries.resize(query_batch_size);
  71. if (need_validate)
  72. std_answers_.resize(query_batch_size);
  73. const uint32_t query_per_data_batch = query_batch_size / data_batch_size;
  74. for (uint32_t i = 0; i < query_batch_size; ++i)
  75. {
  76. const uint32_t k = std::min(i / query_per_data_batch, data_batch_size - 1);
  77. const float lookup = distribution(generator);
  78. queries[i] = {.k = k, .lookup = lookup};
  79.  
  80. if (need_validate)
  81. std_answers_[i] = std::lower_bound(data[k].begin(), data[k].end(), lookup) - data[k].begin();
  82. }
  83.  
  84. std::printf("done: n=%u, size=%u, m=%u\n", data_batch_size, data_size, query_batch_size);
  85. }
  86.  
  87. void start(const char *algo_name) noexcept
  88. {
  89. this->algo_name_ = algo_name;
  90. invalidate_cache();
  91. start_time_ = hrc_clock::now();
  92. }
  93.  
  94. void stop() noexcept
  95. {
  96. auto stop_time = hrc_clock::now();
  97. std::chrono::duration<double, std::nano> elapsed = stop_time - start_time_;
  98.  
  99. auto divby = std::log2((double)data.front().size());
  100. auto perop = elapsed.count() / divby / double(queries.size());
  101.  
  102. std::printf("%g ns per lookup/log2(size)\n", perop);
  103.  
  104. if (need_validate_ && answers != std_answers_)
  105. {
  106. std::printf("!!!! Wrong answer\n");
  107. }
  108. }
  109. };
  110.  
  111. int usage(const char *msg = nullptr)
  112. {
  113. std::puts("");
  114. if (msg)
  115. std::puts(msg);
  116.  
  117. std::printf(" Usage: bisect_test <algo> <size>\n\n"
  118. " <algo>: naive t0 t1 t2 t3\n"
  119. " <size>: 8 big\n\n");
  120. return 1;
  121. }
  122.  
  123. struct test_param
  124. {
  125. uint32_t data_batch_size;
  126. uint32_t data_size;
  127. uint32_t query_batch_size;
  128. bool need_validate;
  129. };
  130.  
  131. using test_function_type = void (*)(State &s) noexcept;
  132.  
  133. void test_naive(State &s) noexcept
  134. {
  135. for (uint32_t i = 0, qn = s.queries.size(); i != qn; ++i)
  136. {
  137. const auto [k, lookup] = s.queries[i];
  138. s.answers[i] = std::lower_bound(s.data[k].begin(), s.data[k].end(), lookup) - s.data[k].begin();
  139. }
  140. }
  141.  
  142. void test_t0(State &s) noexcept
  143. {
  144. for (uint32_t i = 0, qn = s.queries.size(); i != qn; ++i)
  145. {
  146. const auto [k, lookup] = s.queries[i];
  147. int64_t size = s.data[k].size();
  148. if (size <= 512) {
  149. int64_t index = 0;
  150. for (; index < size; ++index) {
  151. if (s.data[k][index] >= lookup) break;
  152. }
  153. s.answers[i] = index;
  154. } else {
  155. s.answers[i] = std::lower_bound(s.data[k].begin(), s.data[k].end(), lookup) - s.data[k].begin();
  156. }
  157. }
  158. }
  159.  
  160. void test_t1(State &s) noexcept
  161. {
  162. // [Optional] Enter your answer No.1 here.
  163. }
  164.  
  165. void test_t2(State &s) noexcept
  166. {
  167. // [Optional] Enter your answer No.2 here.
  168. }
  169.  
  170. void test_t3(State &s) noexcept
  171. {
  172. // [Optional] Enter your answer No.3 here.
  173. }
  174.  
  175. } // namespace
  176.  
  177. int main(int argc, const char **argv)
  178. {
  179. if (argc != 3)
  180. return usage();
  181.  
  182. test_function_type test_fn = nullptr;
  183.  
  184. using namespace std::string_view_literals;
  185. if (argv[1] == "naive"sv)
  186. test_fn = &test_naive;
  187. else if (argv[1] == "t0"sv)
  188. test_fn = &test_t0;
  189. else if (argv[1] == "t1"sv)
  190. test_fn = &test_t1;
  191. else if (argv[1] == "t2"sv)
  192. test_fn = &test_t2;
  193. else if (argv[1] == "t3"sv)
  194. test_fn = &test_t3;
  195. else
  196. return usage("invalid algorithm name\n\n");
  197.  
  198. test_param param;
  199. if (argv[2] == "8"sv)
  200. param = test_param{
  201. .data_batch_size = uint32_t(1e6),
  202. .data_size = 8,
  203. .query_batch_size = uint32_t(2e6),
  204. .need_validate = true,
  205. };
  206. else if (argv[2] == "big"sv)
  207. param = test_param{
  208. .data_batch_size = 500u,
  209. .data_size = uint32_t(1e6),
  210. .query_batch_size = 2000u,
  211. .need_validate = true,
  212. };
  213. else
  214. return usage("invalid size\n\n");
  215.  
  216. State s(param.data_batch_size, param.data_size, param.query_batch_size, param.need_validate);
  217.  
  218. s.start(argv[1]);
  219. (*test_fn)(s);
  220. s.stop();
  221.  
  222. return 0;
  223. }
  224.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement