init commit

This commit is contained in:
2025-11-30 13:01:24 -05:00
parent f4596a372d
commit 29355260ed
607 changed files with 136371 additions and 234 deletions

1
plugin/cron/README.md Normal file
View File

@@ -0,0 +1 @@
Fork from https://github.com/robfig/cron

96
plugin/cron/chain.go Normal file
View File

@@ -0,0 +1,96 @@
package cron
import (
"errors"
"fmt"
"runtime"
"sync"
"time"
)
// JobWrapper decorates the given Job with some behavior.
type JobWrapper func(Job) Job
// Chain is a sequence of JobWrappers that decorates submitted jobs with
// cross-cutting behaviors like logging or synchronization.
type Chain struct {
wrappers []JobWrapper
}
// NewChain returns a Chain consisting of the given JobWrappers.
func NewChain(c ...JobWrapper) Chain {
return Chain{c}
}
// Then decorates the given job with all JobWrappers in the chain.
//
// This:
//
// NewChain(m1, m2, m3).Then(job)
//
// is equivalent to:
//
// m1(m2(m3(job)))
func (c Chain) Then(j Job) Job {
for i := range c.wrappers {
j = c.wrappers[len(c.wrappers)-i-1](j)
}
return j
}
// Recover panics in wrapped jobs and log them with the provided logger.
func Recover(logger Logger) JobWrapper {
return func(j Job) Job {
return FuncJob(func() {
defer func() {
if r := recover(); r != nil {
const size = 64 << 10
buf := make([]byte, size)
buf = buf[:runtime.Stack(buf, false)]
err, ok := r.(error)
if !ok {
err = errors.New("panic: " + fmt.Sprint(r))
}
logger.Error(err, "panic", "stack", "...\n"+string(buf))
}
}()
j.Run()
})
}
}
// DelayIfStillRunning serializes jobs, delaying subsequent runs until the
// previous one is complete. Jobs running after a delay of more than a minute
// have the delay logged at Info.
func DelayIfStillRunning(logger Logger) JobWrapper {
return func(j Job) Job {
var mu sync.Mutex
return FuncJob(func() {
start := time.Now()
mu.Lock()
defer mu.Unlock()
if dur := time.Since(start); dur > time.Minute {
logger.Info("delay", "duration", dur)
}
j.Run()
})
}
}
// SkipIfStillRunning skips an invocation of the Job if a previous invocation is
// still running. It logs skips to the given logger at Info level.
func SkipIfStillRunning(logger Logger) JobWrapper {
return func(j Job) Job {
var ch = make(chan struct{}, 1)
ch <- struct{}{}
return FuncJob(func() {
select {
case v := <-ch:
defer func() { ch <- v }()
j.Run()
default:
logger.Info("skip")
}
})
}
}

239
plugin/cron/chain_test.go Normal file
View File

@@ -0,0 +1,239 @@
//nolint:all
package cron
import (
"io"
"log"
"reflect"
"sync"
"testing"
"time"
)
func appendingJob(slice *[]int, value int) Job {
var m sync.Mutex
return FuncJob(func() {
m.Lock()
*slice = append(*slice, value)
m.Unlock()
})
}
func appendingWrapper(slice *[]int, value int) JobWrapper {
return func(j Job) Job {
return FuncJob(func() {
appendingJob(slice, value).Run()
j.Run()
})
}
}
func TestChain(t *testing.T) {
var nums []int
var (
append1 = appendingWrapper(&nums, 1)
append2 = appendingWrapper(&nums, 2)
append3 = appendingWrapper(&nums, 3)
append4 = appendingJob(&nums, 4)
)
NewChain(append1, append2, append3).Then(append4).Run()
if !reflect.DeepEqual(nums, []int{1, 2, 3, 4}) {
t.Error("unexpected order of calls:", nums)
}
}
func TestChainRecover(t *testing.T) {
panickingJob := FuncJob(func() {
panic("panickingJob panics")
})
t.Run("panic exits job by default", func(*testing.T) {
defer func() {
if err := recover(); err == nil {
t.Errorf("panic expected, but none received")
}
}()
NewChain().Then(panickingJob).
Run()
})
t.Run("Recovering JobWrapper recovers", func(*testing.T) {
NewChain(Recover(PrintfLogger(log.New(io.Discard, "", 0)))).
Then(panickingJob).
Run()
})
t.Run("composed with the *IfStillRunning wrappers", func(*testing.T) {
NewChain(Recover(PrintfLogger(log.New(io.Discard, "", 0)))).
Then(panickingJob).
Run()
})
}
type countJob struct {
m sync.Mutex
started int
done int
delay time.Duration
}
func (j *countJob) Run() {
j.m.Lock()
j.started++
j.m.Unlock()
time.Sleep(j.delay)
j.m.Lock()
j.done++
j.m.Unlock()
}
func (j *countJob) Started() int {
defer j.m.Unlock()
j.m.Lock()
return j.started
}
func (j *countJob) Done() int {
defer j.m.Unlock()
j.m.Lock()
return j.done
}
func TestChainDelayIfStillRunning(t *testing.T) {
t.Run("runs immediately", func(*testing.T) {
var j countJob
wrappedJob := NewChain(DelayIfStillRunning(DiscardLogger)).Then(&j)
go wrappedJob.Run()
time.Sleep(2 * time.Millisecond) // Give the job 2ms to complete.
if c := j.Done(); c != 1 {
t.Errorf("expected job run once, immediately, got %d", c)
}
})
t.Run("second run immediate if first done", func(*testing.T) {
var j countJob
wrappedJob := NewChain(DelayIfStillRunning(DiscardLogger)).Then(&j)
go func() {
go wrappedJob.Run()
time.Sleep(time.Millisecond)
go wrappedJob.Run()
}()
time.Sleep(3 * time.Millisecond) // Give both jobs 3ms to complete.
if c := j.Done(); c != 2 {
t.Errorf("expected job run twice, immediately, got %d", c)
}
})
t.Run("second run delayed if first not done", func(*testing.T) {
var j countJob
j.delay = 10 * time.Millisecond
wrappedJob := NewChain(DelayIfStillRunning(DiscardLogger)).Then(&j)
go func() {
go wrappedJob.Run()
time.Sleep(time.Millisecond)
go wrappedJob.Run()
}()
// After 5ms, the first job is still in progress, and the second job was
// run but should be waiting for it to finish.
time.Sleep(5 * time.Millisecond)
started, done := j.Started(), j.Done()
if started != 1 || done != 0 {
t.Error("expected first job started, but not finished, got", started, done)
}
// Verify that the second job completes.
time.Sleep(25 * time.Millisecond)
started, done = j.Started(), j.Done()
if started != 2 || done != 2 {
t.Error("expected both jobs done, got", started, done)
}
})
}
func TestChainSkipIfStillRunning(t *testing.T) {
t.Run("runs immediately", func(*testing.T) {
var j countJob
wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j)
go wrappedJob.Run()
time.Sleep(2 * time.Millisecond) // Give the job 2ms to complete.
if c := j.Done(); c != 1 {
t.Errorf("expected job run once, immediately, got %d", c)
}
})
t.Run("second run immediate if first done", func(*testing.T) {
var j countJob
wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j)
go func() {
go wrappedJob.Run()
time.Sleep(time.Millisecond)
go wrappedJob.Run()
}()
time.Sleep(3 * time.Millisecond) // Give both jobs 3ms to complete.
if c := j.Done(); c != 2 {
t.Errorf("expected job run twice, immediately, got %d", c)
}
})
t.Run("second run skipped if first not done", func(*testing.T) {
var j countJob
j.delay = 10 * time.Millisecond
wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j)
go func() {
go wrappedJob.Run()
time.Sleep(time.Millisecond)
go wrappedJob.Run()
}()
// After 5ms, the first job is still in progress, and the second job was
// aleady skipped.
time.Sleep(5 * time.Millisecond)
started, done := j.Started(), j.Done()
if started != 1 || done != 0 {
t.Error("expected first job started, but not finished, got", started, done)
}
// Verify that the first job completes and second does not run.
time.Sleep(25 * time.Millisecond)
started, done = j.Started(), j.Done()
if started != 1 || done != 1 {
t.Error("expected second job skipped, got", started, done)
}
})
t.Run("skip 10 jobs on rapid fire", func(*testing.T) {
var j countJob
j.delay = 10 * time.Millisecond
wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j)
for i := 0; i < 11; i++ {
go wrappedJob.Run()
}
time.Sleep(200 * time.Millisecond)
done := j.Done()
if done != 1 {
t.Error("expected 1 jobs executed, 10 jobs dropped, got", done)
}
})
t.Run("different jobs independent", func(*testing.T) {
var j1, j2 countJob
j1.delay = 10 * time.Millisecond
j2.delay = 10 * time.Millisecond
chain := NewChain(SkipIfStillRunning(DiscardLogger))
wrappedJob1 := chain.Then(&j1)
wrappedJob2 := chain.Then(&j2)
for i := 0; i < 11; i++ {
go wrappedJob1.Run()
go wrappedJob2.Run()
}
time.Sleep(100 * time.Millisecond)
var (
done1 = j1.Done()
done2 = j2.Done()
)
if done1 != 1 || done2 != 1 {
t.Error("expected both jobs executed once, got", done1, "and", done2)
}
})
}

View File

@@ -0,0 +1,27 @@
package cron
import "time"
// ConstantDelaySchedule represents a simple recurring duty cycle, e.g. "Every 5 minutes".
// It does not support jobs more frequent than once a second.
type ConstantDelaySchedule struct {
Delay time.Duration
}
// Every returns a crontab Schedule that activates once every duration.
// Delays of less than a second are not supported (will round up to 1 second).
// Any fields less than a Second are truncated.
func Every(duration time.Duration) ConstantDelaySchedule {
if duration < time.Second {
duration = time.Second
}
return ConstantDelaySchedule{
Delay: duration - time.Duration(duration.Nanoseconds())%time.Second,
}
}
// Next returns the next time this should be run.
// This rounds so that the next activation time will be on the second.
func (schedule ConstantDelaySchedule) Next(t time.Time) time.Time {
return t.Add(schedule.Delay - time.Duration(t.Nanosecond())*time.Nanosecond)
}

View File

@@ -0,0 +1,55 @@
//nolint:all
package cron
import (
"testing"
"time"
)
func TestConstantDelayNext(t *testing.T) {
tests := []struct {
time string
delay time.Duration
expected string
}{
// Simple cases
{"Mon Jul 9 14:45 2012", 15*time.Minute + 50*time.Nanosecond, "Mon Jul 9 15:00 2012"},
{"Mon Jul 9 14:59 2012", 15 * time.Minute, "Mon Jul 9 15:14 2012"},
{"Mon Jul 9 14:59:59 2012", 15 * time.Minute, "Mon Jul 9 15:14:59 2012"},
// Wrap around hours
{"Mon Jul 9 15:45 2012", 35 * time.Minute, "Mon Jul 9 16:20 2012"},
// Wrap around days
{"Mon Jul 9 23:46 2012", 14 * time.Minute, "Tue Jul 10 00:00 2012"},
{"Mon Jul 9 23:45 2012", 35 * time.Minute, "Tue Jul 10 00:20 2012"},
{"Mon Jul 9 23:35:51 2012", 44*time.Minute + 24*time.Second, "Tue Jul 10 00:20:15 2012"},
{"Mon Jul 9 23:35:51 2012", 25*time.Hour + 44*time.Minute + 24*time.Second, "Thu Jul 11 01:20:15 2012"},
// Wrap around months
{"Mon Jul 9 23:35 2012", 91*24*time.Hour + 25*time.Minute, "Thu Oct 9 00:00 2012"},
// Wrap around minute, hour, day, month, and year
{"Mon Dec 31 23:59:45 2012", 15 * time.Second, "Tue Jan 1 00:00:00 2013"},
// Round to nearest second on the delay
{"Mon Jul 9 14:45 2012", 15*time.Minute + 50*time.Nanosecond, "Mon Jul 9 15:00 2012"},
// Round up to 1 second if the duration is less.
{"Mon Jul 9 14:45:00 2012", 15 * time.Millisecond, "Mon Jul 9 14:45:01 2012"},
// Round to nearest second when calculating the next time.
{"Mon Jul 9 14:45:00.005 2012", 15 * time.Minute, "Mon Jul 9 15:00 2012"},
// Round to nearest second for both.
{"Mon Jul 9 14:45:00.005 2012", 15*time.Minute + 50*time.Nanosecond, "Mon Jul 9 15:00 2012"},
}
for _, c := range tests {
actual := Every(c.delay).Next(getTime(c.time))
expected := getTime(c.expected)
if actual != expected {
t.Errorf("%s, \"%s\": (expected) %v != %v (actual)", c.time, c.delay, expected, actual)
}
}
}

355
plugin/cron/cron.go Normal file
View File

