package org.bensjavaspot.handlers;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.xml.soap.SOAPException;
import javax.xml.transform.stream.StreamSource;
import org.apache.axis.AxisFault;
import org.apache.axis.Message;
import org.apache.axis.MessageContext;
import org.apache.axis.handlers.BasicHandler;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.axis.SOAPPart;

/**
 * Custom Axis handler that allows us to intercept the incoming input stream and
 * remove any bad characters before passing it back into the chain.
 *
 * @author Ben Hills
 */
public class SOAPCleanerHandler extends BasicHandler
{

   private static final Log log = LogFactory.getLog(SOAPCleanerHandler.class);

   public void invoke(MessageContext mtx) throws AxisFault {

      /* Get the SOAP Message */
      Message message = mtx.getResponseMessage();

      if (message != null) {
         try {
            /* Message normally consists of single part and optionally attachments. We just need the main part */
            SOAPPart part = (org.apache.axis.SOAPPart) message.getSOAPPart();

            /* Get the source */
            StreamSource source = (StreamSource) part.getContent();

            /* Get the input XML as a string */
            String messageAsString = message.getSOAPPartAsString();

            /* Regexp utils need a StringBuffer so create new one with input string */
            StringBuffer out = new StringBuffer(messageAsString);

            /* Create regular expresion to remove all &#nnn; type entities */
            Pattern pat = Pattern.compile("&#[0-9]*;");

            Matcher matcher = pat.matcher(out.toString());
            StringBuffer sb = new StringBuffer();

            /* Replace each match found with a single space */
            while (matcher.find()) {
               matcher.appendReplacement(sb, " ");
            }

            matcher.appendTail(sb);

            /* Convert back to byte array using the correct encoding */
            ByteArrayInputStream b = new ByteArrayInputStream(sb.toString().getBytes(part.getEncoding()));

            /* Pass input stream back in to chain */
            source.setInputStream(b);

            /* Set it again */
            part.setContent(source);

         } 
         catch (SOAPException ex) {
            log.error("Failed to get and set source", ex);
         } 
         catch (IOException ex) {
            log.error("Error reading input stream", ex);
         }
      }
   }
}

