init commit
This commit is contained in:
1
plugin/cron/README.md
Normal file
1
plugin/cron/README.md
Normal file
@@ -0,0 +1 @@
|
||||
Fork from https://github.com/robfig/cron
|
||||
96
plugin/cron/chain.go
Normal file
96
plugin/cron/chain.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package cron
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// JobWrapper decorates the given Job with some behavior.
|
||||
type JobWrapper func(Job) Job
|
||||
|
||||
// Chain is a sequence of JobWrappers that decorates submitted jobs with
|
||||
// cross-cutting behaviors like logging or synchronization.
|
||||
type Chain struct {
|
||||
wrappers []JobWrapper
|
||||
}
|
||||
|
||||
// NewChain returns a Chain consisting of the given JobWrappers.
|
||||
func NewChain(c ...JobWrapper) Chain {
|
||||
return Chain{c}
|
||||
}
|
||||
|
||||
// Then decorates the given job with all JobWrappers in the chain.
|
||||
//
|
||||
// This:
|
||||
//
|
||||
// NewChain(m1, m2, m3).Then(job)
|
||||
//
|
||||
// is equivalent to:
|
||||
//
|
||||
// m1(m2(m3(job)))
|
||||
func (c Chain) Then(j Job) Job {
|
||||
for i := range c.wrappers {
|
||||
j = c.wrappers[len(c.wrappers)-i-1](j)
|
||||
}
|
||||
return j
|
||||
}
|
||||
|
||||
// Recover panics in wrapped jobs and log them with the provided logger.
|
||||
func Recover(logger Logger) JobWrapper {
|
||||
return func(j Job) Job {
|
||||
return FuncJob(func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
const size = 64 << 10
|
||||
buf := make([]byte, size)
|
||||
buf = buf[:runtime.Stack(buf, false)]
|
||||
err, ok := r.(error)
|
||||
if !ok {
|
||||
err = errors.New("panic: " + fmt.Sprint(r))
|
||||
}
|
||||
logger.Error(err, "panic", "stack", "...\n"+string(buf))
|
||||
}
|
||||
}()
|
||||
j.Run()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// DelayIfStillRunning serializes jobs, delaying subsequent runs until the
|
||||
// previous one is complete. Jobs running after a delay of more than a minute
|
||||
// have the delay logged at Info.
|
||||
func DelayIfStillRunning(logger Logger) JobWrapper {
|
||||
return func(j Job) Job {
|
||||
var mu sync.Mutex
|
||||
return FuncJob(func() {
|
||||
start := time.Now()
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if dur := time.Since(start); dur > time.Minute {
|
||||
logger.Info("delay", "duration", dur)
|
||||
}
|
||||
j.Run()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// SkipIfStillRunning skips an invocation of the Job if a previous invocation is
|
||||
// still running. It logs skips to the given logger at Info level.
|
||||
func SkipIfStillRunning(logger Logger) JobWrapper {
|
||||
return func(j Job) Job {
|
||||
var ch = make(chan struct{}, 1)
|
||||
ch <- struct{}{}
|
||||
return FuncJob(func() {
|
||||
select {
|
||||
case v := <-ch:
|
||||
defer func() { ch <- v }()
|
||||
j.Run()
|
||||
default:
|
||||
logger.Info("skip")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
239
plugin/cron/chain_test.go
Normal file
239
plugin/cron/chain_test.go
Normal file
@@ -0,0 +1,239 @@
|
||||
//nolint:all
|
||||
package cron
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log"
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func appendingJob(slice *[]int, value int) Job {
|
||||
var m sync.Mutex
|
||||
return FuncJob(func() {
|
||||
m.Lock()
|
||||
*slice = append(*slice, value)
|
||||
m.Unlock()
|
||||
})
|
||||
}
|
||||
|
||||
func appendingWrapper(slice *[]int, value int) JobWrapper {
|
||||
return func(j Job) Job {
|
||||
return FuncJob(func() {
|
||||
appendingJob(slice, value).Run()
|
||||
j.Run()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChain(t *testing.T) {
|
||||
var nums []int
|
||||
var (
|
||||
append1 = appendingWrapper(&nums, 1)
|
||||
append2 = appendingWrapper(&nums, 2)
|
||||
append3 = appendingWrapper(&nums, 3)
|
||||
append4 = appendingJob(&nums, 4)
|
||||
)
|
||||
NewChain(append1, append2, append3).Then(append4).Run()
|
||||
if !reflect.DeepEqual(nums, []int{1, 2, 3, 4}) {
|
||||
t.Error("unexpected order of calls:", nums)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChainRecover(t *testing.T) {
|
||||
panickingJob := FuncJob(func() {
|
||||
panic("panickingJob panics")
|
||||
})
|
||||
|
||||
t.Run("panic exits job by default", func(*testing.T) {
|
||||
defer func() {
|
||||
if err := recover(); err == nil {
|
||||
t.Errorf("panic expected, but none received")
|
||||
}
|
||||
}()
|
||||
NewChain().Then(panickingJob).
|
||||
Run()
|
||||
})
|
||||
|
||||
t.Run("Recovering JobWrapper recovers", func(*testing.T) {
|
||||
NewChain(Recover(PrintfLogger(log.New(io.Discard, "", 0)))).
|
||||
Then(panickingJob).
|
||||
Run()
|
||||
})
|
||||
|
||||
t.Run("composed with the *IfStillRunning wrappers", func(*testing.T) {
|
||||
NewChain(Recover(PrintfLogger(log.New(io.Discard, "", 0)))).
|
||||
Then(panickingJob).
|
||||
Run()
|
||||
})
|
||||
}
|
||||
|
||||
type countJob struct {
|
||||
m sync.Mutex
|
||||
started int
|
||||
done int
|
||||
delay time.Duration
|
||||
}
|
||||
|
||||
func (j *countJob) Run() {
|
||||
j.m.Lock()
|
||||
j.started++
|
||||
j.m.Unlock()
|
||||
time.Sleep(j.delay)
|
||||
j.m.Lock()
|
||||
j.done++
|
||||
j.m.Unlock()
|
||||
}
|
||||
|
||||
func (j *countJob) Started() int {
|
||||
defer j.m.Unlock()
|
||||
j.m.Lock()
|
||||
return j.started
|
||||
}
|
||||
|
||||
func (j *countJob) Done() int {
|
||||
defer j.m.Unlock()
|
||||
j.m.Lock()
|
||||
return j.done
|
||||
}
|
||||
|
||||
func TestChainDelayIfStillRunning(t *testing.T) {
|
||||
t.Run("runs immediately", func(*testing.T) {
|
||||
var j countJob
|
||||
wrappedJob := NewChain(DelayIfStillRunning(DiscardLogger)).Then(&j)
|
||||
go wrappedJob.Run()
|
||||
time.Sleep(2 * time.Millisecond) // Give the job 2ms to complete.
|
||||
if c := j.Done(); c != 1 {
|
||||
t.Errorf("expected job run once, immediately, got %d", c)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("second run immediate if first done", func(*testing.T) {
|
||||
var j countJob
|
||||
wrappedJob := NewChain(DelayIfStillRunning(DiscardLogger)).Then(&j)
|
||||
go func() {
|
||||
go wrappedJob.Run()
|
||||
time.Sleep(time.Millisecond)
|
||||
go wrappedJob.Run()
|
||||
}()
|
||||
time.Sleep(3 * time.Millisecond) // Give both jobs 3ms to complete.
|
||||
if c := j.Done(); c != 2 {
|
||||
t.Errorf("expected job run twice, immediately, got %d", c)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("second run delayed if first not done", func(*testing.T) {
|
||||
var j countJob
|
||||
j.delay = 10 * time.Millisecond
|
||||
wrappedJob := NewChain(DelayIfStillRunning(DiscardLogger)).Then(&j)
|
||||
go func() {
|
||||
go wrappedJob.Run()
|
||||
time.Sleep(time.Millisecond)
|
||||
go wrappedJob.Run()
|
||||
}()
|
||||
|
||||
// After 5ms, the first job is still in progress, and the second job was
|
||||
// run but should be waiting for it to finish.
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
started, done := j.Started(), j.Done()
|
||||
if started != 1 || done != 0 {
|
||||
t.Error("expected first job started, but not finished, got", started, done)
|
||||
}
|
||||
|
||||
// Verify that the second job completes.
|
||||
time.Sleep(25 * time.Millisecond)
|
||||
started, done = j.Started(), j.Done()
|
||||
if started != 2 || done != 2 {
|
||||
t.Error("expected both jobs done, got", started, done)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestChainSkipIfStillRunning(t *testing.T) {
|
||||
t.Run("runs immediately", func(*testing.T) {
|
||||
var j countJob
|
||||
wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j)
|
||||
go wrappedJob.Run()
|
||||
time.Sleep(2 * time.Millisecond) // Give the job 2ms to complete.
|
||||
if c := j.Done(); c != 1 {
|
||||
t.Errorf("expected job run once, immediately, got %d", c)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("second run immediate if first done", func(*testing.T) {
|
||||
var j countJob
|
||||
wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j)
|
||||
go func() {
|
||||
go wrappedJob.Run()
|
||||
time.Sleep(time.Millisecond)
|
||||
go wrappedJob.Run()
|
||||
}()
|
||||
time.Sleep(3 * time.Millisecond) // Give both jobs 3ms to complete.
|
||||
if c := j.Done(); c != 2 {
|
||||
t.Errorf("expected job run twice, immediately, got %d", c)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("second run skipped if first not done", func(*testing.T) {
|
||||
var j countJob
|
||||
j.delay = 10 * time.Millisecond
|
||||
wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j)
|
||||
go func() {
|
||||
go wrappedJob.Run()
|
||||
time.Sleep(time.Millisecond)
|
||||
go wrappedJob.Run()
|
||||
}()
|
||||
|
||||
// After 5ms, the first job is still in progress, and the second job was
|
||||
// aleady skipped.
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
started, done := j.Started(), j.Done()
|
||||
if started != 1 || done != 0 {
|
||||
t.Error("expected first job started, but not finished, got", started, done)
|
||||
}
|
||||
|
||||
// Verify that the first job completes and second does not run.
|
||||
time.Sleep(25 * time.Millisecond)
|
||||
started, done = j.Started(), j.Done()
|
||||
if started != 1 || done != 1 {
|
||||
t.Error("expected second job skipped, got", started, done)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("skip 10 jobs on rapid fire", func(*testing.T) {
|
||||
var j countJob
|
||||
j.delay = 10 * time.Millisecond
|
||||
wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j)
|
||||
for i := 0; i < 11; i++ {
|
||||
go wrappedJob.Run()
|
||||
}
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
done := j.Done()
|
||||
if done != 1 {
|
||||
t.Error("expected 1 jobs executed, 10 jobs dropped, got", done)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("different jobs independent", func(*testing.T) {
|
||||
var j1, j2 countJob
|
||||
j1.delay = 10 * time.Millisecond
|
||||
j2.delay = 10 * time.Millisecond
|
||||
chain := NewChain(SkipIfStillRunning(DiscardLogger))
|
||||
wrappedJob1 := chain.Then(&j1)
|
||||
wrappedJob2 := chain.Then(&j2)
|
||||
for i := 0; i < 11; i++ {
|
||||
go wrappedJob1.Run()
|
||||
go wrappedJob2.Run()
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
var (
|
||||
done1 = j1.Done()
|
||||
done2 = j2.Done()
|
||||
)
|
||||
if done1 != 1 || done2 != 1 {
|
||||
t.Error("expected both jobs executed once, got", done1, "and", done2)
|
||||
}
|
||||
})
|
||||
}
|
||||
27
plugin/cron/constantdelay.go
Normal file
27
plugin/cron/constantdelay.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package cron
|
||||
|
||||
import "time"
|
||||
|
||||
// ConstantDelaySchedule represents a simple recurring duty cycle, e.g. "Every 5 minutes".
|
||||
// It does not support jobs more frequent than once a second.
|
||||
type ConstantDelaySchedule struct {
|
||||
Delay time.Duration
|
||||
}
|
||||
|
||||
// Every returns a crontab Schedule that activates once every duration.
|
||||
// Delays of less than a second are not supported (will round up to 1 second).
|
||||
// Any fields less than a Second are truncated.
|
||||
func Every(duration time.Duration) ConstantDelaySchedule {
|
||||
if duration < time.Second {
|
||||
duration = time.Second
|
||||
}
|
||||
return ConstantDelaySchedule{
|
||||
Delay: duration - time.Duration(duration.Nanoseconds())%time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// Next returns the next time this should be run.
|
||||
// This rounds so that the next activation time will be on the second.
|
||||
func (schedule ConstantDelaySchedule) Next(t time.Time) time.Time {
|
||||
return t.Add(schedule.Delay - time.Duration(t.Nanosecond())*time.Nanosecond)
|
||||
}
|
||||
55
plugin/cron/constantdelay_test.go
Normal file
55
plugin/cron/constantdelay_test.go
Normal file
@@ -0,0 +1,55 @@
|
||||
//nolint:all
|
||||
package cron
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestConstantDelayNext(t *testing.T) {
|
||||
tests := []struct {
|
||||
time string
|
||||
delay time.Duration
|
||||
expected string
|
||||
}{
|
||||
// Simple cases
|
||||
{"Mon Jul 9 14:45 2012", 15*time.Minute + 50*time.Nanosecond, "Mon Jul 9 15:00 2012"},
|
||||
{"Mon Jul 9 14:59 2012", 15 * time.Minute, "Mon Jul 9 15:14 2012"},
|
||||
{"Mon Jul 9 14:59:59 2012", 15 * time.Minute, "Mon Jul 9 15:14:59 2012"},
|
||||
|
||||
// Wrap around hours
|
||||
{"Mon Jul 9 15:45 2012", 35 * time.Minute, "Mon Jul 9 16:20 2012"},
|
||||
|
||||
// Wrap around days
|
||||
{"Mon Jul 9 23:46 2012", 14 * time.Minute, "Tue Jul 10 00:00 2012"},
|
||||
{"Mon Jul 9 23:45 2012", 35 * time.Minute, "Tue Jul 10 00:20 2012"},
|
||||
{"Mon Jul 9 23:35:51 2012", 44*time.Minute + 24*time.Second, "Tue Jul 10 00:20:15 2012"},
|
||||
{"Mon Jul 9 23:35:51 2012", 25*time.Hour + 44*time.Minute + 24*time.Second, "Thu Jul 11 01:20:15 2012"},
|
||||
|
||||
// Wrap around months
|
||||
{"Mon Jul 9 23:35 2012", 91*24*time.Hour + 25*time.Minute, "Thu Oct 9 00:00 2012"},
|
||||
|
||||
// Wrap around minute, hour, day, month, and year
|
||||
{"Mon Dec 31 23:59:45 2012", 15 * time.Second, "Tue Jan 1 00:00:00 2013"},
|
||||
|
||||
// Round to nearest second on the delay
|
||||
{"Mon Jul 9 14:45 2012", 15*time.Minute + 50*time.Nanosecond, "Mon Jul 9 15:00 2012"},
|
||||
|
||||
// Round up to 1 second if the duration is less.
|
||||
{"Mon Jul 9 14:45:00 2012", 15 * time.Millisecond, "Mon Jul 9 14:45:01 2012"},
|
||||
|
||||
// Round to nearest second when calculating the next time.
|
||||
{"Mon Jul 9 14:45:00.005 2012", 15 * time.Minute, "Mon Jul 9 15:00 2012"},
|
||||
|
||||
// Round to nearest second for both.
|
||||
{"Mon Jul 9 14:45:00.005 2012", 15*time.Minute + 50*time.Nanosecond, "Mon Jul 9 15:00 2012"},
|
||||
}
|
||||
|
||||
for _, c := range tests {
|
||||
actual := Every(c.delay).Next(getTime(c.time))
|
||||
expected := getTime(c.expected)
|
||||
if actual != expected {
|
||||
t.Errorf("%s, \"%s\": (expected) %v != %v (actual)", c.time, c.delay, expected, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
355
plugin/cron/cron.go
Normal file
355
plugin/cron/cron.go
Normal file
@@ -0,0 +1,355 @@
|
||||
package cron
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Cron keeps track of any number of entries, invoking the associated func as
|
||||
// specified by the schedule. It may be started, stopped, and the entries may
|
||||
// be inspected while running.
|
||||
type Cron struct {
|
||||
entries []*Entry
|
||||
chain Chain
|
||||
stop chan struct{}
|
||||
add chan *Entry
|
||||
remove chan EntryID
|
||||
snapshot chan chan []Entry
|
||||
running bool
|
||||
logger Logger
|
||||
runningMu sync.Mutex
|
||||
location *time.Location
|
||||
parser ScheduleParser
|
||||
nextID EntryID
|
||||
jobWaiter sync.WaitGroup
|
||||
}
|
||||
|
||||
// ScheduleParser is an interface for schedule spec parsers that return a Schedule.
|
||||
type ScheduleParser interface {
|
||||
Parse(spec string) (Schedule, error)
|
||||
}
|
||||
|
||||
// Job is an interface for submitted cron jobs.
|
||||
type Job interface {
|
||||
Run()
|
||||
}
|
||||
|
||||
// Schedule describes a job's duty cycle.
|
||||
type Schedule interface {
|
||||
// Next returns the next activation time, later than the given time.
|
||||
// Next is invoked initially, and then each time the job is run.
|
||||
Next(time.Time) time.Time
|
||||
}
|
||||
|
||||
// EntryID identifies an entry within a Cron instance.
|
||||
type EntryID int
|
||||
|
||||
// Entry consists of a schedule and the func to execute on that schedule.
|
||||
type Entry struct {
|
||||
// ID is the cron-assigned ID of this entry, which may be used to look up a
|
||||
// snapshot or remove it.
|
||||
ID EntryID
|
||||
|
||||
// Schedule on which this job should be run.
|
||||
Schedule Schedule
|
||||
|
||||
// Next time the job will run, or the zero time if Cron has not been
|
||||
// started or this entry's schedule is unsatisfiable
|
||||
Next time.Time
|
||||
|
||||
// Prev is the last time this job was run, or the zero time if never.
|
||||
Prev time.Time
|
||||
|
||||
// WrappedJob is the thing to run when the Schedule is activated.
|
||||
WrappedJob Job
|
||||
|
||||
// Job is the thing that was submitted to cron.
|
||||
// It is kept around so that user code that needs to get at the job later,
|
||||
// e.g. via Entries() can do so.
|
||||
Job Job
|
||||
}
|
||||
|
||||
// Valid returns true if this is not the zero entry.
|
||||
func (e Entry) Valid() bool { return e.ID != 0 }
|
||||
|
||||
// byTime is a wrapper for sorting the entry array by time
|
||||
// (with zero time at the end).
|
||||
type byTime []*Entry
|
||||
|
||||
func (s byTime) Len() int { return len(s) }
|
||||
func (s byTime) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||
func (s byTime) Less(i, j int) bool {
|
||||
// Two zero times should return false.
|
||||
// Otherwise, zero is "greater" than any other time.
|
||||
// (To sort it at the end of the list.)
|
||||
if s[i].Next.IsZero() {
|
||||
return false
|
||||
}
|
||||
if s[j].Next.IsZero() {
|
||||
return true
|
||||
}
|
||||
return s[i].Next.Before(s[j].Next)
|
||||
}
|
||||
|
||||
// New returns a new Cron job runner, modified by the given options.
|
||||
//
|
||||
// Available Settings
|
||||
//
|
||||
// Time Zone
|
||||
// Description: The time zone in which schedules are interpreted
|
||||
// Default: time.Local
|
||||
//
|
||||
// Parser
|
||||
// Description: Parser converts cron spec strings into cron.Schedules.
|
||||
// Default: Accepts this spec: https://en.wikipedia.org/wiki/Cron
|
||||
//
|
||||
// Chain
|
||||
// Description: Wrap submitted jobs to customize behavior.
|
||||
// Default: A chain that recovers panics and logs them to stderr.
|
||||
//
|
||||
// See "cron.With*" to modify the default behavior.
|
||||
func New(opts ...Option) *Cron {
|
||||
c := &Cron{
|
||||
entries: nil,
|
||||
chain: NewChain(),
|
||||
add: make(chan *Entry),
|
||||
stop: make(chan struct{}),
|
||||
snapshot: make(chan chan []Entry),
|
||||
remove: make(chan EntryID),
|
||||
running: false,
|
||||
runningMu: sync.Mutex{},
|
||||
logger: DefaultLogger,
|
||||
location: time.Local,
|
||||
parser: standardParser,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(c)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// FuncJob is a wrapper that turns a func() into a cron.Job.
|
||||
type FuncJob func()
|
||||
|
||||
func (f FuncJob) Run() { f() }
|
||||
|
||||
// AddFunc adds a func to the Cron to be run on the given schedule.
|
||||
// The spec is parsed using the time zone of this Cron instance as the default.
|
||||
// An opaque ID is returned that can be used to later remove it.
|
||||
func (c *Cron) AddFunc(spec string, cmd func()) (EntryID, error) {
|
||||
return c.AddJob(spec, FuncJob(cmd))
|
||||
}
|
||||
|
||||
// AddJob adds a Job to the Cron to be run on the given schedule.
|
||||
// The spec is parsed using the time zone of this Cron instance as the default.
|
||||
// An opaque ID is returned that can be used to later remove it.
|
||||
func (c *Cron) AddJob(spec string, cmd Job) (EntryID, error) {
|
||||
schedule, err := c.parser.Parse(spec)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return c.Schedule(schedule, cmd), nil
|
||||
}
|
||||
|
||||
// Schedule adds a Job to the Cron to be run on the given schedule.
|
||||
// The job is wrapped with the configured Chain.
|
||||
func (c *Cron) Schedule(schedule Schedule, cmd Job) EntryID {
|
||||
c.runningMu.Lock()
|
||||
defer c.runningMu.Unlock()
|
||||
c.nextID++
|
||||
entry := &Entry{
|
||||
ID: c.nextID,
|
||||
Schedule: schedule,
|
||||
WrappedJob: c.chain.Then(cmd),
|
||||
Job: cmd,
|
||||
}
|
||||
if !c.running {
|
||||
c.entries = append(c.entries, entry)
|
||||
} else {
|
||||
c.add <- entry
|
||||
}
|
||||
return entry.ID
|
||||
}
|
||||
|
||||
// Entries returns a snapshot of the cron entries.
|
||||
func (c *Cron) Entries() []Entry {
|
||||
c.runningMu.Lock()
|
||||
defer c.runningMu.Unlock()
|
||||
if c.running {
|
||||
replyChan := make(chan []Entry, 1)
|
||||
c.snapshot <- replyChan
|
||||
return <-replyChan
|
||||
}
|
||||
return c.entrySnapshot()
|
||||
}
|
||||
|
||||
// Location gets the time zone location.
|
||||
func (c *Cron) Location() *time.Location {
|
||||
return c.location
|
||||
}
|
||||
|
||||
// Entry returns a snapshot of the given entry, or nil if it couldn't be found.
|
||||
func (c *Cron) Entry(id EntryID) Entry {
|
||||
for _, entry := range c.Entries() {
|
||||
if id == entry.ID {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
return Entry{}
|
||||
}
|
||||
|
||||
// Remove an entry from being run in the future.
|
||||
func (c *Cron) Remove(id EntryID) {
|
||||
c.runningMu.Lock()
|
||||
defer c.runningMu.Unlock()
|
||||
if c.running {
|
||||
c.remove <- id
|
||||
} else {
|
||||
c.removeEntry(id)
|
||||
}
|
||||
}
|
||||
|
||||
// Start the cron scheduler in its own goroutine, or no-op if already started.
|
||||
func (c *Cron) Start() {
|
||||
c.runningMu.Lock()
|
||||
defer c.runningMu.Unlock()
|
||||
if c.running {
|
||||
return
|
||||
}
|
||||
c.running = true
|
||||
go c.runScheduler()
|
||||
}
|
||||
|
||||
// Run the cron scheduler, or no-op if already running.
|
||||
func (c *Cron) Run() {
|
||||
c.runningMu.Lock()
|
||||
if c.running {
|
||||
c.runningMu.Unlock()
|
||||
return
|
||||
}
|
||||
c.running = true
|
||||
c.runningMu.Unlock()
|
||||
c.runScheduler()
|
||||
}
|
||||
|
||||
// runScheduler runs the scheduler.. this is private just due to the need to synchronize
|
||||
// access to the 'running' state variable.
|
||||
func (c *Cron) runScheduler() {
|
||||
c.logger.Info("start")
|
||||
|
||||
// Figure out the next activation times for each entry.
|
||||
now := c.now()
|
||||
for _, entry := range c.entries {
|
||||
entry.Next = entry.Schedule.Next(now)
|
||||
c.logger.Info("schedule", "now", now, "entry", entry.ID, "next", entry.Next)
|
||||
}
|
||||
|
||||
for {
|
||||
// Determine the next entry to run.
|
||||
sort.Sort(byTime(c.entries))
|
||||
|
||||
var timer *time.Timer
|
||||
if len(c.entries) == 0 || c.entries[0].Next.IsZero() {
|
||||
// If there are no entries yet, just sleep - it still handles new entries
|
||||
// and stop requests.
|
||||
timer = time.NewTimer(100000 * time.Hour)
|
||||
} else {
|
||||
timer = time.NewTimer(c.entries[0].Next.Sub(now))
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case now = <-timer.C:
|
||||
now = now.In(c.location)
|
||||
c.logger.Info("wake", "now", now)
|
||||
|
||||
// Run every entry whose next time was less than now
|
||||
for _, e := range c.entries {
|
||||
if e.Next.After(now) || e.Next.IsZero() {
|
||||
break
|
||||
}
|
||||
c.startJob(e.WrappedJob)
|
||||
e.Prev = e.Next
|
||||
e.Next = e.Schedule.Next(now)
|
||||
c.logger.Info("run", "now", now, "entry", e.ID, "next", e.Next)
|
||||
}
|
||||
|
||||
case newEntry := <-c.add:
|
||||
timer.Stop()
|
||||
now = c.now()
|
||||
newEntry.Next = newEntry.Schedule.Next(now)
|
||||
c.entries = append(c.entries, newEntry)
|
||||
c.logger.Info("added", "now", now, "entry", newEntry.ID, "next", newEntry.Next)
|
||||
|
||||
case replyChan := <-c.snapshot:
|
||||
replyChan <- c.entrySnapshot()
|
||||
continue
|
||||
|
||||
case <-c.stop:
|
||||
timer.Stop()
|
||||
c.logger.Info("stop")
|
||||
return
|
||||
|
||||
case id := <-c.remove:
|
||||
timer.Stop()
|
||||
now = c.now()
|
||||
c.removeEntry(id)
|
||||
c.logger.Info("removed", "entry", id)
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// startJob runs the given job in a new goroutine.
|
||||
func (c *Cron) startJob(j Job) {
|
||||
c.jobWaiter.Add(1)
|
||||
go func() {
|
||||
defer c.jobWaiter.Done()
|
||||
j.Run()
|
||||
}()
|
||||
}
|
||||
|
||||
// now returns current time in c location.
|
||||
func (c *Cron) now() time.Time {
|
||||
return time.Now().In(c.location)
|
||||
}
|
||||
|
||||
// Stop stops the cron scheduler if it is running; otherwise it does nothing.
|
||||
// A context is returned so the caller can wait for running jobs to complete.
|
||||
func (c *Cron) Stop() context.Context {
|
||||
c.runningMu.Lock()
|
||||
defer c.runningMu.Unlock()
|
||||
if c.running {
|
||||
c.stop <- struct{}{}
|
||||
c.running = false
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
c.jobWaiter.Wait()
|
||||
cancel()
|
||||
}()
|
||||
return ctx
|
||||
}
|
||||
|
||||
// entrySnapshot returns a copy of the current cron entry list.
|
||||
func (c *Cron) entrySnapshot() []Entry {
|
||||
var entries = make([]Entry, len(c.entries))
|
||||
for i, e := range c.entries {
|
||||
entries[i] = *e
|
||||
}
|
||||
return entries
|
||||
}
|
||||
|
||||
func (c *Cron) removeEntry(id EntryID) {
|
||||
var entries []*Entry
|
||||
for _, e := range c.entries {
|
||||
if e.ID != id {
|
||||
entries = append(entries, e)
|
||||
}
|
||||
}
|
||||
c.entries = entries
|
||||
}
|
||||
702
plugin/cron/cron_test.go
Normal file
702
plugin/cron/cron_test.go
Normal file
@@ -0,0 +1,702 @@
|
||||
//nolint:all
|
||||
package cron
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Many tests schedule a job for every second, and then wait at most a second
|
||||
// for it to run. This amount is just slightly larger than 1 second to
|
||||
// compensate for a few milliseconds of runtime.
|
||||
const OneSecond = 1*time.Second + 50*time.Millisecond
|
||||
|
||||
type syncWriter struct {
|
||||
wr bytes.Buffer
|
||||
m sync.Mutex
|
||||
}
|
||||
|
||||
func (sw *syncWriter) Write(data []byte) (n int, err error) {
|
||||
sw.m.Lock()
|
||||
n, err = sw.wr.Write(data)
|
||||
sw.m.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
func (sw *syncWriter) String() string {
|
||||
sw.m.Lock()
|
||||
defer sw.m.Unlock()
|
||||
return sw.wr.String()
|
||||
}
|
||||
|
||||
func newBufLogger(sw *syncWriter) Logger {
|
||||
return PrintfLogger(log.New(sw, "", log.LstdFlags))
|
||||
}
|
||||
|
||||
func TestFuncPanicRecovery(t *testing.T) {
|
||||
var buf syncWriter
|
||||
cron := New(WithParser(secondParser),
|
||||
WithChain(Recover(newBufLogger(&buf))))
|
||||
cron.Start()
|
||||
defer cron.Stop()
|
||||
cron.AddFunc("* * * * * ?", func() {
|
||||
panic("YOLO")
|
||||
})
|
||||
|
||||
select {
|
||||
case <-time.After(OneSecond):
|
||||
if !strings.Contains(buf.String(), "YOLO") {
|
||||
t.Error("expected a panic to be logged, got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
type DummyJob struct{}
|
||||
|
||||
func (DummyJob) Run() {
|
||||
panic("YOLO")
|
||||
}
|
||||
|
||||
func TestJobPanicRecovery(t *testing.T) {
|
||||
var job DummyJob
|
||||
|
||||
var buf syncWriter
|
||||
cron := New(WithParser(secondParser),
|
||||
WithChain(Recover(newBufLogger(&buf))))
|
||||
cron.Start()
|
||||
defer cron.Stop()
|
||||
cron.AddJob("* * * * * ?", job)
|
||||
|
||||
select {
|
||||
case <-time.After(OneSecond):
|
||||
if !strings.Contains(buf.String(), "YOLO") {
|
||||
t.Error("expected a panic to be logged, got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Start and stop cron with no entries.
|
||||
func TestNoEntries(t *testing.T) {
|
||||
cron := newWithSeconds()
|
||||
cron.Start()
|
||||
|
||||
select {
|
||||
case <-time.After(OneSecond):
|
||||
t.Fatal("expected cron will be stopped immediately")
|
||||
case <-stop(cron):
|
||||
}
|
||||
}
|
||||
|
||||
// Start, stop, then add an entry. Verify entry doesn't run.
|
||||
func TestStopCausesJobsToNotRun(t *testing.T) {
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
|
||||
cron := newWithSeconds()
|
||||
cron.Start()
|
||||
cron.Stop()
|
||||
cron.AddFunc("* * * * * ?", func() { wg.Done() })
|
||||
|
||||
select {
|
||||
case <-time.After(OneSecond):
|
||||
// No job ran!
|
||||
case <-wait(wg):
|
||||
t.Fatal("expected stopped cron does not run any job")
|
||||
}
|
||||
}
|
||||
|
||||
// Add a job, start cron, expect it runs.
|
||||
func TestAddBeforeRunning(t *testing.T) {
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
|
||||
cron := newWithSeconds()
|
||||
cron.AddFunc("* * * * * ?", func() { wg.Done() })
|
||||
cron.Start()
|
||||
defer cron.Stop()
|
||||
|
||||
// Give cron 2 seconds to run our job (which is always activated).
|
||||
select {
|
||||
case <-time.After(OneSecond):
|
||||
t.Fatal("expected job runs")
|
||||
case <-wait(wg):
|
||||
}
|
||||
}
|
||||
|
||||
// Start cron, add a job, expect it runs.
|
||||
func TestAddWhileRunning(t *testing.T) {
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
|
||||
cron := newWithSeconds()
|
||||
cron.Start()
|
||||
defer cron.Stop()
|
||||
cron.AddFunc("* * * * * ?", func() { wg.Done() })
|
||||
|
||||
select {
|
||||
case <-time.After(OneSecond):
|
||||
t.Fatal("expected job runs")
|
||||
case <-wait(wg):
|
||||
}
|
||||
}
|
||||
|
||||
// Test for #34. Adding a job after calling start results in multiple job invocations
|
||||
func TestAddWhileRunningWithDelay(t *testing.T) {
|
||||
cron := newWithSeconds()
|
||||
cron.Start()
|
||||
defer cron.Stop()
|
||||
time.Sleep(5 * time.Second)
|
||||
var calls int64
|
||||
cron.AddFunc("* * * * * *", func() { atomic.AddInt64(&calls, 1) })
|
||||
|
||||
<-time.After(OneSecond)
|
||||
if atomic.LoadInt64(&calls) != 1 {
|
||||
t.Errorf("called %d times, expected 1\n", calls)
|
||||
}
|
||||
}
|
||||
|
||||
// Add a job, remove a job, start cron, expect nothing runs.
|
||||
func TestRemoveBeforeRunning(t *testing.T) {
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
|
||||
cron := newWithSeconds()
|
||||
id, _ := cron.AddFunc("* * * * * ?", func() { wg.Done() })
|
||||
cron.Remove(id)
|
||||
cron.Start()
|
||||
defer cron.Stop()
|
||||
|
||||
select {
|
||||
case <-time.After(OneSecond):
|
||||
// Success, shouldn't run
|
||||
case <-wait(wg):
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
// Start cron, add a job, remove it, expect it doesn't run.
|
||||
func TestRemoveWhileRunning(t *testing.T) {
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
|
||||
cron := newWithSeconds()
|
||||
cron.Start()
|
||||
defer cron.Stop()
|
||||
id, _ := cron.AddFunc("* * * * * ?", func() { wg.Done() })
|
||||
cron.Remove(id)
|
||||
|
||||
select {
|
||||
case <-time.After(OneSecond):
|
||||
case <-wait(wg):
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
// Test timing with Entries.
|
||||
func TestSnapshotEntries(t *testing.T) {
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
|
||||
cron := New()
|
||||
cron.AddFunc("@every 2s", func() { wg.Done() })
|
||||
cron.Start()
|
||||
defer cron.Stop()
|
||||
|
||||
// Cron should fire in 2 seconds. After 1 second, call Entries.
|
||||
select {
|
||||
case <-time.After(OneSecond):
|
||||
cron.Entries()
|
||||
}
|
||||
|
||||
// Even though Entries was called, the cron should fire at the 2 second mark.
|
||||
select {
|
||||
case <-time.After(OneSecond):
|
||||
t.Error("expected job runs at 2 second mark")
|
||||
case <-wait(wg):
|
||||
}
|
||||
}
|
||||
|
||||
// Test that the entries are correctly sorted.
|
||||
// Add a bunch of long-in-the-future entries, and an immediate entry, and ensure
|
||||
// that the immediate entry runs immediately.
|
||||
// Also: Test that multiple jobs run in the same instant.
|
||||
func TestMultipleEntries(t *testing.T) {
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
|
||||
cron := newWithSeconds()
|
||||
cron.AddFunc("0 0 0 1 1 ?", func() {})
|
||||
cron.AddFunc("* * * * * ?", func() { wg.Done() })
|
||||
id1, _ := cron.AddFunc("* * * * * ?", func() { t.Fatal() })
|
||||
id2, _ := cron.AddFunc("* * * * * ?", func() { t.Fatal() })
|
||||
cron.AddFunc("0 0 0 31 12 ?", func() {})
|
||||
cron.AddFunc("* * * * * ?", func() { wg.Done() })
|
||||
|
||||
cron.Remove(id1)
|
||||
cron.Start()
|
||||
cron.Remove(id2)
|
||||
defer cron.Stop()
|
||||
|
||||
select {
|
||||
case <-time.After(OneSecond):
|
||||
t.Error("expected job run in proper order")
|
||||
case <-wait(wg):
|
||||
}
|
||||
}
|
||||
|
||||
// Test running the same job twice.
|
||||
func TestRunningJobTwice(t *testing.T) {
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
|
||||
cron := newWithSeconds()
|
||||
cron.AddFunc("0 0 0 1 1 ?", func() {})
|
||||
cron.AddFunc("0 0 0 31 12 ?", func() {})
|
||||
cron.AddFunc("* * * * * ?", func() { wg.Done() })
|
||||
|
||||
cron.Start()
|
||||
defer cron.Stop()
|
||||
|
||||
select {
|
||||
case <-time.After(2 * OneSecond):
|
||||
t.Error("expected job fires 2 times")
|
||||
case <-wait(wg):
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunningMultipleSchedules(t *testing.T) {
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
|
||||
cron := newWithSeconds()
|
||||
cron.AddFunc("0 0 0 1 1 ?", func() {})
|
||||
cron.AddFunc("0 0 0 31 12 ?", func() {})
|
||||
cron.AddFunc("* * * * * ?", func() { wg.Done() })
|
||||
cron.Schedule(Every(time.Minute), FuncJob(func() {}))
|
||||
cron.Schedule(Every(time.Second), FuncJob(func() { wg.Done() }))
|
||||
cron.Schedule(Every(time.Hour), FuncJob(func() {}))
|
||||
|
||||
cron.Start()
|
||||
defer cron.Stop()
|
||||
|
||||
select {
|
||||
case <-time.After(2 * OneSecond):
|
||||
t.Error("expected job fires 2 times")
|
||||
case <-wait(wg):
|
||||
}
|
||||
}
|
||||
|
||||
// Test that the cron is run in the local time zone (as opposed to UTC).
|
||||
func TestLocalTimezone(t *testing.T) {
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
|
||||
now := time.Now()
|
||||
// FIX: Issue #205
|
||||
// This calculation doesn't work in seconds 58 or 59.
|
||||
// Take the easy way out and sleep.
|
||||
if now.Second() >= 58 {
|
||||
time.Sleep(2 * time.Second)
|
||||
now = time.Now()
|
||||
}
|
||||
spec := fmt.Sprintf("%d,%d %d %d %d %d ?",
|
||||
now.Second()+1, now.Second()+2, now.Minute(), now.Hour(), now.Day(), now.Month())
|
||||
|
||||
cron := newWithSeconds()
|
||||
cron.AddFunc(spec, func() { wg.Done() })
|
||||
cron.Start()
|
||||
defer cron.Stop()
|
||||
|
||||
select {
|
||||
case <-time.After(OneSecond * 2):
|
||||
t.Error("expected job fires 2 times")
|
||||
case <-wait(wg):
|
||||
}
|
||||
}
|
||||
|
||||
// Test that the cron is run in the given time zone (as opposed to local).
|
||||
func TestNonLocalTimezone(t *testing.T) {
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
|
||||
loc, err := time.LoadLocation("Atlantic/Cape_Verde")
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to load time zone Atlantic/Cape_Verde: %+v", err)
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
now := time.Now().In(loc)
|
||||
// FIX: Issue #205
|
||||
// This calculation doesn't work in seconds 58 or 59.
|
||||
// Take the easy way out and sleep.
|
||||
if now.Second() >= 58 {
|
||||
time.Sleep(2 * time.Second)
|
||||
now = time.Now().In(loc)
|
||||
}
|
||||
spec := fmt.Sprintf("%d,%d %d %d %d %d ?",
|
||||
now.Second()+1, now.Second()+2, now.Minute(), now.Hour(), now.Day(), now.Month())
|
||||
|
||||
cron := New(WithLocation(loc), WithParser(secondParser))
|
||||
cron.AddFunc(spec, func() { wg.Done() })
|
||||
cron.Start()
|
||||
defer cron.Stop()
|
||||
|
||||
select {
|
||||
case <-time.After(OneSecond * 2):
|
||||
t.Error("expected job fires 2 times")
|
||||
case <-wait(wg):
|
||||
}
|
||||
}
|
||||
|
||||
// Test that calling stop before start silently returns without
|
||||
// blocking the stop channel.
|
||||
func TestStopWithoutStart(t *testing.T) {
|
||||
cron := New()
|
||||
cron.Stop()
|
||||
}
|
||||
|
||||
type testJob struct {
|
||||
wg *sync.WaitGroup
|
||||
name string
|
||||
}
|
||||
|
||||
func (t testJob) Run() {
|
||||
t.wg.Done()
|
||||
}
|
||||
|
||||
// Test that adding an invalid job spec returns an error
|
||||
func TestInvalidJobSpec(t *testing.T) {
|
||||
cron := New()
|
||||
_, err := cron.AddJob("this will not parse", nil)
|
||||
if err == nil {
|
||||
t.Errorf("expected an error with invalid spec, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// Test blocking run method behaves as Start()
|
||||
func TestBlockingRun(t *testing.T) {
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
|
||||
cron := newWithSeconds()
|
||||
cron.AddFunc("* * * * * ?", func() { wg.Done() })
|
||||
|
||||
var unblockChan = make(chan struct{})
|
||||
|
||||
go func() {
|
||||
cron.Run()
|
||||
close(unblockChan)
|
||||
}()
|
||||
defer cron.Stop()
|
||||
|
||||
select {
|
||||
case <-time.After(OneSecond):
|
||||
t.Error("expected job fires")
|
||||
case <-unblockChan:
|
||||
t.Error("expected that Run() blocks")
|
||||
case <-wait(wg):
|
||||
}
|
||||
}
|
||||
|
||||
// Test that double-running is a no-op
|
||||
func TestStartNoop(t *testing.T) {
|
||||
var tickChan = make(chan struct{}, 2)
|
||||
|
||||
cron := newWithSeconds()
|
||||
cron.AddFunc("* * * * * ?", func() {
|
||||
tickChan <- struct{}{}
|
||||
})
|
||||
|
||||
cron.Start()
|
||||
defer cron.Stop()
|
||||
|
||||
// Wait for the first firing to ensure the runner is going
|
||||
<-tickChan
|
||||
|
||||
cron.Start()
|
||||
|
||||
<-tickChan
|
||||
|
||||
// Fail if this job fires again in a short period, indicating a double-run
|
||||
select {
|
||||
case <-time.After(time.Millisecond):
|
||||
case <-tickChan:
|
||||
t.Error("expected job fires exactly twice")
|
||||
}
|
||||
}
|
||||
|
||||
// Simple test using Runnables.
|
||||
func TestJob(t *testing.T) {
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
|
||||
cron := newWithSeconds()
|
||||
cron.AddJob("0 0 0 30 Feb ?", testJob{wg, "job0"})
|
||||
cron.AddJob("0 0 0 1 1 ?", testJob{wg, "job1"})
|
||||
job2, _ := cron.AddJob("* * * * * ?", testJob{wg, "job2"})
|
||||
cron.AddJob("1 0 0 1 1 ?", testJob{wg, "job3"})
|
||||
cron.Schedule(Every(5*time.Second+5*time.Nanosecond), testJob{wg, "job4"})
|
||||
job5 := cron.Schedule(Every(5*time.Minute), testJob{wg, "job5"})
|
||||
|
||||
// Test getting an Entry pre-Start.
|
||||
if actualName := cron.Entry(job2).Job.(testJob).name; actualName != "job2" {
|
||||
t.Error("wrong job retrieved:", actualName)
|
||||
}
|
||||
if actualName := cron.Entry(job5).Job.(testJob).name; actualName != "job5" {
|
||||
t.Error("wrong job retrieved:", actualName)
|
||||
}
|
||||
|
||||
cron.Start()
|
||||
defer cron.Stop()
|
||||
|
||||
select {
|
||||
case <-time.After(OneSecond):
|
||||
t.FailNow()
|
||||
case <-wait(wg):
|
||||
}
|
||||
|
||||
// Ensure the entries are in the right order.
|
||||
expecteds := []string{"job2", "job4", "job5", "job1", "job3", "job0"}
|
||||
|
||||
var actuals []string
|
||||
for _, entry := range cron.Entries() {
|
||||
actuals = append(actuals, entry.Job.(testJob).name)
|
||||
}
|
||||
|
||||
for i, expected := range expecteds {
|
||||
if actuals[i] != expected {
|
||||
t.Fatalf("Jobs not in the right order. (expected) %s != %s (actual)", expecteds, actuals)
|
||||
}
|
||||
}
|
||||
|
||||
// Test getting Entries.
|
||||
if actualName := cron.Entry(job2).Job.(testJob).name; actualName != "job2" {
|
||||
t.Error("wrong job retrieved:", actualName)
|
||||
}
|
||||
if actualName := cron.Entry(job5).Job.(testJob).name; actualName != "job5" {
|
||||
t.Error("wrong job retrieved:", actualName)
|
||||
}
|
||||
}
|
||||
|
||||
// Issue #206
|
||||
// Ensure that the next run of a job after removing an entry is accurate.
|
||||
func TestScheduleAfterRemoval(t *testing.T) {
|
||||
var wg1 sync.WaitGroup
|
||||
var wg2 sync.WaitGroup
|
||||
wg1.Add(1)
|
||||
wg2.Add(1)
|
||||
|
||||
// The first time this job is run, set a timer and remove the other job
|
||||
// 750ms later. Correct behavior would be to still run the job again in
|
||||
// 250ms, but the bug would cause it to run instead 1s later.
|
||||
|
||||
var calls int
|
||||
var mu sync.Mutex
|
||||
|
||||
cron := newWithSeconds()
|
||||
hourJob := cron.Schedule(Every(time.Hour), FuncJob(func() {}))
|
||||
cron.Schedule(Every(time.Second), FuncJob(func() {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
switch calls {
|
||||
case 0:
|
||||
wg1.Done()
|
||||
calls++
|
||||
case 1:
|
||||
time.Sleep(750 * time.Millisecond)
|
||||
cron.Remove(hourJob)
|
||||
calls++
|
||||
case 2:
|
||||
calls++
|
||||
wg2.Done()
|
||||
case 3:
|
||||
panic("unexpected 3rd call")
|
||||
}
|
||||
}))
|
||||
|
||||
cron.Start()
|
||||
defer cron.Stop()
|
||||
|
||||
// the first run might be any length of time 0 - 1s, since the schedule
|
||||
// rounds to the second. wait for the first run to true up.
|
||||
wg1.Wait()
|
||||
|
||||
select {
|
||||
case <-time.After(2 * OneSecond):
|
||||
t.Error("expected job fires 2 times")
|
||||
case <-wait(&wg2):
|
||||
}
|
||||
}
|
||||
|
||||
type ZeroSchedule struct{}
|
||||
|
||||
func (*ZeroSchedule) Next(time.Time) time.Time {
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
// Tests that job without time does not run
|
||||
func TestJobWithZeroTimeDoesNotRun(t *testing.T) {
|
||||
cron := newWithSeconds()
|
||||
var calls int64
|
||||
cron.AddFunc("* * * * * *", func() { atomic.AddInt64(&calls, 1) })
|
||||
cron.Schedule(new(ZeroSchedule), FuncJob(func() { t.Error("expected zero task will not run") }))
|
||||
cron.Start()
|
||||
defer cron.Stop()
|
||||
<-time.After(OneSecond)
|
||||
if atomic.LoadInt64(&calls) != 1 {
|
||||
t.Errorf("called %d times, expected 1\n", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStopAndWait(t *testing.T) {
|
||||
t.Run("nothing running, returns immediately", func(*testing.T) {
|
||||
cron := newWithSeconds()
|
||||
cron.Start()
|
||||
ctx := cron.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-time.After(time.Millisecond):
|
||||
t.Error("context was not done immediately")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("repeated calls to Stop", func(*testing.T) {
|
||||
cron := newWithSeconds()
|
||||
cron.Start()
|
||||
_ = cron.Stop()
|
||||
time.Sleep(time.Millisecond)
|
||||
ctx := cron.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-time.After(time.Millisecond):
|
||||
t.Error("context was not done immediately")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("a couple fast jobs added, still returns immediately", func(*testing.T) {
|
||||
cron := newWithSeconds()
|
||||
cron.AddFunc("* * * * * *", func() {})
|
||||
cron.Start()
|
||||
cron.AddFunc("* * * * * *", func() {})
|
||||
cron.AddFunc("* * * * * *", func() {})
|
||||
cron.AddFunc("* * * * * *", func() {})
|
||||
time.Sleep(time.Second)
|
||||
ctx := cron.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-time.After(time.Millisecond):
|
||||
t.Error("context was not done immediately")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("a couple fast jobs and a slow job added, waits for slow job", func(*testing.T) {
|
||||
cron := newWithSeconds()
|
||||
cron.AddFunc("* * * * * *", func() {})
|
||||
cron.Start()
|
||||
cron.AddFunc("* * * * * *", func() { time.Sleep(2 * time.Second) })
|
||||
cron.AddFunc("* * * * * *", func() {})
|
||||
time.Sleep(time.Second)
|
||||
|
||||
ctx := cron.Stop()
|
||||
|
||||
// Verify that it is not done for at least 750ms
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Error("context was done too quickly immediately")
|
||||
case <-time.After(750 * time.Millisecond):
|
||||
// expected, because the job sleeping for 1 second is still running
|
||||
}
|
||||
|
||||
// Verify that it IS done in the next 500ms (giving 250ms buffer)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// expected
|
||||
case <-time.After(1500 * time.Millisecond):
|
||||
t.Error("context not done after job should have completed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("repeated calls to stop, waiting for completion and after", func(*testing.T) {
|
||||
cron := newWithSeconds()
|
||||
cron.AddFunc("* * * * * *", func() {})
|
||||
cron.AddFunc("* * * * * *", func() { time.Sleep(2 * time.Second) })
|
||||
cron.Start()
|
||||
cron.AddFunc("* * * * * *", func() {})
|
||||
time.Sleep(time.Second)
|
||||
ctx := cron.Stop()
|
||||
ctx2 := cron.Stop()
|
||||
|
||||
// Verify that it is not done for at least 1500ms
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Error("context was done too quickly immediately")
|
||||
case <-ctx2.Done():
|
||||
t.Error("context2 was done too quickly immediately")
|
||||
case <-time.After(1500 * time.Millisecond):
|
||||
// expected, because the job sleeping for 2 seconds is still running
|
||||
}
|
||||
|
||||
// Verify that it IS done in the next 1s (giving 500ms buffer)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// expected
|
||||
case <-time.After(time.Second):
|
||||
t.Error("context not done after job should have completed")
|
||||
}
|
||||
|
||||
// Verify that ctx2 is also done.
|
||||
select {
|
||||
case <-ctx2.Done():
|
||||
// expected
|
||||
case <-time.After(time.Millisecond):
|
||||
t.Error("context2 not done even though context1 is")
|
||||
}
|
||||
|
||||
// Verify that a new context retrieved from stop is immediately done.
|
||||
ctx3 := cron.Stop()
|
||||
select {
|
||||
case <-ctx3.Done():
|
||||
// expected
|
||||
case <-time.After(time.Millisecond):
|
||||
t.Error("context not done even when cron Stop is completed")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMultiThreadedStartAndStop(t *testing.T) {
|
||||
cron := New()
|
||||
go cron.Run()
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
cron.Stop()
|
||||
}
|
||||
|
||||
func wait(wg *sync.WaitGroup) chan bool {
|
||||
ch := make(chan bool)
|
||||
go func() {
|
||||
wg.Wait()
|
||||
ch <- true
|
||||
}()
|
||||
return ch
|
||||
}
|
||||
|
||||
func stop(cron *Cron) chan bool {
|
||||
ch := make(chan bool)
|
||||
go func() {
|
||||
cron.Stop()
|
||||
ch <- true
|
||||
}()
|
||||
return ch
|
||||
}
|
||||
|
||||
// newWithSeconds returns a Cron with the seconds field enabled.
|
||||
func newWithSeconds() *Cron {
|
||||
return New(WithParser(secondParser), WithChain())
|
||||
}
|
||||
86
plugin/cron/logger.go
Normal file
86
plugin/cron/logger.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package cron
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DefaultLogger is used by Cron if none is specified.
|
||||
var DefaultLogger = PrintfLogger(log.New(os.Stdout, "cron: ", log.LstdFlags))
|
||||
|
||||
// DiscardLogger can be used by callers to discard all log messages.
|
||||
var DiscardLogger = PrintfLogger(log.New(io.Discard, "", 0))
|
||||
|
||||
// Logger is the interface used in this package for logging, so that any backend
|
||||
// can be plugged in. It is a subset of the github.com/go-logr/logr interface.
|
||||
type Logger interface {
|
||||
// Info logs routine messages about cron's operation.
|
||||
Info(msg string, keysAndValues ...interface{})
|
||||
// Error logs an error condition.
|
||||
Error(err error, msg string, keysAndValues ...interface{})
|
||||
}
|
||||
|
||||
// PrintfLogger wraps a Printf-based logger (such as the standard library "log")
|
||||
// into an implementation of the Logger interface which logs errors only.
|
||||
func PrintfLogger(l interface{ Printf(string, ...interface{}) }) Logger {
|
||||
return printfLogger{l, false}
|
||||
}
|
||||
|
||||
// VerbosePrintfLogger wraps a Printf-based logger (such as the standard library
|
||||
// "log") into an implementation of the Logger interface which logs everything.
|
||||
func VerbosePrintfLogger(l interface{ Printf(string, ...interface{}) }) Logger {
|
||||
return printfLogger{l, true}
|
||||
}
|
||||
|
||||
type printfLogger struct {
|
||||
logger interface{ Printf(string, ...interface{}) }
|
||||
logInfo bool
|
||||
}
|
||||
|
||||
func (pl printfLogger) Info(msg string, keysAndValues ...interface{}) {
|
||||
if pl.logInfo {
|
||||
keysAndValues = formatTimes(keysAndValues)
|
||||
pl.logger.Printf(
|
||||
formatString(len(keysAndValues)),
|
||||
append([]interface{}{msg}, keysAndValues...)...)
|
||||
}
|
||||
}
|
||||
|
||||
func (pl printfLogger) Error(err error, msg string, keysAndValues ...interface{}) {
|
||||
keysAndValues = formatTimes(keysAndValues)
|
||||
pl.logger.Printf(
|
||||
formatString(len(keysAndValues)+2),
|
||||
append([]interface{}{msg, "error", err}, keysAndValues...)...)
|
||||
}
|
||||
|
||||
// formatString returns a logfmt-like format string for the number of
|
||||
// key/values.
|
||||
func formatString(numKeysAndValues int) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("%s")
|
||||
if numKeysAndValues > 0 {
|
||||
sb.WriteString(", ")
|
||||
}
|
||||
for i := 0; i < numKeysAndValues/2; i++ {
|
||||
if i > 0 {
|
||||
sb.WriteString(", ")
|
||||
}
|
||||
sb.WriteString("%v=%v")
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// formatTimes formats any time.Time values as RFC3339.
|
||||
func formatTimes(keysAndValues []interface{}) []interface{} {
|
||||
var formattedArgs []interface{}
|
||||
for _, arg := range keysAndValues {
|
||||
if t, ok := arg.(time.Time); ok {
|
||||
arg = t.Format(time.RFC3339)
|
||||
}
|
||||
formattedArgs = append(formattedArgs, arg)
|
||||
}
|
||||
return formattedArgs
|
||||
}
|
||||
45
plugin/cron/option.go
Normal file
45
plugin/cron/option.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package cron
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Option represents a modification to the default behavior of a Cron.
|
||||
type Option func(*Cron)
|
||||
|
||||
// WithLocation overrides the timezone of the cron instance.
|
||||
func WithLocation(loc *time.Location) Option {
|
||||
return func(c *Cron) {
|
||||
c.location = loc
|
||||
}
|
||||
}
|
||||
|
||||
// WithSeconds overrides the parser used for interpreting job schedules to
|
||||
// include a seconds field as the first one.
|
||||
func WithSeconds() Option {
|
||||
return WithParser(NewParser(
|
||||
Second | Minute | Hour | Dom | Month | Dow | Descriptor,
|
||||
))
|
||||
}
|
||||
|
||||
// WithParser overrides the parser used for interpreting job schedules.
|
||||
func WithParser(p ScheduleParser) Option {
|
||||
return func(c *Cron) {
|
||||
c.parser = p
|
||||
}
|
||||
}
|
||||
|
||||
// WithChain specifies Job wrappers to apply to all jobs added to this cron.
|
||||
// Refer to the Chain* functions in this package for provided wrappers.
|
||||
func WithChain(wrappers ...JobWrapper) Option {
|
||||
return func(c *Cron) {
|
||||
c.chain = NewChain(wrappers...)
|
||||
}
|
||||
}
|
||||
|
||||
// WithLogger uses the provided logger.
|
||||
func WithLogger(logger Logger) Option {
|
||||
return func(c *Cron) {
|
||||
c.logger = logger
|
||||
}
|
||||
}
|
||||
43
plugin/cron/option_test.go
Normal file
43
plugin/cron/option_test.go
Normal file
@@ -0,0 +1,43 @@
|
||||
//nolint:all
|
||||
package cron
|
||||
|
||||
import (
|
||||
"log"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestWithLocation(t *testing.T) {
|
||||
c := New(WithLocation(time.UTC))
|
||||
if c.location != time.UTC {
|
||||
t.Errorf("expected UTC, got %v", c.location)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithParser(t *testing.T) {
|
||||
var parser = NewParser(Dow)
|
||||
c := New(WithParser(parser))
|
||||
if c.parser != parser {
|
||||
t.Error("expected provided parser")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithVerboseLogger(t *testing.T) {
|
||||
var buf syncWriter
|
||||
var logger = log.New(&buf, "", log.LstdFlags)
|
||||
c := New(WithLogger(VerbosePrintfLogger(logger)))
|
||||
if c.logger.(printfLogger).logger != logger {
|
||||
t.Error("expected provided logger")
|
||||
}
|
||||
|
||||
c.AddFunc("@every 1s", func() {})
|
||||
c.Start()
|
||||
time.Sleep(OneSecond)
|
||||
c.Stop()
|
||||
out := buf.String()
|
||||
if !strings.Contains(out, "schedule,") ||
|
||||
!strings.Contains(out, "run,") {
|
||||
t.Error("expected to see some actions, got:", out)
|
||||
}
|
||||
}
|
||||
435
plugin/cron/parser.go
Normal file
435
plugin/cron/parser.go
Normal file
@@ -0,0 +1,435 @@
|
||||
package cron
|
||||
|
||||
import (
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Configuration options for creating a parser. Most options specify which
|
||||
// fields should be included, while others enable features. If a field is not
|
||||
// included the parser will assume a default value. These options do not change
|
||||
// the order fields are parse in.
|
||||
type ParseOption int
|
||||
|
||||
const (
|
||||
Second ParseOption = 1 << iota // Seconds field, default 0
|
||||
SecondOptional // Optional seconds field, default 0
|
||||
Minute // Minutes field, default 0
|
||||
Hour // Hours field, default 0
|
||||
Dom // Day of month field, default *
|
||||
Month // Month field, default *
|
||||
Dow // Day of week field, default *
|
||||
DowOptional // Optional day of week field, default *
|
||||
Descriptor // Allow descriptors such as @monthly, @weekly, etc.
|
||||
)
|
||||
|
||||
var places = []ParseOption{
|
||||
Second,
|
||||
Minute,
|
||||
Hour,
|
||||
Dom,
|
||||
Month,
|
||||
Dow,
|
||||
}
|
||||
|
||||
var defaults = []string{
|
||||
"0",
|
||||
"0",
|
||||
"0",
|
||||
"*",
|
||||
"*",
|
||||
"*",
|
||||
}
|
||||
|
||||
// A custom Parser that can be configured.
|
||||
type Parser struct {
|
||||
options ParseOption
|
||||
}
|
||||
|
||||
// NewParser creates a Parser with custom options.
|
||||
//
|
||||
// It panics if more than one Optional is given, since it would be impossible to
|
||||
// correctly infer which optional is provided or missing in general.
|
||||
//
|
||||
// Examples
|
||||
//
|
||||
// // Standard parser without descriptors
|
||||
// specParser := NewParser(Minute | Hour | Dom | Month | Dow)
|
||||
// sched, err := specParser.Parse("0 0 15 */3 *")
|
||||
//
|
||||
// // Same as above, just excludes time fields
|
||||
// specParser := NewParser(Dom | Month | Dow)
|
||||
// sched, err := specParser.Parse("15 */3 *")
|
||||
//
|
||||
// // Same as above, just makes Dow optional
|
||||
// specParser := NewParser(Dom | Month | DowOptional)
|
||||
// sched, err := specParser.Parse("15 */3")
|
||||
func NewParser(options ParseOption) Parser {
|
||||
optionals := 0
|
||||
if options&DowOptional > 0 {
|
||||
optionals++
|
||||
}
|
||||
if options&SecondOptional > 0 {
|
||||
optionals++
|
||||
}
|
||||
if optionals > 1 {
|
||||
panic("multiple optionals may not be configured")
|
||||
}
|
||||
return Parser{options}
|
||||
}
|
||||
|
||||
// Parse returns a new crontab schedule representing the given spec.
|
||||
// It returns a descriptive error if the spec is not valid.
|
||||
// It accepts crontab specs and features configured by NewParser.
|
||||
func (p Parser) Parse(spec string) (Schedule, error) {
|
||||
if len(spec) == 0 {
|
||||
return nil, errors.New("empty spec string")
|
||||
}
|
||||
|
||||
// Extract timezone if present
|
||||
var loc = time.Local
|
||||
if strings.HasPrefix(spec, "TZ=") || strings.HasPrefix(spec, "CRON_TZ=") {
|
||||
var err error
|
||||
i := strings.Index(spec, " ")
|
||||
eq := strings.Index(spec, "=")
|
||||
if loc, err = time.LoadLocation(spec[eq+1 : i]); err != nil {
|
||||
return nil, errors.Wrap(err, "provided bad location")
|
||||
}
|
||||
spec = strings.TrimSpace(spec[i:])
|
||||
}
|
||||
|
||||
// Handle named schedules (descriptors), if configured
|
||||
if strings.HasPrefix(spec, "@") {
|
||||
if p.options&Descriptor == 0 {
|
||||
return nil, errors.New("descriptors not enabled")
|
||||
}
|
||||
return parseDescriptor(spec, loc)
|
||||
}
|
||||
|
||||
// Split on whitespace.
|
||||
fields := strings.Fields(spec)
|
||||
|
||||
// Validate & fill in any omitted or optional fields
|
||||
var err error
|
||||
fields, err = normalizeFields(fields, p.options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
field := func(field string, r bounds) uint64 {
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
var bits uint64
|
||||
bits, err = getField(field, r)
|
||||
return bits
|
||||
}
|
||||
|
||||
var (
|
||||
second = field(fields[0], seconds)
|
||||
minute = field(fields[1], minutes)
|
||||
hour = field(fields[2], hours)
|
||||
dayofmonth = field(fields[3], dom)
|
||||
month = field(fields[4], months)
|
||||
dayofweek = field(fields[5], dow)
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &SpecSchedule{
|
||||
Second: second,
|
||||
Minute: minute,
|
||||
Hour: hour,
|
||||
Dom: dayofmonth,
|
||||
Month: month,
|
||||
Dow: dayofweek,
|
||||
Location: loc,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// normalizeFields takes a subset set of the time fields and returns the full set
|
||||
// with defaults (zeroes) populated for unset fields.
|
||||
//
|
||||
// As part of performing this function, it also validates that the provided
|
||||
// fields are compatible with the configured options.
|
||||
func normalizeFields(fields []string, options ParseOption) ([]string, error) {
|
||||
// Validate optionals & add their field to options
|
||||
optionals := 0
|
||||
if options&SecondOptional > 0 {
|
||||
options |= Second
|
||||
optionals++
|
||||
}
|
||||
if options&DowOptional > 0 {
|
||||
options |= Dow
|
||||
optionals++
|
||||
}
|
||||
if optionals > 1 {
|
||||
return nil, errors.New("multiple optionals may not be configured")
|
||||
}
|
||||
|
||||
// Figure out how many fields we need
|
||||
max := 0
|
||||
for _, place := range places {
|
||||
if options&place > 0 {
|
||||
max++
|
||||
}
|
||||
}
|
||||
min := max - optionals
|
||||
|
||||
// Validate number of fields
|
||||
if count := len(fields); count < min || count > max {
|
||||
if min == max {
|
||||
return nil, errors.New("incorrect number of fields")
|
||||
}
|
||||
return nil, errors.New("incorrect number of fields, expected " + strconv.Itoa(min) + "-" + strconv.Itoa(max))
|
||||
}
|
||||
|
||||
// Populate the optional field if not provided
|
||||
if min < max && len(fields) == min {
|
||||
switch {
|
||||
case options&DowOptional > 0:
|
||||
fields = append(fields, defaults[5]) // TODO: improve access to default
|
||||
case options&SecondOptional > 0:
|
||||
fields = append([]string{defaults[0]}, fields...)
|
||||
default:
|
||||
return nil, errors.New("unexpected optional field")
|
||||
}
|
||||
}
|
||||
|
||||
// Populate all fields not part of options with their defaults
|
||||
n := 0
|
||||
expandedFields := make([]string, len(places))
|
||||
copy(expandedFields, defaults)
|
||||
for i, place := range places {
|
||||
if options&place > 0 {
|
||||
expandedFields[i] = fields[n]
|
||||
n++
|
||||
}
|
||||
}
|
||||
return expandedFields, nil
|
||||
}
|
||||
|
||||
var standardParser = NewParser(
|
||||
Minute | Hour | Dom | Month | Dow | Descriptor,
|
||||
)
|
||||
|
||||
// ParseStandard returns a new crontab schedule representing the given
|
||||
// standardSpec (https://en.wikipedia.org/wiki/Cron). It requires 5 entries
|
||||
// representing: minute, hour, day of month, month and day of week, in that
|
||||
// order. It returns a descriptive error if the spec is not valid.
|
||||
//
|
||||
// It accepts
|
||||
// - Standard crontab specs, e.g. "* * * * ?"
|
||||
// - Descriptors, e.g. "@midnight", "@every 1h30m"
|
||||
func ParseStandard(standardSpec string) (Schedule, error) {
|
||||
return standardParser.Parse(standardSpec)
|
||||
}
|
||||
|
||||
// getField returns an Int with the bits set representing all of the times that
|
||||
// the field represents or error parsing field value. A "field" is a comma-separated
|
||||
// list of "ranges".
|
||||
func getField(field string, r bounds) (uint64, error) {
|
||||
var bits uint64
|
||||
ranges := strings.FieldsFunc(field, func(r rune) bool { return r == ',' })
|
||||
for _, expr := range ranges {
|
||||
bit, err := getRange(expr, r)
|
||||
if err != nil {
|
||||
return bits, err
|
||||
}
|
||||
bits |= bit
|
||||
}
|
||||
return bits, nil
|
||||
}
|
||||
|
||||
// getRange returns the bits indicated by the given expression:
|
||||
//
|
||||
// number | number "-" number [ "/" number ]
|
||||
//
|
||||
// or error parsing range.
|
||||
func getRange(expr string, r bounds) (uint64, error) {
|
||||
var (
|
||||
start, end, step uint
|
||||
rangeAndStep = strings.Split(expr, "/")
|
||||
lowAndHigh = strings.Split(rangeAndStep[0], "-")
|
||||
singleDigit = len(lowAndHigh) == 1
|
||||
err error
|
||||
)
|
||||
|
||||
var extra uint64
|
||||
if lowAndHigh[0] == "*" || lowAndHigh[0] == "?" {
|
||||
start = r.min
|
||||
end = r.max
|
||||
extra = starBit
|
||||
} else {
|
||||
start, err = parseIntOrName(lowAndHigh[0], r.names)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
switch len(lowAndHigh) {
|
||||
case 1:
|
||||
end = start
|
||||
case 2:
|
||||
end, err = parseIntOrName(lowAndHigh[1], r.names)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
default:
|
||||
return 0, errors.New("too many hyphens: " + expr)
|
||||
}
|
||||
}
|
||||
|
||||
switch len(rangeAndStep) {
|
||||
case 1:
|
||||
step = 1
|
||||
case 2:
|
||||
step, err = mustParseInt(rangeAndStep[1])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Special handling: "N/step" means "N-max/step".
|
||||
if singleDigit {
|
||||
end = r.max
|
||||
}
|
||||
if step > 1 {
|
||||
extra = 0
|
||||
}
|
||||
default:
|
||||
return 0, errors.New("too many slashes: " + expr)
|
||||
}
|
||||
|
||||
if start < r.min {
|
||||
return 0, errors.New("beginning of range below minimum: " + expr)
|
||||
}
|
||||
if end > r.max {
|
||||
return 0, errors.New("end of range above maximum: " + expr)
|
||||
}
|
||||
if start > end {
|
||||
return 0, errors.New("beginning of range after end: " + expr)
|
||||
}
|
||||
if step == 0 {
|
||||
return 0, errors.New("step cannot be zero: " + expr)
|
||||
}
|
||||
|
||||
return getBits(start, end, step) | extra, nil
|
||||
}
|
||||
|
||||
// parseIntOrName returns the (possibly-named) integer contained in expr.
|
||||
func parseIntOrName(expr string, names map[string]uint) (uint, error) {
|
||||
if names != nil {
|
||||
if namedInt, ok := names[strings.ToLower(expr)]; ok {
|
||||
return namedInt, nil
|
||||
}
|
||||
}
|
||||
return mustParseInt(expr)
|
||||
}
|
||||
|
||||
// mustParseInt parses the given expression as an int or returns an error.
|
||||
func mustParseInt(expr string) (uint, error) {
|
||||
num, err := strconv.Atoi(expr)
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "failed to parse number")
|
||||
}
|
||||
if num < 0 {
|
||||
return 0, errors.New("number must be positive")
|
||||
}
|
||||
|
||||
return uint(num), nil
|
||||
}
|
||||
|
||||
// getBits sets all bits in the range [min, max], modulo the given step size.
|
||||
func getBits(min, max, step uint) uint64 {
|
||||
var bits uint64
|
||||
|
||||
// If step is 1, use shifts.
|
||||
if step == 1 {
|
||||
return ^(math.MaxUint64 << (max + 1)) & (math.MaxUint64 << min)
|
||||
}
|
||||
|
||||
// Else, use a simple loop.
|
||||
for i := min; i <= max; i += step {
|
||||
bits |= 1 << i
|
||||
}
|
||||
return bits
|
||||
}
|
||||
|
||||
// all returns all bits within the given bounds.
|
||||
func all(r bounds) uint64 {
|
||||
return getBits(r.min, r.max, 1) | starBit
|
||||
}
|
||||
|
||||
// parseDescriptor returns a predefined schedule for the expression, or error if none matches.
|
||||
func parseDescriptor(descriptor string, loc *time.Location) (Schedule, error) {
|
||||
switch descriptor {
|
||||
case "@yearly", "@annually":
|
||||
return &SpecSchedule{
|
||||
Second: 1 << seconds.min,
|
||||
Minute: 1 << minutes.min,
|
||||
Hour: 1 << hours.min,
|
||||
Dom: 1 << dom.min,
|
||||
Month: 1 << months.min,
|
||||
Dow: all(dow),
|
||||
Location: loc,
|
||||
}, nil
|
||||
|
||||
case "@monthly":
|
||||
return &SpecSchedule{
|
||||
Second: 1 << seconds.min,
|
||||
Minute: 1 << minutes.min,
|
||||
Hour: 1 << hours.min,
|
||||
Dom: 1 << dom.min,
|
||||
Month: all(months),
|
||||
Dow: all(dow),
|
||||
Location: loc,
|
||||
}, nil
|
||||
|
||||
case "@weekly":
|
||||
return &SpecSchedule{
|
||||
Second: 1 << seconds.min,
|
||||
Minute: 1 << minutes.min,
|
||||
Hour: 1 << hours.min,
|
||||
Dom: all(dom),
|
||||
Month: all(months),
|
||||
Dow: 1 << dow.min,
|
||||
Location: loc,
|
||||
}, nil
|
||||
|
||||
case "@daily", "@midnight":
|
||||
return &SpecSchedule{
|
||||
Second: 1 << seconds.min,
|
||||
Minute: 1 << minutes.min,
|
||||
Hour: 1 << hours.min,
|
||||
Dom: all(dom),
|
||||
Month: all(months),
|
||||
Dow: all(dow),
|
||||
Location: loc,
|
||||
}, nil
|
||||
|
||||
case "@hourly":
|
||||
return &SpecSchedule{
|
||||
Second: 1 << seconds.min,
|
||||
Minute: 1 << minutes.min,
|
||||
Hour: all(hours),
|
||||
Dom: all(dom),
|
||||
Month: all(months),
|
||||
Dow: all(dow),
|
||||
Location: loc,
|
||||
}, nil
|
||||
}
|
||||
|
||||
const every = "@every "
|
||||
if strings.HasPrefix(descriptor, every) {
|
||||
duration, err := time.ParseDuration(descriptor[len(every):])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to parse duration")
|
||||
}
|
||||
return Every(duration), nil
|
||||
}
|
||||
|
||||
return nil, errors.New("unrecognized descriptor: " + descriptor)
|
||||
}
|
||||
384
plugin/cron/parser_test.go
Normal file
384
plugin/cron/parser_test.go
Normal file
@@ -0,0 +1,384 @@
|
||||
//nolint:all
|
||||
package cron
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var secondParser = NewParser(Second | Minute | Hour | Dom | Month | DowOptional | Descriptor)
|
||||
|
||||
func TestRange(t *testing.T) {
|
||||
zero := uint64(0)
|
||||
ranges := []struct {
|
||||
expr string
|
||||
min, max uint
|
||||
expected uint64
|
||||
err string
|
||||
}{
|
||||
{"5", 0, 7, 1 << 5, ""},
|
||||
{"0", 0, 7, 1 << 0, ""},
|
||||
{"7", 0, 7, 1 << 7, ""},
|
||||
|
||||
{"5-5", 0, 7, 1 << 5, ""},
|
||||
{"5-6", 0, 7, 1<<5 | 1<<6, ""},
|
||||
{"5-7", 0, 7, 1<<5 | 1<<6 | 1<<7, ""},
|
||||
|
||||
{"5-6/2", 0, 7, 1 << 5, ""},
|
||||
{"5-7/2", 0, 7, 1<<5 | 1<<7, ""},
|
||||
{"5-7/1", 0, 7, 1<<5 | 1<<6 | 1<<7, ""},
|
||||
|
||||
{"*", 1, 3, 1<<1 | 1<<2 | 1<<3 | starBit, ""},
|
||||
{"*/2", 1, 3, 1<<1 | 1<<3, ""},
|
||||
|
||||
{"5--5", 0, 0, zero, "too many hyphens"},
|
||||
{"jan-x", 0, 0, zero, `failed to parse number: strconv.Atoi: parsing "jan": invalid syntax`},
|
||||
{"2-x", 1, 5, zero, `failed to parse number: strconv.Atoi: parsing "x": invalid syntax`},
|
||||
{"*/-12", 0, 0, zero, "number must be positive"},
|
||||
{"*//2", 0, 0, zero, "too many slashes"},
|
||||
{"1", 3, 5, zero, "below minimum"},
|
||||
{"6", 3, 5, zero, "above maximum"},
|
||||
{"5-3", 3, 5, zero, "beginning of range after end: 5-3"},
|
||||
{"*/0", 0, 0, zero, "step cannot be zero: */0"},
|
||||
}
|
||||
|
||||
for _, c := range ranges {
|
||||
actual, err := getRange(c.expr, bounds{c.min, c.max, nil})
|
||||
if len(c.err) != 0 && (err == nil || !strings.Contains(err.Error(), c.err)) {
|
||||
t.Errorf("%s => expected %v, got %v", c.expr, c.err, err)
|
||||
}
|
||||
if len(c.err) == 0 && err != nil {
|
||||
t.Errorf("%s => unexpected error %v", c.expr, err)
|
||||
}
|
||||
if actual != c.expected {
|
||||
t.Errorf("%s => expected %d, got %d", c.expr, c.expected, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestField(t *testing.T) {
|
||||
fields := []struct {
|
||||
expr string
|
||||
min, max uint
|
||||
expected uint64
|
||||
}{
|
||||
{"5", 1, 7, 1 << 5},
|
||||
{"5,6", 1, 7, 1<<5 | 1<<6},
|
||||
{"5,6,7", 1, 7, 1<<5 | 1<<6 | 1<<7},
|
||||
{"1,5-7/2,3", 1, 7, 1<<1 | 1<<5 | 1<<7 | 1<<3},
|
||||
}
|
||||
|
||||
for _, c := range fields {
|
||||
actual, _ := getField(c.expr, bounds{c.min, c.max, nil})
|
||||
if actual != c.expected {
|
||||
t.Errorf("%s => expected %d, got %d", c.expr, c.expected, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAll(t *testing.T) {
|
||||
allBits := []struct {
|
||||
r bounds
|
||||
expected uint64
|
||||
}{
|
||||
{minutes, 0xfffffffffffffff}, // 0-59: 60 ones
|
||||
{hours, 0xffffff}, // 0-23: 24 ones
|
||||
{dom, 0xfffffffe}, // 1-31: 31 ones, 1 zero
|
||||
{months, 0x1ffe}, // 1-12: 12 ones, 1 zero
|
||||
{dow, 0x7f}, // 0-6: 7 ones
|
||||
}
|
||||
|
||||
for _, c := range allBits {
|
||||
actual := all(c.r) // all() adds the starBit, so compensate for that..
|
||||
if c.expected|starBit != actual {
|
||||
t.Errorf("%d-%d/%d => expected %b, got %b",
|
||||
c.r.min, c.r.max, 1, c.expected|starBit, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBits(t *testing.T) {
|
||||
bits := []struct {
|
||||
min, max, step uint
|
||||
expected uint64
|
||||
}{
|
||||
{0, 0, 1, 0x1},
|
||||
{1, 1, 1, 0x2},
|
||||
{1, 5, 2, 0x2a}, // 101010
|
||||
{1, 4, 2, 0xa}, // 1010
|
||||
}
|
||||
|
||||
for _, c := range bits {
|
||||
actual := getBits(c.min, c.max, c.step)
|
||||
if c.expected != actual {
|
||||
t.Errorf("%d-%d/%d => expected %b, got %b",
|
||||
c.min, c.max, c.step, c.expected, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseScheduleErrors(t *testing.T) {
|
||||
var tests = []struct{ expr, err string }{
|
||||
{"* 5 j * * *", `failed to parse number: strconv.Atoi: parsing "j": invalid syntax`},
|
||||
{"@every Xm", "failed to parse duration"},
|
||||
{"@unrecognized", "unrecognized descriptor"},
|
||||
{"* * * *", "incorrect number of fields, expected 5-6"},
|
||||
{"", "empty spec string"},
|
||||
}
|
||||
for _, c := range tests {
|
||||
actual, err := secondParser.Parse(c.expr)
|
||||
if err == nil || !strings.Contains(err.Error(), c.err) {
|
||||
t.Errorf("%s => expected %v, got %v", c.expr, c.err, err)
|
||||
}
|
||||
if actual != nil {
|
||||
t.Errorf("expected nil schedule on error, got %v", actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSchedule(t *testing.T) {
|
||||
tokyo, _ := time.LoadLocation("Asia/Tokyo")
|
||||
entries := []struct {
|
||||
parser Parser
|
||||
expr string
|
||||
expected Schedule
|
||||
}{
|
||||
{secondParser, "0 5 * * * *", every5min(time.Local)},
|
||||
{standardParser, "5 * * * *", every5min(time.Local)},
|
||||
{secondParser, "CRON_TZ=UTC 0 5 * * * *", every5min(time.UTC)},
|
||||
{standardParser, "CRON_TZ=UTC 5 * * * *", every5min(time.UTC)},
|
||||
{secondParser, "CRON_TZ=Asia/Tokyo 0 5 * * * *", every5min(tokyo)},
|
||||
{secondParser, "@every 5m", ConstantDelaySchedule{5 * time.Minute}},
|
||||
{secondParser, "@midnight", midnight(time.Local)},
|
||||
{secondParser, "TZ=UTC @midnight", midnight(time.UTC)},
|
||||
{secondParser, "TZ=Asia/Tokyo @midnight", midnight(tokyo)},
|
||||
{secondParser, "@yearly", annual(time.Local)},
|
||||
{secondParser, "@annually", annual(time.Local)},
|
||||
{
|
||||
parser: secondParser,
|
||||
expr: "* 5 * * * *",
|
||||
expected: &SpecSchedule{
|
||||
Second: all(seconds),
|
||||
Minute: 1 << 5,
|
||||
Hour: all(hours),
|
||||
Dom: all(dom),
|
||||
Month: all(months),
|
||||
Dow: all(dow),
|
||||
Location: time.Local,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range entries {
|
||||
actual, err := c.parser.Parse(c.expr)
|
||||
if err != nil {
|
||||
t.Errorf("%s => unexpected error %v", c.expr, err)
|
||||
}
|
||||
if !reflect.DeepEqual(actual, c.expected) {
|
||||
t.Errorf("%s => expected %b, got %b", c.expr, c.expected, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestOptionalSecondSchedule(t *testing.T) {
|
||||
parser := NewParser(SecondOptional | Minute | Hour | Dom | Month | Dow | Descriptor)
|
||||
entries := []struct {
|
||||
expr string
|
||||
expected Schedule
|
||||
}{
|
||||
{"0 5 * * * *", every5min(time.Local)},
|
||||
{"5 5 * * * *", every5min5s(time.Local)},
|
||||
{"5 * * * *", every5min(time.Local)},
|
||||
}
|
||||
|
||||
for _, c := range entries {
|
||||
actual, err := parser.Parse(c.expr)
|
||||
if err != nil {
|
||||
t.Errorf("%s => unexpected error %v", c.expr, err)
|
||||
}
|
||||
if !reflect.DeepEqual(actual, c.expected) {
|
||||
t.Errorf("%s => expected %b, got %b", c.expr, c.expected, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeFields(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []string
|
||||
options ParseOption
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
"AllFields_NoOptional",
|
||||
[]string{"0", "5", "*", "*", "*", "*"},
|
||||
Second | Minute | Hour | Dom | Month | Dow | Descriptor,
|
||||
[]string{"0", "5", "*", "*", "*", "*"},
|
||||
},
|
||||
{
|
||||
"AllFields_SecondOptional_Provided",
|
||||
[]string{"0", "5", "*", "*", "*", "*"},
|
||||
SecondOptional | Minute | Hour | Dom | Month | Dow | Descriptor,
|
||||
[]string{"0", "5", "*", "*", "*", "*"},
|
||||
},
|
||||
{
|
||||
"AllFields_SecondOptional_NotProvided",
|
||||
[]string{"5", "*", "*", "*", "*"},
|
||||
SecondOptional | Minute | Hour | Dom | Month | Dow | Descriptor,
|
||||
[]string{"0", "5", "*", "*", "*", "*"},
|
||||
},
|
||||
{
|
||||
"SubsetFields_NoOptional",
|
||||
[]string{"5", "15", "*"},
|
||||
Hour | Dom | Month,
|
||||
[]string{"0", "0", "5", "15", "*", "*"},
|
||||
},
|
||||
{
|
||||
"SubsetFields_DowOptional_Provided",
|
||||
[]string{"5", "15", "*", "4"},
|
||||
Hour | Dom | Month | DowOptional,
|
||||
[]string{"0", "0", "5", "15", "*", "4"},
|
||||
},
|
||||
{
|
||||
"SubsetFields_DowOptional_NotProvided",
|
||||
[]string{"5", "15", "*"},
|
||||
Hour | Dom | Month | DowOptional,
|
||||
[]string{"0", "0", "5", "15", "*", "*"},
|
||||
},
|
||||
{
|
||||
"SubsetFields_SecondOptional_NotProvided",
|
||||
[]string{"5", "15", "*"},
|
||||
SecondOptional | Hour | Dom | Month,
|
||||
[]string{"0", "0", "5", "15", "*", "*"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(*testing.T) {
|
||||
actual, err := normalizeFields(test.input, test.options)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(actual, test.expected) {
|
||||
t.Errorf("expected %v, got %v", test.expected, actual)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeFields_Errors(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []string
|
||||
options ParseOption
|
||||
err string
|
||||
}{
|
||||
{
|
||||
"TwoOptionals",
|
||||
[]string{"0", "5", "*", "*", "*", "*"},
|
||||
SecondOptional | Minute | Hour | Dom | Month | DowOptional,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"TooManyFields",
|
||||
[]string{"0", "5", "*", "*"},
|
||||
SecondOptional | Minute | Hour,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"NoFields",
|
||||
[]string{},
|
||||
SecondOptional | Minute | Hour,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"TooFewFields",
|
||||
[]string{"*"},
|
||||
SecondOptional | Minute | Hour,
|
||||
"",
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(*testing.T) {
|
||||
actual, err := normalizeFields(test.input, test.options)
|
||||
if err == nil {
|
||||
t.Errorf("expected an error, got none. results: %v", actual)
|
||||
}
|
||||
if !strings.Contains(err.Error(), test.err) {
|
||||
t.Errorf("expected error %q, got %q", test.err, err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStandardSpecSchedule(t *testing.T) {
|
||||
entries := []struct {
|
||||
expr string
|
||||
expected Schedule
|
||||
err string
|
||||
}{
|
||||
{
|
||||
expr: "5 * * * *",
|
||||
expected: &SpecSchedule{1 << seconds.min, 1 << 5, all(hours), all(dom), all(months), all(dow), time.Local},
|
||||
},
|
||||
{
|
||||
expr: "@every 5m",
|
||||
expected: ConstantDelaySchedule{time.Duration(5) * time.Minute},
|
||||
},
|
||||
{
|
||||
expr: "5 j * * *",
|
||||
err: `failed to parse number: strconv.Atoi: parsing "j": invalid syntax`,
|
||||
},
|
||||
{
|
||||
expr: "* * * *",
|
||||
err: "incorrect number of fields",
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range entries {
|
||||
actual, err := ParseStandard(c.expr)
|
||||
if len(c.err) != 0 && (err == nil || !strings.Contains(err.Error(), c.err)) {
|
||||
t.Errorf("%s => expected %v, got %v", c.expr, c.err, err)
|
||||
}
|
||||
if len(c.err) == 0 && err != nil {
|
||||
t.Errorf("%s => unexpected error %v", c.expr, err)
|
||||
}
|
||||
if !reflect.DeepEqual(actual, c.expected) {
|
||||
t.Errorf("%s => expected %b, got %b", c.expr, c.expected, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNoDescriptorParser(t *testing.T) {
|
||||
parser := NewParser(Minute | Hour)
|
||||
_, err := parser.Parse("@every 1m")
|
||||
if err == nil {
|
||||
t.Error("expected an error, got none")
|
||||
}
|
||||
}
|
||||
|
||||
func every5min(loc *time.Location) *SpecSchedule {
|
||||
return &SpecSchedule{1 << 0, 1 << 5, all(hours), all(dom), all(months), all(dow), loc}
|
||||
}
|
||||
|
||||
func every5min5s(loc *time.Location) *SpecSchedule {
|
||||
return &SpecSchedule{1 << 5, 1 << 5, all(hours), all(dom), all(months), all(dow), loc}
|
||||
}
|
||||
|
||||
func midnight(loc *time.Location) *SpecSchedule {
|
||||
return &SpecSchedule{1, 1, 1, all(dom), all(months), all(dow), loc}
|
||||
}
|
||||
|
||||
func annual(loc *time.Location) *SpecSchedule {
|
||||
return &SpecSchedule{
|
||||
Second: 1 << seconds.min,
|
||||
Minute: 1 << minutes.min,
|
||||
Hour: 1 << hours.min,
|
||||
Dom: 1 << dom.min,
|
||||
Month: 1 << months.min,
|
||||
Dow: all(dow),
|
||||
Location: loc,
|
||||
}
|
||||
}
|
||||
188
plugin/cron/spec.go
Normal file
188
plugin/cron/spec.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package cron
|
||||
|
||||
import "time"
|
||||
|
||||
// SpecSchedule specifies a duty cycle (to the second granularity), based on a
|
||||
// traditional crontab specification. It is computed initially and stored as bit sets.
|
||||
type SpecSchedule struct {
|
||||
Second, Minute, Hour, Dom, Month, Dow uint64
|
||||
|
||||
// Override location for this schedule.
|
||||
Location *time.Location
|
||||
}
|
||||
|
||||
// bounds provides a range of acceptable values (plus a map of name to value).
|
||||
type bounds struct {
|
||||
min, max uint
|
||||
names map[string]uint
|
||||
}
|
||||
|
||||
// The bounds for each field.
|
||||
var (
|
||||
seconds = bounds{0, 59, nil}
|
||||
minutes = bounds{0, 59, nil}
|
||||
hours = bounds{0, 23, nil}
|
||||
dom = bounds{1, 31, nil}
|
||||
months = bounds{1, 12, map[string]uint{
|
||||
"jan": 1,
|
||||
"feb": 2,
|
||||
"mar": 3,
|
||||
"apr": 4,
|
||||
"may": 5,
|
||||
"jun": 6,
|
||||
"jul": 7,
|
||||
"aug": 8,
|
||||
"sep": 9,
|
||||
"oct": 10,
|
||||
"nov": 11,
|
||||
"dec": 12,
|
||||
}}
|
||||
dow = bounds{0, 6, map[string]uint{
|
||||
"sun": 0,
|
||||
"mon": 1,
|
||||
"tue": 2,
|
||||
"wed": 3,
|
||||
"thu": 4,
|
||||
"fri": 5,
|
||||
"sat": 6,
|
||||
}}
|
||||
)
|
||||
|
||||
const (
|
||||
// Set the top bit if a star was included in the expression.
|
||||
starBit = 1 << 63
|
||||
)
|
||||
|
||||
// Next returns the next time this schedule is activated, greater than the given
|
||||
// time. If no time can be found to satisfy the schedule, return the zero time.
|
||||
func (s *SpecSchedule) Next(t time.Time) time.Time {
|
||||
// General approach
|
||||
//
|
||||
// For Month, Day, Hour, Minute, Second:
|
||||
// Check if the time value matches. If yes, continue to the next field.
|
||||
// If the field doesn't match the schedule, then increment the field until it matches.
|
||||
// While incrementing the field, a wrap-around brings it back to the beginning
|
||||
// of the field list (since it is necessary to re-verify previous field
|
||||
// values)
|
||||
|
||||
// Convert the given time into the schedule's timezone, if one is specified.
|
||||
// Save the original timezone so we can convert back after we find a time.
|
||||
// Note that schedules without a time zone specified (time.Local) are treated
|
||||
// as local to the time provided.
|
||||
origLocation := t.Location()
|
||||
loc := s.Location
|
||||
if loc == time.Local {
|
||||
loc = t.Location()
|
||||
}
|
||||
if s.Location != time.Local {
|
||||
t = t.In(s.Location)
|
||||
}
|
||||
|
||||
// Start at the earliest possible time (the upcoming second).
|
||||
t = t.Add(1*time.Second - time.Duration(t.Nanosecond())*time.Nanosecond)
|
||||
|
||||
// This flag indicates whether a field has been incremented.
|
||||
added := false
|
||||
|
||||
// If no time is found within five years, return zero.
|
||||
yearLimit := t.Year() + 5
|
||||
|
||||
WRAP:
|
||||
if t.Year() > yearLimit {
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
// Find the first applicable month.
|
||||
// If it's this month, then do nothing.
|
||||
for 1<<uint(t.Month())&s.Month == 0 {
|
||||
// If we have to add a month, reset the other parts to 0.
|
||||
if !added {
|
||||
added = true
|
||||
// Otherwise, set the date at the beginning (since the current time is irrelevant).
|
||||
t = time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, loc)
|
||||
}
|
||||
t = t.AddDate(0, 1, 0)
|
||||
|
||||
// Wrapped around.
|
||||
if t.Month() == time.January {
|
||||
goto WRAP
|
||||
}
|
||||
}
|
||||
|
||||
// Now get a day in that month.
|
||||
//
|
||||
// NOTE: This causes issues for daylight savings regimes where midnight does
|
||||
// not exist. For example: Sao Paulo has DST that transforms midnight on
|
||||
// 11/3 into 1am. Handle that by noticing when the Hour ends up != 0.
|
||||
for !dayMatches(s, t) {
|
||||
if !added {
|
||||
added = true
|
||||
t = time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, loc)
|
||||
}
|
||||
t = t.AddDate(0, 0, 1)
|
||||
// Notice if the hour is no longer midnight due to DST.
|
||||
// Add an hour if it's 23, subtract an hour if it's 1.
|
||||
if t.Hour() != 0 {
|
||||
if t.Hour() > 12 {
|
||||
t = t.Add(time.Duration(24-t.Hour()) * time.Hour)
|
||||
} else {
|
||||
t = t.Add(time.Duration(-t.Hour()) * time.Hour)
|
||||
}
|
||||
}
|
||||
|
||||
if t.Day() == 1 {
|
||||
goto WRAP
|
||||
}
|
||||
}
|
||||
|
||||
for 1<<uint(t.Hour())&s.Hour == 0 {
|
||||
if !added {
|
||||
added = true
|
||||
t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), 0, 0, 0, loc)
|
||||
}
|
||||
t = t.Add(1 * time.Hour)
|
||||
|
||||
if t.Hour() == 0 {
|
||||
goto WRAP
|
||||
}
|
||||
}
|
||||
|
||||
for 1<<uint(t.Minute())&s.Minute == 0 {
|
||||
if !added {
|
||||
added = true
|
||||
t = t.Truncate(time.Minute)
|
||||
}
|
||||
t = t.Add(1 * time.Minute)
|
||||
|
||||
if t.Minute() == 0 {
|
||||
goto WRAP
|
||||
}
|
||||
}
|
||||
|
||||
for 1<<uint(t.Second())&s.Second == 0 {
|
||||
if !added {
|
||||
added = true
|
||||
t = t.Truncate(time.Second)
|
||||
}
|
||||
t = t.Add(1 * time.Second)
|
||||
|
||||
if t.Second() == 0 {
|
||||
goto WRAP
|
||||
}
|
||||
}
|
||||
|
||||
return t.In(origLocation)
|
||||
}
|
||||
|
||||
// dayMatches returns true if the schedule's day-of-week and day-of-month
|
||||
// restrictions are satisfied by the given time.
|
||||
func dayMatches(s *SpecSchedule, t time.Time) bool {
|
||||
var (
|
||||
domMatch = 1<<uint(t.Day())&s.Dom > 0
|
||||
dowMatch = 1<<uint(t.Weekday())&s.Dow > 0
|
||||
)
|
||||
if s.Dom&starBit > 0 || s.Dow&starBit > 0 {
|
||||
return domMatch && dowMatch
|
||||
}
|
||||
return domMatch || dowMatch
|
||||
}
|
||||
301
plugin/cron/spec_test.go
Normal file
301
plugin/cron/spec_test.go
Normal file
@@ -0,0 +1,301 @@
|
||||
//nolint:all
|
||||
package cron
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestActivation(t *testing.T) {
|
||||
tests := []struct {
|
||||
time, spec string
|
||||
expected bool
|
||||
}{
|
||||
// Every fifteen minutes.
|
||||
{"Mon Jul 9 15:00 2012", "0/15 * * * *", true},
|
||||
{"Mon Jul 9 15:45 2012", "0/15 * * * *", true},
|
||||
{"Mon Jul 9 15:40 2012", "0/15 * * * *", false},
|
||||
|
||||
// Every fifteen minutes, starting at 5 minutes.
|
||||
{"Mon Jul 9 15:05 2012", "5/15 * * * *", true},
|
||||
{"Mon Jul 9 15:20 2012", "5/15 * * * *", true},
|
||||
{"Mon Jul 9 15:50 2012", "5/15 * * * *", true},
|
||||
|
||||
// Named months
|
||||
{"Sun Jul 15 15:00 2012", "0/15 * * Jul *", true},
|
||||
{"Sun Jul 15 15:00 2012", "0/15 * * Jun *", false},
|
||||
|
||||
// Everything set.
|
||||
{"Sun Jul 15 08:30 2012", "30 08 ? Jul Sun", true},
|
||||
{"Sun Jul 15 08:30 2012", "30 08 15 Jul ?", true},
|
||||
{"Mon Jul 16 08:30 2012", "30 08 ? Jul Sun", false},
|
||||
{"Mon Jul 16 08:30 2012", "30 08 15 Jul ?", false},
|
||||
|
||||
// Predefined schedules
|
||||
{"Mon Jul 9 15:00 2012", "@hourly", true},
|
||||
{"Mon Jul 9 15:04 2012", "@hourly", false},
|
||||
{"Mon Jul 9 15:00 2012", "@daily", false},
|
||||
{"Mon Jul 9 00:00 2012", "@daily", true},
|
||||
{"Mon Jul 9 00:00 2012", "@weekly", false},
|
||||
{"Sun Jul 8 00:00 2012", "@weekly", true},
|
||||
{"Sun Jul 8 01:00 2012", "@weekly", false},
|
||||
{"Sun Jul 8 00:00 2012", "@monthly", false},
|
||||
{"Sun Jul 1 00:00 2012", "@monthly", true},
|
||||
|
||||
// Test interaction of DOW and DOM.
|
||||
// If both are restricted, then only one needs to match.
|
||||
{"Sun Jul 15 00:00 2012", "* * 1,15 * Sun", true},
|
||||
{"Fri Jun 15 00:00 2012", "* * 1,15 * Sun", true},
|
||||
{"Wed Aug 1 00:00 2012", "* * 1,15 * Sun", true},
|
||||
{"Sun Jul 15 00:00 2012", "* * */10 * Sun", true}, // verifies #70
|
||||
|
||||
// However, if one has a star, then both need to match.
|
||||
{"Sun Jul 15 00:00 2012", "* * * * Mon", false},
|
||||
{"Mon Jul 9 00:00 2012", "* * 1,15 * *", false},
|
||||
{"Sun Jul 15 00:00 2012", "* * 1,15 * *", true},
|
||||
{"Sun Jul 15 00:00 2012", "* * */2 * Sun", true},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
sched, err := ParseStandard(test.spec)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
continue
|
||||
}
|
||||
actual := sched.Next(getTime(test.time).Add(-1 * time.Second))
|
||||
expected := getTime(test.time)
|
||||
if test.expected && expected != actual || !test.expected && expected == actual {
|
||||
t.Errorf("Fail evaluating %s on %s: (expected) %s != %s (actual)",
|
||||
test.spec, test.time, expected, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNext(t *testing.T) {
|
||||
runs := []struct {
|
||||
time, spec string
|
||||
expected string
|
||||
}{
|
||||
// Simple cases
|
||||
{"Mon Jul 9 14:45 2012", "0 0/15 * * * *", "Mon Jul 9 15:00 2012"},
|
||||
{"Mon Jul 9 14:59 2012", "0 0/15 * * * *", "Mon Jul 9 15:00 2012"},
|
||||
{"Mon Jul 9 14:59:59 2012", "0 0/15 * * * *", "Mon Jul 9 15:00 2012"},
|
||||
|
||||
// Wrap around hours
|
||||
{"Mon Jul 9 15:45 2012", "0 20-35/15 * * * *", "Mon Jul 9 16:20 2012"},
|
||||
|
||||
// Wrap around days
|
||||
{"Mon Jul 9 23:46 2012", "0 */15 * * * *", "Tue Jul 10 00:00 2012"},
|
||||
{"Mon Jul 9 23:45 2012", "0 20-35/15 * * * *", "Tue Jul 10 00:20 2012"},
|
||||
{"Mon Jul 9 23:35:51 2012", "15/35 20-35/15 * * * *", "Tue Jul 10 00:20:15 2012"},
|
||||
{"Mon Jul 9 23:35:51 2012", "15/35 20-35/15 1/2 * * *", "Tue Jul 10 01:20:15 2012"},
|
||||
{"Mon Jul 9 23:35:51 2012", "15/35 20-35/15 10-12 * * *", "Tue Jul 10 10:20:15 2012"},
|
||||
|
||||
{"Mon Jul 9 23:35:51 2012", "15/35 20-35/15 1/2 */2 * *", "Thu Jul 11 01:20:15 2012"},
|
||||
{"Mon Jul 9 23:35:51 2012", "15/35 20-35/15 * 9-20 * *", "Wed Jul 10 00:20:15 2012"},
|
||||
{"Mon Jul 9 23:35:51 2012", "15/35 20-35/15 * 9-20 Jul *", "Wed Jul 10 00:20:15 2012"},
|
||||
|
||||
// Wrap around months
|
||||
{"Mon Jul 9 23:35 2012", "0 0 0 9 Apr-Oct ?", "Thu Aug 9 00:00 2012"},
|
||||
{"Mon Jul 9 23:35 2012", "0 0 0 */5 Apr,Aug,Oct Mon", "Tue Aug 1 00:00 2012"},
|
||||
{"Mon Jul 9 23:35 2012", "0 0 0 */5 Oct Mon", "Mon Oct 1 00:00 2012"},
|
||||
|
||||
// Wrap around years
|
||||
{"Mon Jul 9 23:35 2012", "0 0 0 * Feb Mon", "Mon Feb 4 00:00 2013"},
|
||||
{"Mon Jul 9 23:35 2012", "0 0 0 * Feb Mon/2", "Fri Feb 1 00:00 2013"},
|
||||
|
||||
// Wrap around minute, hour, day, month, and year
|
||||
{"Mon Dec 31 23:59:45 2012", "0 * * * * *", "Tue Jan 1 00:00:00 2013"},
|
||||
|
||||
// Leap year
|
||||
{"Mon Jul 9 23:35 2012", "0 0 0 29 Feb ?", "Mon Feb 29 00:00 2016"},
|
||||
|
||||
// Daylight savings time 2am EST (-5) -> 3am EDT (-4)
|
||||
{"2012-03-11T00:00:00-0500", "TZ=America/New_York 0 30 2 11 Mar ?", "2013-03-11T02:30:00-0400"},
|
||||
|
||||
// hourly job
|
||||
{"2012-03-11T00:00:00-0500", "TZ=America/New_York 0 0 * * * ?", "2012-03-11T01:00:00-0500"},
|
||||
{"2012-03-11T01:00:00-0500", "TZ=America/New_York 0 0 * * * ?", "2012-03-11T03:00:00-0400"},
|
||||
{"2012-03-11T03:00:00-0400", "TZ=America/New_York 0 0 * * * ?", "2012-03-11T04:00:00-0400"},
|
||||
{"2012-03-11T04:00:00-0400", "TZ=America/New_York 0 0 * * * ?", "2012-03-11T05:00:00-0400"},
|
||||
|
||||
// hourly job using CRON_TZ
|
||||
{"2012-03-11T00:00:00-0500", "CRON_TZ=America/New_York 0 0 * * * ?", "2012-03-11T01:00:00-0500"},
|
||||
{"2012-03-11T01:00:00-0500", "CRON_TZ=America/New_York 0 0 * * * ?", "2012-03-11T03:00:00-0400"},
|
||||
{"2012-03-11T03:00:00-0400", "CRON_TZ=America/New_York 0 0 * * * ?", "2012-03-11T04:00:00-0400"},
|
||||
{"2012-03-11T04:00:00-0400", "CRON_TZ=America/New_York 0 0 * * * ?", "2012-03-11T05:00:00-0400"},
|
||||
|
||||
// 1am nightly job
|
||||
{"2012-03-11T00:00:00-0500", "TZ=America/New_York 0 0 1 * * ?", "2012-03-11T01:00:00-0500"},
|
||||
{"2012-03-11T01:00:00-0500", "TZ=America/New_York 0 0 1 * * ?", "2012-03-12T01:00:00-0400"},
|
||||
|
||||
// 2am nightly job (skipped)
|
||||
{"2012-03-11T00:00:00-0500", "TZ=America/New_York 0 0 2 * * ?", "2012-03-12T02:00:00-0400"},
|
||||
|
||||
// Daylight savings time 2am EDT (-4) => 1am EST (-5)
|
||||
{"2012-11-04T00:00:00-0400", "TZ=America/New_York 0 30 2 04 Nov ?", "2012-11-04T02:30:00-0500"},
|
||||
{"2012-11-04T01:45:00-0400", "TZ=America/New_York 0 30 1 04 Nov ?", "2012-11-04T01:30:00-0500"},
|
||||
|
||||
// hourly job
|
||||
{"2012-11-04T00:00:00-0400", "TZ=America/New_York 0 0 * * * ?", "2012-11-04T01:00:00-0400"},
|
||||
{"2012-11-04T01:00:00-0400", "TZ=America/New_York 0 0 * * * ?", "2012-11-04T01:00:00-0500"},
|
||||
{"2012-11-04T01:00:00-0500", "TZ=America/New_York 0 0 * * * ?", "2012-11-04T02:00:00-0500"},
|
||||
|
||||
// 1am nightly job (runs twice)
|
||||
{"2012-11-04T00:00:00-0400", "TZ=America/New_York 0 0 1 * * ?", "2012-11-04T01:00:00-0400"},
|
||||
{"2012-11-04T01:00:00-0400", "TZ=America/New_York 0 0 1 * * ?", "2012-11-04T01:00:00-0500"},
|
||||
{"2012-11-04T01:00:00-0500", "TZ=America/New_York 0 0 1 * * ?", "2012-11-05T01:00:00-0500"},
|
||||
|
||||
// 2am nightly job
|
||||
{"2012-11-04T00:00:00-0400", "TZ=America/New_York 0 0 2 * * ?", "2012-11-04T02:00:00-0500"},
|
||||
{"2012-11-04T02:00:00-0500", "TZ=America/New_York 0 0 2 * * ?", "2012-11-05T02:00:00-0500"},
|
||||
|
||||
// 3am nightly job
|
||||
{"2012-11-04T00:00:00-0400", "TZ=America/New_York 0 0 3 * * ?", "2012-11-04T03:00:00-0500"},
|
||||
{"2012-11-04T03:00:00-0500", "TZ=America/New_York 0 0 3 * * ?", "2012-11-05T03:00:00-0500"},
|
||||
|
||||
// hourly job
|
||||
{"TZ=America/New_York 2012-11-04T00:00:00-0400", "0 0 * * * ?", "2012-11-04T01:00:00-0400"},
|
||||
{"TZ=America/New_York 2012-11-04T01:00:00-0400", "0 0 * * * ?", "2012-11-04T01:00:00-0500"},
|
||||
{"TZ=America/New_York 2012-11-04T01:00:00-0500", "0 0 * * * ?", "2012-11-04T02:00:00-0500"},
|
||||
|
||||
// 1am nightly job (runs twice)
|
||||
{"TZ=America/New_York 2012-11-04T00:00:00-0400", "0 0 1 * * ?", "2012-11-04T01:00:00-0400"},
|
||||
{"TZ=America/New_York 2012-11-04T01:00:00-0400", "0 0 1 * * ?", "2012-11-04T01:00:00-0500"},
|
||||
{"TZ=America/New_York 2012-11-04T01:00:00-0500", "0 0 1 * * ?", "2012-11-05T01:00:00-0500"},
|
||||
|
||||
// 2am nightly job
|
||||
{"TZ=America/New_York 2012-11-04T00:00:00-0400", "0 0 2 * * ?", "2012-11-04T02:00:00-0500"},
|
||||
{"TZ=America/New_York 2012-11-04T02:00:00-0500", "0 0 2 * * ?", "2012-11-05T02:00:00-0500"},
|
||||
|
||||
// 3am nightly job
|
||||
{"TZ=America/New_York 2012-11-04T00:00:00-0400", "0 0 3 * * ?", "2012-11-04T03:00:00-0500"},
|
||||
{"TZ=America/New_York 2012-11-04T03:00:00-0500", "0 0 3 * * ?", "2012-11-05T03:00:00-0500"},
|
||||
|
||||
// Unsatisfiable
|
||||
{"Mon Jul 9 23:35 2012", "0 0 0 30 Feb ?", ""},
|
||||
{"Mon Jul 9 23:35 2012", "0 0 0 31 Apr ?", ""},
|
||||
|
||||
// Monthly job
|
||||
{"TZ=America/New_York 2012-11-04T00:00:00-0400", "0 0 3 3 * ?", "2012-12-03T03:00:00-0500"},
|
||||
|
||||
// Test the scenario of DST resulting in midnight not being a valid time.
|
||||
// https://github.com/robfig/cron/issues/157
|
||||
{"2018-10-17T05:00:00-0400", "TZ=America/Sao_Paulo 0 0 9 10 * ?", "2018-11-10T06:00:00-0500"},
|
||||
{"2018-02-14T05:00:00-0500", "TZ=America/Sao_Paulo 0 0 9 22 * ?", "2018-02-22T07:00:00-0500"},
|
||||
}
|
||||
|
||||
for _, c := range runs {
|
||||
sched, err := secondParser.Parse(c.spec)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
continue
|
||||
}
|
||||
actual := sched.Next(getTime(c.time))
|
||||
expected := getTime(c.expected)
|
||||
if !actual.Equal(expected) {
|
||||
t.Errorf("%s, \"%s\": (expected) %v != %v (actual)", c.time, c.spec, expected, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrors(t *testing.T) {
|
||||
invalidSpecs := []string{
|
||||
"xyz",
|
||||
"60 0 * * *",
|
||||
"0 60 * * *",
|
||||
"0 0 * * XYZ",
|
||||
}
|
||||
for _, spec := range invalidSpecs {
|
||||
_, err := ParseStandard(spec)
|
||||
if err == nil {
|
||||
t.Error("expected an error parsing: ", spec)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getTime(value string) time.Time {
|
||||
if value == "" {
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
var location = time.Local
|
||||
if strings.HasPrefix(value, "TZ=") {
|
||||
parts := strings.Fields(value)
|
||||
loc, err := time.LoadLocation(parts[0][len("TZ="):])
|
||||
if err != nil {
|
||||
panic("could not parse location:" + err.Error())
|
||||
}
|
||||
location = loc
|
||||
value = parts[1]
|
||||
}
|
||||
|
||||
var layouts = []string{
|
||||
"Mon Jan 2 15:04 2006",
|
||||
"Mon Jan 2 15:04:05 2006",
|
||||
}
|
||||
for _, layout := range layouts {
|
||||
if t, err := time.ParseInLocation(layout, value, location); err == nil {
|
||||
return t
|
||||
}
|
||||
}
|
||||
if t, err := time.ParseInLocation("2006-01-02T15:04:05-0700", value, location); err == nil {
|
||||
return t
|
||||
}
|
||||
panic("could not parse time value " + value)
|
||||
}
|
||||
|
||||
func TestNextWithTz(t *testing.T) {
|
||||
runs := []struct {
|
||||
time, spec string
|
||||
expected string
|
||||
}{
|
||||
// Failing tests
|
||||
{"2016-01-03T13:09:03+0530", "14 14 * * *", "2016-01-03T14:14:00+0530"},
|
||||
{"2016-01-03T04:09:03+0530", "14 14 * * ?", "2016-01-03T14:14:00+0530"},
|
||||
|
||||
// Passing tests
|
||||
{"2016-01-03T14:09:03+0530", "14 14 * * *", "2016-01-03T14:14:00+0530"},
|
||||
{"2016-01-03T14:00:00+0530", "14 14 * * ?", "2016-01-03T14:14:00+0530"},
|
||||
}
|
||||
for _, c := range runs {
|
||||
sched, err := ParseStandard(c.spec)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
continue
|
||||
}
|
||||
actual := sched.Next(getTimeTZ(c.time))
|
||||
expected := getTimeTZ(c.expected)
|
||||
if !actual.Equal(expected) {
|
||||
t.Errorf("%s, \"%s\": (expected) %v != %v (actual)", c.time, c.spec, expected, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getTimeTZ(value string) time.Time {
|
||||
if value == "" {
|
||||
return time.Time{}
|
||||
}
|
||||
t, err := time.Parse("Mon Jan 2 15:04 2006", value)
|
||||
if err != nil {
|
||||
t, err = time.Parse("Mon Jan 2 15:04:05 2006", value)
|
||||
if err != nil {
|
||||
t, err = time.Parse("2006-01-02T15:04:05-0700", value)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
// https://github.com/robfig/cron/issues/144
|
||||
func TestSlash0NoHang(t *testing.T) {
|
||||
schedule := "TZ=America/New_York 15/0 * * * *"
|
||||
_, err := ParseStandard(schedule)
|
||||
if err == nil {
|
||||
t.Error("expected an error on 0 increment")
|
||||
}
|
||||
}
|
||||
448
plugin/filter/common_converter.go
Normal file
448
plugin/filter/common_converter.go
Normal file
@@ -0,0 +1,448 @@
|
||||
package filter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
|
||||
)
|
||||
|
||||
// CommonSQLConverter handles the common CEL to SQL conversion logic.
|
||||
type CommonSQLConverter struct {
|
||||
dialect SQLDialect
|
||||
paramIndex int
|
||||
}
|
||||
|
||||
// NewCommonSQLConverter creates a new converter with the specified dialect.
|
||||
func NewCommonSQLConverter(dialect SQLDialect) *CommonSQLConverter {
|
||||
return &CommonSQLConverter{
|
||||
dialect: dialect,
|
||||
paramIndex: 1,
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertExprToSQL converts a CEL expression to SQL using the configured dialect.
|
||||
func (c *CommonSQLConverter) ConvertExprToSQL(ctx *ConvertContext, expr *exprv1.Expr) error {
|
||||
if v, ok := expr.ExprKind.(*exprv1.Expr_CallExpr); ok {
|
||||
switch v.CallExpr.Function {
|
||||
case "_||_", "_&&_":
|
||||
return c.handleLogicalOperator(ctx, v.CallExpr)
|
||||
case "!_":
|
||||
return c.handleNotOperator(ctx, v.CallExpr)
|
||||
case "_==_", "_!=_", "_<_", "_>_", "_<=_", "_>=_":
|
||||
return c.handleComparisonOperator(ctx, v.CallExpr)
|
||||
case "@in":
|
||||
return c.handleInOperator(ctx, v.CallExpr)
|
||||
case "contains":
|
||||
return c.handleContainsOperator(ctx, v.CallExpr)
|
||||
}
|
||||
} else if v, ok := expr.ExprKind.(*exprv1.Expr_IdentExpr); ok {
|
||||
return c.handleIdentifier(ctx, v.IdentExpr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleLogicalOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error {
|
||||
if len(callExpr.Args) != 2 {
|
||||
return errors.Errorf("invalid number of arguments for %s", callExpr.Function)
|
||||
}
|
||||
|
||||
if _, err := ctx.Buffer.WriteString("("); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := c.ConvertExprToSQL(ctx, callExpr.Args[0]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
operator := "AND"
|
||||
if callExpr.Function == "_||_" {
|
||||
operator = "OR"
|
||||
}
|
||||
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := c.ConvertExprToSQL(ctx, callExpr.Args[1]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleNotOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error {
|
||||
if len(callExpr.Args) != 1 {
|
||||
return errors.Errorf("invalid number of arguments for %s", callExpr.Function)
|
||||
}
|
||||
|
||||
if _, err := ctx.Buffer.WriteString("NOT ("); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := c.ConvertExprToSQL(ctx, callExpr.Args[0]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleComparisonOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error {
|
||||
if len(callExpr.Args) != 2 {
|
||||
return errors.Errorf("invalid number of arguments for %s", callExpr.Function)
|
||||
}
|
||||
|
||||
// Check if the left side is a function call like size(tags)
|
||||
if leftCallExpr, ok := callExpr.Args[0].ExprKind.(*exprv1.Expr_CallExpr); ok {
|
||||
if leftCallExpr.CallExpr.Function == "size" {
|
||||
return c.handleSizeComparison(ctx, callExpr, leftCallExpr.CallExpr)
|
||||
}
|
||||
}
|
||||
|
||||
identifier, err := GetIdentExprName(callExpr.Args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !slices.Contains([]string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "has_task_list"}, identifier) {
|
||||
return errors.Errorf("invalid identifier for %s", callExpr.Function)
|
||||
}
|
||||
|
||||
value, err := GetExprValue(callExpr.Args[1])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
operator := c.getComparisonOperator(callExpr.Function)
|
||||
|
||||
switch identifier {
|
||||
case "created_ts", "updated_ts":
|
||||
return c.handleTimestampComparison(ctx, identifier, operator, value)
|
||||
case "visibility", "content":
|
||||
return c.handleStringComparison(ctx, identifier, operator, value)
|
||||
case "creator_id":
|
||||
return c.handleIntComparison(ctx, identifier, operator, value)
|
||||
case "has_task_list":
|
||||
return c.handleBooleanComparison(ctx, identifier, operator, value)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleSizeComparison(ctx *ConvertContext, callExpr *exprv1.Expr_Call, sizeCall *exprv1.Expr_Call) error {
|
||||
if len(sizeCall.Args) != 1 {
|
||||
return errors.New("size function requires exactly one argument")
|
||||
}
|
||||
|
||||
identifier, err := GetIdentExprName(sizeCall.Args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if identifier != "tags" {
|
||||
return errors.Errorf("size function only supports 'tags' identifier, got: %s", identifier)
|
||||
}
|
||||
|
||||
value, err := GetExprValue(callExpr.Args[1])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
valueInt, ok := value.(int64)
|
||||
if !ok {
|
||||
return errors.New("size comparison value must be an integer")
|
||||
}
|
||||
|
||||
operator := c.getComparisonOperator(callExpr.Function)
|
||||
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s",
|
||||
c.dialect.GetJSONArrayLength("$.tags"),
|
||||
operator,
|
||||
c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx.Args = append(ctx.Args, valueInt)
|
||||
c.paramIndex++
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleInOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error {
|
||||
if len(callExpr.Args) != 2 {
|
||||
return errors.Errorf("invalid number of arguments for %s", callExpr.Function)
|
||||
}
|
||||
|
||||
// Check if this is "element in collection" syntax
|
||||
if identifier, err := GetIdentExprName(callExpr.Args[1]); err == nil {
|
||||
if identifier == "tags" {
|
||||
return c.handleElementInTags(ctx, callExpr.Args[0])
|
||||
}
|
||||
return errors.Errorf("invalid collection identifier for %s: %s", callExpr.Function, identifier)
|
||||
}
|
||||
|
||||
// Original logic for "identifier in [list]" syntax
|
||||
identifier, err := GetIdentExprName(callExpr.Args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !slices.Contains([]string{"tag", "visibility"}, identifier) {
|
||||
return errors.Errorf("invalid identifier for %s", callExpr.Function)
|
||||
}
|
||||
|
||||
values := []any{}
|
||||
for _, element := range callExpr.Args[1].GetListExpr().Elements {
|
||||
value, err := GetConstValue(element)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
values = append(values, value)
|
||||
}
|
||||
|
||||
if identifier == "tag" {
|
||||
return c.handleTagInList(ctx, values)
|
||||
} else if identifier == "visibility" {
|
||||
return c.handleVisibilityInList(ctx, values)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleElementInTags(ctx *ConvertContext, elementExpr *exprv1.Expr) error {
|
||||
element, err := GetConstValue(elementExpr)
|
||||
if err != nil {
|
||||
return errors.Errorf("first argument must be a constant value for 'element in tags': %v", err)
|
||||
}
|
||||
|
||||
// Use dialect-specific JSON contains logic
|
||||
sqlExpr := c.dialect.GetJSONContains("$.tags", "element")
|
||||
if _, err := ctx.Buffer.WriteString(sqlExpr); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// For SQLite, we need a different approach since it uses LIKE
|
||||
if _, ok := c.dialect.(*SQLiteDialect); ok {
|
||||
ctx.Args = append(ctx.Args, fmt.Sprintf(`%%"%s"%%`, element))
|
||||
} else {
|
||||
ctx.Args = append(ctx.Args, element)
|
||||
}
|
||||
c.paramIndex++
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleTagInList(ctx *ConvertContext, values []any) error {
|
||||
subconditions := []string{}
|
||||
args := []any{}
|
||||
|
||||
for _, v := range values {
|
||||
if _, ok := c.dialect.(*SQLiteDialect); ok {
|
||||
subconditions = append(subconditions, c.dialect.GetJSONLike("$.tags", "pattern"))
|
||||
args = append(args, fmt.Sprintf(`%%"%s"%%`, v))
|
||||
} else {
|
||||
subconditions = append(subconditions, c.dialect.GetJSONContains("$.tags", "element"))
|
||||
args = append(args, v)
|
||||
}
|
||||
c.paramIndex++
|
||||
}
|
||||
|
||||
if len(subconditions) == 1 {
|
||||
if _, err := ctx.Buffer.WriteString(subconditions[0]); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s)", strings.Join(subconditions, " OR "))); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
ctx.Args = append(ctx.Args, args...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleVisibilityInList(ctx *ConvertContext, values []any) error {
|
||||
placeholders := []string{}
|
||||
for range values {
|
||||
placeholders = append(placeholders, c.dialect.GetParameterPlaceholder(c.paramIndex))
|
||||
c.paramIndex++
|
||||
}
|
||||
|
||||
tablePrefix := c.dialect.GetTablePrefix()
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`visibility` IN (%s)", tablePrefix, strings.Join(placeholders, ","))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx.Args = append(ctx.Args, values...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleContainsOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error {
|
||||
if len(callExpr.Args) != 1 {
|
||||
return errors.Errorf("invalid number of arguments for %s", callExpr.Function)
|
||||
}
|
||||
|
||||
identifier, err := GetIdentExprName(callExpr.Target)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if identifier != "content" {
|
||||
return errors.Errorf("invalid identifier for %s", callExpr.Function)
|
||||
}
|
||||
|
||||
arg, err := GetConstValue(callExpr.Args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tablePrefix := c.dialect.GetTablePrefix()
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`content` LIKE %s", tablePrefix, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx.Args = append(ctx.Args, fmt.Sprintf("%%%s%%", arg))
|
||||
c.paramIndex++
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleIdentifier(ctx *ConvertContext, identExpr *exprv1.Expr_Ident) error {
|
||||
identifier := identExpr.GetName()
|
||||
|
||||
if !slices.Contains([]string{"pinned", "has_task_list"}, identifier) {
|
||||
return errors.Errorf("invalid identifier %s", identifier)
|
||||
}
|
||||
|
||||
if identifier == "pinned" {
|
||||
tablePrefix := c.dialect.GetTablePrefix()
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`pinned` IS TRUE", tablePrefix)); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if identifier == "has_task_list" {
|
||||
if _, err := ctx.Buffer.WriteString(c.dialect.GetBooleanCheck("$.property.hasTaskList")); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleTimestampComparison(ctx *ConvertContext, field, operator string, value interface{}) error {
|
||||
valueInt, ok := value.(int64)
|
||||
if !ok {
|
||||
return errors.New("invalid integer timestamp value")
|
||||
}
|
||||
|
||||
timestampField := c.dialect.GetTimestampComparison(field)
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s", timestampField, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx.Args = append(ctx.Args, valueInt)
|
||||
c.paramIndex++
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleStringComparison(ctx *ConvertContext, field, operator string, value interface{}) error {
|
||||
if operator != "=" && operator != "!=" {
|
||||
return errors.Errorf("invalid operator for %s", field)
|
||||
}
|
||||
|
||||
valueStr, ok := value.(string)
|
||||
if !ok {
|
||||
return errors.New("invalid string value")
|
||||
}
|
||||
|
||||
tablePrefix := c.dialect.GetTablePrefix()
|
||||
fieldName := field
|
||||
if field == "visibility" {
|
||||
fieldName = "`visibility`"
|
||||
} else if field == "content" {
|
||||
fieldName = "`content`"
|
||||
}
|
||||
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.%s %s %s", tablePrefix, fieldName, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx.Args = append(ctx.Args, valueStr)
|
||||
c.paramIndex++
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleIntComparison(ctx *ConvertContext, field, operator string, value interface{}) error {
|
||||
if operator != "=" && operator != "!=" {
|
||||
return errors.Errorf("invalid operator for %s", field)
|
||||
}
|
||||
|
||||
valueInt, ok := value.(int64)
|
||||
if !ok {
|
||||
return errors.New("invalid int value")
|
||||
}
|
||||
|
||||
tablePrefix := c.dialect.GetTablePrefix()
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`%s` %s %s", tablePrefix, field, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx.Args = append(ctx.Args, valueInt)
|
||||
c.paramIndex++
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleBooleanComparison(ctx *ConvertContext, field, operator string, value interface{}) error {
|
||||
if operator != "=" && operator != "!=" {
|
||||
return errors.Errorf("invalid operator for %s", field)
|
||||
}
|
||||
|
||||
valueBool, ok := value.(bool)
|
||||
if !ok {
|
||||
return errors.New("invalid boolean value for has_task_list")
|
||||
}
|
||||
|
||||
sqlExpr := c.dialect.GetBooleanComparison("$.property.hasTaskList", valueBool)
|
||||
if _, err := ctx.Buffer.WriteString(sqlExpr); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// For dialects that need parameters (PostgreSQL)
|
||||
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
|
||||
ctx.Args = append(ctx.Args, valueBool)
|
||||
c.paramIndex++
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*CommonSQLConverter) getComparisonOperator(function string) string {
|
||||
switch function {
|
||||
case "_==_":
|
||||
return "="
|
||||
case "_!=_":
|
||||
return "!="
|
||||
case "_<_":
|
||||
return "<"
|
||||
case "_>_":
|
||||
return ">"
|
||||
case "_<=_":
|
||||
return "<="
|
||||
case "_>=_":
|
||||
return ">="
|
||||
default:
|
||||
return "="
|
||||
}
|
||||
}
|
||||
20
plugin/filter/converter.go
Normal file
20
plugin/filter/converter.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package filter
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ConvertContext struct {
|
||||
Buffer strings.Builder
|
||||
Args []any
|
||||
// The offset of the next argument in the condition string.
|
||||
// Mainly using for PostgreSQL.
|
||||
ArgsOffset int
|
||||
}
|
||||
|
||||
func NewConvertContext() *ConvertContext {
|
||||
return &ConvertContext{
|
||||
Buffer: strings.Builder{},
|
||||
Args: []any{},
|
||||
}
|
||||
}
|
||||
212
plugin/filter/dialect.go
Normal file
212
plugin/filter/dialect.go
Normal file
@@ -0,0 +1,212 @@
|
||||
package filter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// SQLDialect defines database-specific SQL generation methods.
|
||||
type SQLDialect interface {
|
||||
// Basic field access
|
||||
GetTablePrefix() string
|
||||
GetParameterPlaceholder(index int) string
|
||||
|
||||
// JSON operations
|
||||
GetJSONExtract(path string) string
|
||||
GetJSONArrayLength(path string) string
|
||||
GetJSONContains(path, element string) string
|
||||
GetJSONLike(path, pattern string) string
|
||||
|
||||
// Boolean operations
|
||||
GetBooleanValue(value bool) interface{}
|
||||
GetBooleanComparison(path string, value bool) string
|
||||
GetBooleanCheck(path string) string
|
||||
|
||||
// Timestamp operations
|
||||
GetTimestampComparison(field string) string
|
||||
GetCurrentTimestamp() string
|
||||
}
|
||||
|
||||
// DatabaseType represents the type of database.
|
||||
type DatabaseType string
|
||||
|
||||
const (
|
||||
SQLite DatabaseType = "sqlite"
|
||||
MySQL DatabaseType = "mysql"
|
||||
PostgreSQL DatabaseType = "postgres"
|
||||
)
|
||||
|
||||
// GetDialect returns the appropriate dialect for the database type.
|
||||
func GetDialect(dbType DatabaseType) SQLDialect {
|
||||
switch dbType {
|
||||
case SQLite:
|
||||
return &SQLiteDialect{}
|
||||
case MySQL:
|
||||
return &MySQLDialect{}
|
||||
case PostgreSQL:
|
||||
return &PostgreSQLDialect{}
|
||||
default:
|
||||
return &SQLiteDialect{} // default fallback
|
||||
}
|
||||
}
|
||||
|
||||
// SQLiteDialect implements SQLDialect for SQLite.
|
||||
type SQLiteDialect struct{}
|
||||
|
||||
func (*SQLiteDialect) GetTablePrefix() string {
|
||||
return "`memo`"
|
||||
}
|
||||
|
||||
func (*SQLiteDialect) GetParameterPlaceholder(_ int) string {
|
||||
return "?"
|
||||
}
|
||||
|
||||
func (d *SQLiteDialect) GetJSONExtract(path string) string {
|
||||
return fmt.Sprintf("JSON_EXTRACT(%s.`payload`, '%s')", d.GetTablePrefix(), path)
|
||||
}
|
||||
|
||||
func (d *SQLiteDialect) GetJSONArrayLength(path string) string {
|
||||
return fmt.Sprintf("JSON_ARRAY_LENGTH(COALESCE(%s, JSON_ARRAY()))", d.GetJSONExtract(path))
|
||||
}
|
||||
|
||||
func (d *SQLiteDialect) GetJSONContains(path, _ string) string {
|
||||
return fmt.Sprintf("%s LIKE ?", d.GetJSONExtract(path))
|
||||
}
|
||||
|
||||
func (d *SQLiteDialect) GetJSONLike(path, _ string) string {
|
||||
return fmt.Sprintf("%s LIKE ?", d.GetJSONExtract(path))
|
||||
}
|
||||
|
||||
func (*SQLiteDialect) GetBooleanValue(value bool) interface{} {
|
||||
if value {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (d *SQLiteDialect) GetBooleanComparison(path string, value bool) string {
|
||||
return fmt.Sprintf("%s = %d", d.GetJSONExtract(path), d.GetBooleanValue(value))
|
||||
}
|
||||
|
||||
func (d *SQLiteDialect) GetBooleanCheck(path string) string {
|
||||
return fmt.Sprintf("%s IS TRUE", d.GetJSONExtract(path))
|
||||
}
|
||||
|
||||
func (d *SQLiteDialect) GetTimestampComparison(field string) string {
|
||||
return fmt.Sprintf("%s.`%s`", d.GetTablePrefix(), field)
|
||||
}
|
||||
|
||||
func (*SQLiteDialect) GetCurrentTimestamp() string {
|
||||
return "strftime('%s', 'now')"
|
||||
}
|
||||
|
||||
// MySQLDialect implements SQLDialect for MySQL.
|
||||
type MySQLDialect struct{}
|
||||
|
||||
func (*MySQLDialect) GetTablePrefix() string {
|
||||
return "`memo`"
|
||||
}
|
||||
|
||||
func (*MySQLDialect) GetParameterPlaceholder(_ int) string {
|
||||
return "?"
|
||||
}
|
||||
|
||||
func (d *MySQLDialect) GetJSONExtract(path string) string {
|
||||
return fmt.Sprintf("JSON_EXTRACT(%s.`payload`, '%s')", d.GetTablePrefix(), path)
|
||||
}
|
||||
|
||||
func (d *MySQLDialect) GetJSONArrayLength(path string) string {
|
||||
return fmt.Sprintf("JSON_LENGTH(COALESCE(%s, JSON_ARRAY()))", d.GetJSONExtract(path))
|
||||
}
|
||||
|
||||
func (d *MySQLDialect) GetJSONContains(path, _ string) string {
|
||||
return fmt.Sprintf("JSON_CONTAINS(%s, ?)", d.GetJSONExtract(path))
|
||||
}
|
||||
|
||||
func (d *MySQLDialect) GetJSONLike(path, _ string) string {
|
||||
return fmt.Sprintf("%s LIKE ?", d.GetJSONExtract(path))
|
||||
}
|
||||
|
||||
func (*MySQLDialect) GetBooleanValue(value bool) interface{} {
|
||||
return value
|
||||
}
|
||||
|
||||
func (d *MySQLDialect) GetBooleanComparison(path string, value bool) string {
|
||||
boolStr := "false"
|
||||
if value {
|
||||
boolStr = "true"
|
||||
}
|
||||
return fmt.Sprintf("%s = CAST('%s' AS JSON)", d.GetJSONExtract(path), boolStr)
|
||||
}
|
||||
|
||||
func (d *MySQLDialect) GetBooleanCheck(path string) string {
|
||||
return fmt.Sprintf("%s = CAST('true' AS JSON)", d.GetJSONExtract(path))
|
||||
}
|
||||
|
||||
func (d *MySQLDialect) GetTimestampComparison(field string) string {
|
||||
return fmt.Sprintf("UNIX_TIMESTAMP(%s.`%s`)", d.GetTablePrefix(), field)
|
||||
}
|
||||
|
||||
func (*MySQLDialect) GetCurrentTimestamp() string {
|
||||
return "UNIX_TIMESTAMP()"
|
||||
}
|
||||
|
||||
// PostgreSQLDialect implements SQLDialect for PostgreSQL.
|
||||
type PostgreSQLDialect struct{}
|
||||
|
||||
func (*PostgreSQLDialect) GetTablePrefix() string {
|
||||
return "memo"
|
||||
}
|
||||
|
||||
func (*PostgreSQLDialect) GetParameterPlaceholder(index int) string {
|
||||
return fmt.Sprintf("$%d", index)
|
||||
}
|
||||
|
||||
func (d *PostgreSQLDialect) GetJSONExtract(path string) string {
|
||||
// Convert $.property.hasTaskList to payload->'property'->>'hasTaskList'
|
||||
parts := strings.Split(strings.TrimPrefix(path, "$."), ".")
|
||||
result := fmt.Sprintf("%s.payload", d.GetTablePrefix())
|
||||
for i, part := range parts {
|
||||
if i == len(parts)-1 {
|
||||
result += fmt.Sprintf("->>'%s'", part)
|
||||
} else {
|
||||
result += fmt.Sprintf("->'%s'", part)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (d *PostgreSQLDialect) GetJSONArrayLength(path string) string {
|
||||
jsonPath := strings.Replace(path, "$.tags", "payload->'tags'", 1)
|
||||
return fmt.Sprintf("jsonb_array_length(COALESCE(%s.%s, '[]'::jsonb))", d.GetTablePrefix(), jsonPath)
|
||||
}
|
||||
|
||||
func (d *PostgreSQLDialect) GetJSONContains(path, _ string) string {
|
||||
jsonPath := strings.Replace(path, "$.tags", "payload->'tags'", 1)
|
||||
return fmt.Sprintf("%s.%s @> jsonb_build_array(?)", d.GetTablePrefix(), jsonPath)
|
||||
}
|
||||
|
||||
func (d *PostgreSQLDialect) GetJSONLike(path, _ string) string {
|
||||
jsonPath := strings.Replace(path, "$.tags", "payload->'tags'", 1)
|
||||
return fmt.Sprintf("%s.%s @> jsonb_build_array(?)", d.GetTablePrefix(), jsonPath)
|
||||
}
|
||||
|
||||
func (*PostgreSQLDialect) GetBooleanValue(value bool) interface{} {
|
||||
return value
|
||||
}
|
||||
|
||||
func (d *PostgreSQLDialect) GetBooleanComparison(path string, _ bool) string {
|
||||
return fmt.Sprintf("(%s)::boolean = ?", d.GetJSONExtract(path))
|
||||
}
|
||||
|
||||
func (d *PostgreSQLDialect) GetBooleanCheck(path string) string {
|
||||
return fmt.Sprintf("(%s)::boolean IS TRUE", d.GetJSONExtract(path))
|
||||
}
|
||||
|
||||
func (d *PostgreSQLDialect) GetTimestampComparison(field string) string {
|
||||
return fmt.Sprintf("EXTRACT(EPOCH FROM %s.%s)", d.GetTablePrefix(), field)
|
||||
}
|
||||
|
||||
func (*PostgreSQLDialect) GetCurrentTimestamp() string {
|
||||
return "EXTRACT(EPOCH FROM NOW())"
|
||||
}
|
||||
127
plugin/filter/expr.go
Normal file
127
plugin/filter/expr.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package filter
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
|
||||
)
|
||||
|
||||
// GetConstValue returns the constant value of the expression.
|
||||
func GetConstValue(expr *exprv1.Expr) (any, error) {
|
||||
v, ok := expr.ExprKind.(*exprv1.Expr_ConstExpr)
|
||||
if !ok {
|
||||
return nil, errors.New("invalid constant expression")
|
||||
}
|
||||
|
||||
switch v.ConstExpr.ConstantKind.(type) {
|
||||
case *exprv1.Constant_StringValue:
|
||||
return v.ConstExpr.GetStringValue(), nil
|
||||
case *exprv1.Constant_Int64Value:
|
||||
return v.ConstExpr.GetInt64Value(), nil
|
||||
case *exprv1.Constant_Uint64Value:
|
||||
return v.ConstExpr.GetUint64Value(), nil
|
||||
case *exprv1.Constant_DoubleValue:
|
||||
return v.ConstExpr.GetDoubleValue(), nil
|
||||
case *exprv1.Constant_BoolValue:
|
||||
return v.ConstExpr.GetBoolValue(), nil
|
||||
default:
|
||||
return nil, errors.New("unexpected constant type")
|
||||
}
|
||||
}
|
||||
|
||||
// GetIdentExprName returns the name of the identifier expression.
|
||||
func GetIdentExprName(expr *exprv1.Expr) (string, error) {
|
||||
_, ok := expr.ExprKind.(*exprv1.Expr_IdentExpr)
|
||||
if !ok {
|
||||
return "", errors.New("invalid identifier expression")
|
||||
}
|
||||
return expr.GetIdentExpr().GetName(), nil
|
||||
}
|
||||
|
||||
// GetFunctionValue evaluates CEL function calls and returns their value.
|
||||
// This is specifically for time functions like now().
|
||||
func GetFunctionValue(expr *exprv1.Expr) (any, error) {
|
||||
callExpr, ok := expr.ExprKind.(*exprv1.Expr_CallExpr)
|
||||
if !ok {
|
||||
return nil, errors.New("invalid function call expression")
|
||||
}
|
||||
|
||||
switch callExpr.CallExpr.Function {
|
||||
case "now":
|
||||
if len(callExpr.CallExpr.Args) != 0 {
|
||||
return nil, errors.New("now() function takes no arguments")
|
||||
}
|
||||
return time.Now().Unix(), nil
|
||||
case "_-_":
|
||||
// Handle subtraction for expressions like "now() - 60 * 60 * 24"
|
||||
if len(callExpr.CallExpr.Args) != 2 {
|
||||
return nil, errors.New("subtraction requires exactly two arguments")
|
||||
}
|
||||
left, err := GetExprValue(callExpr.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
right, err := GetExprValue(callExpr.CallExpr.Args[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
leftInt, ok1 := left.(int64)
|
||||
rightInt, ok2 := right.(int64)
|
||||
if !ok1 || !ok2 {
|
||||
return nil, errors.New("subtraction operands must be integers")
|
||||
}
|
||||
return leftInt - rightInt, nil
|
||||
case "_*_":
|
||||
// Handle multiplication for expressions like "60 * 60 * 24"
|
||||
if len(callExpr.CallExpr.Args) != 2 {
|
||||
return nil, errors.New("multiplication requires exactly two arguments")
|
||||
}
|
||||
left, err := GetExprValue(callExpr.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
right, err := GetExprValue(callExpr.CallExpr.Args[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
leftInt, ok1 := left.(int64)
|
||||
rightInt, ok2 := right.(int64)
|
||||
if !ok1 || !ok2 {
|
||||
return nil, errors.New("multiplication operands must be integers")
|
||||
}
|
||||
return leftInt * rightInt, nil
|
||||
case "_+_":
|
||||
// Handle addition
|
||||
if len(callExpr.CallExpr.Args) != 2 {
|
||||
return nil, errors.New("addition requires exactly two arguments")
|
||||
}
|
||||
left, err := GetExprValue(callExpr.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
right, err := GetExprValue(callExpr.CallExpr.Args[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
leftInt, ok1 := left.(int64)
|
||||
rightInt, ok2 := right.(int64)
|
||||
if !ok1 || !ok2 {
|
||||
return nil, errors.New("addition operands must be integers")
|
||||
}
|
||||
return leftInt + rightInt, nil
|
||||
default:
|
||||
return nil, errors.New("unsupported function: " + callExpr.CallExpr.Function)
|
||||
}
|
||||
}
|
||||
|
||||
// GetExprValue attempts to get a value from an expression, trying constants first, then functions.
|
||||
func GetExprValue(expr *exprv1.Expr) (any, error) {
|
||||
// Try to get constant value first
|
||||
if constValue, err := GetConstValue(expr); err == nil {
|
||||
return constValue, nil
|
||||
}
|
||||
|
||||
// If not a constant, try to evaluate as a function
|
||||
return GetFunctionValue(expr)
|
||||
}
|
||||
48
plugin/filter/filter.go
Normal file
48
plugin/filter/filter.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package filter
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/cel-go/cel"
|
||||
"github.com/google/cel-go/common/types"
|
||||
"github.com/google/cel-go/common/types/ref"
|
||||
"github.com/pkg/errors"
|
||||
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
|
||||
)
|
||||
|
||||
// MemoFilterCELAttributes are the CEL attributes for memo.
|
||||
var MemoFilterCELAttributes = []cel.EnvOption{
|
||||
cel.Variable("content", cel.StringType),
|
||||
cel.Variable("creator_id", cel.IntType),
|
||||
cel.Variable("created_ts", cel.IntType),
|
||||
cel.Variable("updated_ts", cel.IntType),
|
||||
cel.Variable("pinned", cel.BoolType),
|
||||
cel.Variable("tag", cel.StringType),
|
||||
cel.Variable("tags", cel.ListType(cel.StringType)),
|
||||
cel.Variable("visibility", cel.StringType),
|
||||
cel.Variable("has_task_list", cel.BoolType),
|
||||
// Current timestamp function.
|
||||
cel.Function("now",
|
||||
cel.Overload("now",
|
||||
[]*cel.Type{},
|
||||
cel.IntType,
|
||||
cel.FunctionBinding(func(_ ...ref.Val) ref.Val {
|
||||
return types.Int(time.Now().Unix())
|
||||
}),
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
// Parse parses the filter string and returns the parsed expression.
|
||||
// The filter string should be a CEL expression.
|
||||
func Parse(filter string, opts ...cel.EnvOption) (expr *exprv1.ParsedExpr, err error) {
|
||||
e, err := cel.NewEnv(opts...)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create CEL environment")
|
||||
}
|
||||
ast, issues := e.Compile(filter)
|
||||
if issues != nil {
|
||||
return nil, errors.Errorf("failed to compile filter: %v", issues)
|
||||
}
|
||||
return cel.AstToParsedExpr(ast)
|
||||
}
|
||||
146
plugin/filter/templates.go
Normal file
146
plugin/filter/templates.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package filter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// SQLTemplate holds database-specific SQL fragments.
|
||||
type SQLTemplate struct {
|
||||
SQLite string
|
||||
MySQL string
|
||||
PostgreSQL string
|
||||
}
|
||||
|
||||
// TemplateDBType represents the database type for templates.
|
||||
type TemplateDBType string
|
||||
|
||||
const (
|
||||
SQLiteTemplate TemplateDBType = "sqlite"
|
||||
MySQLTemplate TemplateDBType = "mysql"
|
||||
PostgreSQLTemplate TemplateDBType = "postgres"
|
||||
)
|
||||
|
||||
// SQLTemplates contains common SQL patterns for different databases.
|
||||
var SQLTemplates = map[string]SQLTemplate{
|
||||
"json_extract": {
|
||||
SQLite: "JSON_EXTRACT(`memo`.`payload`, '%s')",
|
||||
MySQL: "JSON_EXTRACT(`memo`.`payload`, '%s')",
|
||||
PostgreSQL: "memo.payload%s",
|
||||
},
|
||||
"json_array_length": {
|
||||
SQLite: "JSON_ARRAY_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY()))",
|
||||
MySQL: "JSON_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY()))",
|
||||
PostgreSQL: "jsonb_array_length(COALESCE(memo.payload->'tags', '[]'::jsonb))",
|
||||
},
|
||||
"json_contains_element": {
|
||||
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?",
|
||||
MySQL: "JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?)",
|
||||
PostgreSQL: "memo.payload->'tags' @> jsonb_build_array(?)",
|
||||
},
|
||||
"json_contains_tag": {
|
||||
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?",
|
||||
MySQL: "JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?)",
|
||||
PostgreSQL: "memo.payload->'tags' @> jsonb_build_array(?)",
|
||||
},
|
||||
"boolean_true": {
|
||||
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = 1",
|
||||
MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON)",
|
||||
PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean = true",
|
||||
},
|
||||
"boolean_false": {
|
||||
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = 0",
|
||||
MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('false' AS JSON)",
|
||||
PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean = false",
|
||||
},
|
||||
"boolean_not_true": {
|
||||
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') != 1",
|
||||
MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') != CAST('true' AS JSON)",
|
||||
PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean != true",
|
||||
},
|
||||
"boolean_not_false": {
|
||||
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') != 0",
|
||||
MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') != CAST('false' AS JSON)",
|
||||
PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean != false",
|
||||
},
|
||||
"boolean_compare": {
|
||||
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') %s ?",
|
||||
MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') %s CAST(? AS JSON)",
|
||||
PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean %s ?",
|
||||
},
|
||||
"boolean_check": {
|
||||
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') IS TRUE",
|
||||
MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON)",
|
||||
PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean IS TRUE",
|
||||
},
|
||||
"table_prefix": {
|
||||
SQLite: "`memo`",
|
||||
MySQL: "`memo`",
|
||||
PostgreSQL: "memo",
|
||||
},
|
||||
"timestamp_field": {
|
||||
SQLite: "`memo`.`%s`",
|
||||
MySQL: "UNIX_TIMESTAMP(`memo`.`%s`)",
|
||||
PostgreSQL: "EXTRACT(EPOCH FROM memo.%s)",
|
||||
},
|
||||
"content_like": {
|
||||
SQLite: "`memo`.`content` LIKE ?",
|
||||
MySQL: "`memo`.`content` LIKE ?",
|
||||
PostgreSQL: "memo.content ILIKE ?",
|
||||
},
|
||||
"visibility_in": {
|
||||
SQLite: "`memo`.`visibility` IN (%s)",
|
||||
MySQL: "`memo`.`visibility` IN (%s)",
|
||||
PostgreSQL: "memo.visibility IN (%s)",
|
||||
},
|
||||
}
|
||||
|
||||
// GetSQL returns the appropriate SQL for the given template and database type.
|
||||
func GetSQL(templateName string, dbType TemplateDBType) string {
|
||||
template, exists := SQLTemplates[templateName]
|
||||
if !exists {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch dbType {
|
||||
case SQLiteTemplate:
|
||||
return template.SQLite
|
||||
case MySQLTemplate:
|
||||
return template.MySQL
|
||||
case PostgreSQLTemplate:
|
||||
return template.PostgreSQL
|
||||
default:
|
||||
return template.SQLite
|
||||
}
|
||||
}
|
||||
|
||||
// GetParameterPlaceholder returns the appropriate parameter placeholder for the database.
|
||||
func GetParameterPlaceholder(dbType TemplateDBType, index int) string {
|
||||
switch dbType {
|
||||
case PostgreSQLTemplate:
|
||||
return fmt.Sprintf("$%d", index)
|
||||
default:
|
||||
return "?"
|
||||
}
|
||||
}
|
||||
|
||||
// GetParameterValue returns the appropriate parameter value for the database.
|
||||
func GetParameterValue(dbType TemplateDBType, templateName string, value interface{}) interface{} {
|
||||
switch templateName {
|
||||
case "json_contains_element", "json_contains_tag":
|
||||
if dbType == SQLiteTemplate {
|
||||
return fmt.Sprintf(`%%"%s"%%`, value)
|
||||
}
|
||||
return value
|
||||
default:
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
// FormatPlaceholders formats a list of placeholders for the given database type.
|
||||
func FormatPlaceholders(dbType TemplateDBType, count int, startIndex int) []string {
|
||||
placeholders := make([]string, count)
|
||||
for i := 0; i < count; i++ {
|
||||
placeholders[i] = GetParameterPlaceholder(dbType, startIndex+i)
|
||||
}
|
||||
return placeholders
|
||||
}
|
||||
166
plugin/httpgetter/html_meta.go
Normal file
166
plugin/httpgetter/html_meta.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package httpgetter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/net/html"
|
||||
"golang.org/x/net/html/atom"
|
||||
)
|
||||
|
||||
var ErrInternalIP = errors.New("internal IP addresses are not allowed")
|
||||
|
||||
var httpClient = &http.Client{
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
if err := validateURL(req.URL.String()); err != nil {
|
||||
return errors.Wrap(err, "redirect to internal IP")
|
||||
}
|
||||
if len(via) >= 10 {
|
||||
return errors.New("too many redirects")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
type HTMLMeta struct {
|
||||
Title string `json:"title"`
|
||||
Description string `json:"description"`
|
||||
Image string `json:"image"`
|
||||
}
|
||||
|
||||
func GetHTMLMeta(urlStr string) (*HTMLMeta, error) {
|
||||
if err := validateURL(urlStr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
response, err := httpClient.Get(urlStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
mediatype, err := getMediatype(response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if mediatype != "text/html" {
|
||||
return nil, errors.New("not a HTML page")
|
||||
}
|
||||
|
||||
// TODO: limit the size of the response body
|
||||
|
||||
htmlMeta := extractHTMLMeta(response.Body)
|
||||
enrichSiteMeta(response.Request.URL, htmlMeta)
|
||||
return htmlMeta, nil
|
||||
}
|
||||
|
||||
func extractHTMLMeta(resp io.Reader) *HTMLMeta {
|
||||
tokenizer := html.NewTokenizer(resp)
|
||||
htmlMeta := new(HTMLMeta)
|
||||
|
||||
for {
|
||||
tokenType := tokenizer.Next()
|
||||
if tokenType == html.ErrorToken {
|
||||
break
|
||||
} else if tokenType == html.StartTagToken || tokenType == html.SelfClosingTagToken {
|
||||
token := tokenizer.Token()
|
||||
if token.DataAtom == atom.Body {
|
||||
break
|
||||
}
|
||||
|
||||
if token.DataAtom == atom.Title {
|
||||
tokenizer.Next()
|
||||
token := tokenizer.Token()
|
||||
htmlMeta.Title = token.Data
|
||||
} else if token.DataAtom == atom.Meta {
|
||||
description, ok := extractMetaProperty(token, "description")
|
||||
if ok {
|
||||
htmlMeta.Description = description
|
||||
}
|
||||
|
||||
ogTitle, ok := extractMetaProperty(token, "og:title")
|
||||
if ok {
|
||||
htmlMeta.Title = ogTitle
|
||||
}
|
||||
|
||||
ogDescription, ok := extractMetaProperty(token, "og:description")
|
||||
if ok {
|
||||
htmlMeta.Description = ogDescription
|
||||
}
|
||||
|
||||
ogImage, ok := extractMetaProperty(token, "og:image")
|
||||
if ok {
|
||||
htmlMeta.Image = ogImage
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return htmlMeta
|
||||
}
|
||||
|
||||
func extractMetaProperty(token html.Token, prop string) (content string, ok bool) {
|
||||
content, ok = "", false
|
||||
for _, attr := range token.Attr {
|
||||
if attr.Key == "property" && attr.Val == prop {
|
||||
ok = true
|
||||
}
|
||||
if attr.Key == "content" {
|
||||
content = attr.Val
|
||||
}
|
||||
}
|
||||
return content, ok
|
||||
}
|
||||
|
||||
func validateURL(urlStr string) error {
|
||||
u, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return errors.New("invalid URL format")
|
||||
}
|
||||
|
||||
if u.Scheme != "http" && u.Scheme != "https" {
|
||||
return errors.New("only http/https protocols are allowed")
|
||||
}
|
||||
|
||||
host := u.Hostname()
|
||||
if host == "" {
|
||||
return errors.New("empty hostname")
|
||||
}
|
||||
|
||||
// check if the hostname is an IP
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() {
|
||||
return errors.Wrap(ErrInternalIP, ip.String())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// check if it's a hostname, resolve it and check all returned IPs
|
||||
ips, err := net.LookupIP(host)
|
||||
if err != nil {
|
||||
return errors.Errorf("failed to resolve hostname: %v", err)
|
||||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() {
|
||||
return errors.Wrapf(ErrInternalIP, "host=%s, ip=%s", host, ip.String())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func enrichSiteMeta(url *url.URL, meta *HTMLMeta) {
|
||||
if url.Hostname() == "www.youtube.com" {
|
||||
if url.Path == "/watch" {
|
||||
vid := url.Query().Get("v")
|
||||
if vid != "" {
|
||||
meta.Image = fmt.Sprintf("https://img.youtube.com/vi/%s/mqdefault.jpg", vid)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
32
plugin/httpgetter/html_meta_test.go
Normal file
32
plugin/httpgetter/html_meta_test.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package httpgetter
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetHTMLMeta(t *testing.T) {
|
||||
tests := []struct {
|
||||
urlStr string
|
||||
htmlMeta HTMLMeta
|
||||
}{}
|
||||
for _, test := range tests {
|
||||
metadata, err := GetHTMLMeta(test.urlStr)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, test.htmlMeta, *metadata)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetHTMLMetaForInternal(t *testing.T) {
|
||||
// test for internal IP
|
||||
if _, err := GetHTMLMeta("http://192.168.0.1"); !errors.Is(err, ErrInternalIP) {
|
||||
t.Errorf("Expected error for internal IP, got %v", err)
|
||||
}
|
||||
|
||||
// test for resolved internal IP
|
||||
if _, err := GetHTMLMeta("http://localhost"); !errors.Is(err, ErrInternalIP) {
|
||||
t.Errorf("Expected error for resolved internal IP, got %v", err)
|
||||
}
|
||||
}
|
||||
1
plugin/httpgetter/http_getter.go
Normal file
1
plugin/httpgetter/http_getter.go
Normal file
@@ -0,0 +1 @@
|
||||
package httpgetter
|
||||
45
plugin/httpgetter/image.go
Normal file
45
plugin/httpgetter/image.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package httpgetter
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Image struct {
|
||||
Blob []byte
|
||||
Mediatype string
|
||||
}
|
||||
|
||||
func GetImage(urlStr string) (*Image, error) {
|
||||
if _, err := url.Parse(urlStr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
response, err := http.Get(urlStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
mediatype, err := getMediatype(response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !strings.HasPrefix(mediatype, "image/") {
|
||||
return nil, errors.New("wrong image mediatype")
|
||||
}
|
||||
|
||||
bodyBytes, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
image := &Image{
|
||||
Blob: bodyBytes,
|
||||
Mediatype: mediatype,
|
||||
}
|
||||
return image, nil
|
||||
}
|
||||
15
plugin/httpgetter/util.go
Normal file
15
plugin/httpgetter/util.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package httpgetter
|
||||
|
||||
import (
|
||||
"mime"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func getMediatype(response *http.Response) (string, error) {
|
||||
contentType := response.Header.Get("content-type")
|
||||
mediatype, _, err := mime.ParseMediaType(contentType)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return mediatype, nil
|
||||
}
|
||||
8
plugin/idp/idp.go
Normal file
8
plugin/idp/idp.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package idp
|
||||
|
||||
type IdentityProviderUserInfo struct {
|
||||
Identifier string
|
||||
DisplayName string
|
||||
Email string
|
||||
AvatarURL string
|
||||
}
|
||||
123
plugin/idp/oauth2/oauth2.go
Normal file
123
plugin/idp/oauth2/oauth2.go
Normal file
@@ -0,0 +1,123 @@
|
||||
// Package oauth2 is the plugin for OAuth2 Identity Provider.
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/usememos/memos/plugin/idp"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
)
|
||||
|
||||
// IdentityProvider represents an OAuth2 Identity Provider.
|
||||
type IdentityProvider struct {
|
||||
config *storepb.OAuth2Config
|
||||
}
|
||||
|
||||
// NewIdentityProvider initializes a new OAuth2 Identity Provider with the given configuration.
|
||||
func NewIdentityProvider(config *storepb.OAuth2Config) (*IdentityProvider, error) {
|
||||
for v, field := range map[string]string{
|
||||
config.ClientId: "clientId",
|
||||
config.ClientSecret: "clientSecret",
|
||||
config.TokenUrl: "tokenUrl",
|
||||
config.UserInfoUrl: "userInfoUrl",
|
||||
config.FieldMapping.Identifier: "fieldMapping.identifier",
|
||||
} {
|
||||
if v == "" {
|
||||
return nil, errors.Errorf(`the field "%s" is empty but required`, field)
|
||||
}
|
||||
}
|
||||
|
||||
return &IdentityProvider{
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ExchangeToken returns the exchanged OAuth2 token using the given authorization code.
|
||||
func (p *IdentityProvider) ExchangeToken(ctx context.Context, redirectURL, code string) (string, error) {
|
||||
conf := &oauth2.Config{
|
||||
ClientID: p.config.ClientId,
|
||||
ClientSecret: p.config.ClientSecret,
|
||||
RedirectURL: redirectURL,
|
||||
Scopes: p.config.Scopes,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: p.config.AuthUrl,
|
||||
TokenURL: p.config.TokenUrl,
|
||||
AuthStyle: oauth2.AuthStyleInParams,
|
||||
},
|
||||
}
|
||||
|
||||
token, err := conf.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "failed to exchange access token")
|
||||
}
|
||||
|
||||
accessToken, ok := token.Extra("access_token").(string)
|
||||
if !ok {
|
||||
return "", errors.New(`missing "access_token" from authorization response`)
|
||||
}
|
||||
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
// UserInfo returns the parsed user information using the given OAuth2 token.
|
||||
func (p *IdentityProvider) UserInfo(token string) (*idp.IdentityProviderUserInfo, error) {
|
||||
client := &http.Client{}
|
||||
req, err := http.NewRequest(http.MethodGet, p.config.UserInfoUrl, nil)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to new http request")
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get user information")
|
||||
}
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to read response body")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var claims map[string]any
|
||||
if err := json.Unmarshal(body, &claims); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to unmarshal response body")
|
||||
}
|
||||
slog.Info("user info claims", "claims", claims)
|
||||
userInfo := &idp.IdentityProviderUserInfo{}
|
||||
if v, ok := claims[p.config.FieldMapping.Identifier].(string); ok {
|
||||
userInfo.Identifier = v
|
||||
}
|
||||
if userInfo.Identifier == "" {
|
||||
return nil, errors.Errorf("the field %q is not found in claims or has empty value", p.config.FieldMapping.Identifier)
|
||||
}
|
||||
|
||||
// Best effort to map optional fields
|
||||
if p.config.FieldMapping.DisplayName != "" {
|
||||
if v, ok := claims[p.config.FieldMapping.DisplayName].(string); ok {
|
||||
userInfo.DisplayName = v
|
||||
}
|
||||
}
|
||||
if userInfo.DisplayName == "" {
|
||||
userInfo.DisplayName = userInfo.Identifier
|
||||
}
|
||||
if p.config.FieldMapping.Email != "" {
|
||||
if v, ok := claims[p.config.FieldMapping.Email].(string); ok {
|
||||
userInfo.Email = v
|
||||
}
|
||||
}
|
||||
if p.config.FieldMapping.AvatarUrl != "" {
|
||||
if v, ok := claims[p.config.FieldMapping.AvatarUrl].(string); ok {
|
||||
userInfo.AvatarURL = v
|
||||
}
|
||||
}
|
||||
slog.Info("user info", "userInfo", userInfo)
|
||||
return userInfo, nil
|
||||
}
|
||||
163
plugin/idp/oauth2/oauth2_test.go
Normal file
163
plugin/idp/oauth2/oauth2_test.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/usememos/memos/plugin/idp"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
)
|
||||
|
||||
func TestNewIdentityProvider(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *storepb.OAuth2Config
|
||||
containsErr string
|
||||
}{
|
||||
{
|
||||
name: "no tokenUrl",
|
||||
config: &storepb.OAuth2Config{
|
||||
ClientId: "test-client-id",
|
||||
ClientSecret: "test-client-secret",
|
||||
AuthUrl: "",
|
||||
TokenUrl: "",
|
||||
UserInfoUrl: "https://example.com/api/user",
|
||||
FieldMapping: &storepb.FieldMapping{
|
||||
Identifier: "login",
|
||||
},
|
||||
},
|
||||
containsErr: `the field "tokenUrl" is empty but required`,
|
||||
},
|
||||
{
|
||||
name: "no userInfoUrl",
|
||||
config: &storepb.OAuth2Config{
|
||||
ClientId: "test-client-id",
|
||||
ClientSecret: "test-client-secret",
|
||||
AuthUrl: "",
|
||||
TokenUrl: "https://example.com/token",
|
||||
UserInfoUrl: "",
|
||||
FieldMapping: &storepb.FieldMapping{
|
||||
Identifier: "login",
|
||||
},
|
||||
},
|
||||
containsErr: `the field "userInfoUrl" is empty but required`,
|
||||
},
|
||||
{
|
||||
name: "no field mapping identifier",
|
||||
config: &storepb.OAuth2Config{
|
||||
ClientId: "test-client-id",
|
||||
ClientSecret: "test-client-secret",
|
||||
AuthUrl: "",
|
||||
TokenUrl: "https://example.com/token",
|
||||
UserInfoUrl: "https://example.com/api/user",
|
||||
FieldMapping: &storepb.FieldMapping{
|
||||
Identifier: "",
|
||||
},
|
||||
},
|
||||
containsErr: `the field "fieldMapping.identifier" is empty but required`,
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(*testing.T) {
|
||||
_, err := NewIdentityProvider(test.config)
|
||||
assert.ErrorContains(t, err, test.containsErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newMockServer(t *testing.T, code, accessToken string, userinfo []byte) *httptest.Server {
|
||||
mux := http.NewServeMux()
|
||||
|
||||
var rawIDToken string
|
||||
mux.HandleFunc("/oauth2/token", func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, http.MethodPost, r.Method)
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
vals, err := url.ParseQuery(string(body))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, code, vals.Get("code"))
|
||||
require.Equal(t, "authorization_code", vals.Get("grant_type"))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w).Encode(map[string]any{
|
||||
"access_token": accessToken,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600,
|
||||
"id_token": rawIDToken,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
mux.HandleFunc("/oauth2/userinfo", func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, err := w.Write(userinfo)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
s := httptest.NewServer(mux)
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func TestIdentityProvider(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
const (
|
||||
testClientID = "test-client-id"
|
||||
testCode = "test-code"
|
||||
testAccessToken = "test-access-token"
|
||||
testSubject = "123456789"
|
||||
testName = "John Doe"
|
||||
testEmail = "john.doe@example.com"
|
||||
)
|
||||
userInfo, err := json.Marshal(
|
||||
map[string]any{
|
||||
"sub": testSubject,
|
||||
"name": testName,
|
||||
"email": testEmail,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
s := newMockServer(t, testCode, testAccessToken, userInfo)
|
||||
|
||||
oauth2, err := NewIdentityProvider(
|
||||
&storepb.OAuth2Config{
|
||||
ClientId: testClientID,
|
||||
ClientSecret: "test-client-secret",
|
||||
TokenUrl: fmt.Sprintf("%s/oauth2/token", s.URL),
|
||||
UserInfoUrl: fmt.Sprintf("%s/oauth2/userinfo", s.URL),
|
||||
FieldMapping: &storepb.FieldMapping{
|
||||
Identifier: "sub",
|
||||
DisplayName: "name",
|
||||
Email: "email",
|
||||
},
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
redirectURL := "https://example.com/oauth/callback"
|
||||
oauthToken, err := oauth2.ExchangeToken(ctx, redirectURL, testCode)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, testAccessToken, oauthToken)
|
||||
|
||||
userInfoResult, err := oauth2.UserInfo(oauthToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
wantUserInfo := &idp.IdentityProviderUserInfo{
|
||||
Identifier: testSubject,
|
||||
DisplayName: testName,
|
||||
Email: testEmail,
|
||||
}
|
||||
assert.Equal(t, wantUserInfo, userInfoResult)
|
||||
}
|
||||
92
plugin/storage/s3/s3.go
Normal file
92
plugin/storage/s3/s3.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package s3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
Client *s3.Client
|
||||
Bucket *string
|
||||
}
|
||||
|
||||
func NewClient(ctx context.Context, s3Config *storepb.StorageS3Config) (*Client, error) {
|
||||
cfg, err := config.LoadDefaultConfig(ctx,
|
||||
config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(s3Config.AccessKeyId, s3Config.AccessKeySecret, "")),
|
||||
config.WithRegion(s3Config.Region),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to load s3 config")
|
||||
}
|
||||
|
||||
client := s3.NewFromConfig(cfg, func(o *s3.Options) {
|
||||
o.BaseEndpoint = aws.String(s3Config.Endpoint)
|
||||
o.UsePathStyle = s3Config.UsePathStyle
|
||||
o.RequestChecksumCalculation = aws.RequestChecksumCalculationWhenRequired
|
||||
o.ResponseChecksumValidation = aws.ResponseChecksumValidationWhenRequired
|
||||
})
|
||||
return &Client{
|
||||
Client: client,
|
||||
Bucket: aws.String(s3Config.Bucket),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UploadObject uploads an object to S3.
|
||||
func (c *Client) UploadObject(ctx context.Context, key string, fileType string, content io.Reader) (string, error) {
|
||||
uploader := manager.NewUploader(c.Client)
|
||||
putInput := s3.PutObjectInput{
|
||||
Bucket: c.Bucket,
|
||||
Key: aws.String(key),
|
||||
ContentType: aws.String(fileType),
|
||||
Body: content,
|
||||
}
|
||||
result, err := uploader.Upload(ctx, &putInput)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
resultKey := result.Key
|
||||
if resultKey == nil || *resultKey == "" {
|
||||
return "", errors.New("failed to get file key")
|
||||
}
|
||||
return *resultKey, nil
|
||||
}
|
||||
|
||||
// PresignGetObject presigns an object in S3.
|
||||
func (c *Client) PresignGetObject(ctx context.Context, key string) (string, error) {
|
||||
presignClient := s3.NewPresignClient(c.Client)
|
||||
presignResult, err := presignClient.PresignGetObject(ctx, &s3.GetObjectInput{
|
||||
Bucket: aws.String(*c.Bucket),
|
||||
Key: aws.String(key),
|
||||
}, func(opts *s3.PresignOptions) {
|
||||
// Set the expiration time of the presigned URL to 5 days.
|
||||
// Reference: https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-query-string-auth.html
|
||||
opts.Expires = time.Duration(5 * 24 * time.Hour)
|
||||
})
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "failed to presign put object")
|
||||
}
|
||||
return presignResult.URL, nil
|
||||
}
|
||||
|
||||
// DeleteObject deletes an object in S3.
|
||||
func (c *Client) DeleteObject(ctx context.Context, key string) error {
|
||||
_, err := c.Client.DeleteObject(ctx, &s3.DeleteObjectInput{
|
||||
Bucket: c.Bucket,
|
||||
Key: aws.String(key),
|
||||
})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to delete object")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
90
plugin/webhook/webhook.go
Normal file
90
plugin/webhook/webhook.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package webhook
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
)
|
||||
|
||||
var (
|
||||
// timeout is the timeout for webhook request. Default to 30 seconds.
|
||||
timeout = 30 * time.Second
|
||||
)
|
||||
|
||||
type WebhookRequestPayload struct {
|
||||
// The target URL for the webhook request.
|
||||
URL string `json:"url"`
|
||||
// The type of activity that triggered this webhook.
|
||||
ActivityType string `json:"activityType"`
|
||||
// The resource name of the creator. Format: users/{user}
|
||||
Creator string `json:"creator"`
|
||||
// The memo that triggered this webhook (if applicable).
|
||||
Memo *v1pb.Memo `json:"memo"`
|
||||
}
|
||||
|
||||
// Post posts the message to webhook endpoint.
|
||||
func Post(requestPayload *WebhookRequestPayload) error {
|
||||
body, err := json.Marshal(requestPayload)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to marshal webhook request to %s", requestPayload.URL)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", requestPayload.URL, bytes.NewBuffer(body))
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to construct webhook request to %s", requestPayload.URL)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
client := &http.Client{
|
||||
Timeout: timeout,
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to post webhook to %s", requestPayload.URL)
|
||||
}
|
||||
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to read webhook response from %s", requestPayload.URL)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode > 299 {
|
||||
return errors.Errorf("failed to post webhook %s, status code: %d, response body: %s", requestPayload.URL, resp.StatusCode, b)
|
||||
}
|
||||
|
||||
response := &struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}{}
|
||||
if err := json.Unmarshal(b, response); err != nil {
|
||||
return errors.Wrapf(err, "failed to unmarshal webhook response from %s", requestPayload.URL)
|
||||
}
|
||||
|
||||
if response.Code != 0 {
|
||||
return errors.Errorf("receive error code sent by webhook server, code %d, msg: %s", response.Code, response.Message)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PostAsync posts the message to webhook endpoint asynchronously.
|
||||
// It spawns a new goroutine to handle the request and does not wait for the response.
|
||||
func PostAsync(requestPayload *WebhookRequestPayload) {
|
||||
go func() {
|
||||
if err := Post(requestPayload); err != nil {
|
||||
// Since we're in a goroutine, we can only log the error
|
||||
slog.Warn("Failed to dispatch webhook asynchronously",
|
||||
slog.String("url", requestPayload.URL),
|
||||
slog.String("activityType", requestPayload.ActivityType),
|
||||
slog.Any("err", err))
|
||||
}
|
||||
}()
|
||||
}
|
||||
1
plugin/webhook/webhook_test.go
Normal file
1
plugin/webhook/webhook_test.go
Normal file
@@ -0,0 +1 @@
|
||||
package webhook
|
||||
Reference in New Issue
Block a user