@@ -0,0 +1,355 @@
package cron
import (
"context"
"sort"
"sync"
"time"
)
// Cron keeps track of any number of entries, invoking the associated func as
// specified by the schedule. It may be started, stopped, and the entries may
// be inspected while running.
type Cron struct {
entries []*Entry
chain Chain
stop chan struct{}
add chan *Entry
remove chan EntryID
snapshot chan chan []Entry
running bool
logger Logger
runningMu sync.Mutex
location *time.Location
parser ScheduleParser
nextID EntryID
jobWaiter sync.WaitGroup
}
// ScheduleParser is an interface for schedule spec parsers that return a Schedule.
type ScheduleParser interface {
Parse(spec string) (Schedule, error)
}
// Job is an interface for submitted cron jobs.
type Job interface {
Run()
}
// Schedule describes a job's duty cycle.
type Schedule interface {
// Next returns the next activation time, later than the given time.
// Next is invoked initially, and then each time the job is run.
Next(time.Time) time.Time
}
// EntryID identifies an entry within a Cron instance.
type EntryID int
// Entry consists of a schedule and the func to execute on that schedule.
type Entry struct {
// ID is the cron-assigned ID of this entry, which may be used to look up a
// snapshot or remove it.
ID EntryID
// Schedule on which this job should be run.
Schedule Schedule
// Next time the job will run, or the zero time if Cron has not been
// started or this entry's schedule is unsatisfiable
Next time.Time
// Prev is the last time this job was run, or the zero time if never.
Prev time.Time
// WrappedJob is the thing to run when the Schedule is activated.
WrappedJob Job
// Job is the thing that was submitted to cron.
// It is kept around so that user code that needs to get at the job later,
// e.g. via Entries() can do so.
Job Job
}
// Valid returns true if this is not the zero entry.
func (e Entry) Valid() bool { return e.ID != 0 }
// byTime is a wrapper for sorting the entry array by time
// (with zero time at the end).
type byTime []*Entry
func (s byTime) Len() int { return len(s) }
func (s byTime) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
func (s byTime) Less(i, j int) bool {
// Two zero times should return false.
// Otherwise, zero is "greater" than any other time.
// (To sort it at the end of the list.)
if s[i].Next.IsZero() {
return false
}
if s[j].Next.IsZero() {
return true
}
return s[i].Next.Before(s[j].Next)
}
// New returns a new Cron job runner, modified by the given options.
//
// Available Settings
//
// Time Zone
// Description: The time zone in which schedules are interpreted
// Default: time.Local
//
// Parser
// Description: Parser converts cron spec strings into cron.Schedules.
// Default: Accepts this spec: https://en.wikipedia.org/wiki/Cron
//
// Chain
// Description: Wrap submitted jobs to customize behavior.
// Default: A chain that recovers panics and logs them to stderr.
//
// See "cron.With*" to modify the default behavior.
func New(opts ...Option) *Cron {
c := &Cron{
entries: nil,
chain: NewChain(),
add: make(chan *Entry),
stop: make(chan struct{}),
snapshot: make(chan chan []Entry),
remove: make(chan EntryID),
running: false,
runningMu: sync.Mutex{},
logger: DefaultLogger,
location: time.Local,
parser: standardParser,
}
for _, opt := range opts {
opt(c)
}
return c
}
// FuncJob is a wrapper that turns a func() into a cron.Job.
type FuncJob func()
func (f FuncJob) Run() { f() }
// AddFunc adds a func to the Cron to be run on the given schedule.
// The spec is parsed using the time zone of this Cron instance as the default.
// An opaque ID is returned that can be used to later remove it.
func (c *Cron) AddFunc(spec string, cmd func()) (EntryID, error) {
return c.AddJob(spec, FuncJob(cmd))
}
// AddJob adds a Job to the Cron to be run on the given schedule.
// The spec is parsed using the time zone of this Cron instance as the default.
// An opaque ID is returned that can be used to later remove it.
func (c *Cron) AddJob(spec string, cmd Job) (EntryID, error) {
schedule, err := c.parser.Parse(spec)
if err != nil {
return 0, err
}
return c.Schedule(schedule, cmd), nil
}
// Schedule adds a Job to the Cron to be run on the given schedule.
// The job is wrapped with the configured Chain.
func (c *Cron) Schedule(schedule Schedule, cmd Job) EntryID {
c.runningMu.Lock()
defer c.runningMu.Unlock()
c.nextID++
entry := &Entry{
ID: c.nextID,
Schedule: schedule,
WrappedJob: c.chain.Then(cmd),
Job: cmd,
}
if !c.running {
c.entries = append(c.entries, entry)
} else {
c.add <- entry
}
return entry.ID
}
// Entries returns a snapshot of the cron entries.
func (c *Cron) Entries() []Entry {
c.runningMu.Lock()
defer c.runningMu.Unlock()
if c.running {
replyChan := make(chan []Entry, 1)
c.snapshot <- replyChan
return <-replyChan
}
return c.entrySnapshot()
}
// Location gets the time zone location.
func (c *Cron) Location() *time.Location {
return c.location
}
// Entry returns a snapshot of the given entry, or nil if it couldn't be found.
func (c *Cron) Entry(id EntryID) Entry {
for _, entry := range c.Entries() {
if id == entry.ID {
return entry
}
}
return Entry{}
}
// Remove an entry from being run in the future.
func (c *Cron) Remove(id EntryID) {
c.runningMu.Lock()
defer c.runningMu.Unlock()
if c.running {
c.remove <- id
} else {
c.removeEntry(id)
}
}
// Start the cron scheduler in its own goroutine, or no-op if already started.
func (c *Cron) Start() {
c.runningMu.Lock()
defer c.runningMu.Unlock()
if c.running {
return
}
c.running = true
go c.runScheduler()
}
// Run the cron scheduler, or no-op if already running.
func (c *Cron) Run() {
c.runningMu.Lock()
if c.running {
c.runningMu.Unlock()
return
}
c.running = true
c.runningMu.Unlock()
c.runScheduler()
}
// runScheduler runs the scheduler.. this is private just due to the need to synchronize
// access to the 'running' state variable.
func (c *Cron) runScheduler() {
c.logger.Info("start")
// Figure out the next activation times for each entry.
now := c.now()
for _, entry := range c.entries {
entry.Next = entry.Schedule.Next(now)
c.logger.Info("schedule", "now", now, "entry", entry.ID, "next", entry.Next)
}
for {
// Determine the next entry to run.
sort.Sort(byTime(c.entries))
var timer *time.Timer
if len(c.entries) == 0 || c.entries[0].Next.IsZero() {
// If there are no entries yet, just sleep - it still handles new entries
// and stop requests.
timer = time.NewTimer(100000 * time.Hour)
} else {
timer = time.NewTimer(c.entries[0].Next.Sub(now))
}
for {
select {
case now = <-timer.C:
now = now.In(c.location)
c.logger.Info("wake", "now", now)
// Run every entry whose next time was less than now
for _, e := range c.entries {
if e.Next.After(now) || e.Next.IsZero() {
break
}
c.startJob(e.WrappedJob)
e.Prev = e.Next
e.Next = e.Schedule.Next(now)
c.logger.Info("run", "now", now, "entry", e.ID, "next", e.Next)
}
case newEntry := <-c.add:
timer.Stop()
now = c.now()
newEntry.Next = newEntry.Schedule.Next(now)
c.entries = append(c.entries, newEntry)
c.logger.Info("added", "now", now, "entry", newEntry.ID, "next", newEntry.Next)
case replyChan := <-c.snapshot:
replyChan <- c.entrySnapshot()
continue
case <-c.stop:
timer.Stop()
c.logger.Info("stop")
return
case id := <-c.remove:
timer.Stop()
now = c.now()
c.removeEntry(id)
c.logger.Info("removed", "entry", id)
}
break
}
}
}
// startJob runs the given job in a new goroutine.
func (c *Cron) startJob(j Job) {
c.jobWaiter.Add(1)
go func() {
defer c.jobWaiter.Done()
j.Run()
}()
}
// now returns current time in c location.
func (c *Cron) now() time.Time {
return time.Now().In(c.location)
}
// Stop stops the cron scheduler if it is running; otherwise it does nothing.
// A context is returned so the caller can wait for running jobs to complete.
func (c *Cron) Stop() context.Context {
c.runningMu.Lock()
defer c.runningMu.Unlock()
if c.running {
c.stop <- struct{}{}
c.running = false
}
ctx, cancel := context.WithCancel(context.Background())
go func() {
c.jobWaiter.Wait()
cancel()
}()
return ctx
}
// entrySnapshot returns a copy of the current cron entry list.
func (c *Cron) entrySnapshot() []Entry {
var entries = make([]Entry, len(c.entries))
for i, e := range c.entries {
entries[i] = *e
}
return entries
}
func (c *Cron) removeEntry(id EntryID) {
var entries []*Entry
for _, e := range c.entries {
if e.ID != id {
entries = append(entries, e)
}
}
c.entries = entries
}

702
plugin/cron/cron_test.go Normal file
View File

@@ -0,0 +1,702 @@
//nolint:all
package cron
import (
"bytes"
"fmt"
"log"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
)
// Many tests schedule a job for every second, and then wait at most a second
// for it to run. This amount is just slightly larger than 1 second to
// compensate for a few milliseconds of runtime.
const OneSecond = 1*time.Second + 50*time.Millisecond
type syncWriter struct {
wr bytes.Buffer
m sync.Mutex
}
func (sw *syncWriter) Write(data []byte) (n int, err error) {
sw.m.Lock()
n, err = sw.wr.Write(data)
sw.m.Unlock()
return
}
func (sw *syncWriter) String() string {
sw.m.Lock()
defer sw.m.Unlock()
return sw.wr.String()
}
func newBufLogger(sw *syncWriter) Logger {
return PrintfLogger(log.New(sw, "", log.LstdFlags))
}
func TestFuncPanicRecovery(t *testing.T) {
var buf syncWriter
cron := New(WithParser(secondParser),
WithChain(Recover(newBufLogger(&buf))))
cron.Start()
defer cron.Stop()
cron.AddFunc("* * * * * ?", func() {
panic("YOLO")
})
select {
case <-time.After(OneSecond):
if !strings.Contains(buf.String(), "YOLO") {
t.Error("expected a panic to be logged, got none")
}
return
}
}
type DummyJob struct{}
func (DummyJob) Run() {
panic("YOLO")
}
func TestJobPanicRecovery(t *testing.T) {
var job DummyJob
var buf syncWriter
cron := New(WithParser(secondParser),
WithChain(Recover(newBufLogger(&buf))))
cron.Start()
defer cron.Stop()
cron.AddJob("* * * * * ?", job)
select {
case <-time.After(OneSecond):
if !strings.Contains(buf.String(), "YOLO") {
t.Error("expected a panic to be logged, got none")
}
return
}
}
// Start and stop cron with no entries.
func TestNoEntries(t *testing.T) {
cron := newWithSeconds()
cron.Start()
select {
case <-time.After(OneSecond):
t.Fatal("expected cron will be stopped immediately")
case <-stop(cron):
}
}
// Start, stop, then add an entry. Verify entry doesn't run.
func TestStopCausesJobsToNotRun(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(1)
cron := newWithSeconds()
cron.Start()
cron.Stop()
cron.AddFunc("* * * * * ?", func() { wg.Done() })
select {
case <-time.After(OneSecond):
// No job ran!
case <-wait(wg):
t.Fatal("expected stopped cron does not run any job")
}
}
// Add a job, start cron, expect it runs.
func TestAddBeforeRunning(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(1)
cron := newWithSeconds()
cron.AddFunc("* * * * * ?", func() { wg.Done() })
cron.Start()
defer cron.Stop()
// Give cron 2 seconds to run our job (which is always activated).
select {
case <-time.After(OneSecond):
t.Fatal("expected job runs")
case <-wait(wg):
}
}
// Start cron, add a job, expect it runs.
func TestAddWhileRunning(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(1)
cron := newWithSeconds()
cron.Start()
defer cron.Stop()
cron.AddFunc("* * * * * ?", func() { wg.Done() })
select {
case <-time.After(OneSecond):
t.Fatal("expected job runs")
case <-wait(wg):
}
}
// Test for #34. Adding a job after calling start results in multiple job invocations
func TestAddWhileRunningWithDelay(t *testing.T) {
cron := newWithSeconds()
cron.Start()
defer cron.Stop()
time.Sleep(5 * time.Second)
var calls int64
cron.AddFunc("* * * * * *", func() { atomic.AddInt64(&calls, 1) })
<-time.After(OneSecond)
if atomic.LoadInt64(&calls) != 1 {
t.Errorf("called %d times, expected 1\n", calls)
}
}
// Add a job, remove a job, start cron, expect nothing runs.
func TestRemoveBeforeRunning(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(1)
cron := newWithSeconds()
id, _ := cron.AddFunc("* * * * * ?", func() { wg.Done() })
cron.Remove(id)
cron.Start()
defer cron.Stop()
select {
case <-time.After(OneSecond):
// Success, shouldn't run
case <-wait(wg):
t.FailNow()
}
}
// Start cron, add a job, remove it, expect it doesn't run.
func TestRemoveWhileRunning(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(1)
cron := newWithSeconds()
cron.Start()
defer cron.Stop()
id, _ := cron.AddFunc("* * * * * ?", func() { wg.Done() })
cron.Remove(id)
select {
case <-time.After(OneSecond):
case <-wait(wg):
t.FailNow()
}
}
// Test timing with Entries.
func TestSnapshotEntries(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(1)
cron := New()
cron.AddFunc("@every 2s", func() { wg.Done() })
cron.Start()
defer cron.Stop()
// Cron should fire in 2 seconds. After 1 second, call Entries.
select {
case <-time.After(OneSecond):
cron.Entries()
}
// Even though Entries was called, the cron should fire at the 2 second mark.
select {
case <-time.After(OneSecond):
t.Error("expected job runs at 2 second mark")
case <-wait(wg):
}
}
// Test that the entries are correctly sorted.
// Add a bunch of long-in-the-future entries, and an immediate entry, and ensure
// that the immediate entry runs immediately.
// Also: Test that multiple jobs run in the same instant.
func TestMultipleEntries(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(2)
cron := newWithSeconds()
cron.AddFunc("0 0 0 1 1 ?", func() {})
cron.AddFunc("* * * * * ?", func() { wg.Done() })
id1, _ := cron.AddFunc("* * * * * ?", func() { t.Fatal() })
id2, _ := cron.AddFunc("* * * * * ?", func() { t.Fatal() })
cron.AddFunc("0 0 0 31 12 ?", func() {})
cron.AddFunc("* * * * * ?", func() { wg.Done() })
cron.Remove(id1)
cron.Start()
cron.Remove(id2)
defer cron.Stop()
select {
case <-time.After(OneSecond):
t.Error("expected job run in proper order")
case <-wait(wg):
}
}
// Test running the same job twice.
func TestRunningJobTwice(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(2)
cron := newWithSeconds()
cron.AddFunc("0 0 0 1 1 ?", func() {})
cron.AddFunc("0 0 0 31 12 ?", func() {})
cron.AddFunc("* * * * * ?", func() { wg.Done() })
cron.Start()
defer cron.Stop()
select {
case <-time.After(2 * OneSecond):
t.Error("expected job fires 2 times")
case <-wait(wg):
}
}
func TestRunningMultipleSchedules(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(2)
cron := newWithSeconds()
cron.AddFunc("0 0 0 1 1 ?", func() {})
cron.AddFunc("0 0 0 31 12 ?", func() {})
cron.AddFunc("* * * * * ?", func() { wg.Done() })
cron.Schedule(Every(time.Minute), FuncJob(func() {}))
cron.Schedule(Every(time.Second), FuncJob(func() { wg.Done() }))
cron.Schedule(Every(time.Hour), FuncJob(func() {}))
cron.Start()
defer cron.Stop()
select {
case <-time.After(2 * OneSecond):
t.Error("expected job fires 2 times")
case <-wait(wg):
}
}
// Test that the cron is run in the local time zone (as opposed to UTC).
func TestLocalTimezone(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(2)
now := time.Now()
// FIX: Issue #205
// This calculation doesn't work in seconds 58 or 59.
// Take the easy way out and sleep.
if now.Second() >= 58 {
time.Sleep(2 * time.Second)
now = time.Now()
}
spec := fmt.Sprintf("%d,%d %d %d %d %d ?",
now.Second()+1, now.Second()+2, now.Minute(), now.Hour(), now.Day(), now.Month())
cron := newWithSeconds()
cron.AddFunc(spec, func() { wg.Done() })
cron.Start()
defer cron.Stop()
select {
case <-time.After(OneSecond * 2):
t.Error("expected job fires 2 times")
case <-wait(wg):
}
}
// Test that the cron is run in the given time zone (as opposed to local).
func TestNonLocalTimezone(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(2)
loc, err := time.LoadLocation("Atlantic/Cape_Verde")
if err != nil {
fmt.Printf("Failed to load time zone Atlantic/Cape_Verde: %+v", err)
t.Fail()
}
now := time.Now().In(loc)
// FIX: Issue #205
// This calculation doesn't work in seconds 58 or 59.
// Take the easy way out and sleep.
if now.Second() >= 58 {
time.Sleep(2 * time.Second)
now = time.Now().In(loc)
}
spec := fmt.Sprintf("%d,%d %d %d %d %d ?",
now.Second()+1, now.Second()+2, now.Minute(), now.Hour(), now.Day(), now.Month())
cron := New(WithLocation(loc), WithParser(secondParser))
cron.AddFunc(spec, func() { wg.Done() })
cron.Start()
defer cron.Stop()
select {
case <-time.After(OneSecond * 2):
t.Error("expected job fires 2 times")
case <-wait(wg):
}
}
// Test that calling stop before start silently returns without
// blocking the stop channel.
func TestStopWithoutStart(t *testing.T) {
cron := New()
cron.Stop()
}
type testJob struct {
wg *sync.WaitGroup
name string
}
func (t testJob) Run() {
t.wg.Done()
}
// Test that adding an invalid job spec returns an error
func TestInvalidJobSpec(t *testing.T) {
cron := New()
_, err := cron.AddJob("this will not parse", nil)
if err == nil {
t.Errorf("expected an error with invalid spec, got nil")
}
}
// Test blocking run method behaves as Start()
func TestBlockingRun(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(1)
cron := newWithSeconds()
cron.AddFunc("* * * * * ?", func() { wg.Done() })
var unblockChan = make(chan struct{})
go func() {
cron.Run()
close(unblockChan)
}()
defer cron.Stop()
select {
case <-time.After(OneSecond):
t.Error("expected job fires")
case <-unblockChan:
t.Error("expected that Run() blocks")
case <-wait(wg):
}
}
// Test that double-running is a no-op
func TestStartNoop(t *testing.T) {
var tickChan = make(chan struct{}, 2)
cron := newWithSeconds()
cron.AddFunc("* * * * * ?", func() {
tickChan <- struct{}{}
})
cron.Start()
defer cron.Stop()
// Wait for the first firing to ensure the runner is going
<-tickChan
cron.Start()
<-tickChan
// Fail if this job fires again in a short period, indicating a double-run
select {
case <-time.After(time.Millisecond):
case <-tickChan:
t.Error("expected job fires exactly twice")
}
}
// Simple test using Runnables.
func TestJob(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(1)
cron := newWithSeconds()
cron.AddJob("0 0 0 30 Feb ?", testJob{wg, "job0"})
cron.AddJob("0 0 0 1 1 ?", testJob{wg, "job1"})
job2, _ := cron.AddJob("* * * * * ?", testJob{wg, "job2"})
cron.AddJob("1 0 0 1 1 ?", testJob{wg, "job3"})
cron.Schedule(Every(5*time.Second+5*time.Nanosecond), testJob{wg, "job4"})
job5 := cron.Schedule(Every(5*time.Minute), testJob{wg, "job5"})
// Test getting an Entry pre-Start.
if actualName := cron.Entry(job2).Job.(testJob).name; actualName != "job2" {
t.Error("wrong job retrieved:", actualName)
}
if actualName := cron.Entry(job5).Job.(testJob).name; actualName != "job5" {
t.Error("wrong job retrieved:", actualName)
}
cron.Start()
defer cron.Stop()
select {
case <-time.After(OneSecond):
t.FailNow()
case <-wait(wg):
}
// Ensure the entries are in the right order.
expecteds := []string{"job2", "job4", "job5", "job1", "job3", "job0"}
var actuals []string
for _, entry := range cron.Entries() {
actuals = append(actuals, entry.Job.(testJob).name)
}
for i, expected := range expecteds {
if actuals[i] != expected {
t.Fatalf("Jobs not in the right order. (expected) %s != %s (actual)", expecteds, actuals)
}
}
// Test getting Entries.
if actualName := cron.Entry(job2).Job.(testJob).name; actualName != "job2" {
t.Error("wrong job retrieved:", actualName)
}
if actualName := cron.Entry(job5).Job.(testJob).name; actualName != "job5" {
t.Error("wrong job retrieved:", actualName)
}
}
// Issue #206
// Ensure that the next run of a job after removing an entry is accurate.
func TestScheduleAfterRemoval(t *testing.T) {
var wg1 sync.WaitGroup
var wg2 sync.WaitGroup
wg1.Add(1)
wg2.Add(1)
// The first time this job is run, set a timer and remove the other job
// 750ms later. Correct behavior would be to still run the job again in
// 250ms, but the bug would cause it to run instead 1s later.
var calls int
var mu sync.Mutex
cron := newWithSeconds()
hourJob := cron.Schedule(Every(time.Hour), FuncJob(func() {}))
cron.Schedule(Every(time.Second), FuncJob(func() {
mu.Lock()
defer mu.Unlock()
switch calls {
case 0:
wg1.Done()
calls++
case 1:
time.Sleep(750 * time.Millisecond)
cron.Remove(hourJob)
calls++
case 2:
calls++
wg2.Done()
case 3:
panic("unexpected 3rd call")
}
}))
cron.Start()
defer cron.Stop()
// the first run might be any length of time 0 - 1s, since the schedule
// rounds to the second. wait for the first run to true up.
wg1.Wait()
select {
case <-time.After(2 * OneSecond):
t.Error("expected job fires 2 times")
case <-wait(&wg2):
}
}
type ZeroSchedule struct{}
func (*ZeroSchedule) Next(time.Time) time.Time {
return time.Time{}
}
// Tests that job without time does not run
func TestJobWithZeroTimeDoesNotRun(t *testing.T) {
cron := newWithSeconds()
var calls int64
cron.AddFunc("* * * * * *", func() { atomic.AddInt64(&calls, 1) })
cron.Schedule(new(ZeroSchedule), FuncJob(func() { t.Error("expected zero task will not run") }))
cron.Start()
defer cron.Stop()
<-time.After(OneSecond)
if atomic.LoadInt64(&calls) != 1 {
t.Errorf("called %d times, expected 1\n", calls)
}
}
func TestStopAndWait(t *testing.T) {
t.Run("nothing running, returns immediately", func(*testing.T) {
cron := newWithSeconds()
cron.Start()
ctx := cron.Stop()
select {
case <-ctx.Done():
case <-time.After(time.Millisecond):
t.Error("context was not done immediately")
}
})
t.Run("repeated calls to Stop", func(*testing.T) {
cron := newWithSeconds()
cron.Start()
_ = cron.Stop()
time.Sleep(time.Millisecond)
ctx := cron.Stop()
select {
case <-ctx.Done():
case <-time.After(time.Millisecond):
t.Error("context was not done immediately")
}
})
t.Run("a couple fast jobs added, still returns immediately", func(*testing.T) {
cron := newWithSeconds()
cron.AddFunc("* * * * * *", func() {})
cron.Start()
cron.AddFunc("* * * * * *", func() {})
cron.AddFunc("* * * * * *", func() {})
cron.AddFunc("* * * * * *", func() {})
time.Sleep(time.Second)
ctx := cron.Stop()
select {
case <-ctx.Done():
case <-time.After(time.Millisecond):
t.Error("context was not done immediately")
}
})
t.Run("a couple fast jobs and a slow job added, waits for slow job", func(*testing.T) {
cron := newWithSeconds()
cron.AddFunc("* * * * * *", func() {})
cron.Start()
cron.AddFunc("* * * * * *", func() { time.Sleep(2 * time.Second) })
cron.AddFunc("* * * * * *", func() {})
time.Sleep(time.Second)
ctx := cron.Stop()
// Verify that it is not done for at least 750ms
select {
case <-ctx.Done():
t.Error("context was done too quickly immediately")
case <-time.After(750 * time.Millisecond):
// expected, because the job sleeping for 1 second is still running
}
// Verify that it IS done in the next 500ms (giving 250ms buffer)
select {
case <-ctx.Done():
// expected
case <-time.After(1500 * time.Millisecond):
t.Error("context not done after job should have completed")
}
})
t.Run("repeated calls to stop, waiting for completion and after", func(*testing.T) {
cron := newWithSeconds()
cron.AddFunc("* * * * * *", func() {})
cron.AddFunc("* * * * * *", func() { time.Sleep(2 * time.Second) })
cron.Start()
cron.AddFunc("* * * * * *", func() {})
time.Sleep(time.Second)
ctx := cron.Stop()
ctx2 := cron.Stop()
// Verify that it is not done for at least 1500ms
select {
case <-ctx.Done():
t.Error("context was done too quickly immediately")
case <-ctx2.Done():
t.Error("context2 was done too quickly immediately")
case <-time.After(1500 * time.Millisecond):
// expected, because the job sleeping for 2 seconds is still running
}
// Verify that it IS done in the next 1s (giving 500ms buffer)
select {
case <-ctx.Done():
// expected
case <-time.After(time.Second):
t.Error("context not done after job should have completed")
}
// Verify that ctx2 is also done.
select {
case <-ctx2.Done():
// expected
case <-time.After(time.Millisecond):
t.Error("context2 not done even though context1 is")
}
// Verify that a new context retrieved from stop is immediately done.
ctx3 := cron.Stop()
select {
case <-ctx3.Done():
// expected
case <-time.After(time.Millisecond):
t.Error("context not done even when cron Stop is completed")
}
})
}
func TestMultiThreadedStartAndStop(t *testing.T) {
cron := New()
go cron.Run()
time.Sleep(2 * time.Millisecond)
cron.Stop()
}
func wait(wg *sync.WaitGroup) chan bool {
ch := make(chan bool)
go func() {
wg.Wait()
ch <- true
}()
return ch
}
func stop(cron *Cron) chan bool {
ch := make(chan bool)
go func() {
cron.Stop()
ch <- true
}()
return ch
}
// newWithSeconds returns a Cron with the seconds field enabled.
func newWithSeconds() *Cron {
return New(WithParser(secondParser), WithChain())
}

