Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: R poly compatibility #92

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/API-reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,11 @@ Spline regression
.. autofunction:: cc
.. autofunction:: te

Polynomial
----------

.. autofunction:: poly

Working with formulas programmatically
--------------------------------------

Expand Down
3 changes: 3 additions & 0 deletions patsy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,5 +113,8 @@ def _reexport(mod):
import patsy.mgcv_cubic_splines
_reexport(patsy.mgcv_cubic_splines)

import patsy.polynomials
_reexport(patsy.polynomials)

# XX FIXME: we aren't exporting any of the explicit parsing interface
# yet. Need to figure out how to do that.
9 changes: 4 additions & 5 deletions patsy/contrasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from patsy.util import (repr_pretty_delegate, repr_pretty_impl,
safe_issubdtype,
no_pickling, assert_no_pickling)
from patsy.polynomials import Poly as Polynomial

class ContrastMatrix(object):
"""A simple container for a matrix used for coding categorical factors.
Expand Down Expand Up @@ -263,11 +264,9 @@ def _code_either(self, intercept, levels):
# quadratic, etc., functions of the raw scores, and then use 'qr' to
# orthogonalize each column against those to its left.
scores -= scores.mean()
raw_poly = scores.reshape((-1, 1)) ** np.arange(n).reshape((1, -1))
q, r = np.linalg.qr(raw_poly)
q *= np.sign(np.diag(r))
q /= np.sqrt(np.sum(q ** 2, axis=1))
# The constant term is always all 1's -- we don't normalize it.
raw_poly = Polynomial.vander(scores, n - 1)
alpha, norm, beta = Polynomial.gen_qr(raw_poly, n - 1)
q = Polynomial.apply_qr(raw_poly, n - 1, alpha, norm, beta)
q[:, 0] = 1
names = [".Constant", ".Linear", ".Quadratic", ".Cubic"]
names += ["^%s" % (i,) for i in range(4, n)]
Expand Down
204 changes: 204 additions & 0 deletions patsy/polynomials.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
# This file is part of Patsy
# Copyright (C) 2012-2013 Nathaniel Smith <[email protected]>
# See file LICENSE.txt for license information.

# R-compatible poly function

# These are made available in the patsy.* namespace
import numpy as np

from patsy.util import have_pandas, no_pickling, assert_no_pickling
from patsy.state import stateful_transform

__all__ = ["poly"]

if have_pandas:
import pandas


class Poly(object):
"""poly(x, degree=3, raw=False)

Generates an orthogonal polynomial transformation of x of degree.
Generic usage is something along the lines of::

y ~ 1 + poly(x, 4)

to fit ``y`` as a function of ``x``, with a 4th degree polynomial.

:arg degree: The number of degrees for the polynomial expansion.
:arg raw: When raw is False (the default), will return orthogonal
polynomials.

