diff options
Diffstat (limited to 'utils/fann_train.pl')
-rwxr-xr-x | utils/fann_train.pl | 138 |
1 files changed, 70 insertions, 68 deletions
diff --git a/utils/fann_train.pl b/utils/fann_train.pl index 46b539489..2ce422eb4 100755 --- a/utils/fann_train.pl +++ b/utils/fann_train.pl @@ -8,28 +8,28 @@ use warnings FATAL => 'all'; use AI::FANN qw(:all); use Getopt::Std; -my %sym_idx; # Symbols by index -my %sym_names; # Symbols by name -my $num = 1; # Number of symbols +my %sym_idx; # Symbols by index +my %sym_names; # Symbols by name +my $num = 1; # Number of symbols my @spam; my @ham; -my $max_samples = -1; -my $split = 1; -my $preprocessed = 0; # output is in format <score>:<0|1>:<SYM1,...SYMN> -my $score_spam = 12; -my $score_ham = -6; +my $max_samples = -1; +my $split = 1; +my $preprocessed = 0; # output is in format <score>:<0|1>:<SYM1,...SYMN> +my $score_spam = 12; +my $score_ham = -6; sub process { - my ($input, $spam, $ham) = @_; + my ( $input, $spam, $ham ) = @_; my $samples = 0; - while(<$input>) { - if (!$preprocessed) { + while (<$input>) { + if ( !$preprocessed ) { if (/^.*rspamd_task_write_log.*: \[(-?\d+\.?\d*)\/(\d+\.?\d*)\]\s*\[(.+)\].*$/) { - if ($1 > $score_spam) { + if ( $1 > $score_spam ) { $_ = "$1:1: $3"; } - elsif ($1 < $score_ham) { + elsif ( $1 < $score_ham ) { $_ = "$1:0: $3\n"; } else { @@ -47,7 +47,7 @@ sub process { my $is_spam = 0; - if ($2 == 1) { + if ( $2 == 1 ) { $is_spam = 1; } @@ -56,13 +56,13 @@ sub process { foreach my $sym (@ar) { chomp $sym; - if (!$sym_idx{$sym}) { - $sym_idx{$sym} = $num; + if ( !$sym_idx{$sym} ) { + $sym_idx{$sym} = $num; $sym_names{$num} = $sym; $num++; } - $sample{$sym_idx{$sym}} = 1; + $sample{ $sym_idx{$sym} } = 1; } if ($is_spam) { @@ -73,32 +73,31 @@ sub process { } $samples++; - if ($max_samples > 0 && $samples > $max_samples) { + if ( $max_samples > 0 && $samples > $max_samples ) { return; } } } # Shuffle array -sub fisher_yates_shuffle -{ +sub fisher_yates_shuffle { my $array = shift; - my $i = @$array; + my $i = @$array; while ( --$i ) { my $j = int rand( $i + 1 ); - @$array[$i, $j] = @$array[$j, $i]; + @$array[ $i, $j ] = @$array[ $j, $i ]; } } # Train network sub train { - my ($ann, $sample, $result) = @_; + my ( $ann, $sample, $result ) = @_; my @row; - for (my $i = 1; $i < $num; $i++) { - if ($sample->{$i}) { + for ( my $i = 1 ; $i < $num ; $i++ ) { + if ( $sample->{$i} ) { push @row, 1; } else { @@ -108,16 +107,16 @@ sub train { #print "@row -> @{$result}\n"; - $ann->train(\@row, \@{$result}); + $ann->train( \@row, \@{$result} ); } sub test { - my ($ann, $sample) = @_; + my ( $ann, $sample ) = @_; my @row; - for (my $i = 1; $i < $num; $i++) { - if ($sample->{$i}) { + for ( my $i = 1 ; $i < $num ; $i++ ) { + if ( $sample->{$i} ) { push @row, 1; } else { @@ -125,117 +124,120 @@ sub test { } } - my $ret = $ann->run(\@row); + my $ret = $ann->run( \@row ); return $ret; } my %opts; -getopts('o:i:s:n:t:hpS:H:', \%opts); +getopts( 'o:i:s:n:t:hpS:H:', \%opts ); -if ($opts{'h'}) { +if ( $opts{'h'} ) { print "$0 [-i input] [-o output] [-s scores] [-n max_samples] [-S spam_score] [-H ham_score] [-ph]\n"; exit; } my $input = *STDIN; -if ($opts{'i'}) { - open($input, '<', $opts{'i'}) or die "cannot open $opts{i}"; +if ( $opts{'i'} ) { + open( $input, '<', $opts{'i'} ) or die "cannot open $opts{i}"; } -if ($opts{'n'}) { +if ( $opts{'n'} ) { $max_samples = $opts{'n'}; } -if ($opts{'t'}) { +if ( $opts{'t'} ) { + # Test split $split = $opts{'t'}; } -if ($opts{'p'}) { +if ( $opts{'p'} ) { $preprocessed = 1; } -if ($opts{'H'}) { +if ( $opts{'H'} ) { $score_ham = $opts{'H'}; } -if ($opts{'S'}) { +if ( $opts{'S'} ) { $score_spam = $opts{'S'}; } # ham_prob, spam_prob my @spam_out = (1); -my @ham_out = (0); +my @ham_out = (0); -process($input, \@spam, \@ham); -fisher_yates_shuffle(\@spam); -fisher_yates_shuffle(\@ham); +process( $input, \@spam, \@ham ); +fisher_yates_shuffle( \@spam ); +fisher_yates_shuffle( \@ham ); -my $nspam = int(scalar(@spam) / $split); -my $nham = int(scalar(@ham) / $split); +my $nspam = int( scalar(@spam) / $split ); +my $nham = int( scalar(@ham) / $split ); -my $ann = AI::FANN->new_standard($num - 1, ($num + 2) / 2, 1); +my $ann = AI::FANN->new_standard( $num - 1, ( $num + 2 ) / 2, 1 ); my @train_data; + # Train ANN -for (my $i = 0; $i < $nham; $i++) { +for ( my $i = 0 ; $i < $nham ; $i++ ) { push @train_data, [ $ham[$i], \@ham_out ]; } -for (my $i = 0; $i < $nspam; $i++) { +for ( my $i = 0 ; $i < $nspam ; $i++ ) { push @train_data, [ $spam[$i], \@spam_out ]; } -fisher_yates_shuffle(\@train_data); +fisher_yates_shuffle( \@train_data ); foreach my $train_row (@train_data) { - train($ann, @{$train_row}[0], @{$train_row}[1]); + train( $ann, @{$train_row}[0], @{$train_row}[1] ); } print "Trained $nspam SPAM and $nham HAM samples\n"; # Now run fann -if ($split > 1) { - my $sample = 0.0; +if ( $split > 1 ) { + my $sample = 0.0; my $correct = 0.0; - for (my $i = $nham; $i < $nham * $split; $i++) { - my $ret = test($ann, $ham[$i]); + for ( my $i = $nham ; $i < $nham * $split ; $i++ ) { + my $ret = test( $ann, $ham[$i] ); + #print "@{$ret}\n"; - if (@{$ret}[0] < 0.5) { + if ( @{$ret}[0] < 0.5 ) { $correct++; } $sample++; } - print "Tested $sample HAM samples, correct matched: $correct, rate: ".($correct / $sample)."\n"; + print "Tested $sample HAM samples, correct matched: $correct, rate: " . ( $correct / $sample ) . "\n"; - $sample = 0.0; + $sample = 0.0; $correct = 0.0; - for (my $i = $nspam; $i < $nspam * $split; $i++) { - my $ret = test($ann, $spam[$i]); + for ( my $i = $nspam ; $i < $nspam * $split ; $i++ ) { + my $ret = test( $ann, $spam[$i] ); + #print "@{$ret}\n"; - if (@{$ret}[0] > 0.5) { + if ( @{$ret}[0] > 0.5 ) { $correct++; } $sample++; } - print "Tested $sample SPAM samples, correct matched: $correct, rate: ".($correct / $sample)."\n"; + print "Tested $sample SPAM samples, correct matched: $correct, rate: " . ( $correct / $sample ) . "\n"; } -if ($opts{'o'}) { - $ann->save($opts{'o'}) or die "cannot save ann into $opts{o}"; +if ( $opts{'o'} ) { + $ann->save( $opts{'o'} ) or die "cannot save ann into $opts{o}"; } -if ($opts{'s'}) { - open(my $scores, '>', - $opts{'s'}) or die "cannot open score file $opts{'s'}"; +if ( $opts{'s'} ) { + open( my $scores, '>', $opts{'s'} ) or die "cannot open score file $opts{'s'}"; print $scores "{"; - for (my $i = 1; $i < $num; $i++) { + for ( my $i = 1 ; $i < $num ; $i++ ) { my $n = $i - 1; - if ($i != $num - 1) { + if ( $i != $num - 1 ) { print $scores "\"$sym_names{$i}\":$n,"; } else { |