## EM for the mediation model
 
## function to calculate the parameters
est<-function(S){
  sX2<-S[1]
  sM2<-S[2]
  sY2<-S[3]
  sXM<-S[4]
  sXY<-S[5]
  sMY<-S[6]
  mX <-S[7]
  mM <-S[8]
  mY <-S[9]
 
  pa<-sXM/sX2
  pb<-(sMY*sX2-sXM*sXY)/(sX2*sM2-sXM^2)
  pc<-(sXY*sM2-sXM*sMY)/(sX2*sM2-sXM^2)
  ps1<-sX2
  ps2<-sM2-sXM^2/sX2
  ps3<-(sX2*sM2*sY2-sX2*sMY^2-sM2*sXY^2-sY2*sXM^2+2*sXM*sXY*sMY)/(sX2*sM2-sXM^2)
 
  pm1<-mX
  pm2<-mM - pa*mX
  pm3<-mY - pb*mM - pc*mX
 
 
 
  return(c(pa,pb,pc,ps1,ps2,ps3,pm1,pm2,pm3))
}
 
 
## expectation step
EMe<-function(dset, para){
  X<-dset[,1]
  M<-dset[,2]
  Y<-dset[,3]
 
  pa<-para[1]
  pb<-para[2]
  pc<-para[3]
  ps1<-para[4]
  ps2<-para[5]
  ps3<-para[6]
 
  pm1<-para[7]
  pm2<-para[8]
  pm3<-para[9]
 
  n<-dim(dset)[1]
 
  sX2<-rep(0,n)
  sM2<-rep(0,n)
  sY2<-rep(0,n)
  sXM<-rep(0,n)
  sXY<-rep(0,n)
  sMY<-rep(0,n)
  mX <-rep(0,n)
  mM <-rep(0,n)
  mY <-rep(0,n)
 
  m<-n-1
 
  for (i in 1:n){
    if (!is.na(M[i])){
       if (!is.na(Y[i])){
         ## for complete dset case
         mX[i] =X[i]
         mM[i] =M[i]
         mY[i] =Y[i]
         sX2[i]= X[i]^2
         sM2[i]= M[i]^2
         sY2[i]= Y[i]^2
         sXM[i]= X[i]*M[i] 
         sXY[i]= X[i]*Y[i]
         sMY[i]= M[i]*Y[i]
       }else{
         ## for complete X, M but missing Y
         mX[i] =X[i]
         mM[i] =M[i]
         mY[i] = pm3 + pb*M[i]+pc*X[i]
         sX2[i]= X[i]^2
         sM2[i]= M[i]^2
         sY2[i]= pm3^2 + pb^2*M[i]^2+pc^2*X[i]^2+2*pb*pc*X[i]*M[i]+2*pm3*pb*M[i]+2*pm3*pc*X[i]+ps3
         sXM[i]= X[i]*M[i]
         sXY[i]= pm3*X[i] + pb*X[i]*M[i]+pc*X[i]^2
         sMY[i]= pm3*M[i] + pc*X[i]*M[i]+pb*M[i]^2   
       }
    }else{
       if (!is.na(Y[i])){
         ## for complete X, Y but missing M
         mX[i] =X[i]
         mM[i] =pm2 + pa*X[i]
         mY[i] =Y[i]
         sX2[i]= X[i]^2
         sM2[i]= pm2^2 + 2*pm2*pa*X[i]+pa^2*X[i]^2+ps2
         sY2[i]= Y[i]^2
         sXM[i]= pm2*X[i]+pa*X[i]^2
         sXY[i]= X[i]*Y[i]
         sMY[i]= pm2*Y[i]+pa*X[i]*Y[i]
       }else{
         ## for complete X but missing M, Y
         mX[i] =X[i]
         mM[i] = pm2 + pa*X[i]
         mY[i] = pm3 + pm2*pb + (pa*pb+pc)*X[i]
         sX2[i]= X[i]^2
         sM2[i]= pm2^2+pa^2*X[i]^2+2*pm2*pa*X[i]+ps2
         sY2[i]= (pm3+pb*pm2)^2 + (pa*pb+pc)^2*X[i]^2+ 2*(pm3+pb*pm2)*(pa*pb+pc)*X[i] + pb^2*ps2+ps3
         sXM[i]= pm2*X[i]+pa*X[i]^2
         sXY[i]= (pm3+pb*pm2)*X[i]+(pa*pb+pc)*X[i]^2
         sMY[i]= pm2*(pm3+pb*pm2)+pm2*(pa*pb+pc)*X[i]+pa*(pm3+pb*pm2)*X[i]+pa*(pa*pb+pc)*X[i]^2+pb*ps2   
       }
     }
 
  }
  ## calculate the covariance matrix
  CM <- rep(0, 9)
  CM[7] <- sum(mX)/n
  CM[8] <- sum(mM)/n
  CM[9] <- sum(mY)/n
  CM[1] <- sum(sX2)/m - n/m*CM[7]^2
  CM[2] <- sum(sM2)/m - n/m*CM[8]^2
  CM[3] <- sum(sY2)/m - n/m*CM[9]^2
  CM[4] <- sum(sXM)/m - n/m*CM[7]*CM[8]
  CM[5] <- sum(sXY)/m - n/m*CM[7]*CM[9]
  CM[6] <- sum(sMY)/m - n/m*CM[9]*CM[8]
 
 
  return(CM)
}
 
