Permalink
Browse files

finish core support for prediction variances (#17)

  • Loading branch information...
leeper committed Aug 3, 2018
1 parent 3bcf0cf commit da74f403803c4495d4b3732debed6b611c683c2f
Showing with 253 additions and 120 deletions.
  1. +3 −5 NAMESPACE
  2. +2 −0 NEWS.md
  3. +1 −1 R/mean_or_mode.R
  4. +19 −12 R/prediction.R
  5. +36 −5 R/prediction_glm.R
  6. +17 −9 R/prediction_lm.R
  7. +1 −1 R/prediction_mclogit.R
  8. +0 −8 R/prediction_summary.R
  9. +59 −64 R/print.R
  10. +65 −0 R/summary.R
  11. +25 −9 man/prediction.Rd
  12. +25 −6 tests/testthat/tests-core.R
@@ -80,21 +80,19 @@ S3method(prediction,train)
S3method(prediction,truncreg)
S3method(prediction,zeroinfl)
S3method(print,prediction)
S3method(print,summary.prediction)
S3method(summary,prediction)
S3method(tail,prediction)
export(build_datalist)
export(find_data)
export(mean_or_mode)
export(median_or_mode)
export(prediction)
export(prediction_summary)
export(seq_range)
import(stats)
importFrom(data.table,rbindlist)
importFrom(stats,aggregate)
importFrom(stats,get_all_vars)
importFrom(stats,median)
importFrom(stats,model.frame)
importFrom(stats,predict)
importFrom(stats,setNames)
importFrom(stats,terms)
importFrom(utils,head)
importFrom(utils,tail)
@@ -1,6 +1,8 @@
# prediction 0.3.7

* `summary(prediction(...))` now reports variances of average predictions, along with test statistics, p-values, and confidence intervals, where supported. (#17)
* Added a function `prediction_summary()` which simply calls `summary(prediction(...))`.
* All methods now return additional attributes.

# prediction 0.3.6

@@ -14,7 +14,7 @@
#' median_or_mode(iris)
#'
#' @seealso \code{\link{prediction}}, \code{\link{build_datalist}}, \code{\link{seq_range}}
#' @importFrom stats median setNames
#' @import stats
#' @export
mean_or_mode <- function(x) {
UseMethod("mean_or_mode")
@@ -6,7 +6,8 @@
#' @param data A data.frame over which to calculate marginal effects. If missing, \code{\link{find_data}} is used to specify the data frame.
#' @param at A list of one or more named vectors, specifically values at which to calculate the predictions. These are used to modify the value of \code{data} (see \code{\link{build_datalist}} for details on use).
#' @param type A character string indicating the type of marginal effects to estimate. Mostly relevant for non-linear models, where the reasonable options are \dQuote{response} (the default) or \dQuote{link} (i.e., on the scale of the linear predictor in a GLM). For models of class \dQuote{polr} (from \code{\link[MASS]{polr}}), possible values are \dQuote{class} or \dQuote{probs}; both are returned.
#' @param calculate_se A logical indicating whether to calculate standard errors (if possible). The output will always contain a \dQuote{calculate_se} column regardless of this value; this only controls the calculation of standard errors. Setting it to \code{FALSE} may improve speed.
#' @param vcov A matrix containing the variance-covariance matrix for estimated model coefficients, or a function to perform the estimation with \code{model} as its only argument.
#' @param calculate_se A logical indicating whether to calculate standard errors for observation-specific predictions and average predictions (if possible). The output will always contain a \dQuote{calculate_se} column regardless of this value; this only controls the calculation of standard errors. Setting it to \code{FALSE} may improve speed.
#' @param category For multi-level or multi-category outcome models (e.g., ordered probit, multinomial logit, etc.), a value specifying which of the outcome levels should be used for the \code{"fitted"} column. If missing, some default is chosen automatically.
#' @param \dots Additional arguments passed to \code{\link[stats]{predict}} methods.
#' @details This function is simply a wrapper around \code{\link[stats]{predict}} that returns a data frame containing the value of \code{data} and the predicted values with respect to all variables specified in \code{data}.
@@ -76,7 +77,7 @@
#' \item \dQuote{zeroinfl}, see \code{\link[pscl]{zeroinfl}}
#' }
#'
#' @return A data frame with class \dQuote{prediction} that has a number of rows equal to number of rows in \code{data}, or a multiple thereof, if \code{!is.null(at)}. The return value contains \code{data} (possibly modified by \code{at} using \code{\link{build_datalist}}), plus a column containing fitted/predicted values (\code{"fitted"}) and a column containing the standard errors thereof (\code{"calculate_se"}). Additional columns may be reported depending on the object class.
#' @return A data frame with class \dQuote{prediction} that has a number of rows equal to number of rows in \code{data}, or a multiple thereof, if \code{!is.null(at)}. The return value contains \code{data} (possibly modified by \code{at} using \code{\link{build_datalist}}), plus a column containing fitted/predicted values (\code{"fitted"}) and a column containing the standard errors thereof (\code{"calculate_se"}). Additional columns may be reported depending on the object class. The data frame also carries attributes used by \code{print} and \code{summary}, which will be lost during subsetting.
#' @examples
#' require("datasets")
#' x <- lm(Petal.Width ~ Sepal.Length * Sepal.Width * Species, data = iris)
@@ -87,7 +88,10 @@
#' prediction(x, iris[1,])
#'
#' # basic use of 'at' argument
#' prediction(x, at = list(Species = c("setosa", "virginica")))
#' summary(prediction(x, at = list(Species = c("setosa", "virginica"))))
#'
#' # basic use of 'at' argument
#' prediction(x, at = list(Sepal.Length = seq_range(iris$Sepal.Length, 5)))
#'
#' # prediction at means/modes of input variables
#' prediction(x, at = lapply(iris, mean_or_mode))
@@ -104,7 +108,7 @@
#'
#' @keywords models
#' @seealso \code{\link{find_data}}, \code{\link{build_datalist}}, \code{\link{mean_or_mode}}, \code{\link{seq_range}}
#' @importFrom stats predict get_all_vars model.frame
#' @import stats
#' @export
prediction <- function(model, ...) {
UseMethod("prediction")
@@ -133,9 +137,7 @@ function(model,
}
} else {
# setup data
if (is.null(at)) {
data <- data
} else {
if (!is.null(at)) {
data <- build_datalist(data, at = at, as.data.frame = TRUE)
at_specification <- attr(data, "at_specification")
}
@@ -151,12 +153,17 @@ function(model,
}
}

# obs-x-(ncol(data)+2) data.frame of predictions
# variance(s) of average predictions
vc <- NA_real_

# output
structure(pred,
class = c("prediction", "data.frame"),
row.names = seq_len(nrow(pred)),
at = if (is.null(at)) at else at_specification,
vc = vc,
model.class = class(model),
type = type)
type = type,
call = if ("call" %in% names(model)) model[["call"]] else NULL,
model_class = class(model),
row.names = seq_len(nrow(pred)),
vcov = vc,
weighted = FALSE)
}
@@ -40,12 +40,43 @@ function(model,
}
}

# obs-x-(ncol(data)+2) data frame
# variance(s) of average predictions
if (isTRUE(calculate_se)) {
# handle case where SEs are calculated
model_terms <- delete.response(terms(model))
if (is.null(at)) {
# no 'at_specification', so calculate variance of overall average prediction
model_frame <- model.frame(model_terms, data, na.action = na.pass, xlev = model$xlevels)
model_mat <- model.matrix(model_terms, model_frame, contrasts.arg = model$contrasts)
means_for_prediction <- colMeans(model_mat)
vc <- (means_for_prediction %*% vcov %*% means_for_prediction)[1L, 1L, drop = TRUE]
} else {
# with 'at_specification', calculate variance of all counterfactual predictions
datalist <- build_datalist(data, at = at, as.data.frame = FALSE)
vc <- unlist(lapply(datalist, function(one) {
model_frame <- model.frame(model_terms, one, na.action = na.pass, xlev = model$xlevels)
model_mat <- model.matrix(model_terms, model_frame, contrasts.arg = model$contrasts)
means_for_prediction <- colMeans(model_mat)
means_for_prediction %*% vcov %*% means_for_prediction
}))
}
} else {
# handle case where SEs are *not* calculated
if (length(at)) {
vc <- rep(NA_real_, nrow(at_specification))
} else {
vc <- NA_real_
}
}

# output
structure(pred,
class = c("prediction", "data.frame"),
row.names = seq_len(nrow(pred)),
class = c("prediction", "data.frame"),
at = if (is.null(at)) at else at_specification,
type = type,
call = if ("call" %in% names(model)) model[["call"]] else NULL,
model_class = class(model),
row.names = seq_len(nrow(pred)),
vcov = vc,
model.class = class(model),
type = type)
weighted = FALSE)
}
@@ -42,33 +42,41 @@ function(model,

# variance(s) of average predictions
if (isTRUE(calculate_se)) {
# handle case where SEs are calculated
model_terms <- delete.response(terms(model))
if (is.null(at)) {
# no 'at_specification', so calculate variance of overall average prediction
model_frame <- model.frame(model_terms, data, na.action = na.pass, xlev = model$xlevels)
model_mat <- model.matrix(model_terms, model_frame, contrasts.arg = model$contrasts)
predicted_means <- colMeans(model_mat, na.rm = TRUE)
vc <- (predicted_means %*% vcov %*% predicted_means)[1L, 1L, drop = TRUE]
means_for_prediction <- colMeans(model_mat)
vc <- (means_for_prediction %*% vcov %*% means_for_prediction)[1L, 1L, drop = TRUE]
} else {
# with 'at_specification', calculate variance of all counterfactual predictions
datalist <- build_datalist(data, at = at, as.data.frame = FALSE)
vc <- unlist(lapply(datalist, function(one) {
model_frame <- model.frame(model_terms, one, na.action = na.pass, xlev = model$xlevels)
model_mat <- model.matrix(model_terms, model_frame, contrasts.arg = model$contrasts)
predicted_means <- colMeans(model_mat, na.rm = TRUE)
predicted_means %*% vcov %*% predicted_means
means_for_prediction <- colMeans(model_mat)
means_for_prediction %*% vcov %*% means_for_prediction
}))
}
} else {
vc <- NA_real_
# handle case where SEs are *not* calculated
if (length(at)) {
vc <- rep(NA_real_, nrow(at_specification))
} else {
vc <- NA_real_
}
}

# obs-x-(ncol(data)+2) data frame
# output
structure(pred,
class = c("prediction", "data.frame"),
row.names = seq_len(nrow(pred)),
at = if (is.null(at)) at else at_specification,
type = type,
call = if ("call" %in% names(model)) model[["call"]] else NULL,
model_class = class(model),
row.names = seq_len(nrow(pred)),
vcov = vc,
model.class = class(model),
type = type)
weighted = FALSE)
}
@@ -1,3 +1,3 @@
#' @rdname prediction
#' @export
prediction.mclogit <- prediction.glm
prediction.mclogit <- prediction.default