86
plugin/cron/logger.go Normal file
View File

@@ -0,0 +1,86 @@
package cron
import (
"io"
"log"
"os"
"strings"
"time"
)
// DefaultLogger is used by Cron if none is specified.
var DefaultLogger = PrintfLogger(log.New(os.Stdout, "cron: ", log.LstdFlags))
// DiscardLogger can be used by callers to discard all log messages.
var DiscardLogger = PrintfLogger(log.New(io.Discard, "", 0))
// Logger is the interface used in this package for logging, so that any backend
// can be plugged in. It is a subset of the github.com/go-logr/logr interface.
type Logger interface {
// Info logs routine messages about cron's operation.
Info(msg string, keysAndValues ...interface{})
// Error logs an error condition.
Error(err error, msg string, keysAndValues ...interface{})
}
// PrintfLogger wraps a Printf-based logger (such as the standard library "log")
// into an implementation of the Logger interface which logs errors only.
func PrintfLogger(l interface{ Printf(string, ...interface{}) }) Logger {
return printfLogger{l, false}
}
// VerbosePrintfLogger wraps a Printf-based logger (such as the standard library
// "log") into an implementation of the Logger interface which logs everything.
func VerbosePrintfLogger(l interface{ Printf(string, ...interface{}) }) Logger {
return printfLogger{l, true}
}
type printfLogger struct {
logger interface{ Printf(string, ...interface{}) }
logInfo bool
}
func (pl printfLogger) Info(msg string, keysAndValues ...interface{}) {
if pl.logInfo {
keysAndValues = formatTimes(keysAndValues)
pl.logger.Printf(
formatString(len(keysAndValues)),
append([]interface{}{msg}, keysAndValues...)...)
}
}
func (pl printfLogger) Error(err error, msg string, keysAndValues ...interface{}) {
keysAndValues = formatTimes(keysAndValues)
pl.logger.Printf(
formatString(len(keysAndValues)+2),
append([]interface{}{msg, "error", err}, keysAndValues...)...)
}
// formatString returns a logfmt-like format string for the number of
// key/values.
func formatString(numKeysAndValues int) string {
var sb strings.Builder
sb.WriteString("%s")
if numKeysAndValues > 0 {
sb.WriteString(", ")
}
for i := 0; i < numKeysAndValues/2; i++ {
if i > 0 {
sb.WriteString(", ")
}
sb.WriteString("%v=%v")
}
return sb.String()
}
// formatTimes formats any time.Time values as RFC3339.
func formatTimes(keysAndValues []interface{}) []interface{} {
var formattedArgs []interface{}
for _, arg := range keysAndValues {
if t, ok := arg.(time.Time); ok {
arg = t.Format(time.RFC3339)
}
formattedArgs = append(formattedArgs, arg)
}
return formattedArgs
}

45
plugin/cron/option.go Normal file
View File

@@ -0,0 +1,45 @@
package cron
import (
"time"
)
// Option represents a modification to the default behavior of a Cron.
type Option func(*Cron)
// WithLocation overrides the timezone of the cron instance.
func WithLocation(loc *time.Location) Option {
return func(c *Cron) {
c.location = loc
}
}
// WithSeconds overrides the parser used for interpreting job schedules to
// include a seconds field as the first one.
func WithSeconds() Option {
return WithParser(NewParser(
Second | Minute | Hour | Dom | Month | Dow | Descriptor,
))
}
// WithParser overrides the parser used for interpreting job schedules.
func WithParser(p ScheduleParser) Option {
return func(c *Cron) {
c.parser = p
}
}
// WithChain specifies Job wrappers to apply to all jobs added to this cron.
// Refer to the Chain* functions in this package for provided wrappers.
func WithChain(wrappers ...JobWrapper) Option {
return func(c *Cron) {
c.chain = NewChain(wrappers...)
}
}
// WithLogger uses the provided logger.
func WithLogger(logger Logger) Option {
return func(c *Cron) {
c.logger = logger
}
}

View File

@@ -0,0 +1,43 @@
//nolint:all
package cron
import (
"log"
"strings"
"testing"
"time"
)
func TestWithLocation(t *testing.T) {
c := New(WithLocation(time.UTC))
if c.location != time.UTC {
t.Errorf("expected UTC, got %v", c.location)
}
}
func TestWithParser(t *testing.T) {
var parser = NewParser(Dow)
c := New(WithParser(parser))
if c.parser != parser {
t.Error("expected provided parser")
}
}
func TestWithVerboseLogger(t *testing.T) {
var buf syncWriter
var logger = log.New(&buf, "", log.LstdFlags)
c := New(WithLogger(VerbosePrintfLogger(logger)))
if c.logger.(printfLogger).logger != logger {
t.Error("expected provided logger")
}
c.AddFunc("@every 1s", func() {})
c.Start()
time.Sleep(OneSecond)
c.Stop()
out := buf.String()
if !strings.Contains(out, "schedule,") ||
!strings.Contains(out, "run,") {
t.Error("expected to see some actions, got:", out)
}
}

435
plugin/cron/parser.go Normal file
View File

@@ -0,0 +1,435 @@
package cron
import (
"math"
"strconv"
"strings"
"time"
"github.com/pkg/errors"
)
// Configuration options for creating a parser. Most options specify which
// fields should be included, while others enable features. If a field is not
// included the parser will assume a default value. These options do not change
// the order fields are parse in.
type ParseOption int
const (
Second ParseOption = 1 << iota // Seconds field, default 0
SecondOptional // Optional seconds field, default 0
Minute // Minutes field, default 0
Hour // Hours field, default 0
Dom // Day of month field, default *
Month // Month field, default *
Dow // Day of week field, default *
DowOptional // Optional day of week field, default *
Descriptor // Allow descriptors such as @monthly, @weekly, etc.
)
var places = []ParseOption{
Second,
Minute,
Hour,
Dom,
Month,
Dow,
}
var defaults = []string{
"0",
"0",
"0",
"*",
"*",
"*",
}
// A custom Parser that can be configured.
type Parser struct {
options ParseOption
}
// NewParser creates a Parser with custom options.
//
// It panics if more than one Optional is given, since it would be impossible to
// correctly infer which optional is provided or missing in general.
//
// Examples
//
// // Standard parser without descriptors
// specParser := NewParser(Minute | Hour | Dom | Month | Dow)
// sched, err := specParser.Parse("0 0 15 */3 *")
//
// // Same as above, just excludes time fields
// specParser := NewParser(Dom | Month | Dow)
// sched, err := specParser.Parse("15 */3 *")
//
// // Same as above, just makes Dow optional
// specParser := NewParser(Dom | Month | DowOptional)
// sched, err := specParser.Parse("15 */3")
func NewParser(options ParseOption) Parser {
optionals := 0
if options&DowOptional > 0 {
optionals++
}
if options&SecondOptional > 0 {
optionals++
}
if optionals > 1 {
panic("multiple optionals may not be configured")
}
return Parser{options}
}
// Parse returns a new crontab schedule representing the given spec.
// It returns a descriptive error if the spec is not valid.
// It accepts crontab specs and features configured by NewParser.
func (p Parser) Parse(spec string) (Schedule, error) {
if len(spec) == 0 {
return nil, errors.New("empty spec string")
}
// Extract timezone if present
var loc = time.Local
if strings.HasPrefix(spec, "TZ=") || strings.HasPrefix(spec, "CRON_TZ=") {
var err error
i := strings.Index(spec, " ")
eq := strings.Index(spec, "=")
if loc, err = time.LoadLocation(spec[eq+1 : i]); err != nil {
return nil, errors.Wrap(err, "provided bad location")
}
spec = strings.TrimSpace(spec[i:])
}
// Handle named schedules (descriptors), if configured
if strings.HasPrefix(spec, "@") {
if p.options&Descriptor == 0 {
return nil, errors.New("descriptors not enabled")
}
return parseDescriptor(spec, loc)
}
// Split on whitespace.
fields := strings.Fields(spec)
// Validate & fill in any omitted or optional fields
var err error
fields, err = normalizeFields(fields, p.options)
if err != nil {
return nil, err
}
field := func(field string, r bounds) uint64 {
if err != nil {
return 0
}
var bits uint64
bits, err = getField(field, r)
return bits
}
var (
second = field(fields[0], seconds)
minute = field(fields[1], minutes)
hour = field(fields[2], hours)
dayofmonth = field(fields[3], dom)
month = field(fields[4], months)
dayofweek = field(fields[5], dow)
)
if err != nil {
return nil, err
}
return &SpecSchedule{
Second: second,
Minute: minute,
Hour: hour,
Dom: dayofmonth,
Month: month,
Dow: dayofweek,
Location: loc,
}, nil
}
// normalizeFields takes a subset set of the time fields and returns the full set
// with defaults (zeroes) populated for unset fields.
//
// As part of performing this function, it also validates that the provided
// fields are compatible with the configured options.
func normalizeFields(fields []string, options ParseOption) ([]string, error) {
// Validate optionals & add their field to options
optionals := 0
if options&SecondOptional > 0 {
options |= Second
optionals++
}
if options&DowOptional > 0 {
options |= Dow
optionals++
}
if optionals > 1 {
return nil, errors.New("multiple optionals may not be configured")
}
// Figure out how many fields we need
max := 0
for _, place := range places {
if options&place > 0 {
max++
}
}
min := max - optionals
// Validate number of fields
if count := len(fields); count < min || count > max {
if min == max {
return nil, errors.New("incorrect number of fields")
}
return nil, errors.New("incorrect number of fields, expected " + strconv.Itoa(min) + "-" + strconv.Itoa(max))
}
// Populate the optional field if not provided
if min < max && len(fields) == min {
switch {
case options&DowOptional > 0:
fields = append(fields, defaults[5]) // TODO: improve access to default
case options&SecondOptional > 0:
fields = append([]string{defaults[0]}, fields...)
default:
return nil, errors.New("unexpected optional field")
}
}
// Populate all fields not part of options with their defaults
n := 0
expandedFields := make([]string, len(places))
copy(expandedFields, defaults)
for i, place := range places {
if options&place > 0 {
expandedFields[i] = fields[n]
n++
}
}
return expandedFields, nil
}
var standardParser = NewParser(
Minute | Hour | Dom | Month | Dow | Descriptor,
)
// ParseStandard returns a new crontab schedule representing the given
// standardSpec (https://en.wikipedia.org/wiki/Cron). It requires 5 entries
// representing: minute, hour, day of month, month and day of week, in that
// order. It returns a descriptive error if the spec is not valid.
//
// It accepts
// - Standard crontab specs, e.g. "* * * * ?"
// - Descriptors, e.g. "@midnight", "@every 1h30m"
func ParseStandard(standardSpec string) (Schedule, error) {
return standardParser.Parse(standardSpec)
}
// getField returns an Int with the bits set representing all of the times that
// the field represents or error parsing field value. A "field" is a comma-separated
// list of "ranges".
func getField(field string, r bounds) (uint64, error) {
var bits uint64
ranges := strings.FieldsFunc(field, func(r rune) bool { return r == ',' })
for _, expr := range ranges {
bit, err := getRange(expr, r)
if err != nil {
return bits, err
}
bits |= bit
}
return bits, nil
}
// getRange returns the bits indicated by the given expression:
//
// number | number "-" number [ "/" number ]
//
// or error parsing range.
func getRange(expr string, r bounds) (uint64, error) {
var (
start, end, step uint
rangeAndStep = strings.Split(expr, "/")
lowAndHigh = strings.Split(rangeAndStep[0], "-")
singleDigit = len(lowAndHigh) == 1
err error
)
var extra uint64
if lowAndHigh[0] == "*" || lowAndHigh[0] == "?" {
start = r.min
end = r.max
extra = starBit
} else {
start, err = parseIntOrName(lowAndHigh[0], r.names)
if err != nil {
return 0, err
}
switch len(lowAndHigh) {
case 1:
end = start
case 2:
end, err = parseIntOrName(lowAndHigh[1], r.names)
if err != nil {
return 0, err
}
default:
return 0, errors.New("too many hyphens: " + expr)
}
}
switch len(rangeAndStep) {
case 1:
step = 1
case 2:
step, err = mustParseInt(rangeAndStep[1])
if err != nil {
return 0, err
}
// Special handling: "N/step" means "N-max/step".
if singleDigit {
end = r.max
}
if step > 1 {
extra = 0
}
default:
return 0, errors.New("too many slashes: " + expr)
}
if start < r.min {
return 0, errors.New("beginning of range below minimum: " + expr)
}
if end > r.max {
return 0, errors.New("end of range above maximum: " + expr)
}
if start > end {
return 0, errors.New("beginning of range after end: " + expr)
}
if step == 0 {
return 0, errors.New("step cannot be zero: " + expr)
}
return getBits(start, end, step) | extra, nil
}
// parseIntOrName returns the (possibly-named) integer contained in expr.
func parseIntOrName(expr string, names map[string]uint) (uint, error) {
if names != nil {
if namedInt, ok := names[strings.ToLower(expr)]; ok {
return namedInt, nil
}
}
return mustParseInt(expr)
}
// mustParseInt parses the given expression as an int or returns an error.
func mustParseInt(expr string) (uint, error) {
num, err := strconv.Atoi(expr)
if err != nil {
return 0, errors.Wrap(err, "failed to parse number")
}
if num < 0 {
return 0, errors.New("number must be positive")
}
return uint(num), nil
}
// getBits sets all bits in the range [min, max], modulo the given step size.
func getBits(min, max, step uint) uint64 {
var bits uint64
// If step is 1, use shifts.
if step == 1 {
return ^(math.MaxUint64 << (max + 1)) & (math.MaxUint64 << min)
}
// Else, use a simple loop.
for i := min; i <= max; i += step {
bits |= 1 << i
}
return bits
}
// all returns all bits within the given bounds.
func all(r bounds) uint64 {
return getBits(r.min, r.max, 1) | starBit
}
// parseDescriptor returns a predefined schedule for the expression, or error if none matches.
func parseDescriptor(descriptor string, loc *time.Location) (Schedule, error) {
switch descriptor {
case "@yearly", "@annually":
return &SpecSchedule{
Second: 1 << seconds.min,
Minute: 1 << minutes.min,
Hour: 1 << hours.min,
Dom: 1 << dom.min,
Month: 1 << months.min,
Dow: all(dow),
Location: loc,
}, nil
case "@monthly":
return &SpecSchedule{
Second: 1 << seconds.min,
Minute: 1 << minutes.min,
Hour: 1 << hours.min,
Dom: 1 << dom.min,
Month: all(months),
Dow: all(dow),
Location: loc,
}, nil
case "@weekly":
return &SpecSchedule{
Second: 1 << seconds.min,
Minute: 1 << minutes.min,
Hour: 1 << hours.min,
Dom: all(dom),
Month: all(months),
Dow: 1 << dow.min,
Location: loc,
}, nil
case "@daily", "@midnight":
return &SpecSchedule{
Second: 1 << seconds.min,
Minute: 1 << minutes.min,
Hour: 1 << hours.min,
Dom: all(dom),
Month: all(months),
Dow: all(dow),
Location: loc,
}, nil
case "@hourly":
return &SpecSchedule{
Second: 1 << seconds.min,
Minute: 1 << minutes.min,
Hour: all(hours),
Dom: all(dom),
Month: all(months),
Dow: all(dow),
Location: loc,
}, nil
}
const every = "@every "
if strings.HasPrefix(descriptor, every) {
duration, err := time.ParseDuration(descriptor[len(every):])
if err != nil {
return nil, errors.Wrap(err, "failed to parse duration")
}
return Every(duration), nil
}
return nil, errors.New("unrecognized descriptor: " + descriptor)
}

