Skip to content

Commit

Permalink
Add code that guesses a different capitalization in case of ambiguity.
Browse files Browse the repository at this point in the history
…Fixes #161.
  • Loading branch information
hwalinga committed Sep 16, 2020
1 parent b4cf182 commit 1daa536
Showing 1 changed file with 54 additions and 26 deletions.
80 changes: 54 additions & 26 deletions quantulum3/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
# Standard library
import json
import logging
import pkg_resources
import os
import multiprocessing
import os

import pkg_resources

# Semi-dependencies
try:
Expand All @@ -27,9 +28,9 @@
wikipedia = None

# Quantulum
from . import load
from . import language, load
from .load import cached
from . import language


_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -109,7 +110,6 @@ def _clean_text_lang(lang):
def train_classifier(
parameters=None, ngram_range=(1, 1), store=True, lang="en_US", n_jobs=None
):

"""
Train the intent classifier
TODO auto invoke if sklearn version is new or first install or sth
Expand Down Expand Up @@ -245,7 +245,35 @@ def disambiguate_unit(unit, text, lang="en_US"):
"""
Resolve ambiguity between units with same names, symbols or abbreviations.
"""
new_unit = disambiguate_unit_by_score(unit, text, lang)
if len(new_unit) == 1:
return next(iter(new_unit))

try:
# Instead of picking a random one now, we first change the
# capitalization of the unit and see if we can improve.
unit_changed = unit[:-1] + unit[-1].swapcase()
text_changed = text.replace(unit, unit_changed)

new_unit_changed = disambiguate_unit_by_score(unit_changed, text_changed, lang)
if len(new_unit_changed) == 1:
return next(iter(new_unit_changed))

if 0 < len(new_unit_changed) < len(new_unit):
# See if we have improved, otherwise we stick with the old new_unit.
new_unit = new_unit_changed

except KeyError:
pass # Attempt failed, we just pick a random from new_unit now.

_LOGGER.warning(
"Could not resolve ambiguous units: '{}'. For unit '{}' in text '{}'. "
"Taking a random.".format(", ".join(str(u) for u in new_unit), unit, text)
)
return next(iter(new_unit))


def disambiguate_unit_by_score(unit, text, lang):
new_unit = (
load.units(lang).symbols.get(unit)
or load.units(lang).surfaces.get(unit)
Expand All @@ -254,25 +282,25 @@ def disambiguate_unit(unit, text, lang="en_US"):
)
if not new_unit:
raise KeyError('Could not find unit "%s" from "%s"' % (unit, text))
if len(new_unit) == 1:
return new_unit

if len(new_unit) > 1:
transformed = classifier(lang).tfidf_model.transform([clean_text(text, lang)])
scores = classifier(lang).classifier.predict_proba(transformed).tolist()[0]
scores = zip(scores, classifier(lang).target_names)

# Filter for possible names
names = [i.name for i in new_unit]
scores = [i for i in scores if i[1] in names]

# Sort by rank
scores = sorted(scores, key=lambda x: x[0], reverse=True)
try:
final = load.units(lang).names[scores[0][1]]
_LOGGER.debug('\tAmbiguity resolved for "%s" (%s)' % (unit, scores))
except IndexError:
_LOGGER.debug('\tAmbiguity not resolved for "%s"' % unit)
final = next(iter(new_unit))
else:
final = next(iter(new_unit))

return final
# Start scoring
transformed = classifier(lang).tfidf_model.transform(
[clean_text(text, lang)]
)
scores = classifier(lang).classifier.predict_proba(transformed).tolist()[0]
scores = zip(scores, classifier(lang).target_names)

# Filter for possible names
names = [i.name for i in new_unit]
scores = [i for i in scores if i[1] in names]

# Sort by rank
scores = sorted(scores, key=lambda x: x[0], reverse=True)
try:
return [load.units(lang).names[scores[0][1]]]
_LOGGER.debug('\tAmbiguity resolved for "%s" (%s)' % (unit, scores))
except IndexError:
_LOGGER.debug('\tAmbiguity not resolved for "%s"' % unit)
return new_unit

0 comments on commit 1daa536

Please sign in to comment.