Browse Source

Add a simple script to learn fann from rspamd logs

tags/1.1.0
Vsevolod Stakhov 8 years ago
parent
commit
a0902bd2ac
1 changed files with 245 additions and 0 deletions
  1. 245
    0
      utils/fann_train.pl

+ 245
- 0
utils/fann_train.pl View File

@@ -0,0 +1,245 @@
#!/usr/bin/env perl

# This script is a very simple prototype to learn fann from rspamd logs
# For now, it is intended for internal use only

use strict;
use warnings FATAL => 'all';
use AI::FANN qw(:all);
use Getopt::Std;

my %sym_idx; # Symbols by index
my %sym_names; # Symbols by name
my $num = 1; # Number of symbols
my @spam;
my @ham;
my $max_samples = -1;
my $split = 1;
my $preprocessed = 0; # ouptut is in format <score>:<0|1>:<SYM1,...SYMN>
my $score_spam = 12;
my $score_ham = -6;

sub process {
my ($input, $spam, $ham) = @_;
my $samples = 0;

while(<$input>) {
if (!$preprocessed) {
if (/^.*rspamd_task_write_log.*: \[(-?\d+\.?\d*)\/(\d+\.?\d*)\]\s*\[(.+)\].*$/) {
if ($1 > $score_spam) {
$_ = "$1:1: $3";
}
elsif ($1 < $score_ham) {
$_ = "$1:0: $3\n";
}
else {
# Out of boundary
next;
}
}
else {
# Not our log message
next;
}
}

$_ =~ /^(-?\d+\.?\d*):([01]):\s*(\S.*)$/;

my $is_spam = 0;

if ($2 == 1) {
$is_spam = 1;
}

my @ar = split /,/, $3;
my %sample;

foreach my $sym (@ar) {
chomp $sym;
if (!$sym_idx{$sym}) {
$sym_idx{$sym} = $num;
$sym_names{$num} = $sym;
$num++;
}

$sample{$sym_idx{$sym}} = 1;
}

if ($is_spam) {
push @{$spam}, \%sample;
}
else {
push @{$ham}, \%sample;
}

$samples++;
if ($max_samples > 0 && $samples > $max_samples) {
return;
}
}
}

# Shuffle array
sub fisher_yates_shuffle
{
my $array = shift;
my $i = @$array;

while ( --$i ) {
my $j = int rand( $i + 1 );
@$array[$i, $j] = @$array[$j, $i];
}
}

# Train network
sub train {
my ($ann, $sample, $result) = @_;

my @row;

for (my $i = 1; $i < $num; $i++) {
if ($sample->{$i}) {
push @row, 1;
}
else {
push @row, 0;
}
}

#print "@row -> @{$result}\n";

$ann->train(\@row, \@{$result});
}

sub test {
my ($ann, $sample) = @_;

my @row;

for (my $i = 1; $i < $num; $i++) {
if ($sample->{$i}) {
push @row, 1;
}
else {
push @row, 0;
}
}

my $ret = $ann->run(\@row);

return $ret;
}

my %opts;
getopts('o:i:s:n:t:hpS:H:', \%opts);

if ($opts{'h'}) {
print "$0 [-i input] [-o output] [-s scores] [-n max_samples] [-S spam_score] [-H ham_score] [-ph]\n";
exit;
}

my $input = *STDIN;

if ($opts{'i'}) {
open($input, '<', $opts{'i'}) or die "cannot open $opts{i}";
}

if ($opts{'n'}) {
$max_samples = $opts{'n'};
}

if ($opts{'t'}) {
# Test split
$split = $opts{'t'};
}
if ($opts{'p'}) {
$preprocessed = 1;
}

if ($opts{'H'}) {
$score_ham = $opts{'H'};
}

if ($opts{'S'}) {
$score_spam = $opts{'S'};
}

# ham_prob, spam_prob
my @spam_out = (1);
my @ham_out = (0);

process($input, \@spam, \@ham);
fisher_yates_shuffle(\@spam);
fisher_yates_shuffle(\@ham);

my $nspam = int(scalar(@spam) / $split);
my $nham = int(scalar(@ham) / $split);

my $ann = AI::FANN->new_standard($num - 1, ($num + 2) / 2, 1);

my @train_data;
# Train ANN
for (my $i = 0; $i < $nham; $i++) {
push @train_data, [ $ham[$i], \@ham_out ];
}

for (my $i = 0; $i < $nspam; $i++) {
push @train_data, [ $spam[$i], \@spam_out ];
}

fisher_yates_shuffle(\@train_data);

foreach my $train_row (@train_data) {
train($ann, @{$train_row}[0], @{$train_row}[1]);
}

print "Trained $nspam SPAM and $nham HAM samples\n";

# Now run fann
if ($split > 1) {
my $sample = 0.0;
my $correct = 0.0;
for (my $i = $nham; $i < $nham * $split; $i++) {
my $ret = test($ann, $ham[$i]);
#print "@{$ret}\n";
if (@{$ret}[0] < 0.5) {
$correct++;
}
$sample++;
}

print "Tested $sample HAM samples, correct matched: $correct, rate: ".($correct / $sample)."\n";

$sample = 0.0;
$correct = 0.0;

for (my $i = $nspam; $i < $nspam * $split; $i++) {
my $ret = test($ann, $spam[$i]);
#print "@{$ret}\n";
if (@{$ret}[0] > 0.5) {
$correct++;
}
$sample++;
}

print "Tested $sample SPAM samples, correct matched: $correct, rate: ".($correct / $sample)."\n";
}

if ($opts{'o'}) {
$ann->save($opts{'o'}) or die "cannot save ann into $opts{o}";
}

if ($opts{'s'}) {
open(my $scores, '>',
$opts{'s'}) or die "cannot open score file $opts{'s'}";
print $scores "{";
for (my $i = 1; $i < $num; $i++) {
my $n = $i - 1;
if ($i != $num - 1) {
print $scores "\"$sym_names{$i}\":$n,";
}
else {
print $scores "\"$sym_names{$i}\":$n}\n";
}
}
}

Loading…
Cancel
Save