384
plugin/cron/parser_test.go Normal file
View File

@@ -0,0 +1,384 @@
//nolint:all
package cron
import (
"reflect"
"strings"
"testing"
"time"
)
var secondParser = NewParser(Second | Minute | Hour | Dom | Month | DowOptional | Descriptor)
func TestRange(t *testing.T) {
zero := uint64(0)
ranges := []struct {
expr string
min, max uint
expected uint64
err string
}{
{"5", 0, 7, 1 << 5, ""},
{"0", 0, 7, 1 << 0, ""},
{"7", 0, 7, 1 << 7, ""},
{"5-5", 0, 7, 1 << 5, ""},
{"5-6", 0, 7, 1<<5 | 1<<6, ""},
{"5-7", 0, 7, 1<<5 | 1<<6 | 1<<7, ""},
{"5-6/2", 0, 7, 1 << 5, ""},
{"5-7/2", 0, 7, 1<<5 | 1<<7, ""},
{"5-7/1", 0, 7, 1<<5 | 1<<6 | 1<<7, ""},
{"*", 1, 3, 1<<1 | 1<<2 | 1<<3 | starBit, ""},
{"*/2", 1, 3, 1<<1 | 1<<3, ""},
{"5--5", 0, 0, zero, "too many hyphens"},
{"jan-x", 0, 0, zero, `failed to parse number: strconv.Atoi: parsing "jan": invalid syntax`},
{"2-x", 1, 5, zero, `failed to parse number: strconv.Atoi: parsing "x": invalid syntax`},
{"*/-12", 0, 0, zero, "number must be positive"},
{"*//2", 0, 0, zero, "too many slashes"},
{"1", 3, 5, zero, "below minimum"},
{"6", 3, 5, zero, "above maximum"},
{"5-3", 3, 5, zero, "beginning of range after end: 5-3"},
{"*/0", 0, 0, zero, "step cannot be zero: */0"},
}
for _, c := range ranges {
actual, err := getRange(c.expr, bounds{c.min, c.max, nil})
if len(c.err) != 0 && (err == nil || !strings.Contains(err.Error(), c.err)) {
t.Errorf("%s => expected %v, got %v", c.expr, c.err, err)
}
if len(c.err) == 0 && err != nil {
t.Errorf("%s => unexpected error %v", c.expr, err)
}
if actual != c.expected {
t.Errorf("%s => expected %d, got %d", c.expr, c.expected, actual)
}
}
}
func TestField(t *testing.T) {
fields := []struct {
expr string
min, max uint
expected uint64
}{
{"5", 1, 7, 1 << 5},
{"5,6", 1, 7, 1<<5 | 1<<6},
{"5,6,7", 1, 7, 1<<5 | 1<<6 | 1<<7},
{"1,5-7/2,3", 1, 7, 1<<1 | 1<<5 | 1<<7 | 1<<3},
}
for _, c := range fields {
actual, _ := getField(c.expr, bounds{c.min, c.max, nil})
if actual != c.expected {
t.Errorf("%s => expected %d, got %d", c.expr, c.expected, actual)
}
}
}
func TestAll(t *testing.T) {
allBits := []struct {
r bounds
expected uint64
}{
{minutes, 0xfffffffffffffff}, // 0-59: 60 ones
{hours, 0xffffff}, // 0-23: 24 ones
{dom, 0xfffffffe}, // 1-31: 31 ones, 1 zero
{months, 0x1ffe}, // 1-12: 12 ones, 1 zero
{dow, 0x7f}, // 0-6: 7 ones
}
for _, c := range allBits {
actual := all(c.r) // all() adds the starBit, so compensate for that..
if c.expected|starBit != actual {
t.Errorf("%d-%d/%d => expected %b, got %b",
c.r.min, c.r.max, 1, c.expected|starBit, actual)
}
}
}
func TestBits(t *testing.T) {
bits := []struct {
min, max, step uint
expected uint64
}{
{0, 0, 1, 0x1},
{1, 1, 1, 0x2},
{1, 5, 2, 0x2a}, // 101010
{1, 4, 2, 0xa}, // 1010
}
for _, c := range bits {
actual := getBits(c.min, c.max, c.step)
if c.expected != actual {
t.Errorf("%d-%d/%d => expected %b, got %b",
c.min, c.max, c.step, c.expected, actual)
}
}
}
func TestParseScheduleErrors(t *testing.T) {
var tests = []struct{ expr, err string }{
{"* 5 j * * *", `failed to parse number: strconv.Atoi: parsing "j": invalid syntax`},
{"@every Xm", "failed to parse duration"},
{"@unrecognized", "unrecognized descriptor"},
{"* * * *", "incorrect number of fields, expected 5-6"},
{"", "empty spec string"},
}
for _, c := range tests {
actual, err := secondParser.Parse(c.expr)
if err == nil || !strings.Contains(err.Error(), c.err) {
t.Errorf("%s => expected %v, got %v", c.expr, c.err, err)
}
if actual != nil {
t.Errorf("expected nil schedule on error, got %v", actual)
}
}
}
func TestParseSchedule(t *testing.T) {
tokyo, _ := time.LoadLocation("Asia/Tokyo")
entries := []struct {
parser Parser
expr string
expected Schedule
}{
{secondParser, "0 5 * * * *", every5min(time.Local)},
{standardParser, "5 * * * *", every5min(time.Local)},
{secondParser, "CRON_TZ=UTC 0 5 * * * *", every5min(time.UTC)},
{standardParser, "CRON_TZ=UTC 5 * * * *", every5min(time.UTC)},
{secondParser, "CRON_TZ=Asia/Tokyo 0 5 * * * *", every5min(tokyo)},
{secondParser, "@every 5m", ConstantDelaySchedule{5 * time.Minute}},
{secondParser, "@midnight", midnight(time.Local)},
{secondParser, "TZ=UTC @midnight", midnight(time.UTC)},
{secondParser, "TZ=Asia/Tokyo @midnight", midnight(tokyo)},
{secondParser, "@yearly", annual(time.Local)},
{secondParser, "@annually", annual(time.Local)},
{
parser: secondParser,
expr: "* 5 * * * *",
expected: &SpecSchedule{
Second: all(seconds),
Minute: 1 << 5,
Hour: all(hours),
Dom: all(dom),
Month: all(months),
Dow: all(dow),
Location: time.Local,
},
},
}
for _, c := range entries {
actual, err := c.parser.Parse(c.expr)
if err != nil {
t.Errorf("%s => unexpected error %v", c.expr, err)
}
if !reflect.DeepEqual(actual, c.expected) {
t.Errorf("%s => expected %b, got %b", c.expr, c.expected, actual)
}
}
}
func TestOptionalSecondSchedule(t *testing.T) {
parser := NewParser(SecondOptional | Minute | Hour | Dom | Month | Dow | Descriptor)
entries := []struct {
expr string
expected Schedule
}{
{"0 5 * * * *", every5min(time.Local)},
{"5 5 * * * *", every5min5s(time.Local)},
{"5 * * * *", every5min(time.Local)},
}
for _, c := range entries {
actual, err := parser.Parse(c.expr)
if err != nil {
t.Errorf("%s => unexpected error %v", c.expr, err)
}
if !reflect.DeepEqual(actual, c.expected) {
t.Errorf("%s => expected %b, got %b", c.expr, c.expected, actual)
}
}
}
func TestNormalizeFields(t *testing.T) {
tests := []struct {
name string
input []string
options ParseOption
expected []string
}{
{
"AllFields_NoOptional",
[]string{"0", "5", "*", "*", "*", "*"},
Second | Minute | Hour | Dom | Month | Dow | Descriptor,
[]string{"0", "5", "*", "*", "*", "*"},
},
{
"AllFields_SecondOptional_Provided",
[]string{"0", "5", "*", "*", "*", "*"},
SecondOptional | Minute | Hour | Dom | Month | Dow | Descriptor,
[]string{"0", "5", "*", "*", "*", "*"},
},
{
"AllFields_SecondOptional_NotProvided",
[]string{"5", "*", "*", "*", "*"},
SecondOptional | Minute | Hour | Dom | Month | Dow | Descriptor,
[]string{"0", "5", "*", "*", "*", "*"},
},
{
"SubsetFields_NoOptional",
[]string{"5", "15", "*"},
Hour | Dom | Month,
[]string{"0", "0", "5", "15", "*", "*"},
},
{
"SubsetFields_DowOptional_Provided",
[]string{"5", "15", "*", "4"},
Hour | Dom | Month | DowOptional,
[]string{"0", "0", "5", "15", "*", "4"},
},
{
"SubsetFields_DowOptional_NotProvided",
[]string{"5", "15", "*"},
Hour | Dom | Month | DowOptional,
[]string{"0", "0", "5", "15", "*", "*"},
},
{
"SubsetFields_SecondOptional_NotProvided",
[]string{"5", "15", "*"},
SecondOptional | Hour | Dom | Month,
[]string{"0", "0", "5", "15", "*", "*"},
},
}
for _, test := range tests {
t.Run(test.name, func(*testing.T) {
actual, err := normalizeFields(test.input, test.options)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if !reflect.DeepEqual(actual, test.expected) {
t.Errorf("expected %v, got %v", test.expected, actual)
}
})
}
}
func TestNormalizeFields_Errors(t *testing.T) {
tests := []struct {
name string
input []string
options ParseOption
err string
}{
{
"TwoOptionals",
[]string{"0", "5", "*", "*", "*", "*"},
SecondOptional | Minute | Hour | Dom | Month | DowOptional,
"",
},
{
"TooManyFields",
[]string{"0", "5", "*", "*"},
SecondOptional | Minute | Hour,
"",
},
{
"NoFields",
[]string{},
SecondOptional | Minute | Hour,
"",
},
{
"TooFewFields",
[]string{"*"},
SecondOptional | Minute | Hour,
"",
},
}
for _, test := range tests {
t.Run(test.name, func(*testing.T) {
actual, err := normalizeFields(test.input, test.options)
if err == nil {
t.Errorf("expected an error, got none. results: %v", actual)
}
if !strings.Contains(err.Error(), test.err) {
t.Errorf("expected error %q, got %q", test.err, err.Error())
}
})
}
}
func TestStandardSpecSchedule(t *testing.T) {
entries := []struct {
expr string
expected Schedule
err string
}{
{
expr: "5 * * * *",
expected: &SpecSchedule{1 << seconds.min, 1 << 5, all(hours), all(dom), all(months), all(dow), time.Local},
},
{
expr: "@every 5m",
expected: ConstantDelaySchedule{time.Duration(5) * time.Minute},
},
{
expr: "5 j * * *",
err: `failed to parse number: strconv.Atoi: parsing "j": invalid syntax`,
},
{
expr: "* * * *",
err: "incorrect number of fields",
},
}
for _, c := range entries {
actual, err := ParseStandard(c.expr)
if len(c.err) != 0 && (err == nil || !strings.Contains(err.Error(), c.err)) {
t.Errorf("%s => expected %v, got %v", c.expr, c.err, err)
}
if len(c.err) == 0 && err != nil {
t.Errorf("%s => unexpected error %v", c.expr, err)
}
if !reflect.DeepEqual(actual, c.expected) {
t.Errorf("%s => expected %b, got %b", c.expr, c.expected, actual)
}
}
}
func TestNoDescriptorParser(t *testing.T) {
parser := NewParser(Minute | Hour)
_, err := parser.Parse("@every 1m")
if err == nil {
t.Error("expected an error, got none")
}
}
func every5min(loc *time.Location) *SpecSchedule {
return &SpecSchedule{1 << 0, 1 << 5, all(hours), all(dom), all(months), all(dow), loc}
}
func every5min5s(loc *time.Location) *SpecSchedule {
return &SpecSchedule{1 << 5, 1 << 5, all(hours), all(dom), all(months), all(dow), loc}
}
func midnight(loc *time.Location) *SpecSchedule {
return &SpecSchedule{1, 1, 1, all(dom), all(months), all(dow), loc}
}
func annual(loc *time.Location) *SpecSchedule {
return &SpecSchedule{
Second: 1 << seconds.min,
Minute: 1 << minutes.min,
Hour: 1 << hours.min,
Dom: 1 << dom.min,
Month: 1 << months.min,
Dow: all(dow),
Location: loc,
}
}

188
plugin/cron/spec.go Normal file
View File

