// $Id: filter-spamassassin.go 66 2024-04-14 16:44:05Z umaxx $
// Copyright (c) 2019-2024 Joerg Jung <mail@umaxx.net>
//
// Permission to use, copy, modify, and distribute this software for any
// purpose with or without fee is hereby granted, provided that the above
// copyright notice and this permission notice appear in all copies.
//
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

// filter-spamassassin - opensmtpd filter for spamassassin

package main

import (
	"bufio"
	"fmt"
	"log"
	"log/syslog"
	"net"
	"os"
	"strconv"
	"strings"
)

const (
	v  = "0.8"
	yr = "2019-2024"
)

type spamassassin struct {
	sid string
	buf []string
	sz  uint64
	r   *bool
}

var (
	addr        = "localhost:783"
	lim  uint64 = 524288 // size limit in bytes
	l3   *syslog.Writer
	sas  = make(map[string]*spamassassin)
)

func (sa *spamassassin) reset(tok string) {
	if sa.buf != nil {
		for _, v := range sa.buf {
			fmt.Printf("filter-dataline|%s|%s|%s\n", sa.sid, tok, v)
		}
	}
	sa.buf = nil
	sa.sz = 0
	sa.r = nil
}

func (sa *spamassassin) status(tok string, ln string) {
	var v1, v2 uint
	var r int
	var ex string
	if _, e := fmt.Sscanf(ln, "SPAMD/%d.%d %d %s", &v1, &v2, &r, &ex); e != nil {
		l3.Err(fmt.Sprintln(sa.sid, "sscanf", e))
		sa.reset(tok)
		return
	}
	if r != 0 {
		l3.Err(fmt.Sprintln(sa.sid, "status", ex))
		sa.reset(tok)
		return
	}
}

func (sa *spamassassin) header(tok string, ln string) {
	if strings.HasPrefix(ln, "Spam: ") {
		l3.Info(fmt.Sprintln(sa.sid, "result", ln))
		sa.r = new(bool)
		if _, e := fmt.Sscanf(ln, "Spam: %t ; ", sa.r); e != nil {
			l3.Err(fmt.Sprintln(sa.sid, "sscanf", e))
			sa.reset(tok)
			return
		}
	}
}

const (
	STATUS = iota
	HEADER
	MESSAGE
)

func (sa *spamassassin) response(tok string, in *bufio.Scanner) {
	s := STATUS
	for in.Scan() {
		switch s {
		case STATUS:
			sa.status(tok, in.Text())
			s++
		case HEADER:
			if len(in.Text()) == 0 {
				s++
				break
			}
			sa.header(tok, in.Text())
		case MESSAGE:
			fmt.Printf("filter-dataline|%s|%s|%s\n", sa.sid, tok, in.Text())
		default:
			l3.Err(fmt.Sprintln(sa.sid, "status", s))
			sa.reset(tok)
			return
		}
	}
	if e := in.Err(); e != nil {
		l3.Err(fmt.Sprintln(sa.sid, "scanner", e))
		sa.reset(tok)
		return
	}
}

func (sa *spamassassin) process(tok string) {
	con, e := net.Dial("tcp", addr)
	if e != nil {
		l3.Err(fmt.Sprintln(sa.sid, e))
		sa.reset(tok)
		return
	}
	defer con.Close()
	if _, e = fmt.Fprintf(con, "PROCESS SPAMC/1.5\r\n\r\n"); e != nil { // spamd.raw source: content length header is optional
		l3.Err(fmt.Sprintln(sa.sid, "write", e))
		sa.reset(tok)
		return
	}
	for _, v := range sa.buf {
		if _, e = fmt.Fprintf(con, "%s\n", v); e != nil {
			l3.Err(fmt.Sprintln(sa.sid, "write", e))
			sa.reset(tok)
			return
		}
	}
	if c, ok := con.(*net.TCPConn); ok {
		if e = c.CloseWrite(); e != nil {
			l3.Warning(fmt.Sprintln(sa.sid, "closewrite", e))
		}
	}
	sa.response(tok, bufio.NewScanner(con))
	fmt.Printf("filter-dataline|%s|%s|.\n", sa.sid, tok)
}