.. versionadded:: 0.4.1
"""

def __init__(self):
self._tmp = {}

def memorize_chunk(self, x, degree=3, raw=False):
args = {"degree": degree,
"raw": raw
}
self._tmp["args"] = args
# XX: check whether we need x values before saving them
x = np.atleast_1d(x)
if x.ndim == 2 and x.shape[1] == 1:
x = x[:, 0]
if x.ndim > 1:
raise ValueError("input to 'poly' must be 1-d, "
"or a 2-d column vector")
# There's no better way to compute exact quantiles than memorizing
# all data.
x = np.array(x, dtype=float)
self._tmp.setdefault("xs", []).append(x)

def memorize_finish(self):
tmp = self._tmp
args = tmp["args"]
del self._tmp

if args["degree"] < 1:
raise ValueError("degree must be greater than 0 (not %r)"
% (args["degree"],))
if int(args["degree"]) != args["degree"]:
raise ValueError("degree must be an integer (not %r)"
% (args['degree'],))

# These are guaranteed to all be 1d vectors by the code above
scores = np.concatenate(tmp["xs"])

n = args['degree']
self.degree = n
self.raw = args['raw']

if not self.raw:
raw_poly = self.vander(scores, n)
self.alpha, self.norm, self.beta = self.gen_qr(raw_poly, n)

def transform(self, x, degree=3, raw=False):
if have_pandas:
if isinstance(x, (pandas.Series, pandas.DataFrame)):
to_pandas = True
idx = x.index
else:
to_pandas = False
else:
to_pandas = False
x = np.array(x, ndmin=1).flatten()

n = self.degree
p = self.vander(x, n)

if not self.raw:
p = self.apply_qr(p, n, self.alpha, self.norm, self.beta)

p = p[:, 1:]
if to_pandas:
p = pandas.DataFrame(p)
p.index = idx
return p

@staticmethod
def vander(x, n):
raw_poly = np.polynomial.polynomial.polyvander(x, n)
return raw_poly

@staticmethod
def gen_qr(raw_poly, n):
x = raw_poly[:, 1]
q, r = np.linalg.qr(raw_poly)
# Q is now orthognoal of degree n. To match what R is doing, we
# need to use the three-term recurrence technique to calculate
# new alpha, beta, and norm.
alpha = (np.sum(x.reshape((-1, 1)) * q[:, :n] ** 2, axis=0)
/ np.sum(q[:, :n] ** 2, axis=0))

# For reasons I don't understand, the norms R uses are based off
# of the diagonal of the r upper triangular matrix.

norm = np.linalg.norm(q * np.diag(r), axis=0)
beta = (norm[1:] / norm[:n]) ** 2
return alpha, norm, beta

@staticmethod
def apply_qr(x, n, alpha, norm, beta):
# This is where the three-term recurrence is unwound for the QR
# decomposition.
if np.ndim(x) == 2:
x = x[:, 1]
p = np.empty((x.shape[0], n + 1))
p[:, 0] = 1

for i in np.arange(n):
p[:, i + 1] = (x - alpha[i]) * p[:, i]
if i > 0:
p[:, i + 1] = (p[:, i + 1] - (beta[i - 1] * p[:, i - 1]))
p /= norm
return p
__getstate__ = no_pickling


poly = stateful_transform(Poly)


def test_poly_compat():
from patsy.test_state import check_stateful
from patsy.test_poly_data import (R_poly_test_x,
R_poly_test_data,
R_poly_num_tests)
from numpy.testing import assert_allclose

lines = R_poly_test_data.split("\n")
tests_ran = 0
start_idx = lines.index("--BEGIN TEST CASE--")
while True:
if not lines[start_idx] == "--BEGIN TEST CASE--":
break
start_idx += 1
stop_idx = lines.index("--END TEST CASE--", start_idx)
block = lines[start_idx:stop_idx]
test_data = {}
for line in block:
key, value = line.split("=", 1)
test_data[key] = value
# Translate the R output into Python calling conventions
kwargs = {
# integer
"degree": int(test_data["degree"]),
# boolen
"raw": (test_data["raw"] == 'TRUE')
}
# Special case: in R, setting intercept=TRUE increases the effective
# dof by 1. Adjust our arguments to match.
# if kwargs["df"] is not None and kwargs["include_intercept"]:
# kwargs["df"] += 1
output = np.asarray(eval(test_data["output"]))
# Do the actual test
check_stateful(Poly, False, R_poly_test_x, output, **kwargs)
raw_poly = Poly.vander(R_poly_test_x, kwargs['degree'])
if kwargs['raw']:
actual = raw_poly[:, 1:]
else:
alpha, norm, beta = Poly.gen_qr(raw_poly, kwargs['degree'])
actual = Poly.apply_qr(R_poly_test_x, kwargs['degree'], alpha,
norm, beta)[:, 1:]
assert_allclose(actual, output)
tests_ran += 1
# Set up for the next one
start_idx = stop_idx + 1
assert tests_ran == R_poly_num_tests


def test_poly_errors():
from nose.tools import assert_raises
x = np.arange(27)
# Invalid input shape
assert_raises(ValueError, poly, x.reshape((3, 3, 3)))
assert_raises(ValueError, poly, x.reshape((3, 3, 3)), raw=True)
# Invalid degree
assert_raises(ValueError, poly, x, degree=-1)
assert_raises(ValueError, poly, x, degree=0)
assert_raises(ValueError, poly, x, degree=3.5)

assert_no_pickling(Poly())
37 changes: 37 additions & 0 deletions patsy/test_poly_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# This file auto-generated by tools/get-R-poly-test-vectors.R
# Using: R version 3.2.4 Revised (2016-03-16 r70336)
import numpy as np
R_poly_test_x = np.array([1, 1.5, 2.25, 3.375, 5.0625, 7.59375, 11.390625, 17.0859375, 25.62890625, 38.443359375, 57.6650390625, 86.49755859375, 129.746337890625, 194.6195068359375, 291.92926025390625, 437.89389038085938, 656.84083557128906, 985.26125335693359, 1477.8918800354004, 2216.8378200531006, ])
R_poly_test_data = """
--BEGIN TEST CASE--
degree=1
raw=TRUE
output=np.array([1, 1.5, 2.25, 3.375, 5.0625, 7.59375, 11.390625, 17.0859375, 25.62890625, 38.443359375, 57.6650390625, 86.49755859375, 129.746337890625, 194.6195068359375, 291.92926025390625, 437.89389038085938, 656.84083557128906, 985.26125335693359, 1477.8918800354004, 2216.8378200531006, ]).reshape((20, 1, ), order="F")
--END TEST CASE--
--BEGIN TEST CASE--
degree=1
raw=FALSE
output=np.array([-0.12865949508274149, -0.12846539500908838, -0.12817424489860868, -0.12773751973288924, -0.12708243198431005, -0.12609980036144131, -0.12462585292713815, -0.12241493177568342, -0.11909855004850137, -0.11412397745772825, -0.10666211857156857, -0.095469330242329037, -0.07868014774846975, -0.053496374007680828, -0.015720713396497447, 0.040942777520277619, 0.12593801389544024, 0.25343086845818413, 0.4446701503023, 0.73152907306847381, ]).reshape((20, 1, ), order="F")
--END TEST CASE--
--BEGIN TEST CASE--
degree=3
raw=TRUE
output=np.array([1, 1.5, 2.25, 3.375, 5.0625, 7.59375, 11.390625, 17.0859375, 25.62890625, 38.443359375, 57.6650390625, 86.49755859375, 129.746337890625, 194.6195068359375, 291.92926025390625, 437.89389038085938, 656.84083557128906, 985.26125335693359, 1477.8918800354004, 2216.8378200531006, 1, 2.25, 5.0625, 11.390625, 25.62890625, 57.6650390625, 129.746337890625, 291.92926025390625, 656.84083557128906, 1477.8918800354004, 3325.2567300796509, 7481.8276426792145, 16834.112196028233, 37876.752441063523, 85222.692992392927, 191751.05923288409, 431439.8832739892, 970739.73736647563, 2184164.4090745705, 4914369.9204177829, 1, 3.375, 11.390625, 38.443359375, 129.746337890625, 437.89389038085938, 1477.8918800354004, 4987.8850951194763, 16834.112196028233, 56815.128661595285, 191751.05923288409, 647159.82491098379, 2184164.4090745705, 7371554.8806266747, 24878997.722115029, 83966617.312138215, 283387333.4284665, 956432250.32107437, 3227958844.8336263, 10894361101.313488, ]).reshape((20, 3, ), order="F")
--END TEST CASE--
--BEGIN TEST CASE--
degree=3
raw=FALSE
output=np.array([-0.12865949508274149, -0.12846539500908838, -0.12817424489860868, -0.12773751973288924, -0.12708243198431005, -0.12609980036144131, -0.12462585292713815, -0.12241493177568342, -0.11909855004850137, -0.11412397745772825, -0.10666211857156857, -0.095469330242329037, -0.07868014774846975, -0.053496374007680828, -0.015720713396497447, 0.040942777520277619, 0.12593801389544024, 0.25343086845818413, 0.4446701503023, 0.73152907306847381, 0.11682670564764953, 0.11622774987820758, 0.1153299112445243, 0.11398449209008393, 0.11196937564961051, 0.10895347864407183, 0.10444488285989936, 0.097716301062945765, 0.087700630095951776, 0.072850827534442664, 0.05096695744238839, 0.019020528242278005, -0.026920519697452645, -0.091380250921070119, -0.1780532062130448, -0.28552519567824058, -0.39602393206231051, -0.44767622905753701, -0.26843910749340033, 0.57802660073100254, -0.11560888340228653, -0.11436481217184656, -0.11250218782662975, -0.10971608089390825, -0.10555451667328646, -0.099351692934324679, -0.090136150925525155, -0.076511614727544461, -0.05651941299388475, -0.027522538371457093, 0.013772191900731716, 0.070864671671751547, 0.14593497036033168, 0.23591981919395397, 0.32391016867398448, 0.36336942185480259, 0.25890497941187346, -0.11572025100301592, -0.66076386903314166, 0.27159578788942196, ]).reshape((20, 3, ), order="F")
--END TEST CASE--
--BEGIN TEST CASE--
degree=5
raw=TRUE
output=np.array([1, 1.5, 2.25, 3.375, 5.0625, 7.59375, 11.390625, 17.0859375, 25.62890625, 38.443359375, 57.6650390625, 86.49755859375, 129.746337890625, 194.6195068359375, 291.92926025390625, 437.89389038085938, 656.84083557128906, 985.26125335693359, 1477.8918800354004, 2216.8378200531006, 1, 2.25, 5.0625, 11.390625, 25.62890625, 57.6650390625, 129.746337890625, 291.92926025390625, 656.84083557128906, 1477.8918800354004, 3325.2567300796509, 7481.8276426792145, 16834.112196028233, 37876.752441063523, 85222.692992392927, 191751.05923288409, 431439.8832739892, 970739.73736647563, 2184164.4090745705, 4914369.9204177829, 1, 3.375, 11.390625, 38.443359375, 129.746337890625, 437.89389038085938, 1477.8918800354004, 4987.8850951194763, 16834.112196028233, 56815.128661595285, 191751.05923288409, 647159.82491098379, 2184164.4090745705, 7371554.8806266747, 24878997.722115029, 83966617.312138215, 283387333.4284665, 956432250.32107437, 3227958844.8336263, 10894361101.313488, 1, 5.0625, 25.62890625, 129.746337890625, 656.84083557128906, 3325.2567300796509, 16834.112196028233, 85222.692992392927, 431439.8832739892, 2184164.4090745705, 11057332.320940012, 55977744.87475881, 283387333.4284665, 1434648375.4816115, 7262907400.875659, 36768468716.933022, 186140372879.47342, 942335637702.33411, 4770574165868.0674, 24151031714707.086, 1, 7.59375, 57.6650390625, 437.89389038085938, 3325.2567300796509, 25251.168294042349, 191751.05923288409, 1456109.6060497134, 11057332.320940012, 83966617.31213823, 637621500.21404958, 4841938267.2504387, 36768468716.933022, 279210559319.21014, 2120255184830.252, 16100687809804.727, 122264598055704.64, 928446791485507, 7050392822843070, 53538920498464552, ]).reshape((20, 5, ), order="F")
--END TEST CASE--
--BEGIN TEST CASE--
degree=5
raw=FALSE
output=np.array([-0.12865949508274149, -0.12846539500908838, -0.12817424489860868, -0.12773751973288924, -0.12708243198431005, -0.12609980036144131, -0.12462585292713815, -0.12241493177568342, -0.11909855004850137, -0.11412397745772825, -0.10666211857156857, -0.095469330242329037, -0.07868014774846975, -0.053496374007680828, -0.015720713396497447, 0.040942777520277619, 0.12593801389544024, 0.25343086845818413, 0.4446701503023, 0.73152907306847381, 0.11682670564764953, 0.11622774987820758, 0.1153299112445243, 0.11398449209008393, 0.11196937564961051, 0.10895347864407183, 0.10444488285989936, 0.097716301062945765, 0.087700630095951776, 0.072850827534442664, 0.05096695744238839, 0.019020528242278005, -0.026920519697452645, -0.091380250921070119, -0.1780532062130448, -0.28552519567824058, -0.39602393206231051, -0.44767622905753701, -0.26843910749340033, 0.57802660073100254, -0.11560888340228653, -0.11436481217184656, -0.11250218782662975, -0.10971608089390825, -0.10555451667328646, -0.099351692934324679, -0.090136150925525155, -0.076511614727544461, -0.05651941299388475, -0.027522538371457093, 0.013772191900731716, 0.070864671671751547, 0.14593497036033168, 0.23591981919395397, 0.32391016867398448, 0.36336942185480259, 0.25890497941187346, -0.11572025100301592, -0.66076386903314166, 0.27159578788942196, 0.11925766326375063, 0.11701962699862156, 0.11367531238125347, 0.10868744714732725, 0.10126981942884175, 0.090287103769210786, 0.074134201646975206, 0.050620044131431986, 0.016933017097416861, -0.030116712154368355, -0.093138533517390085, -0.17160263551697441, -0.25618209006285081, -0.3183631162695052, -0.29707753517866498, -0.10102478727647804, 0.30185248746535442, 0.55289166632880227, -0.46108564710186972, 0.081962667419115426, -0.12626707822019206, -0.12250155553682644, -0.11689136915447108, -0.10856147160045609, -0.096257598068575617, -0.078227654788373013, -0.052128116579684983, -0.015063001240831148, 0.035988153544508683, 0.10280803884977513, 0.18263307034840112, 0.26144732880503613, 0.30325203347309243, 0.24116709207723347, -0.00082575540196283526, -0.37830141983168153, -0.42887161757203512, 0.55207091753656046, -0.17171017635275559, 0.016240179713238136, ]).reshape((20, 5, ), order="F")
--END TEST CASE--
"""
R_poly_num_tests = 6
62 changes: 62 additions & 0 deletions tools/get-R-poly-test-vectors.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
cat("# This file auto-generated by tools/get-R-poly-test-vectors.R\n")
cat(sprintf("# Using: %s\n", R.Version()$version.string))
cat("import numpy as np\n")

options(digits=20)
library(splines)
x <- (1.5)^(0:19)

MISSING <- "MISSING"

is.missing <- function(obj) {
length(obj) == 1 && obj == MISSING
}

pyprint <- function(arr) {
if (is.missing(arr)) {
cat("None\n")
} else {
cat("np.array([")
for (val in arr) {
cat(val)
cat(", ")
}
cat("])")
if (!is.null(dim(arr))) {
cat(".reshape((")
for (size in dim(arr)) {
cat(sprintf("%s, ", size))
}
cat("), order=\"F\")")
}
cat("\n")
}
}

num.tests <- 0
dump.poly <- function(degree, raw) {
cat("--BEGIN TEST CASE--\n")
cat(sprintf("degree=%s\n", degree))
cat(sprintf("raw=%s\n", raw))

args <- list(x=x, degree=degree, raw=raw)

result <- do.call(poly, args)

cat("output=")
pyprint(result)
cat("--END TEST CASE--\n")
assign("num.tests", num.tests + 1, envir=.GlobalEnv)
}

cat("R_poly_test_x = ")
pyprint(x)
cat("R_poly_test_data = \"\"\"\n")

for (degree in c(1, 3, 5)) {
for (raw in c(TRUE, FALSE)) {
dump.poly(degree, raw)
}
}
cat("\"\"\"\n")
cat(sprintf("R_poly_num_tests = %s\n", num.tests))