@@ -0,0 +1,188 @@
package cron
import "time"
// SpecSchedule specifies a duty cycle (to the second granularity), based on a
// traditional crontab specification. It is computed initially and stored as bit sets.
type SpecSchedule struct {
Second, Minute, Hour, Dom, Month, Dow uint64
// Override location for this schedule.
Location *time.Location
}
// bounds provides a range of acceptable values (plus a map of name to value).
type bounds struct {
min, max uint
names map[string]uint
}
// The bounds for each field.
var (
seconds = bounds{0, 59, nil}
minutes = bounds{0, 59, nil}
hours = bounds{0, 23, nil}
dom = bounds{1, 31, nil}
months = bounds{1, 12, map[string]uint{
"jan": 1,
"feb": 2,
"mar": 3,
"apr": 4,
"may": 5,
"jun": 6,
"jul": 7,
"aug": 8,
"sep": 9,
"oct": 10,
"nov": 11,
"dec": 12,
}}
dow = bounds{0, 6, map[string]uint{
"sun": 0,
"mon": 1,
"tue": 2,
"wed": 3,
"thu": 4,
"fri": 5,
"sat": 6,
}}
)
const (
// Set the top bit if a star was included in the expression.
starBit = 1 << 63
)
// Next returns the next time this schedule is activated, greater than the given
// time. If no time can be found to satisfy the schedule, return the zero time.
func (s *SpecSchedule) Next(t time.Time) time.Time {
// General approach
//
// For Month, Day, Hour, Minute, Second:
// Check if the time value matches. If yes, continue to the next field.
// If the field doesn't match the schedule, then increment the field until it matches.
// While incrementing the field, a wrap-around brings it back to the beginning
// of the field list (since it is necessary to re-verify previous field
// values)
// Convert the given time into the schedule's timezone, if one is specified.
// Save the original timezone so we can convert back after we find a time.
// Note that schedules without a time zone specified (time.Local) are treated
// as local to the time provided.
origLocation := t.Location()
loc := s.Location
if loc == time.Local {
loc = t.Location()
}
if s.Location != time.Local {
t = t.In(s.Location)
}
// Start at the earliest possible time (the upcoming second).
t = t.Add(1*time.Second - time.Duration(t.Nanosecond())*time.Nanosecond)
// This flag indicates whether a field has been incremented.
added := false
// If no time is found within five years, return zero.
yearLimit := t.Year() + 5
WRAP:
if t.Year() > yearLimit {
return time.Time{}
}
// Find the first applicable month.
// If it's this month, then do nothing.
for 1<<uint(t.Month())&s.Month == 0 {
// If we have to add a month, reset the other parts to 0.
if !added {
added = true
// Otherwise, set the date at the beginning (since the current time is irrelevant).
t = time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, loc)
}
t = t.AddDate(0, 1, 0)
// Wrapped around.
if t.Month() == time.January {
goto WRAP
}
}
// Now get a day in that month.
//
// NOTE: This causes issues for daylight savings regimes where midnight does
// not exist. For example: Sao Paulo has DST that transforms midnight on
// 11/3 into 1am. Handle that by noticing when the Hour ends up != 0.
for !dayMatches(s, t) {
if !added {
added = true
t = time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, loc)
}
t = t.AddDate(0, 0, 1)
// Notice if the hour is no longer midnight due to DST.
// Add an hour if it's 23, subtract an hour if it's 1.
if t.Hour() != 0 {
if t.Hour() > 12 {
t = t.Add(time.Duration(24-t.Hour()) * time.Hour)
} else {
t = t.Add(time.Duration(-t.Hour()) * time.Hour)
}
}
if t.Day() == 1 {
goto WRAP
}
}
for 1<<uint(t.Hour())&s.Hour == 0 {
if !added {
added = true
t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), 0, 0, 0, loc)
}
t = t.Add(1 * time.Hour)
if t.Hour() == 0 {
goto WRAP
}
}
for 1<<uint(t.Minute())&s.Minute == 0 {
if !added {
added = true
t = t.Truncate(time.Minute)
}
t = t.Add(1 * time.Minute)
if t.Minute() == 0 {
goto WRAP
}
}
for 1<<uint(t.Second())&s.Second == 0 {
if !added {
added = true
t = t.Truncate(time.Second)
}
t = t.Add(1 * time.Second)
if t.Second() == 0 {
goto WRAP
}
}
return t.In(origLocation)
}
// dayMatches returns true if the schedule's day-of-week and day-of-month
// restrictions are satisfied by the given time.
func dayMatches(s *SpecSchedule, t time.Time) bool {
var (
domMatch = 1<<uint(t.Day())&s.Dom > 0
dowMatch = 1<<uint(t.Weekday())&s.Dow > 0
)
if s.Dom&starBit > 0 || s.Dow&starBit > 0 {
return domMatch && dowMatch
}
return domMatch || dowMatch
}

301
plugin/cron/spec_test.go Normal file
View File

@@ -0,0 +1,301 @@
//nolint:all
package cron
import (
"strings"
"testing"
"time"
)
func TestActivation(t *testing.T) {
tests := []struct {
time, spec string
expected bool
}{
// Every fifteen minutes.
{"Mon Jul 9 15:00 2012", "0/15 * * * *", true},
{"Mon Jul 9 15:45 2012", "0/15 * * * *", true},
{"Mon Jul 9 15:40 2012", "0/15 * * * *", false},
// Every fifteen minutes, starting at 5 minutes.
{"Mon Jul 9 15:05 2012", "5/15 * * * *", true},
{"Mon Jul 9 15:20 2012", "5/15 * * * *", true},
{"Mon Jul 9 15:50 2012", "5/15 * * * *", true},
// Named months
{"Sun Jul 15 15:00 2012", "0/15 * * Jul *", true},
{"Sun Jul 15 15:00 2012", "0/15 * * Jun *", false},
// Everything set.
{"Sun Jul 15 08:30 2012", "30 08 ? Jul Sun", true},
{"Sun Jul 15 08:30 2012", "30 08 15 Jul ?", true},
{"Mon Jul 16 08:30 2012", "30 08 ? Jul Sun", false},
{"Mon Jul 16 08:30 2012", "30 08 15 Jul ?", false},
// Predefined schedules
{"Mon Jul 9 15:00 2012", "@hourly", true},
{"Mon Jul 9 15:04 2012", "@hourly", false},
{"Mon Jul 9 15:00 2012", "@daily", false},
{"Mon Jul 9 00:00 2012", "@daily", true},
{"Mon Jul 9 00:00 2012", "@weekly", false},
{"Sun Jul 8 00:00 2012", "@weekly", true},
{"Sun Jul 8 01:00 2012", "@weekly", false},
{"Sun Jul 8 00:00 2012", "@monthly", false},
{"Sun Jul 1 00:00 2012", "@monthly", true},
// Test interaction of DOW and DOM.
// If both are restricted, then only one needs to match.
{"Sun Jul 15 00:00 2012", "* * 1,15 * Sun", true},
{"Fri Jun 15 00:00 2012", "* * 1,15 * Sun", true},
{"Wed Aug 1 00:00 2012", "* * 1,15 * Sun", true},
{"Sun Jul 15 00:00 2012", "* * */10 * Sun", true}, // verifies #70
// However, if one has a star, then both need to match.
{"Sun Jul 15 00:00 2012", "* * * * Mon", false},
{"Mon Jul 9 00:00 2012", "* * 1,15 * *", false},
{"Sun Jul 15 00:00 2012", "* * 1,15 * *", true},
{"Sun Jul 15 00:00 2012", "* * */2 * Sun", true},
}
for _, test := range tests {
sched, err := ParseStandard(test.spec)
if err != nil {
t.Error(err)
continue
}
actual := sched.Next(getTime(test.time).Add(-1 * time.Second))
expected := getTime(test.time)
if test.expected && expected != actual || !test.expected && expected == actual {
t.Errorf("Fail evaluating %s on %s: (expected) %s != %s (actual)",
test.spec, test.time, expected, actual)
}
}
}
func TestNext(t *testing.T) {
runs := []struct {
time, spec string
expected string
}{
// Simple cases
{"Mon Jul 9 14:45 2012", "0 0/15 * * * *", "Mon Jul 9 15:00 2012"},
{"Mon Jul 9 14:59 2012", "0 0/15 * * * *", "Mon Jul 9 15:00 2012"},
{"Mon Jul 9 14:59:59 2012", "0 0/15 * * * *", "Mon Jul 9 15:00 2012"},
// Wrap around hours
{"Mon Jul 9 15:45 2012", "0 20-35/15 * * * *", "Mon Jul 9 16:20 2012"},
// Wrap around days
{"Mon Jul 9 23:46 2012", "0 */15 * * * *", "Tue Jul 10 00:00 2012"},
{"Mon Jul 9 23:45 2012", "0 20-35/15 * * * *", "Tue Jul 10 00:20 2012"},
{"Mon Jul 9 23:35:51 2012", "15/35 20-35/15 * * * *", "Tue Jul 10 00:20:15 2012"},
{"Mon Jul 9 23:35:51 2012", "15/35 20-35/15 1/2 * * *", "Tue Jul 10 01:20:15 2012"},
{"Mon Jul 9 23:35:51 2012", "15/35 20-35/15 10-12 * * *", "Tue Jul 10 10:20:15 2012"},
{"Mon Jul 9 23:35:51 2012", "15/35 20-35/15 1/2 */2 * *", "Thu Jul 11 01:20:15 2012"},
{"Mon Jul 9 23:35:51 2012", "15/35 20-35/15 * 9-20 * *", "Wed Jul 10 00:20:15 2012"},
{"Mon Jul 9 23:35:51 2012", "15/35 20-35/15 * 9-20 Jul *", "Wed Jul 10 00:20:15 2012"},
// Wrap around months
{"Mon Jul 9 23:35 2012", "0 0 0 9 Apr-Oct ?", "Thu Aug 9 00:00 2012"},
{"Mon Jul 9 23:35 2012", "0 0 0 */5 Apr,Aug,Oct Mon", "Tue Aug 1 00:00 2012"},
{"Mon Jul 9 23:35 2012", "0 0 0 */5 Oct Mon", "Mon Oct 1 00:00 2012"},
// Wrap around years
{"Mon Jul 9 23:35 2012", "0 0 0 * Feb Mon", "Mon Feb 4 00:00 2013"},
{"Mon Jul 9 23:35 2012", "0 0 0 * Feb Mon/2", "Fri Feb 1 00:00 2013"},
// Wrap around minute, hour, day, month, and year
{"Mon Dec 31 23:59:45 2012", "0 * * * * *", "Tue Jan 1 00:00:00 2013"},
// Leap year
{"Mon Jul 9 23:35 2012", "0 0 0 29 Feb ?", "Mon Feb 29 00:00 2016"},
// Daylight savings time 2am EST (-5) -> 3am EDT (-4)
{"2012-03-11T00:00:00-0500", "TZ=America/New_York 0 30 2 11 Mar ?", "2013-03-11T02:30:00-0400"},
// hourly job
{"2012-03-11T00:00:00-0500", "TZ=America/New_York 0 0 * * * ?", "2012-03-11T01:00:00-0500"},
{"2012-03-11T01:00:00-0500", "TZ=America/New_York 0 0 * * * ?", "2012-03-11T03:00:00-0400"},
{"2012-03-11T03:00:00-0400", "TZ=America/New_York 0 0 * * * ?", "2012-03-11T04:00:00-0400"},
{"2012-03-11T04:00:00-0400", "TZ=America/New_York 0 0 * * * ?", "2012-03-11T05:00:00-0400"},
// hourly job using CRON_TZ
{"2012-03-11T00:00:00-0500", "CRON_TZ=America/New_York 0 0 * * * ?", "2012-03-11T01:00:00-0500"},
{"2012-03-11T01:00:00-0500", "CRON_TZ=America/New_York 0 0 * * * ?", "2012-03-11T03:00:00-0400"},
{"2012-03-11T03:00:00-0400", "CRON_TZ=America/New_York 0 0 * * * ?", "2012-03-11T04:00:00-0400"},
{"2012-03-11T04:00:00-0400", "CRON_TZ=America/New_York 0 0 * * * ?", "2012-03-11T05:00:00-0400"},
// 1am nightly job
{"2012-03-11T00:00:00-0500", "TZ=America/New_York 0 0 1 * * ?", "2012-03-11T01:00:00-0500"},
{"2012-03-11T01:00:00-0500", "TZ=America/New_York 0 0 1 * * ?", "2012-03-12T01:00:00-0400"},
// 2am nightly job (skipped)
{"2012-03-11T00:00:00-0500", "TZ=America/New_York 0 0 2 * * ?", "2012-03-12T02:00:00-0400"},
// Daylight savings time 2am EDT (-4) => 1am EST (-5)
{"2012-11-04T00:00:00-0400", "TZ=America/New_York 0 30 2 04 Nov ?", "2012-11-04T02:30:00-0500"},
{"2012-11-04T01:45:00-0400", "TZ=America/New_York 0 30 1 04 Nov ?", "2012-11-04T01:30:00-0500"},
// hourly job
{"2012-11-04T00:00:00-0400", "TZ=America/New_York 0 0 * * * ?", "2012-11-04T01:00:00-0400"},
{"2012-11-04T01:00:00-0400", "TZ=America/New_York 0 0 * * * ?", "2012-11-04T01:00:00-0500"},
{"2012-11-04T01:00:00-0500", "TZ=America/New_York 0 0 * * * ?", "2012-11-04T02:00:00-0500"},
// 1am nightly job (runs twice)
{"2012-11-04T00:00:00-0400", "TZ=America/New_York 0 0 1 * * ?", "2012-11-04T01:00:00-0400"},
{"2012-11-04T01:00:00-0400", "TZ=America/New_York 0 0 1 * * ?", "2012-11-04T01:00:00-0500"},
{"2012-11-04T01:00:00-0500", "TZ=America/New_York 0 0 1 * * ?", "2012-11-05T01:00:00-0500"},
// 2am nightly job
{"2012-11-04T00:00:00-0400", "TZ=America/New_York 0 0 2 * * ?", "2012-11-04T02:00:00-0500"},
{"2012-11-04T02:00:00-0500", "TZ=America/New_York 0 0 2 * * ?", "2012-11-05T02:00:00-0500"},
// 3am nightly job
{"2012-11-04T00:00:00-0400", "TZ=America/New_York 0 0 3 * * ?", "2012-11-04T03:00:00-0500"},
{"2012-11-04T03:00:00-0500", "TZ=America/New_York 0 0 3 * * ?", "2012-11-05T03:00:00-0500"},
// hourly job
{"TZ=America/New_York 2012-11-04T00:00:00-0400", "0 0 * * * ?", "2012-11-04T01:00:00-0400"},
{"TZ=America/New_York 2012-11-04T01:00:00-0400", "0 0 * * * ?", "2012-11-04T01:00:00-0500"},
{"TZ=America/New_York 2012-11-04T01:00:00-0500", "0 0 * * * ?", "2012-11-04T02:00:00-0500"},
// 1am nightly job (runs twice)
{"TZ=America/New_York 2012-11-04T00:00:00-0400", "0 0 1 * * ?", "2012-11-04T01:00:00-0400"},
{"TZ=America/New_York 2012-11-04T01:00:00-0400", "0 0 1 * * ?", "2012-11-04T01:00:00-0500"},
{"TZ=America/New_York 2012-11-04T01:00:00-0500", "0 0 1 * * ?", "2012-11-05T01:00:00-0500"},
// 2am nightly job
{"TZ=America/New_York 2012-11-04T00:00:00-0400", "0 0 2 * * ?", "2012-11-04T02:00:00-0500"},
{"TZ=America/New_York 2012-11-04T02:00:00-0500", "0 0 2 * * ?", "2012-11-05T02:00:00-0500"},
// 3am nightly job
{"TZ=America/New_York 2012-11-04T00:00:00-0400", "0 0 3 * * ?", "2012-11-04T03:00:00-0500"},
{"TZ=America/New_York 2012-11-04T03:00:00-0500", "0 0 3 * * ?", "2012-11-05T03:00:00-0500"},
// Unsatisfiable
{"Mon Jul 9 23:35 2012", "0 0 0 30 Feb ?", ""},
{"Mon Jul 9 23:35 2012", "0 0 0 31 Apr ?", ""},
// Monthly job
{"TZ=America/New_York 2012-11-04T00:00:00-0400", "0 0 3 3 * ?", "2012-12-03T03:00:00-0500"},
// Test the scenario of DST resulting in midnight not being a valid time.
// https://github.com/robfig/cron/issues/157
{"2018-10-17T05:00:00-0400", "TZ=America/Sao_Paulo 0 0 9 10 * ?", "2018-11-10T06:00:00-0500"},
{"2018-02-14T05:00:00-0500", "TZ=America/Sao_Paulo 0 0 9 22 * ?", "2018-02-22T07:00:00-0500"},
}
for _, c := range runs {
sched, err := secondParser.Parse(c.spec)
if err != nil {
t.Error(err)
continue
}
actual := sched.Next(getTime(c.time))
expected := getTime(c.expected)
if !actual.Equal(expected) {
t.Errorf("%s, \"%s\": (expected) %v != %v (actual)", c.time, c.spec, expected, actual)
}
}
}
func TestErrors(t *testing.T) {
invalidSpecs := []string{
"xyz",
"60 0 * * *",
"0 60 * * *",
"0 0 * * XYZ",
}
for _, spec := range invalidSpecs {
_, err := ParseStandard(spec)
if err == nil {
t.Error("expected an error parsing: ", spec)
}
}
}
func getTime(value string) time.Time {
if value == "" {
return time.Time{}
}
var location = time.Local
if strings.HasPrefix(value, "TZ=") {
parts := strings.Fields(value)
loc, err := time.LoadLocation(parts[0][len("TZ="):])
if err != nil {
panic("could not parse location:" + err.Error())
}
location = loc
value = parts[1]
}
var layouts = []string{
"Mon Jan 2 15:04 2006",
"Mon Jan 2 15:04:05 2006",
}
for _, layout := range layouts {
if t, err := time.ParseInLocation(layout, value, location); err == nil {
return t
}
}
if t, err := time.ParseInLocation("2006-01-02T15:04:05-0700", value, location); err == nil {
return t
}
panic("could not parse time value " + value)
}
func TestNextWithTz(t *testing.T) {
runs := []struct {
time, spec string
expected string
}{
// Failing tests
{"2016-01-03T13:09:03+0530", "14 14 * * *", "2016-01-03T14:14:00+0530"},
{"2016-01-03T04:09:03+0530", "14 14 * * ?", "2016-01-03T14:14:00+0530"},
// Passing tests
{"2016-01-03T14:09:03+0530", "14 14 * * *", "2016-01-03T14:14:00+0530"},
{"2016-01-03T14:00:00+0530", "14 14 * * ?", "2016-01-03T14:14:00+0530"},
}
for _, c := range runs {
sched, err := ParseStandard(c.spec)
if err != nil {
t.Error(err)
continue
}
actual := sched.Next(getTimeTZ(c.time))
expected := getTimeTZ(c.expected)
if !actual.Equal(expected) {
t.Errorf("%s, \"%s\": (expected) %v != %v (actual)", c.time, c.spec, expected, actual)
}
}
}
func getTimeTZ(value string) time.Time {
if value == "" {
return time.Time{}
}
t, err := time.Parse("Mon Jan 2 15:04 2006", value)
if err != nil {
t, err = time.Parse("Mon Jan 2 15:04:05 2006", value)
if err != nil {
t, err = time.Parse("2006-01-02T15:04:05-0700", value)
if err != nil {
panic(err)
}
}
}
return t
}
// https://github.com/robfig/cron/issues/144
func TestSlash0NoHang(t *testing.T) {
schedule := "TZ=America/New_York 15/0 * * * *"
_, err := ParseStandard(schedule)
if err == nil {
t.Error("expected an error on 0 increment")
}
}

