You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

fann_train.pl 5.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. #!/usr/bin/env perl
  2. # This script is a very simple prototype to learn fann from rspamd logs
  3. # For now, it is intended for internal use only
  4. use strict;
  5. use warnings FATAL => 'all';
  6. use AI::FANN qw(:all);
  7. use Getopt::Std;
  8. my %sym_idx; # Symbols by index
  9. my %sym_names; # Symbols by name
  10. my $num = 1; # Number of symbols
  11. my @spam;
  12. my @ham;
  13. my $max_samples = -1;
  14. my $split = 1;
  15. my $preprocessed = 0; # output is in format <score>:<0|1>:<SYM1,...SYMN>
  16. my $score_spam = 12;
  17. my $score_ham = -6;
  18. sub process {
  19. my ( $input, $spam, $ham ) = @_;
  20. my $samples = 0;
  21. while (<$input>) {
  22. if ( !$preprocessed ) {
  23. if (/^.*rspamd_task_write_log.*: \[(-?\d+\.?\d*)\/(\d+\.?\d*)\]\s*\[(.+)\].*$/) {
  24. if ( $1 > $score_spam ) {
  25. $_ = "$1:1: $3";
  26. }
  27. elsif ( $1 < $score_ham ) {
  28. $_ = "$1:0: $3\n";
  29. }
  30. else {
  31. # Out of boundary
  32. next;
  33. }
  34. }
  35. else {
  36. # Not our log message
  37. next;
  38. }
  39. }
  40. $_ =~ /^(-?\d+\.?\d*):([01]):\s*(\S.*)$/;
  41. my $is_spam = 0;
  42. if ( $2 == 1 ) {
  43. $is_spam = 1;
  44. }
  45. my @ar = split /,/, $3;
  46. my %sample;
  47. foreach my $sym (@ar) {
  48. chomp $sym;
  49. if ( !$sym_idx{$sym} ) {
  50. $sym_idx{$sym} = $num;
  51. $sym_names{$num} = $sym;
  52. $num++;
  53. }
  54. $sample{ $sym_idx{$sym} } = 1;
  55. }
  56. if ($is_spam) {
  57. push @{$spam}, \%sample;
  58. }
  59. else {
  60. push @{$ham}, \%sample;
  61. }
  62. $samples++;
  63. if ( $max_samples > 0 && $samples > $max_samples ) {
  64. return;
  65. }
  66. }
  67. }
  68. # Shuffle array
  69. sub fisher_yates_shuffle {
  70. my $array = shift;
  71. my $i = @$array;
  72. while ( --$i ) {
  73. my $j = int rand( $i + 1 );
  74. @$array[ $i, $j ] = @$array[ $j, $i ];
  75. }
  76. }
  77. # Train network
  78. sub train {
  79. my ( $ann, $sample, $result ) = @_;
  80. my @row;
  81. for ( my $i = 1 ; $i < $num ; $i++ ) {
  82. if ( $sample->{$i} ) {
  83. push @row, 1;
  84. }
  85. else {
  86. push @row, 0;
  87. }
  88. }
  89. #print "@row -> @{$result}\n";
  90. $ann->train( \@row, \@{$result} );
  91. }
  92. sub test {
  93. my ( $ann, $sample ) = @_;
  94. my @row;
  95. for ( my $i = 1 ; $i < $num ; $i++ ) {
  96. if ( $sample->{$i} ) {
  97. push @row, 1;
  98. }
  99. else {
  100. push @row, 0;
  101. }
  102. }
  103. my $ret = $ann->run( \@row );
  104. return $ret;
  105. }
  106. my %opts;
  107. getopts( 'o:i:s:n:t:hpS:H:', \%opts );
  108. if ( $opts{'h'} ) {
  109. print "$0 [-i input] [-o output] [-s scores] [-n max_samples] [-S spam_score] [-H ham_score] [-ph]\n";
  110. exit;
  111. }
  112. my $input = *STDIN;
  113. if ( $opts{'i'} ) {
  114. open( $input, '<', $opts{'i'} ) or die "cannot open $opts{i}";
  115. }
  116. if ( $opts{'n'} ) {
  117. $max_samples = $opts{'n'};
  118. }
  119. if ( $opts{'t'} ) {
  120. # Test split
  121. $split = $opts{'t'};
  122. }
  123. if ( $opts{'p'} ) {
  124. $preprocessed = 1;
  125. }
  126. if ( $opts{'H'} ) {
  127. $score_ham = $opts{'H'};
  128. }
  129. if ( $opts{'S'} ) {
  130. $score_spam = $opts{'S'};
  131. }
  132. # ham_prob, spam_prob
  133. my @spam_out = (1);
  134. my @ham_out = (0);
  135. process( $input, \@spam, \@ham );
  136. fisher_yates_shuffle( \@spam );
  137. fisher_yates_shuffle( \@ham );
  138. my $nspam = int( scalar(@spam) / $split );
  139. my $nham = int( scalar(@ham) / $split );
  140. my $ann = AI::FANN->new_standard( $num - 1, ( $num + 2 ) / 2, 1 );
  141. my @train_data;
  142. # Train ANN
  143. for ( my $i = 0 ; $i < $nham ; $i++ ) {
  144. push @train_data, [ $ham[$i], \@ham_out ];
  145. }
  146. for ( my $i = 0 ; $i < $nspam ; $i++ ) {
  147. push @train_data, [ $spam[$i], \@spam_out ];
  148. }
  149. fisher_yates_shuffle( \@train_data );
  150. foreach my $train_row (@train_data) {
  151. train( $ann, @{$train_row}[0], @{$train_row}[1] );
  152. }
  153. print "Trained $nspam SPAM and $nham HAM samples\n";
  154. # Now run fann
  155. if ( $split > 1 ) {
  156. my $sample = 0.0;
  157. my $correct = 0.0;
  158. for ( my $i = $nham ; $i < $nham * $split ; $i++ ) {
  159. my $ret = test( $ann, $ham[$i] );
  160. #print "@{$ret}\n";
  161. if ( @{$ret}[0] < 0.5 ) {
  162. $correct++;
  163. }
  164. $sample++;
  165. }
  166. print "Tested $sample HAM samples, correct matched: $correct, rate: " . ( $correct / $sample ) . "\n";
  167. $sample = 0.0;
  168. $correct = 0.0;
  169. for ( my $i = $nspam ; $i < $nspam * $split ; $i++ ) {
  170. my $ret = test( $ann, $spam[$i] );
  171. #print "@{$ret}\n";
  172. if ( @{$ret}[0] > 0.5 ) {
  173. $correct++;
  174. }
  175. $sample++;
  176. }
  177. print "Tested $sample SPAM samples, correct matched: $correct, rate: " . ( $correct / $sample ) . "\n";
  178. }
  179. if ( $opts{'o'} ) {
  180. $ann->save( $opts{'o'} ) or die "cannot save ann into $opts{o}";
  181. }
  182. if ( $opts{'s'} ) {
  183. open( my $scores, '>', $opts{'s'} ) or die "cannot open score file $opts{'s'}";
  184. print $scores "{";
  185. for ( my $i = 1 ; $i < $num ; $i++ ) {
  186. my $n = $i - 1;
  187. if ( $i != $num - 1 ) {
  188. print $scores "\"$sym_names{$i}\":$n,";
  189. }
  190. else {
  191. print $scores "\"$sym_names{$i}\":$n}\n";
  192. }
  193. }
  194. }