func (sa *spamassassin) skip(tok string) {
	sa.reset(tok)
	sa.r = new(bool)
	*sa.r = false
}

func (sa *spamassassin) line(tok string, ln string) {
	if sa.sz > lim+1 { // +1 don't count eom dot '.' into limit
		l3.Warning(fmt.Sprintln(sa.sid, "limit reached, skip scan"))
		sa.skip(tok)
	}
	if sa.r != nil && *sa.r == false { // skipped message
		fmt.Printf("filter-dataline|%s|%s|%s\n", sa.sid, tok, ln)
		return
	}
	if ln == "." {
		go sa.process(tok)
		return
	}
	sa.buf = append(sa.buf, ln)
	sa.sz += uint64(len(ln))
}

func (sa *spamassassin) commit(tok string) {
	if sa.r == nil {
		l3.Warning(fmt.Sprintln(sa.sid, "reject filter failed"))
		fmt.Printf("filter-result|%s|%s|reject|451 4.7.1 Spam filter failed\n", sa.sid, tok)
	} else if *sa.r {
		l3.Info(fmt.Sprintln(sa.sid, "reject spam"))
		fmt.Printf("filter-result|%s|%s|reject|554 5.7.1 Message considered spam\n", sa.sid, tok)
	} else {
		l3.Debug(fmt.Sprintln(sa.sid, "accept"))
		fmt.Printf("filter-result|%s|%s|proceed\n", sa.sid, tok)
	}
}

func register(in *bufio.Scanner) error {
	l3.Info("register")
	for in.Scan() { // skip config
		if in.Text() == "config|ready" {
			fmt.Println("register|report|smtp-in|link-connect")
			fmt.Println("register|filter|smtp-in|data-line")
			fmt.Println("register|filter|smtp-in|commit")
			fmt.Println("register|report|smtp-in|link-disconnect")
			fmt.Println("register|ready")
			return nil
		}
	}
	return in.Err()
}

func run() {
	l3.Info("start")
	defer l3.Info("exit")
	in := bufio.NewScanner(os.Stdin)
	if e := register(in); e != nil {
		l3.Err(fmt.Sprintln("register", e))
		return
	}
	for in.Scan() {
		f := strings.Split(in.Text(), "|")
		t, ver, ev, sid := f[0], f[1], f[4], f[5]
		if (t != "filter" && t != "report") || ver != "0.7" {
			l3.Err(fmt.Sprintln(sid, "protocol", t, ver))
			return
		}
		switch ev {
		case "link-connect":
			sas[sid] = &spamassassin{sid: sid, buf: nil, sz: 0, r: nil}
		case "data-line":
			if s, ok := sas[sid]; ok {
				s.line(f[6], strings.Join(f[7:], "|"))
			}
		case "commit":
			if s, ok := sas[sid]; ok {
				s.commit(f[6])
			}
		case "link-disconnect":
			if s, ok := sas[sid]; ok {
				delete(sas, s.sid)
			}
		default:
			l3.Err(fmt.Sprintln(sid, "event", ev))
			return
		}
	}
	if e := in.Err(); e != nil {
		l3.Err(fmt.Sprintln("scanner", e))
		return
	}
}

func init() {
	log.SetFlags(log.Lshortfile)
}

func main() {
	var e error
	if len(os.Args) == 2 && os.Args[1] == "version" {
		fmt.Println("filter-spamassassin", v, "(c)", yr, "Joerg Jung")
		return
	}
	if len(os.Args) > 3 {
		log.Fatalf("usage: filter-spamassassin [<address>] [<limit>]\n%35sfilter-spamassassin version\n", "")
	}
	if len(os.Args) >= 2 {
		addr = os.Args[1]
	}
	if len(os.Args) == 3 {
		if lim, e = strconv.ParseUint(os.Args[2], 10, 0); e != nil {
			log.Fatal(e)
		}
		if lim == 0 {
			log.Fatal("limit must be larger than 0")
		}
	}
	if l3, e = syslog.New(syslog.LOG_MAIL, "filter-spamassassin"); e != nil {
		log.Fatal(e)
	}
	defer l3.Close()
	run()
}