View File

@@ -0,0 +1,448 @@
package filter
import (
"fmt"
"slices"
"strings"
"github.com/pkg/errors"
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// CommonSQLConverter handles the common CEL to SQL conversion logic.
type CommonSQLConverter struct {
dialect SQLDialect
paramIndex int
}
// NewCommonSQLConverter creates a new converter with the specified dialect.
func NewCommonSQLConverter(dialect SQLDialect) *CommonSQLConverter {
return &CommonSQLConverter{
dialect: dialect,
paramIndex: 1,
}
}
// ConvertExprToSQL converts a CEL expression to SQL using the configured dialect.
func (c *CommonSQLConverter) ConvertExprToSQL(ctx *ConvertContext, expr *exprv1.Expr) error {
if v, ok := expr.ExprKind.(*exprv1.Expr_CallExpr); ok {
switch v.CallExpr.Function {
case "_||_", "_&&_":
return c.handleLogicalOperator(ctx, v.CallExpr)
case "!_":
return c.handleNotOperator(ctx, v.CallExpr)
case "_==_", "_!=_", "_<_", "_>_", "_<=_", "_>=_":
return c.handleComparisonOperator(ctx, v.CallExpr)
case "@in":
return c.handleInOperator(ctx, v.CallExpr)
case "contains":
return c.handleContainsOperator(ctx, v.CallExpr)
}
} else if v, ok := expr.ExprKind.(*exprv1.Expr_IdentExpr); ok {
return c.handleIdentifier(ctx, v.IdentExpr)
}
return nil
}
func (c *CommonSQLConverter) handleLogicalOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error {
if len(callExpr.Args) != 2 {
return errors.Errorf("invalid number of arguments for %s", callExpr.Function)
}
if _, err := ctx.Buffer.WriteString("("); err != nil {
return err
}
if err := c.ConvertExprToSQL(ctx, callExpr.Args[0]); err != nil {
return err
}
operator := "AND"
if callExpr.Function == "_||_" {
operator = "OR"
}
if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil {
return err
}
if err := c.ConvertExprToSQL(ctx, callExpr.Args[1]); err != nil {
return err
}
if _, err := ctx.Buffer.WriteString(")"); err != nil {
return err
}
return nil
}
func (c *CommonSQLConverter) handleNotOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error {
if len(callExpr.Args) != 1 {
return errors.Errorf("invalid number of arguments for %s", callExpr.Function)
}
if _, err := ctx.Buffer.WriteString("NOT ("); err != nil {
return err
}
if err := c.ConvertExprToSQL(ctx, callExpr.Args[0]); err != nil {
return err
}
if _, err := ctx.Buffer.WriteString(")"); err != nil {
return err
}
return nil
}
func (c *CommonSQLConverter) handleComparisonOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error {
if len(callExpr.Args) != 2 {
return errors.Errorf("invalid number of arguments for %s", callExpr.Function)
}
// Check if the left side is a function call like size(tags)
if leftCallExpr, ok := callExpr.Args[0].ExprKind.(*exprv1.Expr_CallExpr); ok {
if leftCallExpr.CallExpr.Function == "size" {
return c.handleSizeComparison(ctx, callExpr, leftCallExpr.CallExpr)
}
}
identifier, err := GetIdentExprName(callExpr.Args[0])
if err != nil {
return err
}
if !slices.Contains([]string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "has_task_list"}, identifier) {
return errors.Errorf("invalid identifier for %s", callExpr.Function)
}
value, err := GetExprValue(callExpr.Args[1])
if err != nil {
return err
}
operator := c.getComparisonOperator(callExpr.Function)
switch identifier {
case "created_ts", "updated_ts":
return c.handleTimestampComparison(ctx, identifier, operator, value)
case "visibility", "content":
return c.handleStringComparison(ctx, identifier, operator, value)
case "creator_id":
return c.handleIntComparison(ctx, identifier, operator, value)
case "has_task_list":
return c.handleBooleanComparison(ctx, identifier, operator, value)
}
return nil
}
func (c *CommonSQLConverter) handleSizeComparison(ctx *ConvertContext, callExpr *exprv1.Expr_Call, sizeCall *exprv1.Expr_Call) error {
if len(sizeCall.Args) != 1 {
return errors.New("size function requires exactly one argument")
}
identifier, err := GetIdentExprName(sizeCall.Args[0])
if err != nil {
return err
}
if identifier != "tags" {
return errors.Errorf("size function only supports 'tags' identifier, got: %s", identifier)
}
value, err := GetExprValue(callExpr.Args[1])
if err != nil {
return err
}
valueInt, ok := value.(int64)
if !ok {
return errors.New("size comparison value must be an integer")
}
operator := c.getComparisonOperator(callExpr.Function)
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s",
c.dialect.GetJSONArrayLength("$.tags"),
operator,
c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
return err
}
ctx.Args = append(ctx.Args, valueInt)
c.paramIndex++
return nil
}
func (c *CommonSQLConverter) handleInOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error {
if len(callExpr.Args) != 2 {
return errors.Errorf("invalid number of arguments for %s", callExpr.Function)
}
// Check if this is "element in collection" syntax
if identifier, err := GetIdentExprName(callExpr.Args[1]); err == nil {
if identifier == "tags" {
return c.handleElementInTags(ctx, callExpr.Args[0])
}
return errors.Errorf("invalid collection identifier for %s: %s", callExpr.Function, identifier)
}
// Original logic for "identifier in [list]" syntax
identifier, err := GetIdentExprName(callExpr.Args[0])
if err != nil {
return err
}
if !slices.Contains([]string{"tag", "visibility"}, identifier) {
return errors.Errorf("invalid identifier for %s", callExpr.Function)
}
values := []any{}
for _, element := range callExpr.Args[1].GetListExpr().Elements {
value, err := GetConstValue(element)
if err != nil {
return err
}
values = append(values, value)
}
if identifier == "tag" {
return c.handleTagInList(ctx, values)
} else if identifier == "visibility" {
return c.handleVisibilityInList(ctx, values)
}
return nil
}
func (c *CommonSQLConverter) handleElementInTags(ctx *ConvertContext, elementExpr *exprv1.Expr) error {
element, err := GetConstValue(elementExpr)
if err != nil {
return errors.Errorf("first argument must be a constant value for 'element in tags': %v", err)
}
// Use dialect-specific JSON contains logic
sqlExpr := c.dialect.GetJSONContains("$.tags", "element")
if _, err := ctx.Buffer.WriteString(sqlExpr); err != nil {
return err
}
// For SQLite, we need a different approach since it uses LIKE
if _, ok := c.dialect.(*SQLiteDialect); ok {
ctx.Args = append(ctx.Args, fmt.Sprintf(`%%"%s"%%`, element))
} else {
ctx.Args = append(ctx.Args, element)
}
c.paramIndex++
return nil
}
func (c *CommonSQLConverter) handleTagInList(ctx *ConvertContext, values []any) error {
subconditions := []string{}
args := []any{}
for _, v := range values {
if _, ok := c.dialect.(*SQLiteDialect); ok {
subconditions = append(subconditions, c.dialect.GetJSONLike("$.tags", "pattern"))
args = append(args, fmt.Sprintf(`%%"%s"%%`, v))
} else {
subconditions = append(subconditions, c.dialect.GetJSONContains("$.tags", "element"))
args = append(args, v)
}
c.paramIndex++
}
if len(subconditions) == 1 {
if _, err := ctx.Buffer.WriteString(subconditions[0]); err != nil {
return err
}
} else {
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s)", strings.Join(subconditions, " OR "))); err != nil {
return err
}
}
ctx.Args = append(ctx.Args, args...)
return nil
}
func (c *CommonSQLConverter) handleVisibilityInList(ctx *ConvertContext, values []any) error {
placeholders := []string{}
for range values {
placeholders = append(placeholders, c.dialect.GetParameterPlaceholder(c.paramIndex))
c.paramIndex++
}
tablePrefix := c.dialect.GetTablePrefix()
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`visibility` IN (%s)", tablePrefix, strings.Join(placeholders, ","))); err != nil {
return err
}
ctx.Args = append(ctx.Args, values...)
return nil
}
func (c *CommonSQLConverter) handleContainsOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error {
if len(callExpr.Args) != 1 {
return errors.Errorf("invalid number of arguments for %s", callExpr.Function)
}
identifier, err := GetIdentExprName(callExpr.Target)
if err != nil {
return err
}
if identifier != "content" {
return errors.Errorf("invalid identifier for %s", callExpr.Function)
}
arg, err := GetConstValue(callExpr.Args[0])
if err != nil {
return err
}
tablePrefix := c.dialect.GetTablePrefix()
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`content` LIKE %s", tablePrefix, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
return err
}
ctx.Args = append(ctx.Args, fmt.Sprintf("%%%s%%", arg))
c.paramIndex++
return nil
}
func (c *CommonSQLConverter) handleIdentifier(ctx *ConvertContext, identExpr *exprv1.Expr_Ident) error {
identifier := identExpr.GetName()
if !slices.Contains([]string{"pinned", "has_task_list"}, identifier) {
return errors.Errorf("invalid identifier %s", identifier)
}
if identifier == "pinned" {
tablePrefix := c.dialect.GetTablePrefix()
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`pinned` IS TRUE", tablePrefix)); err != nil {
return err
}
} else if identifier == "has_task_list" {
if _, err := ctx.Buffer.WriteString(c.dialect.GetBooleanCheck("$.property.hasTaskList")); err != nil {
return err
}
}
return nil
}
func (c *CommonSQLConverter) handleTimestampComparison(ctx *ConvertContext, field, operator string, value interface{}) error {
valueInt, ok := value.(int64)
if !ok {
return errors.New("invalid integer timestamp value")
}
timestampField := c.dialect.GetTimestampComparison(field)
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s", timestampField, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
return err
}
ctx.Args = append(ctx.Args, valueInt)
c.paramIndex++
return nil
}
func (c *CommonSQLConverter) handleStringComparison(ctx *ConvertContext, field, operator string, value interface{}) error {
if operator != "=" && operator != "!=" {
return errors.Errorf("invalid operator for %s", field)
}
valueStr, ok := value.(string)
if !ok {
return errors.New("invalid string value")
}
tablePrefix := c.dialect.GetTablePrefix()
fieldName := field
if field == "visibility" {
fieldName = "`visibility`"
} else if field == "content" {
fieldName = "`content`"
}
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.%s %s %s", tablePrefix, fieldName, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
return err
}
ctx.Args = append(ctx.Args, valueStr)
c.paramIndex++
return nil
}
func (c *CommonSQLConverter) handleIntComparison(ctx *ConvertContext, field, operator string, value interface{}) error {
if operator != "=" && operator != "!=" {
return errors.Errorf("invalid operator for %s", field)
}
valueInt, ok := value.(int64)
if !ok {
return errors.New("invalid int value")
}
tablePrefix := c.dialect.GetTablePrefix()
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`%s` %s %s", tablePrefix, field, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
return err
}
ctx.Args = append(ctx.Args, valueInt)
c.paramIndex++
return nil
}
func (c *CommonSQLConverter) handleBooleanComparison(ctx *ConvertContext, field, operator string, value interface{}) error {
if operator != "=" && operator != "!=" {
return errors.Errorf("invalid operator for %s", field)
}
valueBool, ok := value.(bool)
if !ok {
return errors.New("invalid boolean value for has_task_list")
}
sqlExpr := c.dialect.GetBooleanComparison("$.property.hasTaskList", valueBool)
if _, err := ctx.Buffer.WriteString(sqlExpr); err != nil {
return err
}
// For dialects that need parameters (PostgreSQL)
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
ctx.Args = append(ctx.Args, valueBool)
c.paramIndex++
}
return nil
}
func (*CommonSQLConverter) getComparisonOperator(function string) string {
switch function {
case "_==_":
return "="
case "_!=_":
return "!="
case "_<_":
return "<"
case "_>_":
return ">"
case "_<=_":
return "<="
case "_>=_":
return ">="
default:
return "="
}
}

View File

@@ -0,0 +1,20 @@
package filter
import (
"strings"
)
type ConvertContext struct {
Buffer strings.Builder
Args []any
// The offset of the next argument in the condition string.
// Mainly using for PostgreSQL.
ArgsOffset int
}
func NewConvertContext() *ConvertContext {
return &ConvertContext{
Buffer: strings.Builder{},
Args: []any{},
}
}

212
plugin/filter/dialect.go Normal file
View File