This file was deleted.

Oops, something went wrong.
123 R/print.R
@@ -1,84 +1,79 @@
#' @importFrom utils head
#' @importFrom stats aggregate
#' @export
summary.prediction <- function(object, digits = 4, ...) {
# summary method
# function also called by `print.prediction()` if object has an 'at' specification

# gather metadata
f <- object[["fitted"]]
fc <- object[["fitted.class"]]
vc <- attributes(object)[["vcov"]]

# convert 'at_specification' into data frame
at <- attributes(object)[["at"]]
if (is.null(at)) {
objectby <- list(rep(1L, nrow(object)))
} else {
objectby <- object[ , setdiff(names(at), "index"), drop = FALSE]
}

# calculate average/modal predictions
if (!"fitted.class" %in% names(object) || is.list(fc)) {
# numeric outcome
## aggregate average predictions from data
out <- aggregate(object[["fitted"]], objectby, FUN = mean, na.rm = TRUE)
## extract calculated variance from object
out$SE <- sqrt(vc)

# message
message(paste0("Data frame with average ", ngettext(nrow(out), "prediction", "predictions"),
" for ", length(f)/nrow(out), " ",
ngettext(nrow(out), "observation", "observations"), ":"))
} else {
# factor outcome
out <- aggregate(object[["fitted.class"]], objectby, FUN = function(set) names(sort(table(set), decreasing = TRUE))[1L])

# message
message(paste0("Data frame with modal ", ngettext(nrow(out), "prediction", "predictions"),
" (of ", nlevels(factor(fc)), " ", ngettext(nlevels(factor(fc)), "level", "levels"),
") for ", length(fc), " ", ngettext(length(fc), "observation", "observations"), ": "))
}

# cleanup output
names(out)[!names(out) %in% c("x", "SE")] <- paste0("at(", names(out)[!names(out) %in% c("x", "SE")], ")")
names(out)[names(out) == "x"] <- "Prediction"
if (is.null(at)) {
out <- out[, c("Prediction", "SE"), drop = FALSE]
}

# print and return
print(out, digits = digits, row.names = FALSE, ...)
invisible(out)
}

