package dp; import java.io.*; import java.util.*; import org.biojava.bio.*; import org.biojava.bio.symbol.*; import org.biojava.bio.seq.*; import org.biojava.bio.seq.db.*; import org.biojava.bio.seq.io.*; import org.biojava.bio.seq.db.*; import org.biojava.bio.dist.*; import org.biojava.bio.dp.*; public class SearchProfile { public static Distribution nullModel; public static void main(String [] args) { try { File seqFile = new File(args[0]); FiniteAlphabet PROTEIN = ProteinTools.getAlphabet(); nullModel = new UniformDistribution(PROTEIN); System.out.println("Loading sequences"); SequenceDB seqDB = readSequenceDB(seqFile, PROTEIN); System.out.println("Creating profile HMM"); ProfileHMM profile = createProfile(seqDB, PROTEIN); System.out.println("make dp object"); DP dp = DPFactory.createDP(profile); // dumpDP(dp); Sequence [] seq1 = { seqDB.sequenceIterator().nextSequence() }; System.out.println("Viterbi: " + dp.viterbi(seq1).getScore()); System.out.println("Forward: " + dp.forward(seq1)); System.out.println("Backward: " + dp.backward(seq1)); System.out.println("Training whole profile"); TrainingAlgorithm ta = new BaumWelchTrainer(dp); ta.train(seqDB, nullModel, 5, new StoppingCriteria() { public boolean isTrainingComplete(TrainingAlgorithm ta) { System.out.println("Cycle " + ta.getCycle() + " completed"); System.out.println("Score: " + ta.getCurrentScore()); if(ta.getCycle() == 5) { return true; } else { return false; } } }); System.out.println("Alignining sequences to the model"); for(SequenceIterator si = seqDB.sequenceIterator(); si.hasNext(); ) { Sequence seq = si.nextSequence(); SymbolList [] rl = { seq }; StatePath statePath = dp.viterbi(rl); double fScore = dp.forward(rl); double bScore = dp.backward(rl); System.out.println( seq.getName() + " viterbi: " + statePath.getScore() + ", forwards: " + fScore + ", backwards: " + bScore ); for(int i = 0; i <= statePath.length() / 60; i++) { for(int j = i*60; j < Math.min((i+1)*60, statePath.length()); j++) { System.out.print(statePath.symbolAt(StatePath.SEQUENCE, j+1).getToken()); } System.out.print("\n"); for(int j = i*60; j < Math.min((i+1)*60, statePath.length()); j++) { System.out.print(statePath.symbolAt(StatePath.STATES, j+1).getToken()); } System.out.print("\n"); System.out.print("\n"); } } } catch (Throwable t) { t.printStackTrace(); } } private static ProfileHMM createProfile(SequenceDB seqs, Alphabet alpha) throws Exception { double l = 0; for(SequenceIterator i = seqs.sequenceIterator(); i.hasNext(); ) { l+=Math.log(i.nextSequence().length()); } l /= seqs.ids().size(); int length = (int) Math.exp(l); System.out.println("Estimating alignment as having length " + length); ProfileHMM profile = new ProfileHMM( alpha, length, DistributionFactory.DEFAULT, DistributionFactory.DEFAULT ); randomize(profile); return profile; } public static SequenceDB readSequenceDB(File seqFile, Alphabet alpha) throws Exception { HashSequenceDB seqDB = new HashSequenceDB(HashSequenceDB.byName); SequenceFactory sFact = new SimpleSequenceFactory(); FastaFormat fFormat = new FastaFormat(); SequenceIterator stateI = null; for( SequenceIterator seqI = new StreamReader( new FileInputStream(seqFile), fFormat, alpha.getParser("symbol"), sFact ); seqI.hasNext(); ) { Sequence seq = seqI.nextSequence(); seqDB.addSequence(seq); } return seqDB; } private static void randomize(MarkovModel model) throws Exception { ModelTrainer mt = new SimpleModelTrainer(model, nullModel, 0.001, 0.00001, 1.0); for(Iterator i = model.stateAlphabet().symbols().iterator(); i.hasNext(); ) { State s = (State) i.next(); if(s instanceof EmissionState && !(s instanceof MagicalState) ) { EmissionState es = (EmissionState) s; Distribution dis = es.getDistribution(); FiniteAlphabet fa = (FiniteAlphabet) dis.getAlphabet(); for( Iterator j = fa.iterator(); j.hasNext(); ) { Symbol r = (Symbol) j.next(); mt.addCount(es.getDistribution(), r, Math.random()); } } for(Iterator j = model.transitionsFrom(s).iterator(); j.hasNext(); ) { State t = (State) j.next(); mt.addTransitionCount(s, t, Math.random()); } } mt.train(); mt.clearCounts(); } private static void dumpDP(DP dp) { State [] states = dp.getStates(); System.out.print("states: "); for(int i = 0; i < states.length; i++) { System.out.print(" " + states[i].getName()); } System.out.println("\n"); int [][] forwardT = dp.getForwardTransitions(); double [][] forwardTS = dp.getForwardTransitionScores(); for(int i = 0; i < states.length; i++) { System.out.print("Transitions from " + i + ": "); for(int j = 0; j < forwardT[i].length; j++) { System.out.print( " " + forwardT[i][j] + "(" + forwardTS[i][j] + ")" ); } } } }