@@ -0,0 +1,212 @@
package filter
import (
"fmt"
"strings"
)
// SQLDialect defines database-specific SQL generation methods.
type SQLDialect interface {
// Basic field access
GetTablePrefix() string
GetParameterPlaceholder(index int) string
// JSON operations
GetJSONExtract(path string) string
GetJSONArrayLength(path string) string
GetJSONContains(path, element string) string
GetJSONLike(path, pattern string) string
// Boolean operations
GetBooleanValue(value bool) interface{}
GetBooleanComparison(path string, value bool) string
GetBooleanCheck(path string) string
// Timestamp operations
GetTimestampComparison(field string) string
GetCurrentTimestamp() string
}
// DatabaseType represents the type of database.
type DatabaseType string
const (
SQLite DatabaseType = "sqlite"
MySQL DatabaseType = "mysql"
PostgreSQL DatabaseType = "postgres"
)
// GetDialect returns the appropriate dialect for the database type.
func GetDialect(dbType DatabaseType) SQLDialect {
switch dbType {
case SQLite:
return &SQLiteDialect{}
case MySQL:
return &MySQLDialect{}
case PostgreSQL:
return &PostgreSQLDialect{}
default:
return &SQLiteDialect{} // default fallback
}
}
// SQLiteDialect implements SQLDialect for SQLite.
type SQLiteDialect struct{}
func (*SQLiteDialect) GetTablePrefix() string {
return "`memo`"
}
func (*SQLiteDialect) GetParameterPlaceholder(_ int) string {
return "?"
}
func (d *SQLiteDialect) GetJSONExtract(path string) string {
return fmt.Sprintf("JSON_EXTRACT(%s.`payload`, '%s')", d.GetTablePrefix(), path)
}
func (d *SQLiteDialect) GetJSONArrayLength(path string) string {
return fmt.Sprintf("JSON_ARRAY_LENGTH(COALESCE(%s, JSON_ARRAY()))", d.GetJSONExtract(path))
}
func (d *SQLiteDialect) GetJSONContains(path, _ string) string {
return fmt.Sprintf("%s LIKE ?", d.GetJSONExtract(path))
}
func (d *SQLiteDialect) GetJSONLike(path, _ string) string {
return fmt.Sprintf("%s LIKE ?", d.GetJSONExtract(path))
}
func (*SQLiteDialect) GetBooleanValue(value bool) interface{} {
if value {
return 1
}
return 0
}
func (d *SQLiteDialect) GetBooleanComparison(path string, value bool) string {
return fmt.Sprintf("%s = %d", d.GetJSONExtract(path), d.GetBooleanValue(value))
}
func (d *SQLiteDialect) GetBooleanCheck(path string) string {
return fmt.Sprintf("%s IS TRUE", d.GetJSONExtract(path))
}
func (d *SQLiteDialect) GetTimestampComparison(field string) string {
return fmt.Sprintf("%s.`%s`", d.GetTablePrefix(), field)
}
func (*SQLiteDialect) GetCurrentTimestamp() string {
return "strftime('%s', 'now')"
}
// MySQLDialect implements SQLDialect for MySQL.
type MySQLDialect struct{}
func (*MySQLDialect) GetTablePrefix() string {
return "`memo`"
}
func (*MySQLDialect) GetParameterPlaceholder(_ int) string {
return "?"
}
func (d *MySQLDialect) GetJSONExtract(path string) string {
return fmt.Sprintf("JSON_EXTRACT(%s.`payload`, '%s')", d.GetTablePrefix(), path)
}
func (d *MySQLDialect) GetJSONArrayLength(path string) string {
return fmt.Sprintf("JSON_LENGTH(COALESCE(%s, JSON_ARRAY()))", d.GetJSONExtract(path))
}
func (d *MySQLDialect) GetJSONContains(path, _ string) string {
return fmt.Sprintf("JSON_CONTAINS(%s, ?)", d.GetJSONExtract(path))
}
func (d *MySQLDialect) GetJSONLike(path, _ string) string {
return fmt.Sprintf("%s LIKE ?", d.GetJSONExtract(path))
}
func (*MySQLDialect) GetBooleanValue(value bool) interface{} {
return value
}
func (d *MySQLDialect) GetBooleanComparison(path string, value bool) string {
boolStr := "false"
if value {
boolStr = "true"
}
return fmt.Sprintf("%s = CAST('%s' AS JSON)", d.GetJSONExtract(path), boolStr)
}
func (d *MySQLDialect) GetBooleanCheck(path string) string {
return fmt.Sprintf("%s = CAST('true' AS JSON)", d.GetJSONExtract(path))
}
func (d *MySQLDialect) GetTimestampComparison(field string) string {
return fmt.Sprintf("UNIX_TIMESTAMP(%s.`%s`)", d.GetTablePrefix(), field)
}
func (*MySQLDialect) GetCurrentTimestamp() string {
return "UNIX_TIMESTAMP()"
}
// PostgreSQLDialect implements SQLDialect for PostgreSQL.
type PostgreSQLDialect struct{}
func (*PostgreSQLDialect) GetTablePrefix() string {
return "memo"
}
func (*PostgreSQLDialect) GetParameterPlaceholder(index int) string {
return fmt.Sprintf("$%d", index)
}
func (d *PostgreSQLDialect) GetJSONExtract(path string) string {
// Convert $.property.hasTaskList to payload->'property'->>'hasTaskList'
parts := strings.Split(strings.TrimPrefix(path, "$."), ".")
result := fmt.Sprintf("%s.payload", d.GetTablePrefix())
for i, part := range parts {
if i == len(parts)-1 {
result += fmt.Sprintf("->>'%s'", part)
} else {
result += fmt.Sprintf("->'%s'", part)
}
}
return result
}
func (d *PostgreSQLDialect) GetJSONArrayLength(path string) string {
jsonPath := strings.Replace(path, "$.tags", "payload->'tags'", 1)
return fmt.Sprintf("jsonb_array_length(COALESCE(%s.%s, '[]'::jsonb))", d.GetTablePrefix(), jsonPath)
}
func (d *PostgreSQLDialect) GetJSONContains(path, _ string) string {
jsonPath := strings.Replace(path, "$.tags", "payload->'tags'", 1)
return fmt.Sprintf("%s.%s @> jsonb_build_array(?)", d.GetTablePrefix(), jsonPath)
}
func (d *PostgreSQLDialect) GetJSONLike(path, _ string) string {
jsonPath := strings.Replace(path, "$.tags", "payload->'tags'", 1)
return fmt.Sprintf("%s.%s @> jsonb_build_array(?)", d.GetTablePrefix(), jsonPath)
}
func (*PostgreSQLDialect) GetBooleanValue(value bool) interface{} {
return value
}
func (d *PostgreSQLDialect) GetBooleanComparison(path string, _ bool) string {
return fmt.Sprintf("(%s)::boolean = ?", d.GetJSONExtract(path))
}
func (d *PostgreSQLDialect) GetBooleanCheck(path string) string {
return fmt.Sprintf("(%s)::boolean IS TRUE", d.GetJSONExtract(path))
}
func (d *PostgreSQLDialect) GetTimestampComparison(field string) string {
return fmt.Sprintf("EXTRACT(EPOCH FROM %s.%s)", d.GetTablePrefix(), field)
}
func (*PostgreSQLDialect) GetCurrentTimestamp() string {
return "EXTRACT(EPOCH FROM NOW())"
}

127
plugin/filter/expr.go Normal file
View File

@@ -0,0 +1,127 @@
package filter
import (
"errors"
"time"
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// GetConstValue returns the constant value of the expression.
func GetConstValue(expr *exprv1.Expr) (any, error) {
v, ok := expr.ExprKind.(*exprv1.Expr_ConstExpr)
if !ok {
return nil, errors.New("invalid constant expression")
}
switch v.ConstExpr.ConstantKind.(type) {
case *exprv1.Constant_StringValue:
return v.ConstExpr.GetStringValue(), nil
case *exprv1.Constant_Int64Value:
return v.ConstExpr.GetInt64Value(), nil
case *exprv1.Constant_Uint64Value:
return v.ConstExpr.GetUint64Value(), nil
case *exprv1.Constant_DoubleValue:
return v.ConstExpr.GetDoubleValue(), nil
case *exprv1.Constant_BoolValue:
return v.ConstExpr.GetBoolValue(), nil
default:
return nil, errors.New("unexpected constant type")
}
}
// GetIdentExprName returns the name of the identifier expression.
func GetIdentExprName(expr *exprv1.Expr) (string, error) {
_, ok := expr.ExprKind.(*exprv1.Expr_IdentExpr)
if !ok {
return "", errors.New("invalid identifier expression")
}
return expr.GetIdentExpr().GetName(), nil
}
// GetFunctionValue evaluates CEL function calls and returns their value.
// This is specifically for time functions like now().
func GetFunctionValue(expr *exprv1.Expr) (any, error) {
callExpr, ok := expr.ExprKind.(*exprv1.Expr_CallExpr)
if !ok {
return nil, errors.New("invalid function call expression")
}
switch callExpr.CallExpr.Function {
case "now":
if len(callExpr.CallExpr.Args) != 0 {
return nil, errors.New("now() function takes no arguments")
}
return time.Now().Unix(), nil
case "_-_":
// Handle subtraction for expressions like "now() - 60 * 60 * 24"
if len(callExpr.CallExpr.Args) != 2 {
return nil, errors.New("subtraction requires exactly two arguments")
}
left, err := GetExprValue(callExpr.CallExpr.Args[0])
if err != nil {
return nil, err
}
right, err := GetExprValue(callExpr.CallExpr.Args[1])
if err != nil {
return nil, err
}
leftInt, ok1 := left.(int64)
rightInt, ok2 := right.(int64)
if !ok1 || !ok2 {
return nil, errors.New("subtraction operands must be integers")
}
return leftInt - rightInt, nil
case "_*_":
// Handle multiplication for expressions like "60 * 60 * 24"
if len(callExpr.CallExpr.Args) != 2 {
return nil, errors.New("multiplication requires exactly two arguments")
}
left, err := GetExprValue(callExpr.CallExpr.Args[0])
if err != nil {
return nil, err
}
right, err := GetExprValue(callExpr.CallExpr.Args[1])
if err != nil {
return nil, err
}
leftInt, ok1 := left.(int64)
rightInt, ok2 := right.(int64)
if !ok1 || !ok2 {
return nil, errors.New("multiplication operands must be integers")
}
return leftInt * rightInt, nil
case "_+_":
// Handle addition
if len(callExpr.CallExpr.Args) != 2 {
return nil, errors.New("addition requires exactly two arguments")
}
left, err := GetExprValue(callExpr.CallExpr.Args[0])
if err != nil {
return nil, err
}
right, err := GetExprValue(callExpr.CallExpr.Args[1])
if err != nil {
return nil, err
}
leftInt, ok1 := left.(int64)
rightInt, ok2 := right.(int64)
if !ok1 || !ok2 {
return nil, errors.New("addition operands must be integers")
}
return leftInt + rightInt, nil
default:
return nil, errors.New("unsupported function: " + callExpr.CallExpr.Function)
}
}
// GetExprValue attempts to get a value from an expression, trying constants first, then functions.
func GetExprValue(expr *exprv1.Expr) (any, error) {
// Try to get constant value first
if constValue, err := GetConstValue(expr); err == nil {
return constValue, nil
}
// If not a constant, try to evaluate as a function
return GetFunctionValue(expr)
}

48
plugin/filter/filter.go Normal file
View File

@@ -0,0 +1,48 @@
package filter
import (
"time"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/pkg/errors"
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// MemoFilterCELAttributes are the CEL attributes for memo.
var MemoFilterCELAttributes = []cel.EnvOption{
cel.Variable("content", cel.StringType),
cel.Variable("creator_id", cel.IntType),
cel.Variable("created_ts", cel.IntType),
cel.Variable("updated_ts", cel.IntType),
cel.Variable("pinned", cel.BoolType),
cel.Variable("tag", cel.StringType),
cel.Variable("tags", cel.ListType(cel.StringType)),
cel.Variable("visibility", cel.StringType),
cel.Variable("has_task_list", cel.BoolType),
// Current timestamp function.
cel.Function("now",
cel.Overload("now",
[]*cel.Type{},
cel.IntType,
cel.FunctionBinding(func(_ ...ref.Val) ref.Val {
return types.Int(time.Now().Unix())
}),
),
),
}
// Parse parses the filter string and returns the parsed expression.
// The filter string should be a CEL expression.
func Parse(filter string, opts ...cel.EnvOption) (expr *exprv1.ParsedExpr, err error) {
e, err := cel.NewEnv(opts...)
if err != nil {
return nil, errors.Wrap(err, "failed to create CEL environment")
}
ast, issues := e.Compile(filter)
if issues != nil {
return nil, errors.Errorf("failed to compile filter: %v", issues)
}
return cel.AstToParsedExpr(ast)
}

146
plugin/filter/templates.go Normal file
View File

@@ -0,0 +1,146 @@
package filter
import (
"fmt"
)
// SQLTemplate holds database-specific SQL fragments.
type SQLTemplate struct {
SQLite string
MySQL string
PostgreSQL string
}
// TemplateDBType represents the database type for templates.
type TemplateDBType string
const (
SQLiteTemplate TemplateDBType = "sqlite"
MySQLTemplate TemplateDBType = "mysql"
PostgreSQLTemplate TemplateDBType = "postgres"
)
// SQLTemplates contains common SQL patterns for different databases.
var SQLTemplates = map[string]SQLTemplate{
"json_extract": {
SQLite: "JSON_EXTRACT(`memo`.`payload`, '%s')",
MySQL: "JSON_EXTRACT(`memo`.`payload`, '%s')",
PostgreSQL: "memo.payload%s",
},
"json_array_length": {
SQLite: "JSON_ARRAY_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY()))",
MySQL: "JSON_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY()))",
PostgreSQL: "jsonb_array_length(COALESCE(memo.payload->'tags', '[]'::jsonb))",
},
"json_contains_element": {
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?",
MySQL: "JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?)",
PostgreSQL: "memo.payload->'tags' @> jsonb_build_array(?)",
},
"json_contains_tag": {
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?",
MySQL: "JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?)",
PostgreSQL: "memo.payload->'tags' @> jsonb_build_array(?)",
},
"boolean_true": {
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = 1",
MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON)",
PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean = true",
},
"boolean_false": {
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = 0",
MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('false' AS JSON)",
PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean = false",
},
"boolean_not_true": {
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') != 1",
MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') != CAST('true' AS JSON)",
PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean != true",
},
"boolean_not_false": {
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') != 0",
MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') != CAST('false' AS JSON)",
PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean != false",
},
"boolean_compare": {
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') %s ?",
MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') %s CAST(? AS JSON)",
PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean %s ?",
},
"boolean_check": {
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') IS TRUE",
MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON)",
PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean IS TRUE",
},
"table_prefix": {
SQLite: "`memo`",
MySQL: "`memo`",
PostgreSQL: "memo",
},
"timestamp_field": {
SQLite: "`memo`.`%s`",
MySQL: "UNIX_TIMESTAMP(`memo`.`%s`)",
PostgreSQL: "EXTRACT(EPOCH FROM memo.%s)",
},
"content_like": {
SQLite: "`memo`.`content` LIKE ?",
MySQL: "`memo`.`content` LIKE ?",
PostgreSQL: "memo.content ILIKE ?",
},
"visibility_in": {
SQLite: "`memo`.`visibility` IN (%s)",
MySQL: "`memo`.`visibility` IN (%s)",
PostgreSQL: "memo.visibility IN (%s)",
},
}
// GetSQL returns the appropriate SQL for the given template and database type.
func GetSQL(templateName string, dbType TemplateDBType) string {
template, exists := SQLTemplates[templateName]
if !exists {
return ""
}
switch dbType {
case SQLiteTemplate:
return template.SQLite
case MySQLTemplate:
return template.MySQL
case PostgreSQLTemplate:
return template.PostgreSQL
default:
return template.SQLite
}
}
// GetParameterPlaceholder returns the appropriate parameter placeholder for the database.
func GetParameterPlaceholder(dbType TemplateDBType, index int) string {
switch dbType {
case PostgreSQLTemplate:
return fmt.Sprintf("$%d", index)
default:
return "?"
}
}
// GetParameterValue returns the appropriate parameter value for the database.
func GetParameterValue(dbType TemplateDBType, templateName string, value interface{}) interface{} {
switch templateName {
case "json_contains_element", "json_contains_tag":
if dbType == SQLiteTemplate {
return fmt.Sprintf(`%%"%s"%%`, value)
}
return value
default:
return value
}
}
// FormatPlaceholders formats a list of placeholders for the given database type.
func FormatPlaceholders(dbType TemplateDBType, count int, startIndex int) []string {
placeholders := make([]string, count)
for i := 0; i < count; i++ {
placeholders[i] = GetParameterPlaceholder(dbType, startIndex+i)
}
return placeholders
}

View File

@@ -0,0 +1,166 @@
package httpgetter
import (
"fmt"
"io"
"net"
"net/http"
"net/url"
"github.com/pkg/errors"
"golang.org/x/net/html"
"golang.org/x/net/html/atom"
)
var ErrInternalIP = errors.New("internal IP addresses are not allowed")
var httpClient = &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
if err := validateURL(req.URL.String()); err != nil {
return errors.Wrap(err, "redirect to internal IP")
}
if len(via) >= 10 {
return errors.New("too many redirects")
}
return nil
},
}
type HTMLMeta struct {
Title string `json:"title"`
Description string `json:"description"`
Image string `json:"image"`
}
func GetHTMLMeta(urlStr string) (*HTMLMeta, error) {
if err := validateURL(urlStr); err != nil {
return nil, err
}
response, err := httpClient.Get(urlStr)
if err != nil {
return nil, err
}
defer response.Body.Close()
mediatype, err := getMediatype(response)
if err != nil {
return nil, err
}
if mediatype != "text/html" {
return nil, errors.New("not a HTML page")
}
// TODO: limit the size of the response body
htmlMeta := extractHTMLMeta(response.Body)
enrichSiteMeta(response.Request.URL, htmlMeta)
return htmlMeta, nil
}
func extractHTMLMeta(resp io.Reader) *HTMLMeta {
tokenizer := html.NewTokenizer(resp)
htmlMeta := new(HTMLMeta)
for {
tokenType := tokenizer.Next()
if tokenType == html.ErrorToken {
break
} else if tokenType == html.StartTagToken || tokenType == html.SelfClosingTagToken {
token := tokenizer.Token()
if token.DataAtom == atom.Body {
break
}
if token.DataAtom == atom.Title {
tokenizer.Next()
token := tokenizer.Token()
htmlMeta.Title = token.Data
} else if token.DataAtom == atom.Meta {
description, ok := extractMetaProperty(token, "description")
if ok {
htmlMeta.Description = description
}
ogTitle, ok := extractMetaProperty(token, "og:title")
if ok {
htmlMeta.Title = ogTitle
}
ogDescription, ok := extractMetaProperty(token, "og:description")
if ok {
htmlMeta.Description = ogDescription
}
ogImage, ok := extractMetaProperty(token, "og:image")
if ok {
htmlMeta.Image = ogImage
}
}
}
}
return htmlMeta
}
func extractMetaProperty(token html.Token, prop string) (content string, ok bool) {
content, ok = "", false
for _, attr := range token.Attr {
if attr.Key == "property" && attr.Val == prop {
ok = true
}
if attr.Key == "content" {
content = attr.Val
}
}
return content, ok
}
func validateURL(urlStr string) error {
u, err := url.Parse(urlStr)
if err != nil {
return errors.New("invalid URL format")
}
if u.Scheme != "http" && u.Scheme != "https" {
return errors.New("only http/https protocols are allowed")
}
host := u.Hostname()
if host == "" {
return errors.New("empty hostname")
}
// check if the hostname is an IP
if ip := net.ParseIP(host); ip != nil {
if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() {
return errors.Wrap(ErrInternalIP, ip.String())
}
return nil
}
// check if it's a hostname, resolve it and check all returned IPs
ips, err := net.LookupIP(host)
if err != nil {
return errors.Errorf("failed to resolve hostname: %v", err)
}
for _, ip := range ips {
if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() {
return errors.Wrapf(ErrInternalIP, "host=%s, ip=%s", host, ip.String())
}
}
return nil
}
func enrichSiteMeta(url *url.URL, meta *HTMLMeta) {
if url.Hostname() == "www.youtube.com" {
if url.Path == "/watch" {
vid := url.Query().Get("v")
if vid != "" {
meta.Image = fmt.Sprintf("https://img.youtube.com/vi/%s/mqdefault.jpg", vid)
}
}
}
}

