Permalink
Join GitHub today
GitHub is home to over 28 million developers working together to host and review code, manage projects, and build software together.
Sign up
Fetching contributors…
Cannot retrieve contributors at this time
--- | |
title: "Tidy, Type-Safe 'prediction()' Methods" | |
output: github_document | |
--- | |
<img src="man/figures/logo.png" align="right" /> | |
The **prediction** and **margins** packages are a combined effort to port the functionality of Stata's (closed source) [`margins`](http://www.stata.com/help.cgi?margins) command to (open source) R. **prediction** is focused on one function - `prediction()` - that provides type-safe methods for generating predictions from fitted regression models. `prediction()` is an S3 generic, which always return a `"data.frame"` class object rather than the mix of vectors, lists, etc. that are returned by the `predict()` methods for various model types. It provides a key piece of underlying infrastructure for the **margins** package. Users interested in generating marginal (partial) effects, like those generated by Stata's `margins, dydx(*)` command, should consider using `margins()` from the sibling project, [**margins**](https://cran.r-project.org/package=margins). | |
In addition to `prediction()`, this package provides a number of utility functions for generating useful predictions: | |
- `find_data()`, an S3 generic with methods that find the data frame used to estimate a regression model. This is a wrapper around `get_all_vars()` that attempts to locate data as well as modify it according to `subset` and `na.action` arguments used in the original modelling call. | |
- `mean_or_mode()` and `median_or_mode()`, which provide a convenient way to compute the data needed for predicted values *at means* (or *at medians*), respecting the differences between factor and numeric variables. | |
- `seq_range()`, which generates a vector of *n* values based upon the range of values in a variable | |
- `build_datalist()`, which generates a list of data frames from an input data frame and a specified set of replacement `at` values (mimicking the `atlist` option of Stata's `margins` command) | |
## Simple code examples | |
```{r opts, echo = FALSE} | |
library("knitr") | |
options(width = 100) | |
opts_knit$set(upload.fun = imgur_upload, base.url = NULL) | |
opts_chunk$set(fig.width=7, fig.height=4) | |
``` | |
A major downside of the `predict()` methods for common modelling classes is that the result is not type-safe. Consider the following simple example: | |
```{r predict} | |
library("stats") | |
library("datasets") | |
x <- lm(mpg ~ cyl * hp + wt, data = mtcars) | |
class(predict(x)) | |
class(predict(x, se.fit = TRUE)) | |
``` | |
**prediction** solves this issue by providing a wrapper around `predict()`, called `prediction()`, that always returns a tidy data frame with a very simple `print()` method: | |
```{r prediction} | |
library("prediction") | |
(p <- prediction(x)) | |
class(p) | |
head(p) | |
``` | |
The output always contains the original data (i.e., either data found using the `find_data()` function or passed to the `data` argument to `prediction()`). This makes it much simpler to pass predictions to, e.g., further summary or plotting functions. | |
Additionally the vast majority of methods allow the passing of an `at` argument, which can be used to obtain predicted values using modified version of `data` held to specific values: | |
```{r at_arg} | |
prediction(x, at = list(hp = seq_range(mtcars$hp, 5))) | |
``` | |
This more or less serves as a direct R port of (the subset of functionality of) Stata's `margins` command that calculates predictive marginal means, etc. For calculation of marginal or partial effects, see the [**margins**](https://cran.r-project.org/package=margins) package. | |
## Supported model classes | |
The currently supported model classes are: | |
- "lm" from `stats::lm()` | |
- "glm" from `stats::glm()`, `MASS::glm.nb()`, `glmx::glmx()`, `glmx::hetglm()`, `brglm::brglm()` | |
- "ar" from `stats::ar()` | |
- "Arima" from `stats::arima()` | |
- "arima0" from `stats::arima0()` | |
- "biglm" from `biglm::biglm()` (including `"ffdf"` backed models) | |
- "betareg" from `betareg::betareg()` | |
- "bruto" from `mda::bruto()` | |
- "clm" from `ordinal::clm()` | |
- "coxph" from `survival::coxph()` | |
- "crch" from `crch::crch()` | |
- "earth" from `earth::earth()` | |
- "fda" from `mda::fda()` | |
- "Gam" from `gam::gam()` | |
- "gausspr" from `kernlab::gausspr()` | |
- "gee" from `gee::gee()` | |
- "glimML" from `aod::betabin()`, `aod::negbin()` | |
- "glimQL" from `aod::quasibin()`, `aod::quasipois()` | |
- "glmnet" from `glmnet::glmnet()` | |
- "gls" from `nlme::gls()` | |
- "hurdle" from `pscl::hurdle()` | |
- "hxlr" from `crch::hxlr()` | |
- "ivreg" from `AER::ivreg()` | |
- "knnreg" from `caret::knnreg()` | |
- "kqr" from `kernlab::kqr()` | |
- "ksvm" from `kernlab::ksvm()` | |
- "lda" from `MASS:lda()` | |
- "lme" from `nlme::lme()` | |
- "loess" from `stats::loess()` | |
- "lqs" from `MASS::lqs()` | |
- "mars" from `mda::mars()` | |
- "mca" from `MASS::mca()` | |
- "mclogit" from `mclogit::mclogit()` | |
- "mda" from `mda::mda()` | |
- "merMod" from `lme4::lmer()` and `lme4::glmer()` | |
- "mnlogit" from `mnlogit::mnlogit()` | |
- "mnp" from `MNP::mnp()` | |
- "naiveBayes" from `e1071::naiveBayes()` | |
- "nlme" from `nlme::nlme()` | |
- "nls" from `stats::nls()` | |
- "nnet" from `nnet::nnet()`, `nnet::multinom()` | |
- "plm" from `plm::plm()` | |
- "polr" from `MASS::polr()` | |
- "ppr" from `stats::ppr()` | |
- "princomp" from `stats::princomp()` | |
- "qda" from `MASS:qda()` | |
- "rlm" from `MASS::rlm()` | |
- "rpart" from `rpart::rpart()` | |
- "rq" from `quantreg::rq()` | |
- "selection" from `sampleSelection::selection()` | |
- "speedglm" from `speedglm::speedglm()` | |
- "speedlm" from `speedglm::speedlm()` | |
- "survreg" from `survival::survreg()` | |
- "svm" from `e1071::svm()` | |
- "svyglm" from `survey::svyglm()` | |
- "tobit" from `AER::tobit()` | |
- "train" from `caret::train()` | |
- "truncreg" from `truncreg::truncreg()` | |
- "zeroinfl" from `pscl::zeroinfl()` | |
## Requirements and Installation | |
[![CRAN](https://www.r-pkg.org/badges/version/prediction)](https://cran.r-project.org/package=prediction) | |
![Downloads](https://cranlogs.r-pkg.org/badges/prediction) | |
[![Build Status](https://travis-ci.org/leeper/prediction.svg?branch=master)](https://travis-ci.org/leeper/prediction) | |
[![Build status](https://ci.appveyor.com/api/projects/status/a4tebeoa98cq07gy/branch/master?svg=true)](https://ci.appveyor.com/project/leeper/prediction/branch/master) | |
[![codecov.io](https://codecov.io/github/leeper/prediction/coverage.svg?branch=master)](https://codecov.io/github/leeper/prediction?branch=master) | |
[![Project Status: Active - The project has reached a stable, usable state and is being actively developed.](http://www.repostatus.org/badges/latest/active.svg)](http://www.repostatus.org/#active) | |
The development version of this package can be installed directly from GitHub using `remotes`: | |
``` r | |
if (!require("remotes")) { | |
install.packages("remotes") | |
} | |
remotes::install_github("leeper/prediction") | |
``` |