### Site Tools

lab:em_for_mediation_model
## EM for the mediation model

## function to calculate the parameters
est<-function(S){
sX2<-S[1]
sM2<-S[2]
sY2<-S[3]
sXM<-S[4]
sXY<-S[5]
sMY<-S[6]

pa<-sXM/sX2
pb<-(sMY*sX2-sXM*sXY)/(sX2*sM2-sXM^2)
pc<-(sXY*sM2-sXM*sMY)/(sX2*sM2-sXM^2)
ps1<-sX2
ps2<-sM2-sXM^2/sX2
ps3<-(sX2*sM2*sY2-sX2*sMY^2-sM2*sXY^2-sY2*sXM^2+2*sXM*sXY*sMY)/(sX2*sM2-sXM^2)
return(c(pa,pb,pc,ps1,ps2,ps3))
}

## expectation step
EMe<-function(data, para){
X<-data[,1]
Y<-data[,3]
M<-data[,2]

pa<-para[1]
pb<-para[2]
pc<-para[3]
ps1<-para[4]
ps2<-para[5]
ps3<-para[6]

n<-dim(data)[1]

sX2<-rep(0,n)
sM2<-rep(0,n)
sY2<-rep(0,n)
sXM<-rep(0,n)
sXY<-rep(0,n)
sMY<-rep(0,n)

for (i in 1:n){
if (!is.na(M[i])){
if (!is.na(Y[i])){
## for complete data case
sX2[i]= X[i]^2
sM2[i]= M[i]^2
sY2[i]= Y[i]^2
sXM[i]= X[i]*M[i]
sXY[i]= X[i]*Y[i]
sMY[i]= M[i]*Y[i]
}else{
## for complete X, M but missing Y
sX2[i]= X[i]^2
sM2[i]= M[i]^2
sY2[i]= pb^2*M[i]^2+pc^2*X[i]^2+2*pb*pc*X[i]*M[i]+ps3
sXM[i]= X[i]*M[i]
sXY[i]= pb*X[i]*M[i]+pc*X[i]^2
sMY[i]= pc*X[i]*M[i]+pb*M[i]^2
}
}else{
if (!is.na(Y[i])){
## for complete X, Y but missing M
sX2[i]= X[i]^2
sM2[i]= pa^2*X[i]^2+ps2
sY2[i]= Y[i]^2
sXM[i]= pa*X[i]^2
sXY[i]= X[i]*Y[i]
sMY[i]= pa*X[i]*Y[i]
}else{
## for complete X but missing M, Y
sX2[i]= X[i]^2
sM2[i]= pa^2*X[i]^2+ps2
sY2[i]= (pa*pb+pc)^2*X[i]^2+pb^2*ps2+ps3
sXM[i]= pa*X[i]^2
sXY[i]= (pa*pb+pc)*X[i]^2
sMY[i]= pa*(pa*pb+pc)*X[i]^2+pb*ps2
}
}

}

return(c(sum(sX2)/(n-1),sum(sM2)/(n-1), sum(sY2)/(n-1), sum(sXM)/(n-1), sum(sXY)/(n-1), sum(sMY)/(n-1)))
}

## Simulation to compare the EM method and the listwise delete method

## Matrice to store the results
R<-100
comp<-array(NA,dim=c(R,6))
em<-array(NA,dim=c(R,6))
true<-array(NA,dim=c(R,6))
listdel<-array(NA,dim=c(R,6))

for (j in 1:R){
## data generation
X<-rnorm(1000)
M<-.5*X + .1*rnorm(1000)
Y<-.5*M + .1*X + .1*rnorm(1000)
temp<-cov(cbind(X,M,Y))
par<-est(c(temp[1,1],temp[2,2],temp[3,3],temp[1,2],temp[1,3],temp[2,3]))
true[j,]<-par

## Create missing data
#for (i in 1:1000){
#  if (runif(1)<.1){ M[i]<-NA }
#  if (runif(1)<.1){ Y[i]<-NA }
#}
Y[801:900]<-NA
M[901:1000]<-NA
data<-cbind(X,M,Y)

## the list wise delete method
temp<-cov(data,use='complete.obs')
par<-est(c(temp[1,1],temp[2,2],temp[3,3],temp[1,2],temp[1,3],temp[2,3]))
comp[j,]<-par

temp<-cov(data,use='pairwise.complete.obs')
par<-est(c(temp[1,1],temp[2,2],temp[3,3],temp[1,2],temp[1,3],temp[2,3]))
listdel[j,]<-par

e<-1

while (e>.00000000001){
SS<-EMe(data, par)
para<-est(SS)
e<-sum(abs(para-par))
par<-para
#print(para,digits=20)
}

em[j,]<-para
}

apply(true,2,mean)
apply(comp,2,mean)
apply(em,2,mean)
apply(listdel,2,mean)