View File

@@ -0,0 +1,32 @@
package httpgetter
import (
"errors"
"testing"
"github.com/stretchr/testify/require"
)
func TestGetHTMLMeta(t *testing.T) {
tests := []struct {
urlStr string
htmlMeta HTMLMeta
}{}
for _, test := range tests {
metadata, err := GetHTMLMeta(test.urlStr)
require.NoError(t, err)
require.Equal(t, test.htmlMeta, *metadata)
}
}
func TestGetHTMLMetaForInternal(t *testing.T) {
// test for internal IP
if _, err := GetHTMLMeta("http://192.168.0.1"); !errors.Is(err, ErrInternalIP) {
t.Errorf("Expected error for internal IP, got %v", err)
}
// test for resolved internal IP
if _, err := GetHTMLMeta("http://localhost"); !errors.Is(err, ErrInternalIP) {
t.Errorf("Expected error for resolved internal IP, got %v", err)
}
}

View File

@@ -0,0 +1 @@
package httpgetter

View File

@@ -0,0 +1,45 @@
package httpgetter
import (
"errors"
"io"
"net/http"
"net/url"
"strings"
)
type Image struct {
Blob []byte
Mediatype string
}
func GetImage(urlStr string) (*Image, error) {
if _, err := url.Parse(urlStr); err != nil {
return nil, err
}
response, err := http.Get(urlStr)
if err != nil {
return nil, err
}
defer response.Body.Close()
mediatype, err := getMediatype(response)
if err != nil {
return nil, err
}
if !strings.HasPrefix(mediatype, "image/") {
return nil, errors.New("wrong image mediatype")
}
bodyBytes, err := io.ReadAll(response.Body)
if err != nil {
return nil, err
}
image := &Image{
Blob: bodyBytes,
Mediatype: mediatype,
}
return image, nil
}

15
plugin/httpgetter/util.go Normal file
View File

@@ -0,0 +1,15 @@
package httpgetter
import (
"mime"
"net/http"
)
func getMediatype(response *http.Response) (string, error) {
contentType := response.Header.Get("content-type")
mediatype, _, err := mime.ParseMediaType(contentType)
if err != nil {
return "", err
}
return mediatype, nil
}

8
plugin/idp/idp.go Normal file
View File

@@ -0,0 +1,8 @@
package idp
type IdentityProviderUserInfo struct {
Identifier string
DisplayName string
Email string
AvatarURL string
}

123
plugin/idp/oauth2/oauth2.go Normal file
View File

@@ -0,0 +1,123 @@
// Package oauth2 is the plugin for OAuth2 Identity Provider.
package oauth2
import (
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"github.com/pkg/errors"
"golang.org/x/oauth2"
"github.com/usememos/memos/plugin/idp"
storepb "github.com/usememos/memos/proto/gen/store"
)
// IdentityProvider represents an OAuth2 Identity Provider.
type IdentityProvider struct {
config *storepb.OAuth2Config
}
// NewIdentityProvider initializes a new OAuth2 Identity Provider with the given configuration.
func NewIdentityProvider(config *storepb.OAuth2Config) (*IdentityProvider, error) {
for v, field := range map[string]string{
config.ClientId: "clientId",
config.ClientSecret: "clientSecret",
config.TokenUrl: "tokenUrl",
config.UserInfoUrl: "userInfoUrl",
config.FieldMapping.Identifier: "fieldMapping.identifier",
} {
if v == "" {
return nil, errors.Errorf(`the field "%s" is empty but required`, field)
}
}
return &IdentityProvider{
config: config,
}, nil
}
// ExchangeToken returns the exchanged OAuth2 token using the given authorization code.
func (p *IdentityProvider) ExchangeToken(ctx context.Context, redirectURL, code string) (string, error) {
conf := &oauth2.Config{
ClientID: p.config.ClientId,
ClientSecret: p.config.ClientSecret,
RedirectURL: redirectURL,
Scopes: p.config.Scopes,
Endpoint: oauth2.Endpoint{
AuthURL: p.config.AuthUrl,
TokenURL: p.config.TokenUrl,
AuthStyle: oauth2.AuthStyleInParams,
},
}
token, err := conf.Exchange(ctx, code)
if err != nil {
return "", errors.Wrap(err, "failed to exchange access token")
}
accessToken, ok := token.Extra("access_token").(string)
if !ok {
return "", errors.New(`missing "access_token" from authorization response`)
}
return accessToken, nil
}
// UserInfo returns the parsed user information using the given OAuth2 token.
func (p *IdentityProvider) UserInfo(token string) (*idp.IdentityProviderUserInfo, error) {
client := &http.Client{}
req, err := http.NewRequest(http.MethodGet, p.config.UserInfoUrl, nil)
if err != nil {
return nil, errors.Wrap(err, "failed to new http request")
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
resp, err := client.Do(req)
if err != nil {
return nil, errors.Wrap(err, "failed to get user information")
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, errors.Wrap(err, "failed to read response body")
}
defer resp.Body.Close()
var claims map[string]any
if err := json.Unmarshal(body, &claims); err != nil {
return nil, errors.Wrap(err, "failed to unmarshal response body")
}
slog.Info("user info claims", "claims", claims)
userInfo := &idp.IdentityProviderUserInfo{}
if v, ok := claims[p.config.FieldMapping.Identifier].(string); ok {
userInfo.Identifier = v
}
if userInfo.Identifier == "" {
return nil, errors.Errorf("the field %q is not found in claims or has empty value", p.config.FieldMapping.Identifier)
}
// Best effort to map optional fields
if p.config.FieldMapping.DisplayName != "" {
if v, ok := claims[p.config.FieldMapping.DisplayName].(string); ok {
userInfo.DisplayName = v
}
}
if userInfo.DisplayName == "" {
userInfo.DisplayName = userInfo.Identifier
}
if p.config.FieldMapping.Email != "" {
if v, ok := claims[p.config.FieldMapping.Email].(string); ok {
userInfo.Email = v
}
}
if p.config.FieldMapping.AvatarUrl != "" {
if v, ok := claims[p.config.FieldMapping.AvatarUrl].(string); ok {
userInfo.AvatarURL = v
}
}
slog.Info("user info", "userInfo", userInfo)
return userInfo, nil
}

View File

@@ -0,0 +1,163 @@
package oauth2
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/usememos/memos/plugin/idp"
storepb "github.com/usememos/memos/proto/gen/store"
)
func TestNewIdentityProvider(t *testing.T) {
tests := []struct {
name string
config *storepb.OAuth2Config
containsErr string
}{
{
name: "no tokenUrl",
config: &storepb.OAuth2Config{
ClientId: "test-client-id",
ClientSecret: "test-client-secret",
AuthUrl: "",
TokenUrl: "",
UserInfoUrl: "https://example.com/api/user",
FieldMapping: &storepb.FieldMapping{
Identifier: "login",
},
},
containsErr: `the field "tokenUrl" is empty but required`,
},
{
name: "no userInfoUrl",
config: &storepb.OAuth2Config{
ClientId: "test-client-id",
ClientSecret: "test-client-secret",
AuthUrl: "",
TokenUrl: "https://example.com/token",
UserInfoUrl: "",
FieldMapping: &storepb.FieldMapping{
Identifier: "login",
},
},
containsErr: `the field "userInfoUrl" is empty but required`,
},
{
name: "no field mapping identifier",
config: &storepb.OAuth2Config{
ClientId: "test-client-id",
ClientSecret: "test-client-secret",
AuthUrl: "",
TokenUrl: "https://example.com/token",
UserInfoUrl: "https://example.com/api/user",
FieldMapping: &storepb.FieldMapping{
Identifier: "",
},
},
containsErr: `the field "fieldMapping.identifier" is empty but required`,
},
}
for _, test := range tests {
t.Run(test.name, func(*testing.T) {
_, err := NewIdentityProvider(test.config)
assert.ErrorContains(t, err, test.containsErr)
})
}
}
func newMockServer(t *testing.T, code, accessToken string, userinfo []byte) *httptest.Server {
mux := http.NewServeMux()
var rawIDToken string
mux.HandleFunc("/oauth2/token", func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodPost, r.Method)
body, err := io.ReadAll(r.Body)
require.NoError(t, err)
vals, err := url.ParseQuery(string(body))
require.NoError(t, err)
require.Equal(t, code, vals.Get("code"))
require.Equal(t, "authorization_code", vals.Get("grant_type"))
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(map[string]any{
"access_token": accessToken,
"token_type": "Bearer",
"expires_in": 3600,
"id_token": rawIDToken,
})
require.NoError(t, err)
})
mux.HandleFunc("/oauth2/userinfo", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, err := w.Write(userinfo)
require.NoError(t, err)
})
s := httptest.NewServer(mux)
return s
}
func TestIdentityProvider(t *testing.T) {
ctx := context.Background()
const (
testClientID = "test-client-id"
testCode = "test-code"
testAccessToken = "test-access-token"
testSubject = "123456789"
testName = "John Doe"
testEmail = "john.doe@example.com"
)
userInfo, err := json.Marshal(
map[string]any{
"sub": testSubject,
"name": testName,
"email": testEmail,
},
)
require.NoError(t, err)
s := newMockServer(t, testCode, testAccessToken, userInfo)
oauth2, err := NewIdentityProvider(
&storepb.OAuth2Config{
ClientId: testClientID,
ClientSecret: "test-client-secret",
TokenUrl: fmt.Sprintf("%s/oauth2/token", s.URL),
UserInfoUrl: fmt.Sprintf("%s/oauth2/userinfo", s.URL),
FieldMapping: &storepb.FieldMapping{
Identifier: "sub",
DisplayName: "name",
Email: "email",
},
},
)
require.NoError(t, err)
redirectURL := "https://example.com/oauth/callback"
oauthToken, err := oauth2.ExchangeToken(ctx, redirectURL, testCode)
require.NoError(t, err)
require.Equal(t, testAccessToken, oauthToken)
userInfoResult, err := oauth2.UserInfo(oauthToken)
require.NoError(t, err)
wantUserInfo := &idp.IdentityProviderUserInfo{
Identifier: testSubject,
DisplayName: testName,
Email: testEmail,
}
assert.Equal(t, wantUserInfo, userInfoResult)
}

92
plugin/storage/s3/s3.go Normal file
View File

@@ -0,0 +1,92 @@
package s3
import (
"context"
"io"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/pkg/errors"
storepb "github.com/usememos/memos/proto/gen/store"
)
type Client struct {
Client *s3.Client
Bucket *string
}
func NewClient(ctx context.Context, s3Config *storepb.StorageS3Config) (*Client, error) {
cfg, err := config.LoadDefaultConfig(ctx,
config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(s3Config.AccessKeyId, s3Config.AccessKeySecret, "")),
config.WithRegion(s3Config.Region),
)
if err != nil {
return nil, errors.Wrap(err, "failed to load s3 config")
}
client := s3.NewFromConfig(cfg, func(o *s3.Options) {
o.BaseEndpoint = aws.String(s3Config.Endpoint)
o.UsePathStyle = s3Config.UsePathStyle
o.RequestChecksumCalculation = aws.RequestChecksumCalculationWhenRequired
o.ResponseChecksumValidation = aws.ResponseChecksumValidationWhenRequired
})
return &Client{
Client: client,
Bucket: aws.String(s3Config.Bucket),
}, nil
}
// UploadObject uploads an object to S3.
func (c *Client) UploadObject(ctx context.Context, key string, fileType string, content io.Reader) (string, error) {
uploader := manager.NewUploader(c.Client)
putInput := s3.PutObjectInput{
Bucket: c.Bucket,
Key: aws.String(key),
ContentType: aws.String(fileType),
Body: content,
}
result, err := uploader.Upload(ctx, &putInput)
if err != nil {
return "", err
}
resultKey := result.Key
if resultKey == nil || *resultKey == "" {
return "", errors.New("failed to get file key")
}
return *resultKey, nil
}
// PresignGetObject presigns an object in S3.
func (c *Client) PresignGetObject(ctx context.Context, key string) (string, error) {
presignClient := s3.NewPresignClient(c.Client)
presignResult, err := presignClient.PresignGetObject(ctx, &s3.GetObjectInput{
Bucket: aws.String(*c.Bucket),
Key: aws.String(key),
}, func(opts *s3.PresignOptions) {
// Set the expiration time of the presigned URL to 5 days.
// Reference: https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-query-string-auth.html
opts.Expires = time.Duration(5 * 24 * time.Hour)
})
if err != nil {
return "", errors.Wrap(err, "failed to presign put object")
}
return presignResult.URL, nil
}
// DeleteObject deletes an object in S3.
func (c *Client) DeleteObject(ctx context.Context, key string) error {
_, err := c.Client.DeleteObject(ctx, &s3.DeleteObjectInput{
Bucket: c.Bucket,
Key: aws.String(key),
})
if err != nil {
return errors.Wrap(err, "failed to delete object")
}
return nil
}

90
plugin/webhook/webhook.go Normal file
View File

@@ -0,0 +1,90 @@
package webhook
import (
"bytes"
"encoding/json"
"io"
"log/slog"
"net/http"
"time"
"github.com/pkg/errors"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
)
var (
// timeout is the timeout for webhook request. Default to 30 seconds.
timeout = 30 * time.Second
)
type WebhookRequestPayload struct {
// The target URL for the webhook request.
URL string `json:"url"`
// The type of activity that triggered this webhook.
ActivityType string `json:"activityType"`
// The resource name of the creator. Format: users/{user}
Creator string `json:"creator"`
// The memo that triggered this webhook (if applicable).
Memo *v1pb.Memo `json:"memo"`
}
// Post posts the message to webhook endpoint.
func Post(requestPayload *WebhookRequestPayload) error {
body, err := json.Marshal(requestPayload)
if err != nil {
return errors.Wrapf(err, "failed to marshal webhook request to %s", requestPayload.URL)
}
req, err := http.NewRequest("POST", requestPayload.URL, bytes.NewBuffer(body))
if err != nil {
return errors.Wrapf(err, "failed to construct webhook request to %s", requestPayload.URL)
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{
Timeout: timeout,
}
resp, err := client.Do(req)
if err != nil {
return errors.Wrapf(err, "failed to post webhook to %s", requestPayload.URL)
}
b, err := io.ReadAll(resp.Body)
if err != nil {
return errors.Wrapf(err, "failed to read webhook response from %s", requestPayload.URL)
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode > 299 {
return errors.Errorf("failed to post webhook %s, status code: %d, response body: %s", requestPayload.URL, resp.StatusCode, b)
}
response := &struct {
Code int `json:"code"`
Message string `json:"message"`
}{}
if err := json.Unmarshal(b, response); err != nil {
return errors.Wrapf(err, "failed to unmarshal webhook response from %s", requestPayload.URL)
}
if response.Code != 0 {
return errors.Errorf("receive error code sent by webhook server, code %d, msg: %s", response.Code, response.Message)
}
return nil
}
// PostAsync posts the message to webhook endpoint asynchronously.
// It spawns a new goroutine to handle the request and does not wait for the response.
func PostAsync(requestPayload *WebhookRequestPayload) {
go func() {
if err := Post(requestPayload); err != nil {
// Since we're in a goroutine, we can only log the error
slog.Warn("Failed to dispatch webhook asynchronously",
slog.String("url", requestPayload.URL),
slog.String("activityType", requestPayload.ActivityType),
slog.Any("err", err))
}
}()
}

View File

@@ -0,0 +1 @@
package webhook