1/*!
2 The candlestick element, which showing the high/low/open/close price
3*/
4
5use std::cmp::Ordering;
6
7use crate::element::{Drawable, PointCollection};
8use crate::style::ShapeStyle;
9use plotters_backend::{BackendCoord, DrawingBackend, DrawingErrorKind};
10
11/// The candlestick data point element
12pub struct CandleStick<X, Y: PartialOrd> {
13 style: ShapeStyle,
14 width: u32,
15 points: [(X, Y); 4],
16}
17
18impl<X: Clone, Y: PartialOrd> CandleStick<X, Y> {
19 /// Create a new candlestick element, which requires the Y coordinate can be compared
20 ///
21 /// - `x`: The x coordinate
22 /// - `open`: The open value
23 /// - `high`: The high value
24 /// - `low`: The low value
25 /// - `close`: The close value
26 /// - `gain_style`: The style for gain
27 /// - `loss_style`: The style for loss
28 /// - `width`: The width
29 /// - **returns** The newly created candlestick element
30 ///
31 /// ```rust
32 /// use chrono::prelude::*;
33 /// use plotters::prelude::*;
34 ///
35 /// let candlestick = CandleStick::new(Local::now(), 130.0600, 131.3700, 128.8300, 129.1500, &GREEN, &RED, 15);
36 /// ```
37 #[allow(clippy::too_many_arguments)]
38 pub fn new<GS: Into<ShapeStyle>, LS: Into<ShapeStyle>>(
39 x: X,
40 open: Y,
41 high: Y,
42 low: Y,
43 close: Y,
44 gain_style: GS,
45 loss_style: LS,
46 width: u32,
47 ) -> Self {
48 Self {
49 style: match open.partial_cmp(&close) {
50 Some(Ordering::Less) => gain_style.into(),
51 _ => loss_style.into(),
52 },
53 width,
54 points: [
55 (x.clone(), open),
56 (x.clone(), high),
57 (x.clone(), low),
58 (x, close),
59 ],
60 }
61 }
62}
63
64impl<'a, X: 'a, Y: PartialOrd + 'a> PointCollection<'a, (X, Y)> for &'a CandleStick<X, Y> {
65 type Point = &'a (X, Y);
66 type IntoIter = &'a [(X, Y)];
67 fn point_iter(self) -> &'a [(X, Y)] {
68 &self.points
69 }
70}
71
72impl<X, Y: PartialOrd, DB: DrawingBackend> Drawable<DB> for CandleStick<X, Y> {
73 fn draw<I: Iterator<Item = BackendCoord>>(
74 &self,
75 points: I,
76 backend: &mut DB,
77 _: (u32, u32),
78 ) -> Result<(), DrawingErrorKind<DB::ErrorType>> {
79 let mut points: Vec<_> = points.take(4).collect();
80 if points.len() == 4 {
81 let fill = self.style.filled;
82 if points[0].1 > points[3].1 {
83 points.swap(0, 3);
84 }
85 let (l, r) = (
86 self.width as i32 / 2,
87 self.width as i32 - self.width as i32 / 2,
88 );
89
90 backend.draw_line(points[0], points[1], &self.style)?;
91 backend.draw_line(points[2], points[3], &self.style)?;
92
93 points[0].0 -= l;
94 points[3].0 += r;
95
96 backend.draw_rect(points[0], points[3], &self.style, fill)?;
97 }
98 Ok(())
99 }
100}
101