## Simulation to compare the EM method and the listwise delete method
 
## Matrice to store the results
R<-1000
pairdel<-array(NA,dim=c(R,9))
em<-array(NA,dim=c(R,9))
true<-array(NA,dim=c(R,9))
listdel<-array(NA,dim=c(R,9))
ml<-array(NA,dim=c(R,9))
 
sampN<-rep(0,R)
 
N<-100
 
for (j in 1:R){
## dset generation
 X<-rnorm(N)
 M<-.5*X + .1*rnorm(N)
 Y<-.5*M + .1*X + .1*rnorm(N)
 
 temp<-cov(cbind(X,M,Y))
 par<-est(c(temp[1,1],temp[2,2],temp[3,3],temp[1,2],temp[1,3],temp[2,3],mean(X),mean(M),mean(Y)))
 true[j,]<-par
 
## Create missing dset
for (i in 1:N){
  p1<-1/(1+exp(2-.1*X))
  if (runif(1)<p1[i]){ M[i]<-NA }
  if (runif(1)<p1[i]){ Y[i]<-NA }
}
 
 
dset<-cbind(X,M,Y)
 
## EM methods
 
e<-1
 
while (e>.00001){
  SS<-EMe(dset, par)
  para<-est(SS)
  e<-sum(abs(para-par))
  par<-para
 # print(par,digits=10)
}
 
em[j,]<-para
 
## ML method
## save the data
write.table(dset, "data.dat", na='.',row.names=F, col.names=F)
system('c:\\programs\\mplus\\mplus.exe mle.inp',show.output.on.console = F)
tempres<-scan('est.txt',quiet=T)
ml[j,]<-tempres[c(6,4,5,9,8,7,3,2,1)]
 
## the list wise delete method
isna<-is.na(dset)
sumisna<-apply(isna,1,sum)
listdata<-dset[sumisna==0,]
sampN[j]<-dim(listdata)[1]
 
temp<-cov(listdata,use='complete.obs')
par<-est(c(temp[1,1],temp[2,2],temp[3,3],temp[1,2],temp[1,3],temp[2,3],mean(listdata[,1]),mean(listdata[,2]),mean(listdata[,3])))
listdel[j,]<-par
 
## pairwise delete
temp<-cov(dset,use='pairwise.complete.obs')
par<-est(c(temp[1,1],temp[2,2],temp[3,3],temp[1,2],temp[1,3],temp[2,3],mean(dset[,1],na.rm=T),mean(dset[,2],na.rm=T),mean(dset[,3],na.rm=T)))
pairdel[j,]<-par
 
}
 
apply(true,2,mean)
apply(pairdel,2,mean)
apply(em,2,mean)
apply(listdel,2,mean)
apply(ml,2,mean)