#' @export
print.prediction <- function(x, digits = 4, ...) {

# gather metadata
f <- x[["fitted"]]
fc <- x[["fitted.class"]]
## at
at <- attributes(x)[["at"]]
vc <- attributes(x)[["vcov"]]
at_names <- setdiff(names(attr(x, "at")), "index")

## weights
is_weighted <- attr(x, "weighted")
if (isTRUE(is_weighted)) {
wts <- x[["_weights"]]
}

# calculate overall predictions
## if no 'at_specification', simply calculate overall average/mode and print
if (is.null(at)) {
# object is a single replication with no 'at' specification
if (!"fitted.class" %in% names(x) || is.list(fc)) {
# numeric outcome
m <- sprintf(paste0("%0.", digits, "f"), mean(f, na.rm = TRUE))
message(paste0("Data frame with average prediction for ", length(f), " ", ngettext(length(f), "observation", "observations"),
": ", m, " (se = ", sprintf(paste0("%0.", digits, "f"), sqrt(vc)), ")"))
} else {
if ("fitted.class" %in% names(x) || is.list(fc)) {
# factor outcome
m <- sort(table(x[["fitted.class"]]), decreasing = TRUE)[1L]
message(paste0("Data frame with modal prediction (of ", nlevels(factor(fc)), " ", ngettext(nlevels(f), "level", "levels"),
") for ", length(fc), " ", ngettext(length(fc), "observation", "observations"), ": ", shQuote(names(m))))
message(
sprintf("Data frame with %d %s%swith modal prediction (of %d %s):",
length(fc),
ngettext(length(fc), "prediction", "predictions"),
if (!is.null(attr(x, "call"))) sprintf(" from\n %s\n", paste0(deparse(attr(x, "call")), collapse = "\n")) else "",
nlevels(factor(fc)),
ngettext(nlevels(f), "level", "levels"),
shQuote(names(m))
)
)
} else {
# numeric outcome
message(
sprintf("Data frame with %d %s%swith average prediction: %s",
length(f),
ngettext(length(fc), "prediction", "predictions"),
if (!is.null(attr(x, "call"))) sprintf(" from\n %s\n", paste0(deparse(attr(x, "call")), collapse = "\n")) else "",
sprintf(paste0("%0.", digits, "f"), mean(f, na.rm = TRUE))
)
)
}
} else {
# otherwise, object has an 'at' specification, reflecting multiple requested predictions
summary(object = x, ...)

# convert 'at_specification' into data frame
xby <- x[ , setdiff(names(at), "index"), drop = FALSE]

if ("fitted.class" %in% names(x) || is.list(fc)) {
# factor outcome
out <- aggregate(x[["fitted.class"]], xby, FUN = function(set) names(sort(table(set), decreasing = TRUE))[1L])
message(
sprintf("Data frame with %d %s%swith modal %s (of %d %s):",
nrow(x),
ngettext(nrow(x), "prediction", "predictions"),
if (!is.null(attr(x, "call"))) sprintf(" from\n %s\n", paste0(deparse(attr(x, "call")), collapse = "\n")) else "",
ngettext(nrow(out), "prediction", "predictions"),
nlevels(factor(fc)),
ngettext(nlevels(fc), "level", "levels")
)
)
} else {
# numeric outcome
out <- aggregate(x[["fitted"]], xby, FUN = mean, na.rm = TRUE)
message(
sprintf("Data frame with %d %s%swith average %s:",
nrow(x),
ngettext(nrow(x), "prediction", "predictions"),
if (!is.null(attr(x, "call"))) sprintf(" from\n %s\n", paste0(deparse(attr(x, "call")), collapse = "\n")) else "",
ngettext(nrow(out), "prediction", "predictions")
)
)
}
print(out, digits = digits, row.names = FALSE, ...)
}

# return invisibly
Oops, something went wrong.

0 comments on commit da74f40

Please sign in to comment.