presage  0.9.1
smoothedNgramPredictor.cpp
Go to the documentation of this file.
1 
2 /******************************************************
3  * Presage, an extensible predictive text entry system
4  * ---------------------------------------------------
5  *
6  * Copyright (C) 2008 Matteo Vescovi <matteo.vescovi@yahoo.co.uk>
7 
8  This program is free software; you can redistribute it and/or modify
9  it under the terms of the GNU General Public License as published by
10  the Free Software Foundation; either version 2 of the License, or
11  (at your option) any later version.
12 
13  This program is distributed in the hope that it will be useful,
14  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16  GNU General Public License for more details.
17 
18  You should have received a copy of the GNU General Public License along
19  with this program; if not, write to the Free Software Foundation, Inc.,
20  51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21  *
22  **********(*)*/
23 
24 
25 #include "smoothedNgramPredictor.h"
26 
27 #include <sstream>
28 #include <algorithm>
29 
30 
32  : Predictor(config,
33  ct,
34  name,
35  "SmoothedNgramPredictor, a linear interpolating n-gram predictor",
36  "SmoothedNgramPredictor, long description." ),
37  db (0),
38  cardinality (0),
39  learn_mode_set (false),
40  dispatcher (this)
41 {
42  LOGGER = PREDICTORS + name + ".LOGGER";
43  DBFILENAME = PREDICTORS + name + ".DBFILENAME";
44  DELTAS = PREDICTORS + name + ".DELTAS";
45  LEARN = PREDICTORS + name + ".LEARN";
46  DATABASE_LOGGER = PREDICTORS + name + ".DatabaseConnector.LOGGER";
47 
48  // build notification dispatch map
54 }
55 
56 
57 
59 {
60  delete db;
61 }
62 
63 
64 void SmoothedNgramPredictor::set_dbfilename (const std::string& filename)
65 {
66  dbfilename = filename;
67  logger << INFO << "DBFILENAME: " << dbfilename << endl;
68 
70 }
71 
72 
74 {
75  dbloglevel = value;
76 }
77 
78 
79 void SmoothedNgramPredictor::set_deltas (const std::string& value)
80 {
81  std::stringstream ss_deltas(value);
82  cardinality = 0;
83  std::string delta;
84  while (ss_deltas >> delta) {
85  logger << DEBUG << "Pushing delta: " << delta << endl;
86  deltas.push_back (Utility::toDouble (delta));
87  cardinality++;
88  }
89  logger << INFO << "DELTAS: " << value << endl;
90  logger << INFO << "CARDINALITY: " << cardinality << endl;
91 
93 }
94 
95 
96 void SmoothedNgramPredictor::set_learn (const std::string& value)
97 {
98  learn_mode = Utility::isYes (value);
99  logger << INFO << "LEARN: " << value << endl;
100 
101  learn_mode_set = true;
102 
104 }
105 
106 
108 {
109  // we can only init the sqlite database connector once we know the
110  // following:
111  // - what database file we need to open
112  // - what cardinality we expect the database file to be
113  // - whether we need to open the database in read only or
114  // read/write mode (learning requires read/write access)
115  //
116  if (! dbfilename.empty()
117  && cardinality > 0
118  && learn_mode_set ) {
119 
120  delete db;
121 
122  if (dbloglevel.empty ()) {
123  // open database connector
125  cardinality,
126  learn_mode);
127  } else {
128  // open database connector with logger lever
130  cardinality,
131  learn_mode,
132  dbloglevel);
133  }
134  }
135 }
136 
137 
138 // convenience function to convert ngram to string
139 //
140 static std::string ngram_to_string(const Ngram& ngram)
141 {
142  const char separator[] = "|";
143  std::string result = separator;
144 
145  for (Ngram::const_iterator it = ngram.begin();
146  it != ngram.end();
147  it++)
148  {
149  result += *it + separator;
150  }
151 
152  return result;
153 }
154 
155 
171 unsigned int SmoothedNgramPredictor::count(const std::vector<std::string>& tokens, int offset, int ngram_size) const
172 {
173  unsigned int result = 0;
174 
175  assert(offset <= 0); // TODO: handle this better
176  assert(ngram_size >= 0);
177 
178  if (ngram_size > 0) {
179  Ngram ngram(ngram_size);
180  copy(tokens.end() - ngram_size + offset , tokens.end() + offset, ngram.begin());
181  result = db->getNgramCount(ngram);
182  logger << DEBUG << "count ngram: " << ngram_to_string (ngram) << " : " << result << endl;
183  } else {
184  result = db->getUnigramCountsSum();
185  logger << DEBUG << "unigram counts sum: " << result << endl;
186  }
187 
188  return result;
189 }
190 
191 Prediction SmoothedNgramPredictor::predict(const size_t max_partial_prediction_size, const char** filter) const
192 {
193  logger << DEBUG << "predict()" << endl;
194 
195  // Result prediction
196  Prediction prediction;
197 
198  // Cache all the needed tokens.
199  // tokens[k] corresponds to w_{i-k} in the generalized smoothed
200  // n-gram probability formula
201  //
202  std::vector<std::string> tokens(cardinality);
203  for (int i = 0; i < cardinality; i++) {
204  tokens[cardinality - 1 - i] = contextTracker->getToken(i);
205  logger << DEBUG << "Cached tokens[" << cardinality - 1 - i << "] = " << tokens[cardinality - 1 - i] << endl;
206  }
207 
208  // Generate list of prefix completition candidates.
209  //
210  // The prefix completion candidates used to be obtained from the
211  // _1_gram table because in a well-constructed ngram database the
212  // _1_gram table (which contains all known tokens). However, this
213  // introduced a skew, since the unigram counts will take
214  // precedence over the higher-order counts.
215  //
216  // The current solution retrieves candidates from the highest
217  // n-gram table, falling back on lower order n-gram tables if
218  // initial completion set is smaller than required.
219  //
220  std::vector<std::string> prefixCompletionCandidates;
221  for (size_t k = cardinality; (k > 0 && prefixCompletionCandidates.size() < max_partial_prediction_size); k--) {
222  logger << DEBUG << "Building partial prefix completion table of cardinality: " << k << endl;
223  // create n-gram used to retrieve initial prefix completion table
224  Ngram prefix_ngram(k);
225  copy(tokens.end() - k, tokens.end(), prefix_ngram.begin());
226 
227  if (logger.shouldLog()) {
228  logger << DEBUG << "prefix_ngram: ";
229  for (size_t r = 0; r < prefix_ngram.size(); r++) {
230  logger << DEBUG << prefix_ngram[r] << ' ';
231  }
232  logger << DEBUG << endl;
233  }
234 
235  // obtain initial prefix completion candidates
236  db->beginTransaction();
237 
238  NgramTable partial;
239 
240  if (filter == 0) {
241  partial = db->getNgramLikeTable(prefix_ngram,max_partial_prediction_size - prefixCompletionCandidates.size());
242  } else {
243  partial = db->getNgramLikeTableFiltered(prefix_ngram,filter, max_partial_prediction_size - prefixCompletionCandidates.size());
244  }
245 
246  db->endTransaction();
247 
248  if (logger.shouldLog()) {
249  logger << DEBUG << "partial prefixCompletionCandidates" << endl
250  << DEBUG << "----------------------------------" << endl;
251  for (size_t j = 0; j < partial.size(); j++) {
252  for (size_t k = 0; k < partial[j].size(); k++) {
253  logger << DEBUG << partial[j][k] << " ";
254  }
255  logger << endl;
256  }
257  }
258 
259  logger << DEBUG << "Partial prefix completion table contains " << partial.size() << " potential completions." << endl;
260 
261  // append newly discovered potential completions to prefix
262  // completion candidates array to fill it up to
263  // max_partial_prediction_size
264  //
265  std::vector<Ngram>::const_iterator it = partial.begin();
266  while (it != partial.end() && prefixCompletionCandidates.size() < max_partial_prediction_size) {
267  // only add new candidates, iterator it points to Ngram,
268  // it->end() - 2 points to the token candidate
269  //
270  std::string candidate = *(it->end() - 2);
271  if (find(prefixCompletionCandidates.begin(),
272  prefixCompletionCandidates.end(),
273  candidate) == prefixCompletionCandidates.end()) {
274  prefixCompletionCandidates.push_back(candidate);
275  }
276  it++;
277  }
278  }
279 
280  if (logger.shouldLog()) {
281  logger << DEBUG << "prefixCompletionCandidates" << endl
282  << DEBUG << "--------------------------" << endl;
283  for (size_t j = 0; j < prefixCompletionCandidates.size(); j++) {
284  logger << DEBUG << prefixCompletionCandidates[j] << endl;
285  }
286  }
287 
288  // compute smoothed probabilities for all candidates
289  //
290  db->beginTransaction();
291  // getUnigramCountsSum is an expensive SQL query
292  // caching it here saves much time later inside the loop
293  int unigrams_counts_sum = db->getUnigramCountsSum();
294  for (size_t j = 0; (j < prefixCompletionCandidates.size() && j < max_partial_prediction_size); j++) {
295  // store w_i candidate at end of tokens
296  tokens[cardinality - 1] = prefixCompletionCandidates[j];
297 
298  logger << DEBUG << "------------------" << endl;
299  logger << DEBUG << "w_i: " << tokens[cardinality - 1] << endl;
300 
301  double probability = 0;
302  for (int k = 0; k < cardinality; k++) {
303  double numerator = count(tokens, 0, k+1);
304  // reuse cached unigrams_counts_sum to speed things up
305  double denominator = (k == 0 ? unigrams_counts_sum : count(tokens, -1, k));
306  double frequency = ((denominator > 0) ? (numerator / denominator) : 0);
307  probability += deltas[k] * frequency;
308 
309  logger << DEBUG << "numerator: " << numerator << endl;
310  logger << DEBUG << "denominator: " << denominator << endl;
311  logger << DEBUG << "frequency: " << frequency << endl;
312  logger << DEBUG << "delta: " << deltas[k] << endl;
313 
314  // for some sanity checks
315  assert(numerator <= denominator);
316  assert(frequency <= 1);
317  }
318 
319  logger << DEBUG << "____________" << endl;
320  logger << DEBUG << "probability: " << probability << endl;
321 
322  if (probability > 0) {
323  prediction.addSuggestion(Suggestion(tokens[cardinality - 1], probability));
324  }
325  }
326  db->endTransaction();
327 
328  logger << DEBUG << "Prediction:" << endl;
329  logger << DEBUG << "-----------" << endl;
330  logger << DEBUG << prediction << endl;
331 
332  return prediction;
333 }
334 
335 void SmoothedNgramPredictor::learn(const std::vector<std::string>& change)
336 {
337  logger << INFO << "learn(\"" << ngram_to_string(change) << "\")" << endl;
338 
339  if (learn_mode) {
340  // learning is turned on
341 
342  std::map<std::list<std::string>, int> ngramMap;
343 
344  // build up ngram map for all cardinalities
345  // i.e. learn all ngrams and counts in memory
346  for (size_t curr_cardinality = 1;
347  curr_cardinality < cardinality + 1;
348  curr_cardinality++)
349  {
350  int change_idx = 0;
351  int change_size = change.size();
352 
353  std::list<std::string> ngram_list;
354 
355  // take care of first N-1 tokens
356  for (int i = 0;
357  (i < curr_cardinality - 1 && change_idx < change_size);
358  i++)
359  {
360  ngram_list.push_back(change[change_idx]);
361  change_idx++;
362  }
363 
364  while (change_idx < change_size)
365  {
366  ngram_list.push_back(change[change_idx++]);
367  ngramMap[ngram_list] = ngramMap[ngram_list] + 1;
368  ngram_list.pop_front();
369  }
370  }
371 
372  // use (past stream - change) to learn token at the boundary
373  // change, i.e.
374  //
375 
376  // if change is "bar foobar", then "bar" will only occur in a
377  // 1-gram, since there are no token before it. By dipping in
378  // the past stream, we additional context to learn a 2-gram by
379  // getting extra tokens (assuming past stream ends with token
380  // "foo":
381  //
382  // <"foo", "bar"> will be learnt
383  //
384  // We do this till we build up to n equal to cardinality.
385  //
386  // First check that change is not empty (nothing to learn) and
387  // that change and past stream match by sampling first and
388  // last token in change and comparing them with corresponding
389  // tokens from past stream
390  //
391  if (change.size() > 0 &&
392  change.back() == contextTracker->getToken(1) &&
393  change.front() == contextTracker->getToken(change.size()))
394  {
395  // create ngram list with first (oldest) token from change
396  std::list<std::string> ngram_list(change.begin(), change.begin() + 1);
397 
398  // prepend token to ngram list by grabbing extra tokens
399  // from past stream (if there are any) till we have built
400  // up to n==cardinality ngrams, and commit them to
401  // ngramMap
402  //
403  for (int tk_idx = 1;
404  ngram_list.size() < cardinality;
405  tk_idx++)
406  {
407  // getExtraTokenToLearn returns tokens from
408  // past stream that come before and are not in
409  // change vector
410  //
411  std::string extra_token = contextTracker->getExtraTokenToLearn(tk_idx, change);
412  logger << DEBUG << "Adding extra token: " << extra_token << endl;
413 
414  if (extra_token.empty())
415  {
416  break;
417  }
418  ngram_list.push_front(extra_token);
419 
420  ngramMap[ngram_list] = ngramMap[ngram_list] + 1;
421  }
422  }
423 
424  // then write out to language model database
425  try
426  {
427  db->beginTransaction();
428 
429  std::map<std::list<std::string>, int>::const_iterator it;
430  for (it = ngramMap.begin(); it != ngramMap.end(); it++)
431  {
432  // convert ngram from list to vector based Ngram
433  Ngram ngram((it->first).begin(), (it->first).end());
434 
435  // update the counts
436  int count = db->getNgramCount(ngram);
437  if (count > 0)
438  {
439  // ngram already in database, update count
440  db->updateNgram(ngram, count + it->second);
442  }
443  else
444  {
445  // ngram not in database, insert it
446  db->insertNgram(ngram, it->second);
447  }
448  }
449 
450  db->endTransaction();
451  logger << INFO << "Committed learning update to database" << endl;
452  }
454  {
456  logger << ERROR << "Rolling back learning update : " << ex.what() << endl;
457  throw;
458  }
459  }
460 
461  logger << DEBUG << "end learn()" << endl;
462 }
463 
465 {
466  // no need to begin a new transaction, as we'll be called from
467  // within an existing transaction from learn()
468 
469  // BEWARE: if the previous sentence is not true, then performance
470  // WILL suffer!
471 
472  size_t size = ngram.size();
473  for (size_t i = 0; i < size; i++) {
474  if (count(ngram, -i, size - i) > count(ngram, -(i + 1), size - (i + 1))) {
475  logger << INFO << "consistency adjustment needed!" << endl;
476 
477  int offset = -(i + 1);
478  int sub_ngram_size = size - (i + 1);
479 
480  logger << DEBUG << "i: " << i << " | offset: " << offset << " | sub_ngram_size: " << sub_ngram_size << endl;
481 
482  Ngram sub_ngram(sub_ngram_size); // need to init to right size for sub_ngram
483  copy(ngram.end() - sub_ngram_size + offset, ngram.end() + offset, sub_ngram.begin());
484 
485  if (logger.shouldLog()) {
486  logger << "ngram to be count adjusted is: ";
487  for (size_t i = 0; i < sub_ngram.size(); i++) {
488  logger << sub_ngram[i] << ' ';
489  }
490  logger << endl;
491  }
492 
493  db->incrementNgramCount(sub_ngram);
494  logger << DEBUG << "consistency adjusted" << endl;
495  }
496  }
497 }
498 
500 {
501  logger << DEBUG << "About to invoke dispatcher: " << var->get_name () << " - " << var->get_value() << endl;
502  dispatcher.dispatch (var);
503 }
Logger< char > logger
Definition: predictor.h:87
void dispatch(const Observable *var)
Definition: dispatcher.h:73
virtual void beginTransaction() const
void check_learn_consistency(const Ngram &name) const
Dispatcher< SmoothedNgramPredictor > dispatcher
NgramTable getNgramLikeTableFiltered(const Ngram ngram, const char **filter, int limit=-1) const
int getUnigramCountsSum() const
Variable * find(const std::string &variable) const
std::string getExtraTokenToLearn(const int index, const std::vector< std::string > &change) const
virtual void learn(const std::vector< std::string > &change)
int getNgramCount(const Ngram ngram) const
void set_database_logger_level(const std::string &level)
virtual Prediction predict(const size_t size, const char **filter) const
Generate prediction.
virtual void set_logger(const std::string &level)
Definition: predictor.cpp:88
unsigned int count(const std::vector< std::string > &tokens, int offset, int ngram_size) const
Builds the required n-gram and returns its count.
std::string config
Definition: presageDemo.cpp:70
void set_deltas(const std::string &deltas)
const std::string PREDICTORS
Definition: predictor.h:81
static std::string ngram_to_string(const Ngram &ngram)
virtual void endTransaction() const
std::vector< double > deltas
void updateNgram(const Ngram ngram, const int count) const
static double toDouble(const std::string)
Definition: utility.cpp:258
virtual void update(const Observable *variable)
std::vector< Ngram > NgramTable
void set_learn(const std::string &learn_mode)
virtual std::string get_name() const =0
SmoothedNgramPredictor(Configuration *, ContextTracker *, const char *)
void map(Observable *var, const mbr_func_ptr_t &ptr)
Definition: dispatcher.h:62
void insertNgram(const Ngram ngram, const int count) const
ContextTracker * contextTracker
Definition: predictor.h:83
std::string getToken(const int) const
void addSuggestion(Suggestion)
Definition: prediction.cpp:90
static bool isYes(const char *)
Definition: utility.cpp:185
Tracks user interaction and context.
virtual std::string get_value() const =0
virtual void rollbackTransaction() const
Definition: ngram.h:33
virtual const char * what() const
void set_dbfilename(const std::string &filename)
int incrementNgramCount(const Ngram ngram) const
NgramTable getNgramLikeTable(const Ngram ngram, int limit=-1) const
bool shouldLog() const
Definition: logger.h:149
const Logger< _charT, _Traits > & endl(const Logger< _charT, _Traits > &lgr)
Definition: logger.h:278