using System;
using System.Collections;
using System.Data;
namespace TradeWeapon.Classification.Bayesian
{
public class NaiveBayesClassifier
{
#region fields
private DataTable sourceData;
private ArrayList classes = new ArrayList();
private string classColumnName;
private const string TABLE_NAME = "source";
#endregion fields
#region ctors
public NaiveBayesClassifier()
{
}
#endregion ctors
#region public methods
public void LoadDataSource(DataTable source, string classColName)
{
sourceData = source;
classColumnName = classColName;
ExtractClasses();
}
public void LoadDataSource(DataSet source, string tableName, string classColName)
{
sourceData = source.Tables[ tableName ];
LoadDataSource(sourceData, classColName);
}
public void AddToSource(DataRow newRow)
{
AddToSource( newRow, false );
}
public void AddToSource(DataRow newRow, bool updateClasses)
{
if(sourceData == null)
{
throw new NullReferenceException("sourceData has not been initialized");
}
sourceData.Rows.Add( newRow );
if( updateClasses )
{
ExtractClasses();
}
}
public DataRow GetPrototypeRow()
{
if(sourceData == null)
{
throw new NullReferenceException("sourceData has not been initialized");
}
return sourceData.NewRow();
}
public string MostLikelyClass(DataRow unknown)
{
string mostLikelyClass = classes[0] as string;
double totalCount = Convert.ToDouble( sourceData.Rows.Count );
double probability = 0;
double bestProbability = 0;
foreach(string className in classes)
{
double countOfClass = Convert.ToDouble( FrequencyOfClass(mostLikelyClass) );
probability = countOfClass / totalCount;
foreach(DataColumn col in sourceData.Columns)
{
string name = col.ColumnName;
if(name == classColumnName)
continue;
double count = CountOccurancesWhere(name, unknown[ name ].ToString(), className);
probability *= (count / countOfClass);
}
if(probability > bestProbability)
{
mostLikelyClass = className;
bestProbability = probability;
}
}
return mostLikelyClass;
}
#endregion public methods
#region private methods
private void ExtractClasses()
{
DataTable distinctTable = SelectDistinct(TABLE_NAME, sourceData, classColumnName);
foreach (DataRow row in distinctTable.Rows)
{
classes.Add( row[ classColumnName ] );
}
}
private int CountOccurances( string attributeName, string attributeValue )
{
string baseString = "[{0}] = '{1}'";
string filter = string.Format(baseString,attributeName,attributeValue);
return sourceData.Select(filter).Length;
}
private int CountOccurancesWhere( string attributeName, string attributeValue, string className )
{
string baseString = "[{0}] = '{1}' and [{2}] == '{3}'";
string filter = string.Format(baseString,attributeName,attributeValue , classColumnName, className);
return sourceData.Select(filter).Length;
}
private int FrequencyOfClass(string className)
{
return CountOccurances( classColumnName, className );
}
private DataTable SelectDistinct(string tableName, DataTable sourceTable, string columName)
{
DataTable dt = new DataTable(tableName);
dt.Columns.Add(columName, sourceTable.Columns[columName].DataType);
object LastValue = null;
foreach (DataRow dr in sourceTable.Select("", columName))
{
if ( LastValue == null || !(ColumnEqual(LastValue, dr[columName])) )
{
LastValue = dr[columName];
dt.Rows.Add(new object[]{LastValue});
}
}
return dt;
}
private bool ColumnEqual(object A, object B)
{
if ( A == DBNull.Value && B == DBNull.Value )
{
return true;
}
if ( A == DBNull.Value || B == DBNull.Value )
{
return false;
}
return A.Equals(B);
}
#endregion private methods
}
}