lower_bound よりも自分で二分探索

競技プログラミングの文脈の話です.std::lower_bound() と std::upper_bound() の使い方がなかなか覚えづらく,使いづらい印象があります.自分で二分探索を行う関数を書いておいて貼り付けた方が楽ではないか,という趣旨の話を書きます.

以下,次のように書いてあるものとします.

#include <bits/stdc++.h>
typedef long long int ll;
using namespace std;

lower_bound()

整列済みの vector<int> があったとして,この中に入っている値で K 未満のものがいくつあるかを知りたい,という場合,lower_bound() を使うと思います.

  vector<int> vec = ...
  auto it = lower_bound(vec.begin(), vec.end(), K);
  cout << distance(vec.begin(), it) << endl;

これ自体には何も問題ありませんが,問題を解く際に,この K を決定する時に間違いやすい,ということがよくあります.たとえば,vec.at(i) * T <= S を満たすものの数がほしい,となったら,ceil((S+1)/T) 未満のものを数えることになるのだと思いますが,この言い換えのところで時間を使った上によく間違えてしまいます.無理に lower_bound の仕様に合わせるのではなく,自然な条件 (ここでは,vec.at(i) * T <= S) をそのまま書きたいと思います.

二分探索

ということで,二分探索を行う関数を書いておいて,コンテストの時にはそのまま貼ることにしました.そんなもんその場で書けよ,という意見ももちろんあると思うのですが,それもまたよく間違えるので....

template<typename T>
T binsearch(auto check, T yes, T no, T err = (T)1) {
  while (abs(yes - no) > err) {
    T mid = (yes + no) / (T)2;
    if (check(mid)) yes = mid;
    else            no  = mid;
  }
  return yes;
}

T が vector 等に入っている値の型.check は判定関数で,T を一つ受け取って bool を返す.yes と no は,それぞれ判定関数が必ず true, false になる端の値です.yes/noのどちらが大きくとも構いません.err は範囲の幅がこれ以下になったら繰り返しを終了するという値で,上の例のようにベクトルの添字を求めるような場合には,既定値の1を使えば良いです.double の境界を求める場合などには誤差条件などから適切に設定します.返される値は,境界の2つの値のうち,判定関数が true を返す方,ということになります.

この関数があれば,先の例は次のように自然な条件をそのまま書けます.

  vector<int> vec = ...
  cout << binsearch<int>([&](int i) -> bool { return vec.at(i) * T <= S; },
                         -1, vec.size()) << endl;

スピード

書きやすくなった (と私は思う) のですが,極端に遅くなるのでは,コンテストで使えません.速度を比較してみました.整列済みのvector X に対し,各 i について,Y.at(i) <= X.at(j) となる最小の添字 j (なければ X.size()) を返す,というもので,X, Y はサイズ 106でランダムに作ったものです.

これを,lower_bound と 上の binsearch で実行したところ,lower_bound では 223ms, binsearch では 255ms となりました.14% 遅くなっただけですので,まあ使えないこともないと思っています.

  ll N, M; cin >> N >> M;
  vector<ll> X(N), Y(M);
  for (ll i = 0; i < N; i++) cin >> X.at(i);
  for (ll i = 0; i < M; i++) cin >> Y.at(i);
  sort(X.begin(), X.end());

  ll s1 = 0;
  auto start1 = chrono::system_clock::now();
  for (ll i = 0; i < M; i++) {
    auto it = lower_bound(X.begin(), X.end(), Y.at(i));
    s1 += distance(X.begin(), it);
  }
  auto end1 = chrono::system_clock::now();

  ll s2 = 0;
  auto start2 = chrono::system_clock::now();
  for (ll i = 0; i < M; i++) {
    s2 += binsearch<ll>([&](ll j) -> bool { return Y.at(i) <= X.at(j); },
                        N, -1LL);
  }
  auto end2 = chrono::system_clock::now();

  ll mega = (ll)1e6;
  cout << "lower_bound " << (end1 - start1).count() / mega
       << "ms" << endl;
  cout << "binsearch   " << (end2 - start2).count() / mega
       << "ms" << endl;

  assert(s1 == s2);